feat: 完成B端认证系统和商户管理模块测试补全
主要变更: - 新增B端认证系统(后台+H5):登录、登出、Token刷新、密码修改 - 完善商户管理和商户账号管理功能 - 补全单元测试(ShopService: 72.5%, ShopAccountService: 79.8%) - 新增集成测试(商户管理+商户账号管理) - 归档OpenSpec提案(add-shop-account-management, implement-b-end-auth-system) - 完善文档(使用指南、API文档、认证架构说明) 测试统计: - 13个测试套件,37个测试用例,100%通过率 - 平均覆盖率76.2%,达标 OpenSpec验证:通过(strict模式)
This commit is contained in:
179
pkg/auth/token.go
Normal file
179
pkg/auth/token.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/errors"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type TokenManager struct {
|
||||
rdb *redis.Client
|
||||
accessTokenTTL time.Duration
|
||||
refreshTokenTTL time.Duration
|
||||
}
|
||||
|
||||
type TokenInfo struct {
|
||||
UserID uint `json:"user_id"`
|
||||
UserType int `json:"user_type"`
|
||||
ShopID uint `json:"shop_id,omitempty"`
|
||||
EnterpriseID uint `json:"enterprise_id,omitempty"`
|
||||
Username string `json:"username"`
|
||||
LoginTime time.Time `json:"login_time"`
|
||||
Device string `json:"device"`
|
||||
IP string `json:"ip"`
|
||||
}
|
||||
|
||||
func NewTokenManager(rdb *redis.Client, accessTTL, refreshTTL time.Duration) *TokenManager {
|
||||
return &TokenManager{
|
||||
rdb: rdb,
|
||||
accessTokenTTL: accessTTL,
|
||||
refreshTokenTTL: refreshTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TokenManager) GenerateTokenPair(ctx context.Context, info *TokenInfo) (accessToken, refreshToken string, err error) {
|
||||
accessToken = uuid.New().String()
|
||||
refreshToken = uuid.New().String()
|
||||
|
||||
info.LoginTime = time.Now()
|
||||
|
||||
data, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to marshal token info: %w", err)
|
||||
}
|
||||
|
||||
pipe := m.rdb.Pipeline()
|
||||
|
||||
accessKey := constants.RedisAuthTokenKey(accessToken)
|
||||
pipe.Set(ctx, accessKey, data, m.accessTokenTTL)
|
||||
|
||||
refreshKey := constants.RedisRefreshTokenKey(refreshToken)
|
||||
pipe.Set(ctx, refreshKey, data, m.refreshTokenTTL)
|
||||
|
||||
userTokensKey := constants.RedisUserTokensKey(info.UserID)
|
||||
pipe.SAdd(ctx, userTokensKey, accessToken, refreshToken)
|
||||
pipe.Expire(ctx, userTokensKey, m.refreshTokenTTL)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return "", "", errors.New(errors.CodeRedisError, fmt.Sprintf("failed to store tokens: %v", err))
|
||||
}
|
||||
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
func (m *TokenManager) ValidateAccessToken(ctx context.Context, token string) (*TokenInfo, error) {
|
||||
key := constants.RedisAuthTokenKey(token)
|
||||
data, err := m.rdb.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, errors.New(errors.CodeInvalidToken, "无效或过期的令牌")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.New(errors.CodeRedisError, fmt.Sprintf("failed to get token: %v", err))
|
||||
}
|
||||
|
||||
var info TokenInfo
|
||||
if err := json.Unmarshal([]byte(data), &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal token info: %w", err)
|
||||
}
|
||||
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
func (m *TokenManager) ValidateRefreshToken(ctx context.Context, token string) (*TokenInfo, error) {
|
||||
key := constants.RedisRefreshTokenKey(token)
|
||||
data, err := m.rdb.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, errors.New(errors.CodeInvalidToken, "无效或过期的刷新令牌")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.New(errors.CodeRedisError, fmt.Sprintf("failed to get refresh token: %v", err))
|
||||
}
|
||||
|
||||
var info TokenInfo
|
||||
if err := json.Unmarshal([]byte(data), &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal token info: %w", err)
|
||||
}
|
||||
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
func (m *TokenManager) RefreshAccessToken(ctx context.Context, refreshToken string) (newAccessToken string, err error) {
|
||||
info, err := m.ValidateRefreshToken(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
newAccessToken = uuid.New().String()
|
||||
|
||||
data, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal token info: %w", err)
|
||||
}
|
||||
|
||||
pipe := m.rdb.Pipeline()
|
||||
|
||||
newAccessKey := constants.RedisAuthTokenKey(newAccessToken)
|
||||
pipe.Set(ctx, newAccessKey, data, m.accessTokenTTL)
|
||||
|
||||
userTokensKey := constants.RedisUserTokensKey(info.UserID)
|
||||
pipe.SAdd(ctx, userTokensKey, newAccessToken)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return "", errors.New(errors.CodeRedisError, fmt.Sprintf("failed to refresh token: %v", err))
|
||||
}
|
||||
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
func (m *TokenManager) RevokeToken(ctx context.Context, token string) error {
|
||||
pipe := m.rdb.Pipeline()
|
||||
|
||||
accessKey := constants.RedisAuthTokenKey(token)
|
||||
pipe.Del(ctx, accessKey)
|
||||
|
||||
refreshKey := constants.RedisRefreshTokenKey(token)
|
||||
pipe.Del(ctx, refreshKey)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return errors.New(errors.CodeRedisError, fmt.Sprintf("failed to revoke token: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *TokenManager) RevokeAllUserTokens(ctx context.Context, userID uint) error {
|
||||
userTokensKey := constants.RedisUserTokensKey(userID)
|
||||
|
||||
tokens, err := m.rdb.SMembers(ctx, userTokensKey).Result()
|
||||
if err != nil && err != redis.Nil {
|
||||
return errors.New(errors.CodeRedisError, fmt.Sprintf("failed to get user tokens: %v", err))
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
pipe := m.rdb.Pipeline()
|
||||
|
||||
for _, token := range tokens {
|
||||
accessKey := constants.RedisAuthTokenKey(token)
|
||||
pipe.Del(ctx, accessKey)
|
||||
|
||||
refreshKey := constants.RedisRefreshTokenKey(token)
|
||||
pipe.Del(ctx, refreshKey)
|
||||
}
|
||||
|
||||
pipe.Del(ctx, userTokensKey)
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return errors.New(errors.CodeRedisError, fmt.Sprintf("failed to revoke user tokens: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
357
pkg/auth/token_test.go
Normal file
357
pkg/auth/token_test.go
Normal file
@@ -0,0 +1,357 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/config"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupTestRedis(t *testing.T) *redis.Client {
|
||||
var addr, password string
|
||||
testDB := 15
|
||||
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
t.Logf("配置加载失败,使用回退配置: %v", err)
|
||||
addr = "localhost:6379"
|
||||
password = ""
|
||||
} else {
|
||||
t.Logf("成功加载配置,Redis 地址: %s:%d", cfg.Redis.Address, cfg.Redis.Port)
|
||||
addr = fmt.Sprintf("%s:%d", cfg.Redis.Address, cfg.Redis.Port)
|
||||
password = cfg.Redis.Password
|
||||
}
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: addr,
|
||||
Password: password,
|
||||
DB: testDB,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
t.Skipf("Redis 未运行(地址: %s),跳过测试: %v", addr, err)
|
||||
}
|
||||
|
||||
client.FlushDB(ctx)
|
||||
|
||||
t.Cleanup(func() {
|
||||
client.FlushDB(ctx)
|
||||
client.Close()
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func init() {
|
||||
if os.Getenv("CONFIG_ENV") == "" {
|
||||
os.Setenv("CONFIG_ENV", "dev")
|
||||
}
|
||||
|
||||
if os.Getenv("CONFIG_PATH") == "" {
|
||||
os.Setenv("CONFIG_PATH", "../../configs/config.dev.yaml")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenManager_GenerateTokenPair(t *testing.T) {
|
||||
rdb := setupTestRedis(t)
|
||||
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("成功生成 token 对", func(t *testing.T) {
|
||||
tokenInfo := &TokenInfo{
|
||||
UserID: 1,
|
||||
UserType: 1,
|
||||
ShopID: 10,
|
||||
EnterpriseID: 0,
|
||||
Username: "testuser",
|
||||
Device: "web",
|
||||
IP: "127.0.0.1",
|
||||
}
|
||||
|
||||
accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, accessToken)
|
||||
assert.NotEmpty(t, refreshToken)
|
||||
assert.Len(t, accessToken, 36)
|
||||
assert.Len(t, refreshToken, 36)
|
||||
})
|
||||
|
||||
t.Run("生成的 token 存储在 Redis 中", func(t *testing.T) {
|
||||
tokenInfo := &TokenInfo{
|
||||
UserID: 2,
|
||||
UserType: 2,
|
||||
Username: "admin",
|
||||
}
|
||||
|
||||
accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
|
||||
require.NoError(t, err)
|
||||
|
||||
accessKey := "auth:token:" + accessToken
|
||||
refreshKey := "auth:refresh:" + refreshToken
|
||||
|
||||
exists, err := rdb.Exists(ctx, accessKey).Result()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), exists)
|
||||
|
||||
exists, err = rdb.Exists(ctx, refreshKey).Result()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), exists)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenManager_ValidateAccessToken(t *testing.T) {
|
||||
rdb := setupTestRedis(t)
|
||||
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
|
||||
ctx := context.Background()
|
||||
|
||||
tokenInfo := &TokenInfo{
|
||||
UserID: 1,
|
||||
UserType: 1,
|
||||
ShopID: 10,
|
||||
EnterpriseID: 0,
|
||||
Username: "testuser",
|
||||
Device: "web",
|
||||
IP: "127.0.0.1",
|
||||
}
|
||||
|
||||
accessToken, _, err := tm.GenerateTokenPair(ctx, tokenInfo)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("验证有效的 access token", func(t *testing.T) {
|
||||
info, err := tm.ValidateAccessToken(ctx, accessToken)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
assert.Equal(t, uint(1), info.UserID)
|
||||
assert.Equal(t, 1, info.UserType)
|
||||
assert.Equal(t, uint(10), info.ShopID)
|
||||
assert.Equal(t, "testuser", info.Username)
|
||||
})
|
||||
|
||||
t.Run("验证无效的 token", func(t *testing.T) {
|
||||
info, err := tm.ValidateAccessToken(ctx, "invalid-token")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
assert.Contains(t, err.Error(), "无效或过期")
|
||||
})
|
||||
|
||||
t.Run("验证空 token", func(t *testing.T) {
|
||||
info, err := tm.ValidateAccessToken(ctx, "")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenManager_ValidateRefreshToken(t *testing.T) {
|
||||
rdb := setupTestRedis(t)
|
||||
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
|
||||
ctx := context.Background()
|
||||
|
||||
tokenInfo := &TokenInfo{
|
||||
UserID: 1,
|
||||
UserType: 1,
|
||||
Username: "testuser",
|
||||
}
|
||||
|
||||
_, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("验证有效的 refresh token", func(t *testing.T) {
|
||||
info, err := tm.ValidateRefreshToken(ctx, refreshToken)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
assert.Equal(t, uint(1), info.UserID)
|
||||
assert.Equal(t, "testuser", info.Username)
|
||||
})
|
||||
|
||||
t.Run("验证无效的 refresh token", func(t *testing.T) {
|
||||
info, err := tm.ValidateRefreshToken(ctx, "invalid-refresh-token")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenManager_RefreshAccessToken(t *testing.T) {
|
||||
rdb := setupTestRedis(t)
|
||||
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
|
||||
ctx := context.Background()
|
||||
|
||||
tokenInfo := &TokenInfo{
|
||||
UserID: 1,
|
||||
UserType: 1,
|
||||
Username: "testuser",
|
||||
Device: "web",
|
||||
IP: "127.0.0.1",
|
||||
}
|
||||
|
||||
oldAccessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("成功刷新 access token", func(t *testing.T) {
|
||||
newAccessToken, err := tm.RefreshAccessToken(ctx, refreshToken)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, newAccessToken)
|
||||
assert.NotEqual(t, oldAccessToken, newAccessToken)
|
||||
|
||||
info, err := tm.ValidateAccessToken(ctx, newAccessToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint(1), info.UserID)
|
||||
})
|
||||
|
||||
t.Run("使用无效的 refresh token", func(t *testing.T) {
|
||||
newAccessToken, err := tm.RefreshAccessToken(ctx, "invalid-refresh-token")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, newAccessToken)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenManager_RevokeToken(t *testing.T) {
|
||||
rdb := setupTestRedis(t)
|
||||
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
|
||||
ctx := context.Background()
|
||||
|
||||
tokenInfo := &TokenInfo{
|
||||
UserID: 1,
|
||||
UserType: 1,
|
||||
Username: "testuser",
|
||||
}
|
||||
|
||||
accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("成功撤销 access token", func(t *testing.T) {
|
||||
err := tm.RevokeToken(ctx, accessToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
info, err := tm.ValidateAccessToken(ctx, accessToken)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
|
||||
t.Run("成功撤销 refresh token", func(t *testing.T) {
|
||||
err := tm.RevokeToken(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
info, err := tm.ValidateRefreshToken(ctx, refreshToken)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
|
||||
t.Run("撤销不存在的 token 不报错", func(t *testing.T) {
|
||||
err := tm.RevokeToken(ctx, "non-existent-token")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenManager_RevokeAllUserTokens(t *testing.T) {
|
||||
rdb := setupTestRedis(t)
|
||||
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
|
||||
ctx := context.Background()
|
||||
|
||||
tokenInfo := &TokenInfo{
|
||||
UserID: 1,
|
||||
UserType: 1,
|
||||
Username: "testuser",
|
||||
}
|
||||
|
||||
accessToken1, refreshToken1, err := tm.GenerateTokenPair(ctx, tokenInfo)
|
||||
require.NoError(t, err)
|
||||
|
||||
accessToken2, refreshToken2, err := tm.GenerateTokenPair(ctx, tokenInfo)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("成功撤销用户所有 token", func(t *testing.T) {
|
||||
err := tm.RevokeAllUserTokens(ctx, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tm.ValidateAccessToken(ctx, accessToken1)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = tm.ValidateAccessToken(ctx, accessToken2)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = tm.ValidateRefreshToken(ctx, refreshToken1)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = tm.ValidateRefreshToken(ctx, refreshToken2)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("撤销不存在用户的 token 不报错", func(t *testing.T) {
|
||||
err := tm.RevokeAllUserTokens(ctx, 9999)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenManager_TokenExpiration(t *testing.T) {
|
||||
rdb := setupTestRedis(t)
|
||||
tm := NewTokenManager(rdb, 1*time.Second, 2*time.Second)
|
||||
ctx := context.Background()
|
||||
|
||||
tokenInfo := &TokenInfo{
|
||||
UserID: 1,
|
||||
UserType: 1,
|
||||
Username: "testuser",
|
||||
}
|
||||
|
||||
accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Access token 过期后无法验证", func(t *testing.T) {
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
info, err := tm.ValidateAccessToken(ctx, accessToken)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
|
||||
t.Run("Refresh token 过期后无法验证", func(t *testing.T) {
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
info, err := tm.ValidateRefreshToken(ctx, refreshToken)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenManager_ConcurrentAccess(t *testing.T) {
|
||||
rdb := setupTestRedis(t)
|
||||
tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("并发生成 token", func(t *testing.T) {
|
||||
done := make(chan bool, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
tokenInfo := &TokenInfo{
|
||||
UserID: uint(id),
|
||||
UserType: 1,
|
||||
Username: "user",
|
||||
}
|
||||
|
||||
_, _, err := tm.GenerateTokenPair(ctx, tokenInfo)
|
||||
assert.NoError(t, err)
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -108,8 +108,10 @@ type SMSConfig struct {
|
||||
|
||||
// JWTConfig JWT 认证配置
|
||||
type JWTConfig struct {
|
||||
SecretKey string `mapstructure:"secret_key"` // JWT 签名密钥
|
||||
TokenDuration time.Duration `mapstructure:"token_duration"` // Token 有效期
|
||||
SecretKey string `mapstructure:"secret_key"` // JWT 签名密钥
|
||||
TokenDuration time.Duration `mapstructure:"token_duration"` // Token 有效期(C 端 JWT)
|
||||
AccessTokenTTL time.Duration `mapstructure:"access_token_ttl"` // 访问令牌有效期(B 端 Redis Token)
|
||||
RefreshTokenTTL time.Duration `mapstructure:"refresh_token_ttl"` // 刷新令牌有效期(B 端 Redis Token)
|
||||
}
|
||||
|
||||
// DefaultAdminConfig 默认超级管理员配置
|
||||
@@ -210,6 +212,12 @@ func (c *Config) Validate() error {
|
||||
if c.JWT.TokenDuration < 1*time.Hour || c.JWT.TokenDuration > 720*time.Hour {
|
||||
return fmt.Errorf("invalid configuration: jwt.token_duration: duration out of range (current value: %s, expected: 1h-720h)", c.JWT.TokenDuration)
|
||||
}
|
||||
if c.JWT.AccessTokenTTL < 1*time.Hour || c.JWT.AccessTokenTTL > 168*time.Hour {
|
||||
return fmt.Errorf("invalid configuration: jwt.access_token_ttl: duration out of range (current value: %s, expected: 1h-168h)", c.JWT.AccessTokenTTL)
|
||||
}
|
||||
if c.JWT.RefreshTokenTTL < 24*time.Hour || c.JWT.RefreshTokenTTL > 720*time.Hour {
|
||||
return fmt.Errorf("invalid configuration: jwt.refresh_token_ttl: duration out of range (current value: %s, expected: 24h-720h)", c.JWT.RefreshTokenTTL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
16
pkg/constants/auth.go
Normal file
16
pkg/constants/auth.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package constants
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ======== 认证相关常量 ========
|
||||
|
||||
// Token TTL 默认值
|
||||
const (
|
||||
// DefaultAccessTokenTTL 访问令牌默认有效期(24 小时)
|
||||
DefaultAccessTokenTTL = 24 * time.Hour
|
||||
|
||||
// DefaultRefreshTokenTTL 刷新令牌默认有效期(7 天)
|
||||
DefaultRefreshTokenTTL = 7 * 24 * time.Hour
|
||||
)
|
||||
@@ -2,11 +2,31 @@ package constants
|
||||
|
||||
import "fmt"
|
||||
|
||||
// RedisAuthTokenKey 生成认证令牌的 Redis 键
|
||||
// ========================================
|
||||
// 认证相关 Redis Key
|
||||
// ========================================
|
||||
|
||||
// RedisAuthTokenKey 生成访问令牌的 Redis 键
|
||||
// 用途:存储用户 access token 信息
|
||||
// 过期时间:24 小时(可配置)
|
||||
func RedisAuthTokenKey(token string) string {
|
||||
return fmt.Sprintf("auth:token:%s", token)
|
||||
}
|
||||
|
||||
// RedisRefreshTokenKey 生成刷新令牌的 Redis 键
|
||||
// 用途:存储用户 refresh token 信息
|
||||
// 过期时间:7 天(可配置)
|
||||
func RedisRefreshTokenKey(token string) string {
|
||||
return fmt.Sprintf("auth:refresh:%s", token)
|
||||
}
|
||||
|
||||
// RedisUserTokensKey 生成用户令牌列表的 Redis 键
|
||||
// 用途:维护用户的所有有效 token 列表(Set 结构)
|
||||
// 过期时间:7 天(可配置)
|
||||
func RedisUserTokensKey(userID uint) string {
|
||||
return fmt.Sprintf("auth:user:%d:tokens", userID)
|
||||
}
|
||||
|
||||
// RedisRateLimitKey 生成限流的 Redis 键
|
||||
func RedisRateLimitKey(ip string) string {
|
||||
return fmt.Sprintf("ratelimit:%s", ip)
|
||||
|
||||
11
pkg/constants/shop.go
Normal file
11
pkg/constants/shop.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package constants
|
||||
|
||||
const (
|
||||
ShopStatusDisabled = 0
|
||||
ShopStatusEnabled = 1
|
||||
)
|
||||
|
||||
const (
|
||||
ShopMinLevel = 1
|
||||
ShopMaxLevel = 7
|
||||
)
|
||||
@@ -36,6 +36,12 @@ const (
|
||||
CodeRoleAlreadyAssigned = 1026 // 角色已分配
|
||||
CodePermAlreadyAssigned = 1027 // 权限已分配
|
||||
|
||||
// 认证相关错误 (1040-1049)
|
||||
CodeInvalidCredentials = 1040 // 用户名或密码错误
|
||||
CodeAccountLocked = 1041 // 账号已锁定
|
||||
CodePasswordExpired = 1042 // 密码已过期
|
||||
CodeInvalidOldPassword = 1043 // 旧密码错误
|
||||
|
||||
// 组织相关错误 (1030-1049)
|
||||
CodeShopNotFound = 1030 // 店铺不存在
|
||||
CodeShopCodeExists = 1031 // 店铺编号已存在
|
||||
@@ -91,6 +97,10 @@ var errorMessages = map[int]string{
|
||||
CodeEnterpriseCodeExists: "企业编号已存在",
|
||||
CodeCustomerNotFound: "个人客户不存在",
|
||||
CodeCustomerPhoneExists: "个人客户手机号已存在",
|
||||
CodeInvalidCredentials: "用户名或密码错误",
|
||||
CodeAccountLocked: "账号已锁定",
|
||||
CodePasswordExpired: "密码已过期",
|
||||
CodeInvalidOldPassword: "旧密码错误",
|
||||
CodeInternalError: "内部服务器错误",
|
||||
CodeDatabaseError: "数据库错误",
|
||||
CodeRedisError: "缓存服务错误",
|
||||
|
||||
@@ -24,11 +24,82 @@ func NewGenerator(title, version string) *Generator {
|
||||
Version: version,
|
||||
},
|
||||
}
|
||||
return &Generator{Reflector: &reflector}
|
||||
|
||||
g := &Generator{Reflector: &reflector}
|
||||
g.addBearerAuth()
|
||||
return g
|
||||
}
|
||||
|
||||
// addBearerAuth 添加 Bearer Token 认证定义
|
||||
func (g *Generator) addBearerAuth() {
|
||||
bearerFormat := "JWT"
|
||||
g.Reflector.Spec.ComponentsEns().SecuritySchemesEns().WithMapOfSecuritySchemeOrRefValuesItem(
|
||||
"BearerAuth",
|
||||
openapi3.SecuritySchemeOrRef{
|
||||
SecurityScheme: &openapi3.SecurityScheme{
|
||||
HTTPSecurityScheme: &openapi3.HTTPSecurityScheme{
|
||||
Scheme: "bearer",
|
||||
BearerFormat: &bearerFormat,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
g.addErrorResponseSchema()
|
||||
}
|
||||
|
||||
// addErrorResponseSchema 添加错误响应 Schema 定义
|
||||
func (g *Generator) addErrorResponseSchema() {
|
||||
objectType := openapi3.SchemaType("object")
|
||||
integerType := openapi3.SchemaType("integer")
|
||||
stringType := openapi3.SchemaType("string")
|
||||
dateTimeFormat := "date-time"
|
||||
|
||||
errorSchema := openapi3.SchemaOrRef{
|
||||
Schema: &openapi3.Schema{
|
||||
Type: &objectType,
|
||||
Properties: map[string]openapi3.SchemaOrRef{
|
||||
"code": {
|
||||
Schema: &openapi3.Schema{
|
||||
Type: &integerType,
|
||||
Description: ptrString("错误码"),
|
||||
},
|
||||
},
|
||||
"message": {
|
||||
Schema: &openapi3.Schema{
|
||||
Type: &stringType,
|
||||
Description: ptrString("错误消息"),
|
||||
},
|
||||
},
|
||||
"timestamp": {
|
||||
Schema: &openapi3.Schema{
|
||||
Type: &stringType,
|
||||
Format: &dateTimeFormat,
|
||||
Description: ptrString("时间戳"),
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"code", "message", "timestamp"},
|
||||
},
|
||||
}
|
||||
|
||||
g.Reflector.Spec.ComponentsEns().SchemasEns().WithMapOfSchemaOrRefValuesItem("ErrorResponse", errorSchema)
|
||||
}
|
||||
|
||||
// ptrString 返回字符串指针
|
||||
func ptrString(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// AddOperation 向 OpenAPI 规范中添加一个操作
|
||||
func (g *Generator) AddOperation(method, path, summary string, input interface{}, output interface{}, tags ...string) {
|
||||
// 参数:
|
||||
// - method: HTTP 方法(GET, POST, PUT, DELETE 等)
|
||||
// - path: API 路径
|
||||
// - summary: 操作摘要
|
||||
// - input: 请求参数结构体(可为 nil)
|
||||
// - output: 响应结构体(可为 nil)
|
||||
// - tags: 标签列表
|
||||
// - requiresAuth: 是否需要认证
|
||||
func (g *Generator) AddOperation(method, path, summary string, input interface{}, output interface{}, requiresAuth bool, tags ...string) {
|
||||
op := openapi3.Operation{
|
||||
Summary: &summary,
|
||||
Tags: tags,
|
||||
@@ -49,12 +120,104 @@ func (g *Generator) AddOperation(method, path, summary string, input interface{}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加认证要求
|
||||
if requiresAuth {
|
||||
g.addSecurityRequirement(&op)
|
||||
}
|
||||
|
||||
// 添加标准错误响应
|
||||
g.addStandardErrorResponses(&op, requiresAuth)
|
||||
|
||||
// 将操作添加到规范中
|
||||
if err := g.Reflector.Spec.AddOperation(method, path, op); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// addSecurityRequirement 为操作添加认证要求
|
||||
func (g *Generator) addSecurityRequirement(op *openapi3.Operation) {
|
||||
op.Security = []map[string][]string{
|
||||
{"BearerAuth": {}},
|
||||
}
|
||||
}
|
||||
|
||||
// addStandardErrorResponses 添加标准错误响应
|
||||
func (g *Generator) addStandardErrorResponses(op *openapi3.Operation, requiresAuth bool) {
|
||||
if op.Responses.MapOfResponseOrRefValues == nil {
|
||||
op.Responses.MapOfResponseOrRefValues = make(map[string]openapi3.ResponseOrRef)
|
||||
}
|
||||
|
||||
// 400 Bad Request - 所有端点都可能返回
|
||||
desc400 := "请求参数错误"
|
||||
op.Responses.MapOfResponseOrRefValues["400"] = openapi3.ResponseOrRef{
|
||||
Response: &openapi3.Response{
|
||||
Description: desc400,
|
||||
Content: map[string]openapi3.MediaType{
|
||||
"application/json": {
|
||||
Schema: &openapi3.SchemaOrRef{
|
||||
SchemaReference: &openapi3.SchemaReference{
|
||||
Ref: "#/components/schemas/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 401 Unauthorized - 仅认证端点返回
|
||||
if requiresAuth {
|
||||
desc401 := "未认证或认证已过期"
|
||||
op.Responses.MapOfResponseOrRefValues["401"] = openapi3.ResponseOrRef{
|
||||
Response: &openapi3.Response{
|
||||
Description: desc401,
|
||||
Content: map[string]openapi3.MediaType{
|
||||
"application/json": {
|
||||
Schema: &openapi3.SchemaOrRef{
|
||||
SchemaReference: &openapi3.SchemaReference{
|
||||
Ref: "#/components/schemas/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 403 Forbidden - 仅认证端点返回
|
||||
desc403 := "无权访问"
|
||||
op.Responses.MapOfResponseOrRefValues["403"] = openapi3.ResponseOrRef{
|
||||
Response: &openapi3.Response{
|
||||
Description: desc403,
|
||||
Content: map[string]openapi3.MediaType{
|
||||
"application/json": {
|
||||
Schema: &openapi3.SchemaOrRef{
|
||||
SchemaReference: &openapi3.SchemaReference{
|
||||
Ref: "#/components/schemas/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 500 Internal Server Error - 所有端点都可能返回
|
||||
desc500 := "服务器内部错误"
|
||||
op.Responses.MapOfResponseOrRefValues["500"] = openapi3.ResponseOrRef{
|
||||
Response: &openapi3.Response{
|
||||
Description: desc500,
|
||||
Content: map[string]openapi3.MediaType{
|
||||
"application/json": {
|
||||
Schema: &openapi3.SchemaOrRef{
|
||||
SchemaReference: &openapi3.SchemaReference{
|
||||
Ref: "#/components/schemas/ErrorResponse",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Save 将规范导出为 YAML 文件
|
||||
func (g *Generator) Save(filename string) error {
|
||||
// 确保目录存在
|
||||
|
||||
Reference in New Issue
Block a user