package testutil import ( "context" "fmt" "math/rand" "sync/atomic" "testing" "time" "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/pkg/auth" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) // phoneCounter 用于生成唯一的手机号 var phoneCounter uint64 func init() { // 使用当前时间作为随机种子 rand.Seed(time.Now().UnixNano()) // 初始化计数器为一个随机值,避免不同测试运行之间的冲突 phoneCounter = uint64(rand.Intn(10000)) } // GenerateUniquePhone 生成唯一的测试手机号(导出供测试使用) func GenerateUniquePhone() string { counter := atomic.AddUint64(&phoneCounter, 1) timestamp := time.Now().UnixNano() % 10000 return fmt.Sprintf("139%04d%04d", timestamp, counter%10000) } // CreateTestAccount 创建测试账号 // userType: 1=超级管理员, 2=平台用户, 3=代理账号, 4=企业账号 func CreateTestAccount(t *testing.T, db *gorm.DB, username, password string, userType int, shopID, enterpriseID *uint) *model.Account { t.Helper() hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) require.NoError(t, err) phone := GenerateUniquePhone() account := &model.Account{ BaseModel: model.BaseModel{ Creator: 1, Updater: 1, }, Username: username, Phone: phone, Password: string(hashedPassword), UserType: userType, ShopID: shopID, EnterpriseID: enterpriseID, Status: 1, } err = db.Create(account).Error require.NoError(t, err) return account } // GenerateTestToken 为测试账号生成 token func GenerateTestToken(t *testing.T, rdb *redis.Client, account *model.Account, device string) (accessToken, refreshToken string) { t.Helper() ctx := context.Background() var shopID, enterpriseID uint if account.ShopID != nil { shopID = *account.ShopID } if account.EnterpriseID != nil { enterpriseID = *account.EnterpriseID } tokenInfo := &auth.TokenInfo{ UserID: account.ID, UserType: account.UserType, ShopID: shopID, EnterpriseID: enterpriseID, Username: account.Username, Device: device, IP: "127.0.0.1", } tokenManager := auth.NewTokenManager(rdb, 24*time.Hour, 7*24*time.Hour) accessToken, refreshToken, err := tokenManager.GenerateTokenPair(ctx, tokenInfo) require.NoError(t, err) return accessToken, refreshToken } // usernameCounter 用于生成唯一的用户名 var usernameCounter uint64 func init() { usernameCounter = uint64(rand.Intn(100000)) } // GenerateUniqueUsername 生成唯一的测试用户名(导出供测试使用) func GenerateUniqueUsername(prefix string) string { counter := atomic.AddUint64(&usernameCounter, 1) return fmt.Sprintf("%s_%d", prefix, counter) } // CreateSuperAdmin 创建或获取超级管理员测试账号 func CreateSuperAdmin(t *testing.T, db *gorm.DB) *model.Account { t.Helper() var existing model.Account err := db.Where("user_type = ?", constants.UserTypeSuperAdmin).First(&existing).Error if err == nil { return &existing } return CreateTestAccount(t, db, GenerateUniqueUsername("superadmin"), "password123", constants.UserTypeSuperAdmin, nil, nil) } // CreatePlatformUser 创建平台用户测试账号 func CreatePlatformUser(t *testing.T, db *gorm.DB) *model.Account { t.Helper() return CreateTestAccount(t, db, GenerateUniqueUsername("platformuser"), "password123", constants.UserTypePlatform, nil, nil) } // CreateAgentUser 创建代理账号测试账号 func CreateAgentUser(t *testing.T, db *gorm.DB, shopID uint) *model.Account { t.Helper() return CreateTestAccount(t, db, GenerateUniqueUsername("agentuser"), "password123", constants.UserTypeAgent, &shopID, nil) } // CreateEnterpriseUser 创建企业账号测试账号 func CreateEnterpriseUser(t *testing.T, db *gorm.DB, enterpriseID uint) *model.Account { t.Helper() return CreateTestAccount(t, db, GenerateUniqueUsername("enterpriseuser"), "password123", constants.UserTypeEnterprise, nil, &enterpriseID) } // shopCodeCounter 用于生成唯一的商户代码 var shopCodeCounter uint64 // CreateTestShop 创建测试商户 func CreateTestShop(t *testing.T, db *gorm.DB, name, code string, level int, parentID *uint) *model.Shop { t.Helper() counter := atomic.AddUint64(&shopCodeCounter, 1) uniqueCode := fmt.Sprintf("%s_%d_%d", code, time.Now().UnixNano()%10000, counter) uniqueName := fmt.Sprintf("%s_%d", name, counter) shop := &model.Shop{ BaseModel: model.BaseModel{ Creator: 1, Updater: 1, }, ShopName: uniqueName, ShopCode: uniqueCode, Level: level, Status: 1, } if parentID != nil { shop.ParentID = parentID } err := db.Create(shop).Error require.NoError(t, err) return shop } // SetupAuthMiddleware 设置认证中间件(用于集成测试) func SetupAuthMiddleware(t *testing.T, tokenManager *auth.TokenManager, allowedUserTypes []int) func(token string) bool { t.Helper() return func(token string) bool { ctx := context.Background() tokenInfo, err := tokenManager.ValidateAccessToken(ctx, token) if err != nil { return false } // 检查用户类型 if len(allowedUserTypes) > 0 { allowed := false for _, userType := range allowedUserTypes { if tokenInfo.UserType == userType { allowed = true break } } if !allowed { return false } } return true } }