package auth import ( "context" "encoding/json" "fmt" "time" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/google/uuid" "github.com/redis/go-redis/v9" ) type TokenManager struct { rdb *redis.Client accessTokenTTL time.Duration refreshTokenTTL time.Duration } type TokenInfo struct { UserID uint `json:"user_id"` UserType int `json:"user_type"` ShopID uint `json:"shop_id,omitempty"` EnterpriseID uint `json:"enterprise_id,omitempty"` Username string `json:"username"` LoginTime time.Time `json:"login_time"` Device string `json:"device"` IP string `json:"ip"` } func NewTokenManager(rdb *redis.Client, accessTTL, refreshTTL time.Duration) *TokenManager { return &TokenManager{ rdb: rdb, accessTokenTTL: accessTTL, refreshTokenTTL: refreshTTL, } } func (m *TokenManager) GenerateTokenPair(ctx context.Context, info *TokenInfo) (accessToken, refreshToken string, err error) { accessToken = uuid.New().String() refreshToken = uuid.New().String() info.LoginTime = time.Now() data, err := json.Marshal(info) if err != nil { return "", "", fmt.Errorf("failed to marshal token info: %w", err) } pipe := m.rdb.Pipeline() accessKey := constants.RedisAuthTokenKey(accessToken) pipe.Set(ctx, accessKey, data, m.accessTokenTTL) refreshKey := constants.RedisRefreshTokenKey(refreshToken) pipe.Set(ctx, refreshKey, data, m.refreshTokenTTL) userTokensKey := constants.RedisUserTokensKey(info.UserID) pipe.SAdd(ctx, userTokensKey, accessToken, refreshToken) pipe.Expire(ctx, userTokensKey, m.refreshTokenTTL) if _, err := pipe.Exec(ctx); err != nil { return "", "", errors.New(errors.CodeRedisError, fmt.Sprintf("failed to store tokens: %v", err)) } return accessToken, refreshToken, nil } func (m *TokenManager) ValidateAccessToken(ctx context.Context, token string) (*TokenInfo, error) { key := constants.RedisAuthTokenKey(token) data, err := m.rdb.Get(ctx, key).Result() if err == redis.Nil { return nil, errors.New(errors.CodeInvalidToken, "无效或过期的令牌") } if err != nil { return nil, errors.New(errors.CodeRedisError, fmt.Sprintf("failed to get token: %v", err)) } var info TokenInfo if err := json.Unmarshal([]byte(data), &info); err != nil { return nil, fmt.Errorf("failed to unmarshal token info: %w", err) } return &info, nil } func (m *TokenManager) ValidateRefreshToken(ctx context.Context, token string) (*TokenInfo, error) { key := constants.RedisRefreshTokenKey(token) data, err := m.rdb.Get(ctx, key).Result() if err == redis.Nil { return nil, errors.New(errors.CodeInvalidToken, "无效或过期的刷新令牌") } if err != nil { return nil, errors.New(errors.CodeRedisError, fmt.Sprintf("failed to get refresh token: %v", err)) } var info TokenInfo if err := json.Unmarshal([]byte(data), &info); err != nil { return nil, fmt.Errorf("failed to unmarshal token info: %w", err) } return &info, nil } func (m *TokenManager) RefreshAccessToken(ctx context.Context, refreshToken string) (newAccessToken string, err error) { info, err := m.ValidateRefreshToken(ctx, refreshToken) if err != nil { return "", err } newAccessToken = uuid.New().String() data, err := json.Marshal(info) if err != nil { return "", fmt.Errorf("failed to marshal token info: %w", err) } pipe := m.rdb.Pipeline() newAccessKey := constants.RedisAuthTokenKey(newAccessToken) pipe.Set(ctx, newAccessKey, data, m.accessTokenTTL) userTokensKey := constants.RedisUserTokensKey(info.UserID) pipe.SAdd(ctx, userTokensKey, newAccessToken) if _, err := pipe.Exec(ctx); err != nil { return "", errors.New(errors.CodeRedisError, fmt.Sprintf("failed to refresh token: %v", err)) } return newAccessToken, nil } func (m *TokenManager) RevokeToken(ctx context.Context, token string) error { pipe := m.rdb.Pipeline() accessKey := constants.RedisAuthTokenKey(token) pipe.Del(ctx, accessKey) refreshKey := constants.RedisRefreshTokenKey(token) pipe.Del(ctx, refreshKey) if _, err := pipe.Exec(ctx); err != nil { return errors.New(errors.CodeRedisError, fmt.Sprintf("failed to revoke token: %v", err)) } return nil } func (m *TokenManager) RevokeAllUserTokens(ctx context.Context, userID uint) error { userTokensKey := constants.RedisUserTokensKey(userID) tokens, err := m.rdb.SMembers(ctx, userTokensKey).Result() if err != nil && err != redis.Nil { return errors.New(errors.CodeRedisError, fmt.Sprintf("failed to get user tokens: %v", err)) } if len(tokens) == 0 { return nil } pipe := m.rdb.Pipeline() for _, token := range tokens { accessKey := constants.RedisAuthTokenKey(token) pipe.Del(ctx, accessKey) refreshKey := constants.RedisRefreshTokenKey(token) pipe.Del(ctx, refreshKey) } pipe.Del(ctx, userTokensKey) if _, err := pipe.Exec(ctx); err != nil { return errors.New(errors.CodeRedisError, fmt.Sprintf("failed to revoke user tokens: %v", err)) } return nil }