package auth import ( "context" "fmt" "os" "testing" "time" "github.com/break/junhong_cmp_fiber/pkg/config" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func setupTestRedis(t *testing.T) *redis.Client { var addr, password string testDB := 15 cfg, err := config.Load() if err != nil { t.Logf("配置加载失败,使用回退配置: %v", err) addr = "localhost:6379" password = "" } else { t.Logf("成功加载配置,Redis 地址: %s:%d", cfg.Redis.Address, cfg.Redis.Port) addr = fmt.Sprintf("%s:%d", cfg.Redis.Address, cfg.Redis.Port) password = cfg.Redis.Password } client := redis.NewClient(&redis.Options{ Addr: addr, Password: password, DB: testDB, }) ctx := context.Background() if err := client.Ping(ctx).Err(); err != nil { t.Skipf("Redis 未运行(地址: %s),跳过测试: %v", addr, err) } client.FlushDB(ctx) t.Cleanup(func() { client.FlushDB(ctx) client.Close() }) return client } func init() { if os.Getenv("CONFIG_ENV") == "" { os.Setenv("CONFIG_ENV", "dev") } if os.Getenv("CONFIG_PATH") == "" { os.Setenv("CONFIG_PATH", "../../configs/config.dev.yaml") } } func TestTokenManager_GenerateTokenPair(t *testing.T) { rdb := setupTestRedis(t) tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour) ctx := context.Background() t.Run("成功生成 token 对", func(t *testing.T) { tokenInfo := &TokenInfo{ UserID: 1, UserType: 1, ShopID: 10, EnterpriseID: 0, Username: "testuser", Device: "web", IP: "127.0.0.1", } accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo) require.NoError(t, err) assert.NotEmpty(t, accessToken) assert.NotEmpty(t, refreshToken) assert.Len(t, accessToken, 36) assert.Len(t, refreshToken, 36) }) t.Run("生成的 token 存储在 Redis 中", func(t *testing.T) { tokenInfo := &TokenInfo{ UserID: 2, UserType: 2, Username: "admin", } accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo) require.NoError(t, err) accessKey := "auth:token:" + accessToken refreshKey := "auth:refresh:" + refreshToken exists, err := rdb.Exists(ctx, accessKey).Result() require.NoError(t, err) assert.Equal(t, int64(1), exists) exists, err = rdb.Exists(ctx, refreshKey).Result() require.NoError(t, err) assert.Equal(t, int64(1), exists) }) } func TestTokenManager_ValidateAccessToken(t *testing.T) { rdb := setupTestRedis(t) tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour) ctx := context.Background() tokenInfo := &TokenInfo{ UserID: 1, UserType: 1, ShopID: 10, EnterpriseID: 0, Username: "testuser", Device: "web", IP: "127.0.0.1", } accessToken, _, err := tm.GenerateTokenPair(ctx, tokenInfo) require.NoError(t, err) t.Run("验证有效的 access token", func(t *testing.T) { info, err := tm.ValidateAccessToken(ctx, accessToken) require.NoError(t, err) require.NotNil(t, info) assert.Equal(t, uint(1), info.UserID) assert.Equal(t, 1, info.UserType) assert.Equal(t, uint(10), info.ShopID) assert.Equal(t, "testuser", info.Username) }) t.Run("验证无效的 token", func(t *testing.T) { info, err := tm.ValidateAccessToken(ctx, "invalid-token") assert.Error(t, err) assert.Nil(t, info) assert.Contains(t, err.Error(), "无效或过期") }) t.Run("验证空 token", func(t *testing.T) { info, err := tm.ValidateAccessToken(ctx, "") assert.Error(t, err) assert.Nil(t, info) }) } func TestTokenManager_ValidateRefreshToken(t *testing.T) { rdb := setupTestRedis(t) tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour) ctx := context.Background() tokenInfo := &TokenInfo{ UserID: 1, UserType: 1, Username: "testuser", } _, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo) require.NoError(t, err) t.Run("验证有效的 refresh token", func(t *testing.T) { info, err := tm.ValidateRefreshToken(ctx, refreshToken) require.NoError(t, err) require.NotNil(t, info) assert.Equal(t, uint(1), info.UserID) assert.Equal(t, "testuser", info.Username) }) t.Run("验证无效的 refresh token", func(t *testing.T) { info, err := tm.ValidateRefreshToken(ctx, "invalid-refresh-token") assert.Error(t, err) assert.Nil(t, info) }) } func TestTokenManager_RefreshAccessToken(t *testing.T) { rdb := setupTestRedis(t) tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour) ctx := context.Background() tokenInfo := &TokenInfo{ UserID: 1, UserType: 1, Username: "testuser", Device: "web", IP: "127.0.0.1", } oldAccessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo) require.NoError(t, err) t.Run("成功刷新 access token", func(t *testing.T) { newAccessToken, err := tm.RefreshAccessToken(ctx, refreshToken) require.NoError(t, err) assert.NotEmpty(t, newAccessToken) assert.NotEqual(t, oldAccessToken, newAccessToken) info, err := tm.ValidateAccessToken(ctx, newAccessToken) require.NoError(t, err) assert.Equal(t, uint(1), info.UserID) }) t.Run("使用无效的 refresh token", func(t *testing.T) { newAccessToken, err := tm.RefreshAccessToken(ctx, "invalid-refresh-token") assert.Error(t, err) assert.Empty(t, newAccessToken) }) } func TestTokenManager_RevokeToken(t *testing.T) { rdb := setupTestRedis(t) tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour) ctx := context.Background() tokenInfo := &TokenInfo{ UserID: 1, UserType: 1, Username: "testuser", } accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo) require.NoError(t, err) t.Run("成功撤销 access token", func(t *testing.T) { err := tm.RevokeToken(ctx, accessToken) require.NoError(t, err) info, err := tm.ValidateAccessToken(ctx, accessToken) assert.Error(t, err) assert.Nil(t, info) }) t.Run("成功撤销 refresh token", func(t *testing.T) { err := tm.RevokeToken(ctx, refreshToken) require.NoError(t, err) info, err := tm.ValidateRefreshToken(ctx, refreshToken) assert.Error(t, err) assert.Nil(t, info) }) t.Run("撤销不存在的 token 不报错", func(t *testing.T) { err := tm.RevokeToken(ctx, "non-existent-token") assert.NoError(t, err) }) } func TestTokenManager_RevokeAllUserTokens(t *testing.T) { rdb := setupTestRedis(t) tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour) ctx := context.Background() tokenInfo := &TokenInfo{ UserID: 1, UserType: 1, Username: "testuser", } accessToken1, refreshToken1, err := tm.GenerateTokenPair(ctx, tokenInfo) require.NoError(t, err) accessToken2, refreshToken2, err := tm.GenerateTokenPair(ctx, tokenInfo) require.NoError(t, err) t.Run("成功撤销用户所有 token", func(t *testing.T) { err := tm.RevokeAllUserTokens(ctx, 1) require.NoError(t, err) _, err = tm.ValidateAccessToken(ctx, accessToken1) assert.Error(t, err) _, err = tm.ValidateAccessToken(ctx, accessToken2) assert.Error(t, err) _, err = tm.ValidateRefreshToken(ctx, refreshToken1) assert.Error(t, err) _, err = tm.ValidateRefreshToken(ctx, refreshToken2) assert.Error(t, err) }) t.Run("撤销不存在用户的 token 不报错", func(t *testing.T) { err := tm.RevokeAllUserTokens(ctx, 9999) assert.NoError(t, err) }) } func TestTokenManager_TokenExpiration(t *testing.T) { rdb := setupTestRedis(t) tm := NewTokenManager(rdb, 1*time.Second, 2*time.Second) ctx := context.Background() tokenInfo := &TokenInfo{ UserID: 1, UserType: 1, Username: "testuser", } accessToken, refreshToken, err := tm.GenerateTokenPair(ctx, tokenInfo) require.NoError(t, err) t.Run("Access token 过期后无法验证", func(t *testing.T) { time.Sleep(2 * time.Second) info, err := tm.ValidateAccessToken(ctx, accessToken) assert.Error(t, err) assert.Nil(t, info) }) t.Run("Refresh token 过期后无法验证", func(t *testing.T) { time.Sleep(1 * time.Second) info, err := tm.ValidateRefreshToken(ctx, refreshToken) assert.Error(t, err) assert.Nil(t, info) }) } func TestTokenManager_ConcurrentAccess(t *testing.T) { rdb := setupTestRedis(t) tm := NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour) ctx := context.Background() t.Run("并发生成 token", func(t *testing.T) { done := make(chan bool, 10) for i := 0; i < 10; i++ { go func(id int) { tokenInfo := &TokenInfo{ UserID: uint(id), UserType: 1, Username: "user", } _, _, err := tm.GenerateTokenPair(ctx, tokenInfo) assert.NoError(t, err) done <- true }(i) } for i := 0; i < 10; i++ { <-done } }) }