package testutils import ( "context" "fmt" "testing" "time" "github.com/redis/go-redis/v9" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" "github.com/break/junhong_cmp_fiber/internal/model" ) // SetupTestDB 设置测试数据库和 Redis(使用事务) func SetupTestDB(t *testing.T) (*gorm.DB, *redis.Client) { t.Helper() // 连接测试数据库(使用远程数据库) dsn := "host=cxd.whcxd.cn port=16159 user=erp_pgsql password=erp_2025 dbname=junhong_cmp_test sslmode=disable TimeZone=Asia/Shanghai" db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { t.Skipf("跳过测试:无法连接测试数据库: %v", err) } // 自动迁移测试表 err = db.AutoMigrate( &model.Account{}, &model.Role{}, &model.Permission{}, &model.AccountRole{}, &model.RolePermission{}, &model.Shop{}, &model.Enterprise{}, &model.PersonalCustomer{}, ) if err != nil { t.Fatalf("数据库迁移失败: %v", err) } txDB := db.Begin() if txDB.Error != nil { t.Fatalf("开启事务失败: %v", txDB.Error) } redisClient := redis.NewClient(&redis.Options{ Addr: "cxd.whcxd.cn:16299", Password: "cpNbWtAaqgo1YJmbMp3h", DB: 15, }) ctx := context.Background() if err := redisClient.Ping(ctx).Err(); err != nil { t.Skipf("跳过测试:无法连接 Redis: %v", err) } testPrefix := fmt.Sprintf("test:%s:", t.Name()) keys, _ := redisClient.Keys(ctx, testPrefix+"*").Result() if len(keys) > 0 { redisClient.Del(ctx, keys...) } return txDB, redisClient } // TeardownTestDB 清理测试数据库(回滚事务) func TeardownTestDB(t *testing.T, db *gorm.DB, redisClient *redis.Client) { t.Helper() ctx := context.Background() testPrefix := fmt.Sprintf("test:%s:", t.Name()) keys, _ := redisClient.Keys(ctx, testPrefix+"*").Result() if len(keys) > 0 { redisClient.Del(ctx, keys...) } db.Rollback() _ = redisClient.Close() } // GenerateUsername 生成测试用户名 func GenerateUsername(prefix string, index int) string { return fmt.Sprintf("%s_%d", prefix, index) } // GeneratePhone 生成测试手机号 func GeneratePhone(prefix string, index int) string { return fmt.Sprintf("%s%08d", prefix, index) } // Now 返回当前时间 func Now() time.Time { return time.Now() } // Since 返回从指定时间到现在的持续时间 func Since(t time.Time) time.Duration { return time.Since(t) }