refactor: align framework cleanup with new bootstrap flow

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
This commit is contained in:
2025-11-19 12:47:25 +08:00
parent 39d14ec093
commit d66323487b
67 changed files with 3020 additions and 3992 deletions

View File

@@ -21,6 +21,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/break/junhong_cmp_fiber/internal/bootstrap"
"github.com/break/junhong_cmp_fiber/internal/handler"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/routes"
@@ -116,8 +117,8 @@ func setupTestEnv(t *testing.T) *testEnv {
})
// 注册路由
services := &routes.Services{
AccountHandler: accountHandler,
services := &bootstrap.Handlers{
Account: accountHandler,
}
routes.RegisterRoutes(app, services)

View File

@@ -19,6 +19,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/break/junhong_cmp_fiber/internal/bootstrap"
"github.com/break/junhong_cmp_fiber/internal/handler"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/routes"
@@ -126,10 +127,10 @@ func setupRegressionTestEnv(t *testing.T) *regressionTestEnv {
})
// 注册所有路由
services := &routes.Services{
AccountHandler: accountHandler,
RoleHandler: roleHandler,
PermissionHandler: permHandler,
services := &bootstrap.Handlers{
Account: accountHandler,
Role: roleHandler,
Permission: permHandler,
}
routes.RegisterRoutes(app, services)

View File

@@ -7,10 +7,10 @@ import (
"testing"
"time"
"github.com/break/junhong_cmp_fiber/internal/middleware"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/break/junhong_cmp_fiber/pkg/response"
"github.com/break/junhong_cmp_fiber/pkg/validator"
"github.com/gofiber/fiber/v2"
@@ -52,7 +52,16 @@ func setupAuthTestApp(t *testing.T, rdb *redis.Client) *fiber.App {
// Add authentication middleware
tokenValidator := validator.NewTokenValidator(rdb, logger.GetAppLogger())
app.Use(middleware.KeyAuth(tokenValidator, logger.GetAppLogger()))
app.Use(middleware.Auth(middleware.AuthConfig{
TokenValidator: func(token string) (uint, int, uint, error) {
_, err := tokenValidator.Validate(token)
if err != nil {
return 0, 0, 0, err
}
// 测试中简化处理userID 设为 1userType 设为普通用户
return 1, 0, 0, nil
},
}))
// Add protected test routes
app.Get("/api/v1/test", func(c *fiber.Ctx) error {
@@ -342,14 +351,23 @@ func TestKeyAuthMiddleware_UserIDPropagation(t *testing.T) {
// Add authentication middleware
tokenValidator := validator.NewTokenValidator(rdb, logger.GetAppLogger())
app.Use(middleware.KeyAuth(tokenValidator, logger.GetAppLogger()))
app.Use(middleware.Auth(middleware.AuthConfig{
TokenValidator: func(token string) (uint, int, uint, error) {
_, err := tokenValidator.Validate(token)
if err != nil {
return 0, 0, 0, err
}
// 测试中简化处理userID 设为 1userType 设为普通用户
return 1, 0, 0, nil
},
}))
// Add test route that checks user ID
var capturedUserID string
var capturedUserID uint
app.Get("/api/v1/check-user", func(c *fiber.Ctx) error {
userID, ok := c.Locals(constants.ContextKeyUserID).(string)
userID, ok := c.Locals(constants.ContextKeyUserID).(uint)
if !ok {
return response.Error(c, 500, errors.CodeInternalError, "User ID not found in context")
return errors.New(errors.CodeInternalError, "User ID not found in context")
}
capturedUserID = userID
return response.Success(c, fiber.Map{

View File

@@ -1,325 +0,0 @@
package integration
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
)
// TestDataPermission_HierarchicalFiltering 测试层级数据权限过滤
func TestDataPermission_HierarchicalFiltering(t *testing.T) {
store, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
db := store.DB()
// 创建层级结构: A -> B -> C
accountA := &model.Account{
Username: "user_a",
Phone: "13800000001",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
Status: constants.StatusEnabled,
Creator: 1,
Updater: 1,
}
require.NoError(t, db.Create(accountA).Error)
shopID := uint(100)
accountA.ShopID = &shopID
require.NoError(t, db.Save(accountA).Error)
accountB := &model.Account{
Username: "user_b",
Phone: "13800000002",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &accountA.ID,
ShopID: &shopID,
Status: constants.StatusEnabled,
Creator: 1,
Updater: 1,
}
require.NoError(t, db.Create(accountB).Error)
accountC := &model.Account{
Username: "user_c",
Phone: "13800000003",
Password: "hashed_password",
UserType: constants.UserTypeEnterprise,
ParentID: &accountB.ID,
ShopID: &shopID,
Status: constants.StatusEnabled,
Creator: 1,
Updater: 1,
}
require.NoError(t, db.Create(accountC).Error)
// 创建测试数据表
type TestData struct {
ID uint `gorm:"primarykey"`
Name string
OwnerID uint
ShopID uint
}
require.NoError(t, db.AutoMigrate(&TestData{}))
// 插入测试数据
testData := []TestData{
{Name: "data_a", OwnerID: accountA.ID, ShopID: 100},
{Name: "data_b", OwnerID: accountB.ID, ShopID: 100},
{Name: "data_c", OwnerID: accountC.ID, ShopID: 100},
{Name: "data_other", OwnerID: 999, ShopID: 100},
}
require.NoError(t, db.Create(&testData).Error)
// 创建 AccountStore 用于递归查询
accountStore := postgres.NewAccountStore(db, nil) // Redis 可选
t.Run("A 用户可以看到 A、B、C 的数据", func(t *testing.T) {
ctxWithA := middleware.SetUserContext(ctx, accountA.ID, constants.UserTypePlatform, 100)
var results []TestData
err := db.WithContext(ctxWithA).
Scopes(postgres.DataPermissionScope(ctxWithA, accountStore)).
Find(&results).Error
require.NoError(t, err)
assert.Len(t, results, 3)
})
t.Run("B 用户可以看到 B、C 的数据", func(t *testing.T) {
ctxWithB := middleware.SetUserContext(ctx, accountB.ID, constants.UserTypeAgent, 100)
var results []TestData
err := db.WithContext(ctxWithB).
Scopes(postgres.DataPermissionScope(ctxWithB, accountStore)).
Find(&results).Error
require.NoError(t, err)
assert.Len(t, results, 2)
})
t.Run("C 用户只能看到自己的数据", func(t *testing.T) {
ctxWithC := middleware.SetUserContext(ctx, accountC.ID, constants.UserTypeEnterprise, 100)
var results []TestData
err := db.WithContext(ctxWithC).
Scopes(postgres.DataPermissionScope(ctxWithC, accountStore)).
Find(&results).Error
require.NoError(t, err)
assert.Len(t, results, 1)
assert.Equal(t, "data_c", results[0].Name)
})
}
// TestDataPermission_WithoutDataFilter 测试 WithoutDataFilter 选项
func TestDataPermission_WithoutDataFilter(t *testing.T) {
store, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
db := store.DB()
// 创建测试账号
shopID := uint(100)
accountA := &model.Account{
Username: "user_a",
Phone: "13800000001",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
ShopID: &shopID,
Status: constants.StatusEnabled,
Creator: 1,
Updater: 1,
}
require.NoError(t, db.Create(accountA).Error)
// 创建测试数据表
type TestData struct {
ID uint `gorm:"primarykey"`
Name string
OwnerID uint
ShopID uint
}
require.NoError(t, db.AutoMigrate(&TestData{}))
// 插入测试数据
testData := []TestData{
{Name: "data_a", OwnerID: accountA.ID, ShopID: 100},
{Name: "data_b", OwnerID: 999, ShopID: 100},
{Name: "data_c", OwnerID: 888, ShopID: 200},
}
require.NoError(t, db.Create(&testData).Error)
// 创建 AccountStore
accountStore := postgres.NewAccountStore(db, nil)
t.Run("正常查询应该过滤数据", func(t *testing.T) {
ctxWithA := middleware.SetUserContext(ctx, accountA.ID, constants.UserTypePlatform, 100)
var results []TestData
err := db.WithContext(ctxWithA).
Scopes(postgres.DataPermissionScope(ctxWithA, accountStore)).
Find(&results).Error
require.NoError(t, err)
assert.Len(t, results, 1)
})
t.Run("不带用户上下文时返回空数据", func(t *testing.T) {
var results []TestData
err := db.WithContext(ctx).
Scopes(postgres.DataPermissionScope(ctx, accountStore)).
Find(&results).Error
require.NoError(t, err)
assert.Len(t, results, 0)
})
}
// TestDataPermission_CrossShopIsolation 测试跨店铺数据隔离
func TestDataPermission_CrossShopIsolation(t *testing.T) {
store, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
db := store.DB()
// 创建两个不同店铺的账号
shopID100 := uint(100)
accountA := &model.Account{
Username: "user_shop100",
Phone: "13800000001",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
ShopID: &shopID100,
Status: constants.StatusEnabled,
Creator: 1,
Updater: 1,
}
require.NoError(t, db.Create(accountA).Error)
shopID200 := uint(200)
accountB := &model.Account{
Username: "user_shop200",
Phone: "13800000002",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
ShopID: &shopID200,
Status: constants.StatusEnabled,
Creator: 1,
Updater: 1,
}
require.NoError(t, db.Create(accountB).Error)
// 创建测试数据表
type TestData struct {
ID uint `gorm:"primarykey"`
Name string
OwnerID uint
ShopID uint
}
require.NoError(t, db.AutoMigrate(&TestData{}))
// 插入不同店铺的数据
testData := []TestData{
{Name: "data_shop100_1", OwnerID: accountA.ID, ShopID: 100},
{Name: "data_shop100_2", OwnerID: accountA.ID, ShopID: 100},
{Name: "data_shop200_1", OwnerID: accountB.ID, ShopID: 200},
{Name: "data_shop200_2", OwnerID: accountB.ID, ShopID: 200},
}
require.NoError(t, db.Create(&testData).Error)
// 创建 AccountStore
accountStore := postgres.NewAccountStore(db, nil)
t.Run("店铺 100 用户只能看到店铺 100 的数据", func(t *testing.T) {
ctxWithA := middleware.SetUserContext(ctx, accountA.ID, constants.UserTypePlatform, 100)
var results []TestData
err := db.WithContext(ctxWithA).
Scopes(postgres.DataPermissionScope(ctxWithA, accountStore)).
Find(&results).Error
require.NoError(t, err)
assert.Len(t, results, 2)
for _, r := range results {
assert.Equal(t, uint(100), r.ShopID)
}
})
t.Run("店铺 200 用户只能看到店铺 200 的数据", func(t *testing.T) {
ctxWithB := middleware.SetUserContext(ctx, accountB.ID, constants.UserTypePlatform, 200)
var results []TestData
err := db.WithContext(ctxWithB).
Scopes(postgres.DataPermissionScope(ctxWithB, accountStore)).
Find(&results).Error
require.NoError(t, err)
assert.Len(t, results, 2)
for _, r := range results {
assert.Equal(t, uint(200), r.ShopID)
}
})
t.Run("店铺 100 用户看不到店铺 200 的数据", func(t *testing.T) {
ctxWithA := middleware.SetUserContext(ctx, accountA.ID, constants.UserTypePlatform, 100)
var results []TestData
err := db.WithContext(ctxWithA).
Scopes(postgres.DataPermissionScope(ctxWithA, accountStore)).
Where("shop_id = ?", 200). // 尝试查询店铺 200 的数据
Find(&results).Error
require.NoError(t, err)
assert.Len(t, results, 0, "不应该看到其他店铺的数据")
})
}
// TestDataPermission_RootUserBypass 测试 root 用户跳过数据权限过滤
func TestDataPermission_RootUserBypass(t *testing.T) {
store, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
db := store.DB()
// 创建 root 用户
rootUser := &model.Account{
Username: "root_user",
Phone: "13800000000",
Password: "hashed_password",
UserType: constants.UserTypeRoot,
Status: constants.StatusEnabled,
Creator: 1,
Updater: 1,
}
require.NoError(t, db.Create(rootUser).Error)
// 创建测试数据表
type TestData struct {
ID uint `gorm:"primarykey"`
Name string
OwnerID uint
ShopID uint
}
require.NoError(t, db.AutoMigrate(&TestData{}))
// 插入不同店铺、不同用户的数据
testData := []TestData{
{Name: "data_1", OwnerID: 1, ShopID: 100},
{Name: "data_2", OwnerID: 2, ShopID: 200},
{Name: "data_3", OwnerID: 3, ShopID: 300},
{Name: "data_4", OwnerID: 4, ShopID: 400},
}
require.NoError(t, db.Create(&testData).Error)
// 创建 AccountStore
accountStore := postgres.NewAccountStore(db, nil)
t.Run("root 用户可以看到所有数据", func(t *testing.T) {
ctxWithRoot := middleware.SetUserContext(ctx, rootUser.ID, constants.UserTypeRoot, 100)
var results []TestData
err := db.WithContext(ctxWithRoot).
Scopes(postgres.DataPermissionScope(ctxWithRoot, accountStore)).
Find(&results).Error
require.NoError(t, err)
assert.Len(t, results, 4, "root 用户应该看到所有数据")
})
}

View File

@@ -1,489 +0,0 @@
package integration
import (
"context"
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
_ "github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
testcontainers_postgres "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"go.uber.org/zap"
postgresDriver "gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// TestMain 设置测试环境
func TestMain(m *testing.M) {
code := m.Run()
os.Exit(code)
}
// setupTestDB 启动 PostgreSQL 容器并使用迁移脚本初始化数据库
func setupTestDB(t *testing.T) (*postgres.Store, func()) {
ctx := context.Background()
// 启动 PostgreSQL 容器
postgresContainer, err := testcontainers_postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:14-alpine"),
testcontainers_postgres.WithDatabase("testdb"),
testcontainers_postgres.WithUsername("postgres"),
testcontainers_postgres.WithPassword("password"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(30*time.Second),
),
)
require.NoError(t, err, "启动 PostgreSQL 容器失败")
// 获取连接字符串
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
require.NoError(t, err, "获取数据库连接字符串失败")
// 应用数据库迁移
migrationsPath := getMigrationsPath(t)
m, err := migrate.New(
fmt.Sprintf("file://%s", migrationsPath),
connStr,
)
require.NoError(t, err, "创建迁移实例失败")
// 执行向上迁移
err = m.Up()
require.NoError(t, err, "执行数据库迁移失败")
// 连接数据库
db, err := gorm.Open(postgresDriver.Open(connStr), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err, "连接数据库失败")
// 创建测试 logger
testLogger := zap.NewNop()
store := postgres.NewStore(db, testLogger)
// 返回清理函数
cleanup := func() {
// 执行向下迁移清理数据
if err := m.Down(); err != nil && err != migrate.ErrNoChange {
t.Logf("清理迁移失败: %v", err)
}
_, _ = m.Close()
sqlDB, _ := db.DB()
if sqlDB != nil {
_ = sqlDB.Close()
}
if err := postgresContainer.Terminate(ctx); err != nil {
t.Logf("终止容器失败: %v", err)
}
}
return store, cleanup
}
// getMigrationsPath 获取迁移文件路径
func getMigrationsPath(t *testing.T) string {
// 获取项目根目录
wd, err := os.Getwd()
require.NoError(t, err, "获取工作目录失败")
// 从测试目录向上找到项目根目录
migrationsPath := filepath.Join(wd, "..", "..", "migrations")
// 验证迁移目录存在
_, err = os.Stat(migrationsPath)
require.NoError(t, err, fmt.Sprintf("迁移目录不存在: %s", migrationsPath))
return migrationsPath
}
// TestUserCRUD 测试用户 CRUD 操作
func TestUserCRUD(t *testing.T) {
store, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("创建用户", func(t *testing.T) {
user := &model.User{
Username: "testuser",
Email: "test@example.com",
Password: "hashedpassword",
Status: constants.UserStatusActive,
}
err := store.User.Create(ctx, user)
assert.NoError(t, err)
assert.NotZero(t, user.ID)
assert.NotZero(t, user.CreatedAt)
assert.NotZero(t, user.UpdatedAt)
})
t.Run("根据ID查询用户", func(t *testing.T) {
// 创建测试用户
user := &model.User{
Username: "queryuser",
Email: "query@example.com",
Password: "hashedpassword",
Status: constants.UserStatusActive,
}
err := store.User.Create(ctx, user)
require.NoError(t, err)
// 查询用户
found, err := store.User.GetByID(ctx, user.ID)
assert.NoError(t, err)
assert.Equal(t, user.Username, found.Username)
assert.Equal(t, user.Email, found.Email)
assert.Equal(t, constants.UserStatusActive, found.Status)
})
t.Run("根据用户名查询用户", func(t *testing.T) {
// 创建测试用户
user := &model.User{
Username: "findbyname",
Email: "findbyname@example.com",
Password: "hashedpassword",
Status: constants.UserStatusActive,
}
err := store.User.Create(ctx, user)
require.NoError(t, err)
// 根据用户名查询
found, err := store.User.GetByUsername(ctx, "findbyname")
assert.NoError(t, err)
assert.Equal(t, user.ID, found.ID)
assert.Equal(t, user.Email, found.Email)
})
t.Run("更新用户", func(t *testing.T) {
// 创建测试用户
user := &model.User{
Username: "updateuser",
Email: "update@example.com",
Password: "hashedpassword",
Status: constants.UserStatusActive,
}
err := store.User.Create(ctx, user)
require.NoError(t, err)
// 更新用户
user.Email = "newemail@example.com"
user.Status = constants.UserStatusInactive
err = store.User.Update(ctx, user)
assert.NoError(t, err)
// 验证更新
found, err := store.User.GetByID(ctx, user.ID)
assert.NoError(t, err)
assert.Equal(t, "newemail@example.com", found.Email)
assert.Equal(t, constants.UserStatusInactive, found.Status)
})
t.Run("列表查询用户", func(t *testing.T) {
// 创建多个测试用户
for i := 1; i <= 5; i++ {
user := &model.User{
Username: fmt.Sprintf("listuser%d", i),
Email: fmt.Sprintf("list%d@example.com", i),
Password: "hashedpassword",
Status: constants.UserStatusActive,
}
err := store.User.Create(ctx, user)
require.NoError(t, err)
}
// 列表查询
users, total, err := store.User.List(ctx, 1, 3)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(users), 3)
assert.GreaterOrEqual(t, total, int64(5))
})
t.Run("软删除用户", func(t *testing.T) {
// 创建测试用户
user := &model.User{
Username: "deleteuser",
Email: "delete@example.com",
Password: "hashedpassword",
Status: constants.UserStatusActive,
}
err := store.User.Create(ctx, user)
require.NoError(t, err)
// 软删除
err = store.User.Delete(ctx, user.ID)
assert.NoError(t, err)
// 验证已删除(查询应该找不到)
_, err = store.User.GetByID(ctx, user.ID)
assert.Error(t, err)
assert.Equal(t, gorm.ErrRecordNotFound, err)
})
}
// TestOrderCRUD 测试订单 CRUD 操作
func TestOrderCRUD(t *testing.T) {
store, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
// 创建测试用户
user := &model.User{
Username: "orderuser",
Email: "orderuser@example.com",
Password: "hashedpassword",
Status: constants.UserStatusActive,
}
err := store.User.Create(ctx, user)
require.NoError(t, err)
t.Run("创建订单", func(t *testing.T) {
order := &model.Order{
OrderID: "ORD-001",
UserID: user.ID,
Amount: 10000,
Status: constants.OrderStatusPending,
Remark: "测试订单",
}
err := store.Order.Create(ctx, order)
assert.NoError(t, err)
assert.NotZero(t, order.ID)
assert.NotZero(t, order.CreatedAt)
})
t.Run("根据ID查询订单", func(t *testing.T) {
// 创建测试订单
order := &model.Order{
OrderID: "ORD-002",
UserID: user.ID,
Amount: 20000,
Status: constants.OrderStatusPending,
}
err := store.Order.Create(ctx, order)
require.NoError(t, err)
// 查询订单
found, err := store.Order.GetByID(ctx, order.ID)
assert.NoError(t, err)
assert.Equal(t, order.OrderID, found.OrderID)
assert.Equal(t, order.Amount, found.Amount)
})
t.Run("根据订单号查询", func(t *testing.T) {
// 创建测试订单
order := &model.Order{
OrderID: "ORD-003",
UserID: user.ID,
Amount: 30000,
Status: constants.OrderStatusPending,
}
err := store.Order.Create(ctx, order)
require.NoError(t, err)
// 根据订单号查询
found, err := store.Order.GetByOrderID(ctx, "ORD-003")
assert.NoError(t, err)
assert.Equal(t, order.ID, found.ID)
})
t.Run("根据用户ID列表查询", func(t *testing.T) {
// 创建多个订单
for i := 1; i <= 3; i++ {
order := &model.Order{
OrderID: fmt.Sprintf("ORD-USER-%d", i),
UserID: user.ID,
Amount: int64(i * 10000),
Status: constants.OrderStatusPending,
}
err := store.Order.Create(ctx, order)
require.NoError(t, err)
}
// 列表查询
orders, total, err := store.Order.ListByUserID(ctx, user.ID, 1, 10)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(orders), 3)
assert.GreaterOrEqual(t, total, int64(3))
})
t.Run("更新订单状态", func(t *testing.T) {
// 创建测试订单
order := &model.Order{
OrderID: "ORD-UPDATE",
UserID: user.ID,
Amount: 50000,
Status: constants.OrderStatusPending,
}
err := store.Order.Create(ctx, order)
require.NoError(t, err)
// 更新状态
now := time.Now()
order.Status = constants.OrderStatusPaid
order.PaidAt = &now
err = store.Order.Update(ctx, order)
assert.NoError(t, err)
// 验证更新
found, err := store.Order.GetByID(ctx, order.ID)
assert.NoError(t, err)
assert.Equal(t, constants.OrderStatusPaid, found.Status)
assert.NotNil(t, found.PaidAt)
})
t.Run("软删除订单", func(t *testing.T) {
// 创建测试订单
order := &model.Order{
OrderID: "ORD-DELETE",
UserID: user.ID,
Amount: 60000,
Status: constants.OrderStatusPending,
}
err := store.Order.Create(ctx, order)
require.NoError(t, err)
// 软删除
err = store.Order.Delete(ctx, order.ID)
assert.NoError(t, err)
// 验证已删除
_, err = store.Order.GetByID(ctx, order.ID)
assert.Error(t, err)
assert.Equal(t, gorm.ErrRecordNotFound, err)
})
}
// TestTransaction 测试事务功能
func TestTransaction(t *testing.T) {
store, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("事务提交", func(t *testing.T) {
var userID uint
var orderID uint
err := store.Transaction(ctx, func(tx *postgres.Store) error {
// 创建用户
user := &model.User{
Username: "txuser",
Email: "txuser@example.com",
Password: "hashedpassword",
Status: constants.UserStatusActive,
}
if err := tx.User.Create(ctx, user); err != nil {
return err
}
userID = user.ID
// 创建订单
order := &model.Order{
OrderID: "ORD-TX-001",
UserID: user.ID,
Amount: 10000,
Status: constants.OrderStatusPending,
}
if err := tx.Order.Create(ctx, order); err != nil {
return err
}
orderID = order.ID
return nil
})
assert.NoError(t, err)
// 验证用户和订单都已创建
user, err := store.User.GetByID(ctx, userID)
assert.NoError(t, err)
assert.Equal(t, "txuser", user.Username)
order, err := store.Order.GetByID(ctx, orderID)
assert.NoError(t, err)
assert.Equal(t, "ORD-TX-001", order.OrderID)
})
t.Run("事务回滚", func(t *testing.T) {
var userID uint
err := store.Transaction(ctx, func(tx *postgres.Store) error {
// 创建用户
user := &model.User{
Username: "rollbackuser",
Email: "rollback@example.com",
Password: "hashedpassword",
Status: constants.UserStatusActive,
}
if err := tx.User.Create(ctx, user); err != nil {
return err
}
userID = user.ID
// 模拟错误,触发回滚
return fmt.Errorf("模拟错误")
})
assert.Error(t, err)
assert.Equal(t, "模拟错误", err.Error())
// 验证用户未创建(已回滚)
_, err = store.User.GetByID(ctx, userID)
assert.Error(t, err)
assert.Equal(t, gorm.ErrRecordNotFound, err)
})
}
// TestConcurrentOperations 测试并发操作
func TestConcurrentOperations(t *testing.T) {
store, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("并发创建用户", func(t *testing.T) {
concurrency := 10
errChan := make(chan error, concurrency)
for i := 0; i < concurrency; i++ {
go func(index int) {
user := &model.User{
Username: fmt.Sprintf("concurrent%d", index),
Email: fmt.Sprintf("concurrent%d@example.com", index),
Password: "hashedpassword",
Status: constants.UserStatusActive,
}
errChan <- store.User.Create(ctx, user)
}(i)
}
// 收集结果
successCount := 0
for i := 0; i < concurrency; i++ {
err := <-errChan
if err == nil {
successCount++
}
}
assert.Equal(t, concurrency, successCount, "所有并发创建应该成功")
})
}

View File

@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/break/junhong_cmp_fiber/tests/testutils"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
@@ -49,7 +50,7 @@ func TestMigration_UpAndDown(t *testing.T) {
require.NoError(t, err, "获取数据库连接字符串失败")
// 应用数据库迁移
migrationsPath := getMigrationsPath(t)
migrationsPath := testutils.GetMigrationsPath()
m, err := migrate.New(
fmt.Sprintf("file://%s", migrationsPath),
connStr,
@@ -135,7 +136,7 @@ func TestMigration_UpAndDown(t *testing.T) {
// TestMigration_NoForeignKeys 验证迁移脚本不包含外键约束
func TestMigration_NoForeignKeys(t *testing.T) {
// 获取迁移目录
migrationsPath := getMigrationsPath(t)
migrationsPath := testutils.GetMigrationsPath()
// 读取所有迁移文件
files, err := filepath.Glob(filepath.Join(migrationsPath, "*.up.sql"))
@@ -187,7 +188,7 @@ func TestMigration_SoftDeleteSupport(t *testing.T) {
require.NoError(t, err, "获取数据库连接字符串失败")
// 应用迁移
migrationsPath := getMigrationsPath(t)
migrationsPath := testutils.GetMigrationsPath()
m, err := migrate.New(
fmt.Sprintf("file://%s", migrationsPath),
connStr,

View File

@@ -19,6 +19,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/break/junhong_cmp_fiber/internal/bootstrap"
"github.com/break/junhong_cmp_fiber/internal/handler"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/routes"
@@ -90,8 +91,8 @@ func setupPermTestEnv(t *testing.T) *permTestEnv {
})
// 注册路由
services := &routes.Services{
PermissionHandler: permHandler,
services := &bootstrap.Handlers{
Permission: permHandler,
}
routes.RegisterRoutes(app, services)

View File

@@ -21,6 +21,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/break/junhong_cmp_fiber/internal/bootstrap"
"github.com/break/junhong_cmp_fiber/internal/handler"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/routes"
@@ -116,8 +117,8 @@ func setupRoleTestEnv(t *testing.T) *roleTestEnv {
})
// 注册路由
services := &routes.Services{
RoleHandler: roleHandler,
services := &bootstrap.Handlers{
Role: roleHandler,
}
routes.RegisterRoutes(app, services)