feat: 完成B端认证系统和商户管理模块测试补全

主要变更:
- 新增B端认证系统(后台+H5):登录、登出、Token刷新、密码修改
- 完善商户管理和商户账号管理功能
- 补全单元测试(ShopService: 72.5%, ShopAccountService: 79.8%)
- 新增集成测试(商户管理+商户账号管理)
- 归档OpenSpec提案(add-shop-account-management, implement-b-end-auth-system)
- 完善文档(使用指南、API文档、认证架构说明)

测试统计:
- 13个测试套件,37个测试用例,100%通过率
- 平均覆盖率76.2%,达标

OpenSpec验证:通过(strict模式)
This commit is contained in:
2026-01-15 18:15:17 +08:00
parent 7ccd3d146c
commit 18f35f3ef4
64 changed files with 11875 additions and 242 deletions

179
pkg/auth/token.go Normal file
View File

@@ -0,0 +1,179 @@
package auth
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
type TokenManager struct {
rdb *redis.Client
accessTokenTTL time.Duration
refreshTokenTTL time.Duration
}
type TokenInfo struct {
UserID uint `json:"user_id"`
UserType int `json:"user_type"`
ShopID uint `json:"shop_id,omitempty"`
EnterpriseID uint `json:"enterprise_id,omitempty"`
Username string `json:"username"`
LoginTime time.Time `json:"login_time"`
Device string `json:"device"`
IP string `json:"ip"`
}
func NewTokenManager(rdb *redis.Client, accessTTL, refreshTTL time.Duration) *TokenManager {
return &TokenManager{
rdb: rdb,
accessTokenTTL: accessTTL,
refreshTokenTTL: refreshTTL,
}
}
func (m *TokenManager) GenerateTokenPair(ctx context.Context, info *TokenInfo) (accessToken, refreshToken string, err error) {
accessToken = uuid.New().String()
refreshToken = uuid.New().String()
info.LoginTime = time.Now()
data, err := json.Marshal(info)
if err != nil {
return "", "", fmt.Errorf("failed to marshal token info: %w", err)
}
pipe := m.rdb.Pipeline()
accessKey := constants.RedisAuthTokenKey(accessToken)
pipe.Set(ctx, accessKey, data, m.accessTokenTTL)
refreshKey := constants.RedisRefreshTokenKey(refreshToken)
pipe.Set(ctx, refreshKey, data, m.refreshTokenTTL)
userTokensKey := constants.RedisUserTokensKey(info.UserID)
pipe.SAdd(ctx, userTokensKey, accessToken, refreshToken)
pipe.Expire(ctx, userTokensKey, m.refreshTokenTTL)
if _, err := pipe.Exec(ctx); err != nil {
return "", "", errors.New(errors.CodeRedisError, fmt.Sprintf("failed to store tokens: %v", err))
}
return accessToken, refreshToken, nil
}
func (m *TokenManager) ValidateAccessToken(ctx context.Context, token string) (*TokenInfo, error) {
key := constants.RedisAuthTokenKey(token)
data, err := m.rdb.Get(ctx, key).Result()
if err == redis.Nil {
return nil, errors.New(errors.CodeInvalidToken, "无效或过期的令牌")
}
if err != nil {
return nil, errors.New(errors.CodeRedisError, fmt.Sprintf("failed to get token: %v", err))
}
var info TokenInfo
if err := json.Unmarshal([]byte(data), &info); err != nil {
return nil, fmt.Errorf("failed to unmarshal token info: %w", err)
}
return &info, nil
}
func (m *TokenManager) ValidateRefreshToken(ctx context.Context, token string) (*TokenInfo, error) {
key := constants.RedisRefreshTokenKey(token)
data, err := m.rdb.Get(ctx, key).Result()
if err == redis.Nil {
return nil, errors.New(errors.CodeInvalidToken, "无效或过期的刷新令牌")
}
if err != nil {
return nil, errors.New(errors.CodeRedisError, fmt.Sprintf("failed to get refresh token: %v", err))
}
var info TokenInfo
if err := json.Unmarshal([]byte(data), &info); err != nil {
return nil, fmt.Errorf("failed to unmarshal token info: %w", err)
}
return &info, nil
}
func (m *TokenManager) RefreshAccessToken(ctx context.Context, refreshToken string) (newAccessToken string, err error) {
info, err := m.ValidateRefreshToken(ctx, refreshToken)
if err != nil {
return "", err
}
newAccessToken = uuid.New().String()
data, err := json.Marshal(info)
if err != nil {
return "", fmt.Errorf("failed to marshal token info: %w", err)
}
pipe := m.rdb.Pipeline()
newAccessKey := constants.RedisAuthTokenKey(newAccessToken)
pipe.Set(ctx, newAccessKey, data, m.accessTokenTTL)
userTokensKey := constants.RedisUserTokensKey(info.UserID)
pipe.SAdd(ctx, userTokensKey, newAccessToken)
if _, err := pipe.Exec(ctx); err != nil {
return "", errors.New(errors.CodeRedisError, fmt.Sprintf("failed to refresh token: %v", err))
}
return newAccessToken, nil
}
func (m *TokenManager) RevokeToken(ctx context.Context, token string) error {
pipe := m.rdb.Pipeline()
accessKey := constants.RedisAuthTokenKey(token)
pipe.Del(ctx, accessKey)
refreshKey := constants.RedisRefreshTokenKey(token)
pipe.Del(ctx, refreshKey)
if _, err := pipe.Exec(ctx); err != nil {
return errors.New(errors.CodeRedisError, fmt.Sprintf("failed to revoke token: %v", err))
}
return nil
}
func (m *TokenManager) RevokeAllUserTokens(ctx context.Context, userID uint) error {
userTokensKey := constants.RedisUserTokensKey(userID)
tokens, err := m.rdb.SMembers(ctx, userTokensKey).Result()
if err != nil && err != redis.Nil {
return errors.New(errors.CodeRedisError, fmt.Sprintf("failed to get user tokens: %v", err))
}
if len(tokens) == 0 {
return nil
}
pipe := m.rdb.Pipeline()
for _, token := range tokens {
accessKey := constants.RedisAuthTokenKey(token)
pipe.Del(ctx, accessKey)
refreshKey := constants.RedisRefreshTokenKey(token)
pipe.Del(ctx, refreshKey)
}
pipe.Del(ctx, userTokensKey)
if _, err := pipe.Exec(ctx); err != nil {
return errors.New(errors.CodeRedisError, fmt.Sprintf("failed to revoke user tokens: %v", err))
}
return nil
}

357
pkg/auth/token_test.go Normal file
View File

@@ -0,0 +1,357 @@
package auth
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/config"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTestRedis(t *testing.T) *redis.Client {
var addr, password string
testDB := 15
cfg, err := config.Load()
if err != nil {
t.Logf("配置加载失败,使用回退配置: %v", err)
addr = "localhost:6379"
password = ""
} else {
t.Logf("成功加载配置Redis 地址: %s:%d", cfg.Redis.Address, cfg.Redis.Port)
addr = fmt.Sprintf("%s:%d", cfg.Redis.Address, cfg.Redis.Port)
password = cfg.Redis.Password
}
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
DB: testDB,
})
ctx := context.Background()
if err := client.Ping(ctx).Err(); err != nil {
t.Skipf("Redis 未运行(地址: %s跳过测试: %v", addr, err)
}
client.FlushDB(ctx)
t.Cleanup(func() {
client.FlushDB(ctx)
client.Close()
})
return client
}
func init() {
if os.Getenv("CONFIG_ENV") == "" {
os.Setenv("CONFIG_ENV", "dev")
}
if os.Getenv("CONFIG_PATH") == "" {
os.Setenv("CONFIG_PATH", "../../configs/config.dev.yaml")
}
}
func TestTokenManager_GenerateTokenPair(t *testing.T) {
rdb := setupTestRedis(t)
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
ctx := context.Background()
t.Run("成功生成 token 对", func(t *testing.T) {
tokenInfo := &TokenInfo{
UserID: 1,
UserType: 1,
ShopID: 10,
EnterpriseID: 0,
Username: "testuser",
Device: "web",
IP: "127.0.0.1",
}
accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
require.NoError(t, err)
assert.NotEmpty(t, accessToken)
assert.NotEmpty(t, refreshToken)
assert.Len(t, accessToken, 36)
assert.Len(t, refreshToken, 36)
})
t.Run("生成的 token 存储在 Redis 中", func(t *testing.T) {
tokenInfo := &TokenInfo{
UserID: 2,
UserType: 2,
Username: "admin",
}
accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
require.NoError(t, err)
accessKey := "auth:token:" + accessToken
refreshKey := "auth:refresh:" + refreshToken
exists, err := rdb.Exists(ctx, accessKey).Result()
require.NoError(t, err)
assert.Equal(t, int64(1), exists)
exists, err = rdb.Exists(ctx, refreshKey).Result()
require.NoError(t, err)
assert.Equal(t, int64(1), exists)
})
}
func TestTokenManager_ValidateAccessToken(t *testing.T) {
rdb := setupTestRedis(t)
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
ctx := context.Background()
tokenInfo := &TokenInfo{
UserID: 1,
UserType: 1,
ShopID: 10,
EnterpriseID: 0,
Username: "testuser",
Device: "web",
IP: "127.0.0.1",
}
accessToken, _, err := tm.GenerateTokenPair(ctx, tokenInfo)
require.NoError(t, err)
t.Run("验证有效的 access token", func(t *testing.T) {
info, err := tm.ValidateAccessToken(ctx, accessToken)
require.NoError(t, err)
require.NotNil(t, info)
assert.Equal(t, uint(1), info.UserID)
assert.Equal(t, 1, info.UserType)
assert.Equal(t, uint(10), info.ShopID)
assert.Equal(t, "testuser", info.Username)
})
t.Run("验证无效的 token", func(t *testing.T) {
info, err := tm.ValidateAccessToken(ctx, "invalid-token")
assert.Error(t, err)
assert.Nil(t, info)
assert.Contains(t, err.Error(), "无效或过期")
})
t.Run("验证空 token", func(t *testing.T) {
info, err := tm.ValidateAccessToken(ctx, "")
assert.Error(t, err)
assert.Nil(t, info)
})
}
func TestTokenManager_ValidateRefreshToken(t *testing.T) {
rdb := setupTestRedis(t)
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
ctx := context.Background()
tokenInfo := &TokenInfo{
UserID: 1,
UserType: 1,
Username: "testuser",
}
_, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
require.NoError(t, err)
t.Run("验证有效的 refresh token", func(t *testing.T) {
info, err := tm.ValidateRefreshToken(ctx, refreshToken)
require.NoError(t, err)
require.NotNil(t, info)
assert.Equal(t, uint(1), info.UserID)
assert.Equal(t, "testuser", info.Username)
})
t.Run("验证无效的 refresh token", func(t *testing.T) {
info, err := tm.ValidateRefreshToken(ctx, "invalid-refresh-token")
assert.Error(t, err)
assert.Nil(t, info)
})
}
func TestTokenManager_RefreshAccessToken(t *testing.T) {
rdb := setupTestRedis(t)
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
ctx := context.Background()
tokenInfo := &TokenInfo{
UserID: 1,
UserType: 1,
Username: "testuser",
Device: "web",
IP: "127.0.0.1",
}
oldAccessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
require.NoError(t, err)
t.Run("成功刷新 access token", func(t *testing.T) {
newAccessToken, err := tm.RefreshAccessToken(ctx, refreshToken)
require.NoError(t, err)
assert.NotEmpty(t, newAccessToken)
assert.NotEqual(t, oldAccessToken, newAccessToken)
info, err := tm.ValidateAccessToken(ctx, newAccessToken)
require.NoError(t, err)
assert.Equal(t, uint(1), info.UserID)
})
t.Run("使用无效的 refresh token", func(t *testing.T) {
newAccessToken, err := tm.RefreshAccessToken(ctx, "invalid-refresh-token")
assert.Error(t, err)
assert.Empty(t, newAccessToken)
})
}
func TestTokenManager_RevokeToken(t *testing.T) {
rdb := setupTestRedis(t)
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
ctx := context.Background()
tokenInfo := &TokenInfo{
UserID: 1,
UserType: 1,
Username: "testuser",
}
accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
require.NoError(t, err)
t.Run("成功撤销 access token", func(t *testing.T) {
err := tm.RevokeToken(ctx, accessToken)
require.NoError(t, err)
info, err := tm.ValidateAccessToken(ctx, accessToken)
assert.Error(t, err)
assert.Nil(t, info)
})
t.Run("成功撤销 refresh token", func(t *testing.T) {
err := tm.RevokeToken(ctx, refreshToken)
require.NoError(t, err)
info, err := tm.ValidateRefreshToken(ctx, refreshToken)
assert.Error(t, err)
assert.Nil(t, info)
})
t.Run("撤销不存在的 token 不报错", func(t *testing.T) {
err := tm.RevokeToken(ctx, "non-existent-token")
assert.NoError(t, err)
})
}
func TestTokenManager_RevokeAllUserTokens(t *testing.T) {
rdb := setupTestRedis(t)
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
ctx := context.Background()
tokenInfo := &TokenInfo{
UserID: 1,
UserType: 1,
Username: "testuser",
}
accessToken1, refreshToken1, err := tm.GenerateTokenPair(ctx, tokenInfo)
require.NoError(t, err)
accessToken2, refreshToken2, err := tm.GenerateTokenPair(ctx, tokenInfo)
require.NoError(t, err)
t.Run("成功撤销用户所有 token", func(t *testing.T) {
err := tm.RevokeAllUserTokens(ctx, 1)
require.NoError(t, err)
_, err = tm.ValidateAccessToken(ctx, accessToken1)
assert.Error(t, err)
_, err = tm.ValidateAccessToken(ctx, accessToken2)
assert.Error(t, err)
_, err = tm.ValidateRefreshToken(ctx, refreshToken1)
assert.Error(t, err)
_, err = tm.ValidateRefreshToken(ctx, refreshToken2)
assert.Error(t, err)
})
t.Run("撤销不存在用户的 token 不报错", func(t *testing.T) {
err := tm.RevokeAllUserTokens(ctx, 9999)
assert.NoError(t, err)
})
}
func TestTokenManager_TokenExpiration(t *testing.T) {
rdb := setupTestRedis(t)
tm := NewTokenManager(rdb, 1*time.Second, 2*time.Second)
ctx := context.Background()
tokenInfo := &TokenInfo{
UserID: 1,
UserType: 1,
Username: "testuser",
}
accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
require.NoError(t, err)
t.Run("Access token 过期后无法验证", func(t *testing.T) {
time.Sleep(2 * time.Second)
info, err := tm.ValidateAccessToken(ctx, accessToken)
assert.Error(t, err)
assert.Nil(t, info)
})
t.Run("Refresh token 过期后无法验证", func(t *testing.T) {
time.Sleep(1 * time.Second)
info, err := tm.ValidateRefreshToken(ctx, refreshToken)
assert.Error(t, err)
assert.Nil(t, info)
})
}
func TestTokenManager_ConcurrentAccess(t *testing.T) {
rdb := setupTestRedis(t)
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
ctx := context.Background()
t.Run("并发生成 token", func(t *testing.T) {
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(id int) {
tokenInfo := &TokenInfo{
UserID: uint(id),
UserType: 1,
Username: "user",
}
_, _, err := tm.GenerateTokenPair(ctx, tokenInfo)
assert.NoError(t, err)
done <- true
}(i)
}
for i := 0; i < 10; i++ {
<-done
}
})
}