package integration import ( "context" "fmt" "os" "path/filepath" "testing" "time" "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" _ "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" testcontainers_postgres "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" "go.uber.org/zap" postgresDriver "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" ) // TestMain 设置测试环境 func TestMain(m *testing.M) { code := m.Run() os.Exit(code) } // setupTestDB 启动 PostgreSQL 容器并使用迁移脚本初始化数据库 func setupTestDB(t *testing.T) (*postgres.Store, func()) { ctx := context.Background() // 启动 PostgreSQL 容器 postgresContainer, err := testcontainers_postgres.RunContainer(ctx, testcontainers.WithImage("postgres:14-alpine"), testcontainers_postgres.WithDatabase("testdb"), testcontainers_postgres.WithUsername("postgres"), testcontainers_postgres.WithPassword("password"), testcontainers.WithWaitStrategy( wait.ForLog("database system is ready to accept connections"). WithOccurrence(2). WithStartupTimeout(30*time.Second), ), ) require.NoError(t, err, "启动 PostgreSQL 容器失败") // 获取连接字符串 connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable") require.NoError(t, err, "获取数据库连接字符串失败") // 应用数据库迁移 migrationsPath := getMigrationsPath(t) m, err := migrate.New( fmt.Sprintf("file://%s", migrationsPath), connStr, ) require.NoError(t, err, "创建迁移实例失败") // 执行向上迁移 err = m.Up() require.NoError(t, err, "执行数据库迁移失败") // 连接数据库 db, err := gorm.Open(postgresDriver.Open(connStr), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) require.NoError(t, err, "连接数据库失败") // 创建测试 logger testLogger := zap.NewNop() store := postgres.NewStore(db, testLogger) // 返回清理函数 cleanup := func() { // 执行向下迁移清理数据 if err := m.Down(); err != nil && err != migrate.ErrNoChange { t.Logf("清理迁移失败: %v", err) } m.Close() sqlDB, _ := db.DB() if sqlDB != nil { sqlDB.Close() } if err := postgresContainer.Terminate(ctx); err != nil { t.Logf("终止容器失败: %v", err) } } return store, cleanup } // getMigrationsPath 获取迁移文件路径 func getMigrationsPath(t *testing.T) string { // 获取项目根目录 wd, err := os.Getwd() require.NoError(t, err, "获取工作目录失败") // 从测试目录向上找到项目根目录 migrationsPath := filepath.Join(wd, "..", "..", "migrations") // 验证迁移目录存在 _, err = os.Stat(migrationsPath) require.NoError(t, err, fmt.Sprintf("迁移目录不存在: %s", migrationsPath)) return migrationsPath } // TestUserCRUD 测试用户 CRUD 操作 func TestUserCRUD(t *testing.T) { store, cleanup := setupTestDB(t) defer cleanup() ctx := context.Background() t.Run("创建用户", func(t *testing.T) { user := &model.User{ Username: "testuser", Email: "test@example.com", Password: "hashedpassword", Status: constants.UserStatusActive, } err := store.User.Create(ctx, user) assert.NoError(t, err) assert.NotZero(t, user.ID) assert.NotZero(t, user.CreatedAt) assert.NotZero(t, user.UpdatedAt) }) t.Run("根据ID查询用户", func(t *testing.T) { // 创建测试用户 user := &model.User{ Username: "queryuser", Email: "query@example.com", Password: "hashedpassword", Status: constants.UserStatusActive, } err := store.User.Create(ctx, user) require.NoError(t, err) // 查询用户 found, err := store.User.GetByID(ctx, user.ID) assert.NoError(t, err) assert.Equal(t, user.Username, found.Username) assert.Equal(t, user.Email, found.Email) assert.Equal(t, constants.UserStatusActive, found.Status) }) t.Run("根据用户名查询用户", func(t *testing.T) { // 创建测试用户 user := &model.User{ Username: "findbyname", Email: "findbyname@example.com", Password: "hashedpassword", Status: constants.UserStatusActive, } err := store.User.Create(ctx, user) require.NoError(t, err) // 根据用户名查询 found, err := store.User.GetByUsername(ctx, "findbyname") assert.NoError(t, err) assert.Equal(t, user.ID, found.ID) assert.Equal(t, user.Email, found.Email) }) t.Run("更新用户", func(t *testing.T) { // 创建测试用户 user := &model.User{ Username: "updateuser", Email: "update@example.com", Password: "hashedpassword", Status: constants.UserStatusActive, } err := store.User.Create(ctx, user) require.NoError(t, err) // 更新用户 user.Email = "newemail@example.com" user.Status = constants.UserStatusInactive err = store.User.Update(ctx, user) assert.NoError(t, err) // 验证更新 found, err := store.User.GetByID(ctx, user.ID) assert.NoError(t, err) assert.Equal(t, "newemail@example.com", found.Email) assert.Equal(t, constants.UserStatusInactive, found.Status) }) t.Run("列表查询用户", func(t *testing.T) { // 创建多个测试用户 for i := 1; i <= 5; i++ { user := &model.User{ Username: fmt.Sprintf("listuser%d", i), Email: fmt.Sprintf("list%d@example.com", i), Password: "hashedpassword", Status: constants.UserStatusActive, } err := store.User.Create(ctx, user) require.NoError(t, err) } // 列表查询 users, total, err := store.User.List(ctx, 1, 3) assert.NoError(t, err) assert.GreaterOrEqual(t, len(users), 3) assert.GreaterOrEqual(t, total, int64(5)) }) t.Run("软删除用户", func(t *testing.T) { // 创建测试用户 user := &model.User{ Username: "deleteuser", Email: "delete@example.com", Password: "hashedpassword", Status: constants.UserStatusActive, } err := store.User.Create(ctx, user) require.NoError(t, err) // 软删除 err = store.User.Delete(ctx, user.ID) assert.NoError(t, err) // 验证已删除(查询应该找不到) _, err = store.User.GetByID(ctx, user.ID) assert.Error(t, err) assert.Equal(t, gorm.ErrRecordNotFound, err) }) } // TestOrderCRUD 测试订单 CRUD 操作 func TestOrderCRUD(t *testing.T) { store, cleanup := setupTestDB(t) defer cleanup() ctx := context.Background() // 创建测试用户 user := &model.User{ Username: "orderuser", Email: "orderuser@example.com", Password: "hashedpassword", Status: constants.UserStatusActive, } err := store.User.Create(ctx, user) require.NoError(t, err) t.Run("创建订单", func(t *testing.T) { order := &model.Order{ OrderID: "ORD-001", UserID: user.ID, Amount: 10000, Status: constants.OrderStatusPending, Remark: "测试订单", } err := store.Order.Create(ctx, order) assert.NoError(t, err) assert.NotZero(t, order.ID) assert.NotZero(t, order.CreatedAt) }) t.Run("根据ID查询订单", func(t *testing.T) { // 创建测试订单 order := &model.Order{ OrderID: "ORD-002", UserID: user.ID, Amount: 20000, Status: constants.OrderStatusPending, } err := store.Order.Create(ctx, order) require.NoError(t, err) // 查询订单 found, err := store.Order.GetByID(ctx, order.ID) assert.NoError(t, err) assert.Equal(t, order.OrderID, found.OrderID) assert.Equal(t, order.Amount, found.Amount) }) t.Run("根据订单号查询", func(t *testing.T) { // 创建测试订单 order := &model.Order{ OrderID: "ORD-003", UserID: user.ID, Amount: 30000, Status: constants.OrderStatusPending, } err := store.Order.Create(ctx, order) require.NoError(t, err) // 根据订单号查询 found, err := store.Order.GetByOrderID(ctx, "ORD-003") assert.NoError(t, err) assert.Equal(t, order.ID, found.ID) }) t.Run("根据用户ID列表查询", func(t *testing.T) { // 创建多个订单 for i := 1; i <= 3; i++ { order := &model.Order{ OrderID: fmt.Sprintf("ORD-USER-%d", i), UserID: user.ID, Amount: int64(i * 10000), Status: constants.OrderStatusPending, } err := store.Order.Create(ctx, order) require.NoError(t, err) } // 列表查询 orders, total, err := store.Order.ListByUserID(ctx, user.ID, 1, 10) assert.NoError(t, err) assert.GreaterOrEqual(t, len(orders), 3) assert.GreaterOrEqual(t, total, int64(3)) }) t.Run("更新订单状态", func(t *testing.T) { // 创建测试订单 order := &model.Order{ OrderID: "ORD-UPDATE", UserID: user.ID, Amount: 50000, Status: constants.OrderStatusPending, } err := store.Order.Create(ctx, order) require.NoError(t, err) // 更新状态 now := time.Now() order.Status = constants.OrderStatusPaid order.PaidAt = &now err = store.Order.Update(ctx, order) assert.NoError(t, err) // 验证更新 found, err := store.Order.GetByID(ctx, order.ID) assert.NoError(t, err) assert.Equal(t, constants.OrderStatusPaid, found.Status) assert.NotNil(t, found.PaidAt) }) t.Run("软删除订单", func(t *testing.T) { // 创建测试订单 order := &model.Order{ OrderID: "ORD-DELETE", UserID: user.ID, Amount: 60000, Status: constants.OrderStatusPending, } err := store.Order.Create(ctx, order) require.NoError(t, err) // 软删除 err = store.Order.Delete(ctx, order.ID) assert.NoError(t, err) // 验证已删除 _, err = store.Order.GetByID(ctx, order.ID) assert.Error(t, err) assert.Equal(t, gorm.ErrRecordNotFound, err) }) } // TestTransaction 测试事务功能 func TestTransaction(t *testing.T) { store, cleanup := setupTestDB(t) defer cleanup() ctx := context.Background() t.Run("事务提交", func(t *testing.T) { var userID uint var orderID uint err := store.Transaction(ctx, func(tx *postgres.Store) error { // 创建用户 user := &model.User{ Username: "txuser", Email: "txuser@example.com", Password: "hashedpassword", Status: constants.UserStatusActive, } if err := tx.User.Create(ctx, user); err != nil { return err } userID = user.ID // 创建订单 order := &model.Order{ OrderID: "ORD-TX-001", UserID: user.ID, Amount: 10000, Status: constants.OrderStatusPending, } if err := tx.Order.Create(ctx, order); err != nil { return err } orderID = order.ID return nil }) assert.NoError(t, err) // 验证用户和订单都已创建 user, err := store.User.GetByID(ctx, userID) assert.NoError(t, err) assert.Equal(t, "txuser", user.Username) order, err := store.Order.GetByID(ctx, orderID) assert.NoError(t, err) assert.Equal(t, "ORD-TX-001", order.OrderID) }) t.Run("事务回滚", func(t *testing.T) { var userID uint err := store.Transaction(ctx, func(tx *postgres.Store) error { // 创建用户 user := &model.User{ Username: "rollbackuser", Email: "rollback@example.com", Password: "hashedpassword", Status: constants.UserStatusActive, } if err := tx.User.Create(ctx, user); err != nil { return err } userID = user.ID // 模拟错误,触发回滚 return fmt.Errorf("模拟错误") }) assert.Error(t, err) assert.Equal(t, "模拟错误", err.Error()) // 验证用户未创建(已回滚) _, err = store.User.GetByID(ctx, userID) assert.Error(t, err) assert.Equal(t, gorm.ErrRecordNotFound, err) }) } // TestConcurrentOperations 测试并发操作 func TestConcurrentOperations(t *testing.T) { store, cleanup := setupTestDB(t) defer cleanup() ctx := context.Background() t.Run("并发创建用户", func(t *testing.T) { concurrency := 10 errChan := make(chan error, concurrency) for i := 0; i < concurrency; i++ { go func(index int) { user := &model.User{ Username: fmt.Sprintf("concurrent%d", index), Email: fmt.Sprintf("concurrent%d@example.com", index), Password: "hashedpassword", Status: constants.UserStatusActive, } errChan <- store.User.Create(ctx, user) }(i) } // 收集结果 successCount := 0 for i := 0; i < concurrency; i++ { err := <-errChan if err == nil { successCount++ } } assert.Equal(t, concurrency, successCount, "所有并发创建应该成功") }) }