refactor: align framework cleanup with new bootstrap flow
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/bootstrap"
|
||||
"github.com/break/junhong_cmp_fiber/internal/handler"
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/internal/routes"
|
||||
@@ -116,8 +117,8 @@ func setupTestEnv(t *testing.T) *testEnv {
|
||||
})
|
||||
|
||||
// 注册路由
|
||||
services := &routes.Services{
|
||||
AccountHandler: accountHandler,
|
||||
services := &bootstrap.Handlers{
|
||||
Account: accountHandler,
|
||||
}
|
||||
routes.RegisterRoutes(app, services)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/bootstrap"
|
||||
"github.com/break/junhong_cmp_fiber/internal/handler"
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/internal/routes"
|
||||
@@ -126,10 +127,10 @@ func setupRegressionTestEnv(t *testing.T) *regressionTestEnv {
|
||||
})
|
||||
|
||||
// 注册所有路由
|
||||
services := &routes.Services{
|
||||
AccountHandler: accountHandler,
|
||||
RoleHandler: roleHandler,
|
||||
PermissionHandler: permHandler,
|
||||
services := &bootstrap.Handlers{
|
||||
Account: accountHandler,
|
||||
Role: roleHandler,
|
||||
Permission: permHandler,
|
||||
}
|
||||
routes.RegisterRoutes(app, services)
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/middleware"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/errors"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/logger"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/middleware"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/response"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/validator"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@@ -52,7 +52,16 @@ func setupAuthTestApp(t *testing.T, rdb *redis.Client) *fiber.App {
|
||||
|
||||
// Add authentication middleware
|
||||
tokenValidator := validator.NewTokenValidator(rdb, logger.GetAppLogger())
|
||||
app.Use(middleware.KeyAuth(tokenValidator, logger.GetAppLogger()))
|
||||
app.Use(middleware.Auth(middleware.AuthConfig{
|
||||
TokenValidator: func(token string) (uint, int, uint, error) {
|
||||
_, err := tokenValidator.Validate(token)
|
||||
if err != nil {
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
// 测试中简化处理:userID 设为 1,userType 设为普通用户
|
||||
return 1, 0, 0, nil
|
||||
},
|
||||
}))
|
||||
|
||||
// Add protected test routes
|
||||
app.Get("/api/v1/test", func(c *fiber.Ctx) error {
|
||||
@@ -342,14 +351,23 @@ func TestKeyAuthMiddleware_UserIDPropagation(t *testing.T) {
|
||||
|
||||
// Add authentication middleware
|
||||
tokenValidator := validator.NewTokenValidator(rdb, logger.GetAppLogger())
|
||||
app.Use(middleware.KeyAuth(tokenValidator, logger.GetAppLogger()))
|
||||
app.Use(middleware.Auth(middleware.AuthConfig{
|
||||
TokenValidator: func(token string) (uint, int, uint, error) {
|
||||
_, err := tokenValidator.Validate(token)
|
||||
if err != nil {
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
// 测试中简化处理:userID 设为 1,userType 设为普通用户
|
||||
return 1, 0, 0, nil
|
||||
},
|
||||
}))
|
||||
|
||||
// Add test route that checks user ID
|
||||
var capturedUserID string
|
||||
var capturedUserID uint
|
||||
app.Get("/api/v1/check-user", func(c *fiber.Ctx) error {
|
||||
userID, ok := c.Locals(constants.ContextKeyUserID).(string)
|
||||
userID, ok := c.Locals(constants.ContextKeyUserID).(uint)
|
||||
if !ok {
|
||||
return response.Error(c, 500, errors.CodeInternalError, "User ID not found in context")
|
||||
return errors.New(errors.CodeInternalError, "User ID not found in context")
|
||||
}
|
||||
capturedUserID = userID
|
||||
return response.Success(c, fiber.Map{
|
||||
|
||||
@@ -1,325 +0,0 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/middleware"
|
||||
)
|
||||
|
||||
// TestDataPermission_HierarchicalFiltering 测试层级数据权限过滤
|
||||
func TestDataPermission_HierarchicalFiltering(t *testing.T) {
|
||||
store, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
db := store.DB()
|
||||
|
||||
// 创建层级结构: A -> B -> C
|
||||
accountA := &model.Account{
|
||||
Username: "user_a",
|
||||
Phone: "13800000001",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypePlatform,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(accountA).Error)
|
||||
|
||||
shopID := uint(100)
|
||||
accountA.ShopID = &shopID
|
||||
require.NoError(t, db.Save(accountA).Error)
|
||||
|
||||
accountB := &model.Account{
|
||||
Username: "user_b",
|
||||
Phone: "13800000002",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypeAgent,
|
||||
ParentID: &accountA.ID,
|
||||
ShopID: &shopID,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(accountB).Error)
|
||||
|
||||
accountC := &model.Account{
|
||||
Username: "user_c",
|
||||
Phone: "13800000003",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypeEnterprise,
|
||||
ParentID: &accountB.ID,
|
||||
ShopID: &shopID,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(accountC).Error)
|
||||
|
||||
// 创建测试数据表
|
||||
type TestData struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
Name string
|
||||
OwnerID uint
|
||||
ShopID uint
|
||||
}
|
||||
require.NoError(t, db.AutoMigrate(&TestData{}))
|
||||
|
||||
// 插入测试数据
|
||||
testData := []TestData{
|
||||
{Name: "data_a", OwnerID: accountA.ID, ShopID: 100},
|
||||
{Name: "data_b", OwnerID: accountB.ID, ShopID: 100},
|
||||
{Name: "data_c", OwnerID: accountC.ID, ShopID: 100},
|
||||
{Name: "data_other", OwnerID: 999, ShopID: 100},
|
||||
}
|
||||
require.NoError(t, db.Create(&testData).Error)
|
||||
|
||||
// 创建 AccountStore 用于递归查询
|
||||
accountStore := postgres.NewAccountStore(db, nil) // Redis 可选
|
||||
|
||||
t.Run("A 用户可以看到 A、B、C 的数据", func(t *testing.T) {
|
||||
ctxWithA := middleware.SetUserContext(ctx, accountA.ID, constants.UserTypePlatform, 100)
|
||||
var results []TestData
|
||||
err := db.WithContext(ctxWithA).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithA, accountStore)).
|
||||
Find(&results).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 3)
|
||||
})
|
||||
|
||||
t.Run("B 用户可以看到 B、C 的数据", func(t *testing.T) {
|
||||
ctxWithB := middleware.SetUserContext(ctx, accountB.ID, constants.UserTypeAgent, 100)
|
||||
var results []TestData
|
||||
err := db.WithContext(ctxWithB).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithB, accountStore)).
|
||||
Find(&results).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
})
|
||||
|
||||
t.Run("C 用户只能看到自己的数据", func(t *testing.T) {
|
||||
ctxWithC := middleware.SetUserContext(ctx, accountC.ID, constants.UserTypeEnterprise, 100)
|
||||
var results []TestData
|
||||
err := db.WithContext(ctxWithC).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithC, accountStore)).
|
||||
Find(&results).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, "data_c", results[0].Name)
|
||||
})
|
||||
}
|
||||
|
||||
// TestDataPermission_WithoutDataFilter 测试 WithoutDataFilter 选项
|
||||
func TestDataPermission_WithoutDataFilter(t *testing.T) {
|
||||
store, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
db := store.DB()
|
||||
|
||||
// 创建测试账号
|
||||
shopID := uint(100)
|
||||
accountA := &model.Account{
|
||||
Username: "user_a",
|
||||
Phone: "13800000001",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypePlatform,
|
||||
ShopID: &shopID,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(accountA).Error)
|
||||
|
||||
// 创建测试数据表
|
||||
type TestData struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
Name string
|
||||
OwnerID uint
|
||||
ShopID uint
|
||||
}
|
||||
require.NoError(t, db.AutoMigrate(&TestData{}))
|
||||
|
||||
// 插入测试数据
|
||||
testData := []TestData{
|
||||
{Name: "data_a", OwnerID: accountA.ID, ShopID: 100},
|
||||
{Name: "data_b", OwnerID: 999, ShopID: 100},
|
||||
{Name: "data_c", OwnerID: 888, ShopID: 200},
|
||||
}
|
||||
require.NoError(t, db.Create(&testData).Error)
|
||||
|
||||
// 创建 AccountStore
|
||||
accountStore := postgres.NewAccountStore(db, nil)
|
||||
|
||||
t.Run("正常查询应该过滤数据", func(t *testing.T) {
|
||||
ctxWithA := middleware.SetUserContext(ctx, accountA.ID, constants.UserTypePlatform, 100)
|
||||
var results []TestData
|
||||
err := db.WithContext(ctxWithA).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithA, accountStore)).
|
||||
Find(&results).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
})
|
||||
|
||||
t.Run("不带用户上下文时返回空数据", func(t *testing.T) {
|
||||
var results []TestData
|
||||
err := db.WithContext(ctx).
|
||||
Scopes(postgres.DataPermissionScope(ctx, accountStore)).
|
||||
Find(&results).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 0)
|
||||
})
|
||||
}
|
||||
|
||||
// TestDataPermission_CrossShopIsolation 测试跨店铺数据隔离
|
||||
func TestDataPermission_CrossShopIsolation(t *testing.T) {
|
||||
store, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
db := store.DB()
|
||||
|
||||
// 创建两个不同店铺的账号
|
||||
shopID100 := uint(100)
|
||||
accountA := &model.Account{
|
||||
Username: "user_shop100",
|
||||
Phone: "13800000001",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypePlatform,
|
||||
ShopID: &shopID100,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(accountA).Error)
|
||||
|
||||
shopID200 := uint(200)
|
||||
accountB := &model.Account{
|
||||
Username: "user_shop200",
|
||||
Phone: "13800000002",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypePlatform,
|
||||
ShopID: &shopID200,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(accountB).Error)
|
||||
|
||||
// 创建测试数据表
|
||||
type TestData struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
Name string
|
||||
OwnerID uint
|
||||
ShopID uint
|
||||
}
|
||||
require.NoError(t, db.AutoMigrate(&TestData{}))
|
||||
|
||||
// 插入不同店铺的数据
|
||||
testData := []TestData{
|
||||
{Name: "data_shop100_1", OwnerID: accountA.ID, ShopID: 100},
|
||||
{Name: "data_shop100_2", OwnerID: accountA.ID, ShopID: 100},
|
||||
{Name: "data_shop200_1", OwnerID: accountB.ID, ShopID: 200},
|
||||
{Name: "data_shop200_2", OwnerID: accountB.ID, ShopID: 200},
|
||||
}
|
||||
require.NoError(t, db.Create(&testData).Error)
|
||||
|
||||
// 创建 AccountStore
|
||||
accountStore := postgres.NewAccountStore(db, nil)
|
||||
|
||||
t.Run("店铺 100 用户只能看到店铺 100 的数据", func(t *testing.T) {
|
||||
ctxWithA := middleware.SetUserContext(ctx, accountA.ID, constants.UserTypePlatform, 100)
|
||||
var results []TestData
|
||||
err := db.WithContext(ctxWithA).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithA, accountStore)).
|
||||
Find(&results).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
for _, r := range results {
|
||||
assert.Equal(t, uint(100), r.ShopID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("店铺 200 用户只能看到店铺 200 的数据", func(t *testing.T) {
|
||||
ctxWithB := middleware.SetUserContext(ctx, accountB.ID, constants.UserTypePlatform, 200)
|
||||
var results []TestData
|
||||
err := db.WithContext(ctxWithB).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithB, accountStore)).
|
||||
Find(&results).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
for _, r := range results {
|
||||
assert.Equal(t, uint(200), r.ShopID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("店铺 100 用户看不到店铺 200 的数据", func(t *testing.T) {
|
||||
ctxWithA := middleware.SetUserContext(ctx, accountA.ID, constants.UserTypePlatform, 100)
|
||||
var results []TestData
|
||||
err := db.WithContext(ctxWithA).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithA, accountStore)).
|
||||
Where("shop_id = ?", 200). // 尝试查询店铺 200 的数据
|
||||
Find(&results).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 0, "不应该看到其他店铺的数据")
|
||||
})
|
||||
}
|
||||
|
||||
// TestDataPermission_RootUserBypass 测试 root 用户跳过数据权限过滤
|
||||
func TestDataPermission_RootUserBypass(t *testing.T) {
|
||||
store, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
db := store.DB()
|
||||
|
||||
// 创建 root 用户
|
||||
rootUser := &model.Account{
|
||||
Username: "root_user",
|
||||
Phone: "13800000000",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypeRoot,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(rootUser).Error)
|
||||
|
||||
// 创建测试数据表
|
||||
type TestData struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
Name string
|
||||
OwnerID uint
|
||||
ShopID uint
|
||||
}
|
||||
require.NoError(t, db.AutoMigrate(&TestData{}))
|
||||
|
||||
// 插入不同店铺、不同用户的数据
|
||||
testData := []TestData{
|
||||
{Name: "data_1", OwnerID: 1, ShopID: 100},
|
||||
{Name: "data_2", OwnerID: 2, ShopID: 200},
|
||||
{Name: "data_3", OwnerID: 3, ShopID: 300},
|
||||
{Name: "data_4", OwnerID: 4, ShopID: 400},
|
||||
}
|
||||
require.NoError(t, db.Create(&testData).Error)
|
||||
|
||||
// 创建 AccountStore
|
||||
accountStore := postgres.NewAccountStore(db, nil)
|
||||
|
||||
t.Run("root 用户可以看到所有数据", func(t *testing.T) {
|
||||
ctxWithRoot := middleware.SetUserContext(ctx, rootUser.ID, constants.UserTypeRoot, 100)
|
||||
var results []TestData
|
||||
err := db.WithContext(ctxWithRoot).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithRoot, accountStore)).
|
||||
Find(&results).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 4, "root 用户应该看到所有数据")
|
||||
})
|
||||
}
|
||||
@@ -1,489 +0,0 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
testcontainers_postgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.uber.org/zap"
|
||||
postgresDriver "gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// TestMain 设置测试环境
|
||||
func TestMain(m *testing.M) {
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// setupTestDB 启动 PostgreSQL 容器并使用迁移脚本初始化数据库
|
||||
func setupTestDB(t *testing.T) (*postgres.Store, func()) {
|
||||
ctx := context.Background()
|
||||
|
||||
// 启动 PostgreSQL 容器
|
||||
postgresContainer, err := testcontainers_postgres.RunContainer(ctx,
|
||||
testcontainers.WithImage("postgres:14-alpine"),
|
||||
testcontainers_postgres.WithDatabase("testdb"),
|
||||
testcontainers_postgres.WithUsername("postgres"),
|
||||
testcontainers_postgres.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(30*time.Second),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err, "启动 PostgreSQL 容器失败")
|
||||
|
||||
// 获取连接字符串
|
||||
connStr, err := postgresContainer.ConnectionString(ctx, "sslmode=disable")
|
||||
require.NoError(t, err, "获取数据库连接字符串失败")
|
||||
|
||||
// 应用数据库迁移
|
||||
migrationsPath := getMigrationsPath(t)
|
||||
m, err := migrate.New(
|
||||
fmt.Sprintf("file://%s", migrationsPath),
|
||||
connStr,
|
||||
)
|
||||
require.NoError(t, err, "创建迁移实例失败")
|
||||
|
||||
// 执行向上迁移
|
||||
err = m.Up()
|
||||
require.NoError(t, err, "执行数据库迁移失败")
|
||||
|
||||
// 连接数据库
|
||||
db, err := gorm.Open(postgresDriver.Open(connStr), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
require.NoError(t, err, "连接数据库失败")
|
||||
|
||||
// 创建测试 logger
|
||||
testLogger := zap.NewNop()
|
||||
store := postgres.NewStore(db, testLogger)
|
||||
|
||||
// 返回清理函数
|
||||
cleanup := func() {
|
||||
// 执行向下迁移清理数据
|
||||
if err := m.Down(); err != nil && err != migrate.ErrNoChange {
|
||||
t.Logf("清理迁移失败: %v", err)
|
||||
}
|
||||
_, _ = m.Close()
|
||||
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
if err := postgresContainer.Terminate(ctx); err != nil {
|
||||
t.Logf("终止容器失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return store, cleanup
|
||||
}
|
||||
|
||||
// getMigrationsPath 获取迁移文件路径
|
||||
func getMigrationsPath(t *testing.T) string {
|
||||
// 获取项目根目录
|
||||
wd, err := os.Getwd()
|
||||
require.NoError(t, err, "获取工作目录失败")
|
||||
|
||||
// 从测试目录向上找到项目根目录
|
||||
migrationsPath := filepath.Join(wd, "..", "..", "migrations")
|
||||
|
||||
// 验证迁移目录存在
|
||||
_, err = os.Stat(migrationsPath)
|
||||
require.NoError(t, err, fmt.Sprintf("迁移目录不存在: %s", migrationsPath))
|
||||
|
||||
return migrationsPath
|
||||
}
|
||||
|
||||
// TestUserCRUD 测试用户 CRUD 操作
|
||||
func TestUserCRUD(t *testing.T) {
|
||||
store, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("创建用户", func(t *testing.T) {
|
||||
user := &model.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
|
||||
err := store.User.Create(ctx, user)
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, user.ID)
|
||||
assert.NotZero(t, user.CreatedAt)
|
||||
assert.NotZero(t, user.UpdatedAt)
|
||||
})
|
||||
|
||||
t.Run("根据ID查询用户", func(t *testing.T) {
|
||||
// 创建测试用户
|
||||
user := &model.User{
|
||||
Username: "queryuser",
|
||||
Email: "query@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 查询用户
|
||||
found, err := store.User.GetByID(ctx, user.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, user.Username, found.Username)
|
||||
assert.Equal(t, user.Email, found.Email)
|
||||
assert.Equal(t, constants.UserStatusActive, found.Status)
|
||||
})
|
||||
|
||||
t.Run("根据用户名查询用户", func(t *testing.T) {
|
||||
// 创建测试用户
|
||||
user := &model.User{
|
||||
Username: "findbyname",
|
||||
Email: "findbyname@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 根据用户名查询
|
||||
found, err := store.User.GetByUsername(ctx, "findbyname")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.ID)
|
||||
assert.Equal(t, user.Email, found.Email)
|
||||
})
|
||||
|
||||
t.Run("更新用户", func(t *testing.T) {
|
||||
// 创建测试用户
|
||||
user := &model.User{
|
||||
Username: "updateuser",
|
||||
Email: "update@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 更新用户
|
||||
user.Email = "newemail@example.com"
|
||||
user.Status = constants.UserStatusInactive
|
||||
err = store.User.Update(ctx, user)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证更新
|
||||
found, err := store.User.GetByID(ctx, user.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "newemail@example.com", found.Email)
|
||||
assert.Equal(t, constants.UserStatusInactive, found.Status)
|
||||
})
|
||||
|
||||
t.Run("列表查询用户", func(t *testing.T) {
|
||||
// 创建多个测试用户
|
||||
for i := 1; i <= 5; i++ {
|
||||
user := &model.User{
|
||||
Username: fmt.Sprintf("listuser%d", i),
|
||||
Email: fmt.Sprintf("list%d@example.com", i),
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// 列表查询
|
||||
users, total, err := store.User.List(ctx, 1, 3)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(users), 3)
|
||||
assert.GreaterOrEqual(t, total, int64(5))
|
||||
})
|
||||
|
||||
t.Run("软删除用户", func(t *testing.T) {
|
||||
// 创建测试用户
|
||||
user := &model.User{
|
||||
Username: "deleteuser",
|
||||
Email: "delete@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 软删除
|
||||
err = store.User.Delete(ctx, user.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证已删除(查询应该找不到)
|
||||
_, err = store.User.GetByID(ctx, user.ID)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, gorm.ErrRecordNotFound, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestOrderCRUD 测试订单 CRUD 操作
|
||||
func TestOrderCRUD(t *testing.T) {
|
||||
store, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建测试用户
|
||||
user := &model.User{
|
||||
Username: "orderuser",
|
||||
Email: "orderuser@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("创建订单", func(t *testing.T) {
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-001",
|
||||
UserID: user.ID,
|
||||
Amount: 10000,
|
||||
Status: constants.OrderStatusPending,
|
||||
Remark: "测试订单",
|
||||
}
|
||||
|
||||
err := store.Order.Create(ctx, order)
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, order.ID)
|
||||
assert.NotZero(t, order.CreatedAt)
|
||||
})
|
||||
|
||||
t.Run("根据ID查询订单", func(t *testing.T) {
|
||||
// 创建测试订单
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-002",
|
||||
UserID: user.ID,
|
||||
Amount: 20000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err := store.Order.Create(ctx, order)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 查询订单
|
||||
found, err := store.Order.GetByID(ctx, order.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, order.OrderID, found.OrderID)
|
||||
assert.Equal(t, order.Amount, found.Amount)
|
||||
})
|
||||
|
||||
t.Run("根据订单号查询", func(t *testing.T) {
|
||||
// 创建测试订单
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-003",
|
||||
UserID: user.ID,
|
||||
Amount: 30000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err := store.Order.Create(ctx, order)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 根据订单号查询
|
||||
found, err := store.Order.GetByOrderID(ctx, "ORD-003")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, order.ID, found.ID)
|
||||
})
|
||||
|
||||
t.Run("根据用户ID列表查询", func(t *testing.T) {
|
||||
// 创建多个订单
|
||||
for i := 1; i <= 3; i++ {
|
||||
order := &model.Order{
|
||||
OrderID: fmt.Sprintf("ORD-USER-%d", i),
|
||||
UserID: user.ID,
|
||||
Amount: int64(i * 10000),
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err := store.Order.Create(ctx, order)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// 列表查询
|
||||
orders, total, err := store.Order.ListByUserID(ctx, user.ID, 1, 10)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(orders), 3)
|
||||
assert.GreaterOrEqual(t, total, int64(3))
|
||||
})
|
||||
|
||||
t.Run("更新订单状态", func(t *testing.T) {
|
||||
// 创建测试订单
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-UPDATE",
|
||||
UserID: user.ID,
|
||||
Amount: 50000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err := store.Order.Create(ctx, order)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 更新状态
|
||||
now := time.Now()
|
||||
order.Status = constants.OrderStatusPaid
|
||||
order.PaidAt = &now
|
||||
err = store.Order.Update(ctx, order)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证更新
|
||||
found, err := store.Order.GetByID(ctx, order.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, constants.OrderStatusPaid, found.Status)
|
||||
assert.NotNil(t, found.PaidAt)
|
||||
})
|
||||
|
||||
t.Run("软删除订单", func(t *testing.T) {
|
||||
// 创建测试订单
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-DELETE",
|
||||
UserID: user.ID,
|
||||
Amount: 60000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err := store.Order.Create(ctx, order)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 软删除
|
||||
err = store.Order.Delete(ctx, order.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证已删除
|
||||
_, err = store.Order.GetByID(ctx, order.ID)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, gorm.ErrRecordNotFound, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestTransaction 测试事务功能
|
||||
func TestTransaction(t *testing.T) {
|
||||
store, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("事务提交", func(t *testing.T) {
|
||||
var userID uint
|
||||
var orderID uint
|
||||
|
||||
err := store.Transaction(ctx, func(tx *postgres.Store) error {
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
Username: "txuser",
|
||||
Email: "txuser@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
if err := tx.User.Create(ctx, user); err != nil {
|
||||
return err
|
||||
}
|
||||
userID = user.ID
|
||||
|
||||
// 创建订单
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-TX-001",
|
||||
UserID: user.ID,
|
||||
Amount: 10000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
if err := tx.Order.Create(ctx, order); err != nil {
|
||||
return err
|
||||
}
|
||||
orderID = order.ID
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证用户和订单都已创建
|
||||
user, err := store.User.GetByID(ctx, userID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "txuser", user.Username)
|
||||
|
||||
order, err := store.Order.GetByID(ctx, orderID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ORD-TX-001", order.OrderID)
|
||||
})
|
||||
|
||||
t.Run("事务回滚", func(t *testing.T) {
|
||||
var userID uint
|
||||
|
||||
err := store.Transaction(ctx, func(tx *postgres.Store) error {
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
Username: "rollbackuser",
|
||||
Email: "rollback@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
if err := tx.User.Create(ctx, user); err != nil {
|
||||
return err
|
||||
}
|
||||
userID = user.ID
|
||||
|
||||
// 模拟错误,触发回滚
|
||||
return fmt.Errorf("模拟错误")
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "模拟错误", err.Error())
|
||||
|
||||
// 验证用户未创建(已回滚)
|
||||
_, err = store.User.GetByID(ctx, userID)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, gorm.ErrRecordNotFound, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestConcurrentOperations 测试并发操作
|
||||
func TestConcurrentOperations(t *testing.T) {
|
||||
store, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("并发创建用户", func(t *testing.T) {
|
||||
concurrency := 10
|
||||
errChan := make(chan error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(index int) {
|
||||
user := &model.User{
|
||||
Username: fmt.Sprintf("concurrent%d", index),
|
||||
Email: fmt.Sprintf("concurrent%d@example.com", index),
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
errChan <- store.User.Create(ctx, user)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 收集结果
|
||||
successCount := 0
|
||||
for i := 0; i < concurrency; i++ {
|
||||
err := <-errChan
|
||||
if err == nil {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, concurrency, successCount, "所有并发创建应该成功")
|
||||
})
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/tests/testutils"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
@@ -49,7 +50,7 @@ func TestMigration_UpAndDown(t *testing.T) {
|
||||
require.NoError(t, err, "获取数据库连接字符串失败")
|
||||
|
||||
// 应用数据库迁移
|
||||
migrationsPath := getMigrationsPath(t)
|
||||
migrationsPath := testutils.GetMigrationsPath()
|
||||
m, err := migrate.New(
|
||||
fmt.Sprintf("file://%s", migrationsPath),
|
||||
connStr,
|
||||
@@ -135,7 +136,7 @@ func TestMigration_UpAndDown(t *testing.T) {
|
||||
// TestMigration_NoForeignKeys 验证迁移脚本不包含外键约束
|
||||
func TestMigration_NoForeignKeys(t *testing.T) {
|
||||
// 获取迁移目录
|
||||
migrationsPath := getMigrationsPath(t)
|
||||
migrationsPath := testutils.GetMigrationsPath()
|
||||
|
||||
// 读取所有迁移文件
|
||||
files, err := filepath.Glob(filepath.Join(migrationsPath, "*.up.sql"))
|
||||
@@ -187,7 +188,7 @@ func TestMigration_SoftDeleteSupport(t *testing.T) {
|
||||
require.NoError(t, err, "获取数据库连接字符串失败")
|
||||
|
||||
// 应用迁移
|
||||
migrationsPath := getMigrationsPath(t)
|
||||
migrationsPath := testutils.GetMigrationsPath()
|
||||
m, err := migrate.New(
|
||||
fmt.Sprintf("file://%s", migrationsPath),
|
||||
connStr,
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/bootstrap"
|
||||
"github.com/break/junhong_cmp_fiber/internal/handler"
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/internal/routes"
|
||||
@@ -90,8 +91,8 @@ func setupPermTestEnv(t *testing.T) *permTestEnv {
|
||||
})
|
||||
|
||||
// 注册路由
|
||||
services := &routes.Services{
|
||||
PermissionHandler: permHandler,
|
||||
services := &bootstrap.Handlers{
|
||||
Permission: permHandler,
|
||||
}
|
||||
routes.RegisterRoutes(app, services)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/bootstrap"
|
||||
"github.com/break/junhong_cmp_fiber/internal/handler"
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/internal/routes"
|
||||
@@ -116,8 +117,8 @@ func setupRoleTestEnv(t *testing.T) *roleTestEnv {
|
||||
})
|
||||
|
||||
// 注册路由
|
||||
services := &routes.Services{
|
||||
RoleHandler: roleHandler,
|
||||
services := &bootstrap.Handlers{
|
||||
Role: roleHandler,
|
||||
}
|
||||
routes.RegisterRoutes(app, services)
|
||||
|
||||
|
||||
39
tests/testutils/helpers.go
Normal file
39
tests/testutils/helpers.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SetupTestDBWithStore 设置测试数据库并返回 AccountStore 和 cleanup 函数
|
||||
// 用于需要 store 接口的集成测试
|
||||
func SetupTestDBWithStore(t *testing.T) (*gorm.DB, func()) {
|
||||
t.Helper()
|
||||
|
||||
db, redisClient := SetupTestDB(t)
|
||||
|
||||
cleanup := func() {
|
||||
TeardownTestDB(t, db, redisClient)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
}
|
||||
|
||||
// GetMigrationsPath 获取数据库迁移文件的路径
|
||||
// 返回项目根目录下的 migrations 目录路径
|
||||
func GetMigrationsPath() string {
|
||||
// 获取当前文件路径
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
panic("无法获取当前文件路径")
|
||||
}
|
||||
|
||||
// 从 tests/testutils/helpers.go 向上两级到项目根目录
|
||||
projectRoot := filepath.Join(filepath.Dir(filename), "..", "..")
|
||||
migrationsPath := filepath.Join(projectRoot, "migrations")
|
||||
|
||||
return migrationsPath
|
||||
}
|
||||
@@ -1,303 +0,0 @@
|
||||
package unit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/middleware"
|
||||
"github.com/break/junhong_cmp_fiber/tests/testutils"
|
||||
)
|
||||
|
||||
// TestDataPermissionScope_RootUser 测试 root 用户跳过数据权限过滤
|
||||
func TestDataPermissionScope_RootUser(t *testing.T) {
|
||||
db, redisClient := testutils.SetupTestDB(t)
|
||||
defer testutils.TeardownTestDB(t, db, redisClient)
|
||||
|
||||
accountStore := postgres.NewAccountStore(db, redisClient)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建 root 用户
|
||||
rootUser := &model.Account{
|
||||
Username: "root_user",
|
||||
Phone: "13800000000",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypeRoot,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(rootUser).Error)
|
||||
|
||||
// 创建测试数据表(模拟业务表)
|
||||
type TestData struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
Name string
|
||||
OwnerID uint
|
||||
ShopID uint
|
||||
}
|
||||
require.NoError(t, db.AutoMigrate(&TestData{}))
|
||||
|
||||
// 插入测试数据(不同的 owner_id 和 shop_id)
|
||||
testData := []TestData{
|
||||
{Name: "data1", OwnerID: 1, ShopID: 100},
|
||||
{Name: "data2", OwnerID: 2, ShopID: 200},
|
||||
{Name: "data3", OwnerID: 3, ShopID: 300},
|
||||
}
|
||||
require.NoError(t, db.Create(&testData).Error)
|
||||
|
||||
// 设置 root 用户上下文
|
||||
ctxWithRoot := middleware.SetUserContext(ctx, rootUser.ID, constants.UserTypeRoot, 100)
|
||||
|
||||
// 查询(应该返回所有数据,不过滤)
|
||||
var results []TestData
|
||||
err := db.WithContext(ctxWithRoot).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithRoot, accountStore)).
|
||||
Find(&results).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 3, "root 用户应该看到所有数据")
|
||||
}
|
||||
|
||||
// TestDataPermissionScope_NormalUser 测试普通用户数据权限过滤
|
||||
func TestDataPermissionScope_NormalUser(t *testing.T) {
|
||||
db, redisClient := testutils.SetupTestDB(t)
|
||||
defer testutils.TeardownTestDB(t, db, redisClient)
|
||||
|
||||
accountStore := postgres.NewAccountStore(db, redisClient)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建账号层级: A -> B
|
||||
accountA := &model.Account{
|
||||
Username: "user_a",
|
||||
Phone: "13800000001",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypePlatform,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(accountA).Error)
|
||||
|
||||
shopIDA := uint(100)
|
||||
accountA.ShopID = &shopIDA
|
||||
require.NoError(t, db.Save(accountA).Error)
|
||||
|
||||
accountB := &model.Account{
|
||||
Username: "user_b",
|
||||
Phone: "13800000002",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypeAgent,
|
||||
ParentID: &accountA.ID,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(accountB).Error)
|
||||
|
||||
shopIDB := uint(100)
|
||||
accountB.ShopID = &shopIDB
|
||||
require.NoError(t, db.Save(accountB).Error)
|
||||
|
||||
// 创建测试数据表
|
||||
type TestData struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
Name string
|
||||
OwnerID uint
|
||||
ShopID uint
|
||||
}
|
||||
require.NoError(t, db.AutoMigrate(&TestData{}))
|
||||
|
||||
// 插入测试数据
|
||||
testData := []TestData{
|
||||
{Name: "data_a", OwnerID: accountA.ID, ShopID: 100}, // A 的数据
|
||||
{Name: "data_b", OwnerID: accountB.ID, ShopID: 100}, // B 的数据
|
||||
{Name: "data_c", OwnerID: 999, ShopID: 100}, // 其他用户数据(同店铺)
|
||||
{Name: "data_d", OwnerID: accountA.ID, ShopID: 200}, // A 的数据(不同店铺)
|
||||
}
|
||||
require.NoError(t, db.Create(&testData).Error)
|
||||
|
||||
// A 登录查询(应该看到 A 和 B 的数据,同店铺)
|
||||
ctxWithA := middleware.SetUserContext(ctx, accountA.ID, constants.UserTypePlatform, 100)
|
||||
var resultsA []TestData
|
||||
err := db.WithContext(ctxWithA).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithA, accountStore)).
|
||||
Find(&resultsA).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resultsA, 2, "A 应该看到自己和下级 B 的数据")
|
||||
|
||||
// B 登录查询(只能看到自己的数据)
|
||||
ctxWithB := middleware.SetUserContext(ctx, accountB.ID, constants.UserTypeAgent, 100)
|
||||
var resultsB []TestData
|
||||
err = db.WithContext(ctxWithB).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithB, accountStore)).
|
||||
Find(&resultsB).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resultsB, 1, "B 只能看到自己的数据")
|
||||
assert.Equal(t, "data_b", resultsB[0].Name)
|
||||
}
|
||||
|
||||
// TestDataPermissionScope_ShopIsolation 测试店铺隔离
|
||||
func TestDataPermissionScope_ShopIsolation(t *testing.T) {
|
||||
db, redisClient := testutils.SetupTestDB(t)
|
||||
defer testutils.TeardownTestDB(t, db, redisClient)
|
||||
|
||||
accountStore := postgres.NewAccountStore(db, redisClient)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建两个账号(同一层级,不同店铺)
|
||||
shopID100 := uint(100)
|
||||
accountA := &model.Account{
|
||||
Username: "user_a",
|
||||
Phone: "13800000001",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypePlatform,
|
||||
ShopID: &shopID100,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(accountA).Error)
|
||||
|
||||
shopID200 := uint(200)
|
||||
accountB := &model.Account{
|
||||
Username: "user_b",
|
||||
Phone: "13800000002",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypePlatform,
|
||||
ShopID: &shopID200,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(accountB).Error)
|
||||
|
||||
// 创建测试数据表
|
||||
type TestData struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
Name string
|
||||
OwnerID uint
|
||||
ShopID uint
|
||||
}
|
||||
require.NoError(t, db.AutoMigrate(&TestData{}))
|
||||
|
||||
// 插入测试数据
|
||||
testData := []TestData{
|
||||
{Name: "data_shop100", OwnerID: accountA.ID, ShopID: 100},
|
||||
{Name: "data_shop200", OwnerID: accountB.ID, ShopID: 200},
|
||||
}
|
||||
require.NoError(t, db.Create(&testData).Error)
|
||||
|
||||
// A 登录查询(只能看到店铺 100 的数据)
|
||||
ctxWithA := middleware.SetUserContext(ctx, accountA.ID, constants.UserTypePlatform, 100)
|
||||
var resultsA []TestData
|
||||
err := db.WithContext(ctxWithA).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithA, accountStore)).
|
||||
Find(&resultsA).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resultsA, 1, "A 只能看到店铺 100 的数据")
|
||||
assert.Equal(t, "data_shop100", resultsA[0].Name)
|
||||
|
||||
// B 登录查询(只能看到店铺 200 的数据)
|
||||
ctxWithB := middleware.SetUserContext(ctx, accountB.ID, constants.UserTypePlatform, 200)
|
||||
var resultsB []TestData
|
||||
err = db.WithContext(ctxWithB).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithB, accountStore)).
|
||||
Find(&resultsB).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resultsB, 1, "B 只能看到店铺 200 的数据")
|
||||
assert.Equal(t, "data_shop200", resultsB[0].Name)
|
||||
}
|
||||
|
||||
// TestDataPermissionScope_NoUserContext 测试无用户上下文时不过滤
|
||||
func TestDataPermissionScope_NoUserContext(t *testing.T) {
|
||||
db, redisClient := testutils.SetupTestDB(t)
|
||||
defer testutils.TeardownTestDB(t, db, redisClient)
|
||||
|
||||
accountStore := postgres.NewAccountStore(db, redisClient)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建测试数据表
|
||||
type TestData struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
Name string
|
||||
OwnerID uint
|
||||
ShopID uint
|
||||
}
|
||||
require.NoError(t, db.AutoMigrate(&TestData{}))
|
||||
|
||||
// 插入测试数据
|
||||
testData := []TestData{
|
||||
{Name: "data1", OwnerID: 1, ShopID: 100},
|
||||
{Name: "data2", OwnerID: 2, ShopID: 200},
|
||||
}
|
||||
require.NoError(t, db.Create(&testData).Error)
|
||||
|
||||
// 使用没有用户信息的上下文查询(不过滤,可能是系统任务)
|
||||
var results []TestData
|
||||
err := db.WithContext(ctx).
|
||||
Scopes(postgres.DataPermissionScope(ctx, accountStore)).
|
||||
Find(&results).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 0, "无用户上下文时应该返回空数据(根据 scopes.go 的实现)")
|
||||
}
|
||||
|
||||
// TestDataPermissionScope_ErrorHandling 测试查询下级 ID 失败时的降级策略
|
||||
func TestDataPermissionScope_ErrorHandling(t *testing.T) {
|
||||
db, redisClient := testutils.SetupTestDB(t)
|
||||
defer testutils.TeardownTestDB(t, db, redisClient)
|
||||
|
||||
accountStore := postgres.NewAccountStore(db, redisClient)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建测试账号
|
||||
accountA := &model.Account{
|
||||
Username: "user_a",
|
||||
Phone: "13800000001",
|
||||
Password: "hashed_password",
|
||||
UserType: constants.UserTypePlatform,
|
||||
Status: constants.StatusEnabled,
|
||||
Creator: 1,
|
||||
Updater: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(accountA).Error)
|
||||
|
||||
shopIDA := uint(100)
|
||||
accountA.ShopID = &shopIDA
|
||||
require.NoError(t, db.Save(accountA).Error)
|
||||
|
||||
// 创建测试数据表
|
||||
type TestData struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
Name string
|
||||
OwnerID uint
|
||||
ShopID uint
|
||||
}
|
||||
require.NoError(t, db.AutoMigrate(&TestData{}))
|
||||
|
||||
// 插入测试数据
|
||||
testData := []TestData{
|
||||
{Name: "data_a", OwnerID: accountA.ID, ShopID: 100},
|
||||
{Name: "data_b", OwnerID: 999, ShopID: 100},
|
||||
}
|
||||
require.NoError(t, db.Create(&testData).Error)
|
||||
|
||||
// 关闭 Redis 连接以模拟错误(递归查询失败)
|
||||
redisClient.Close()
|
||||
|
||||
// 使用 A 的上下文查询(降级策略:只返回自己的数据)
|
||||
ctxWithA := middleware.SetUserContext(ctx, accountA.ID, constants.UserTypePlatform, 100)
|
||||
var resultsA []TestData
|
||||
err := db.WithContext(ctxWithA).
|
||||
Scopes(postgres.DataPermissionScope(ctxWithA, accountStore)).
|
||||
Find(&resultsA).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// 降级策略应该只返回自己的数据
|
||||
assert.Len(t, resultsA, 1, "查询下级 ID 失败时,应该降级为只返回自己的数据")
|
||||
assert.Equal(t, "data_a", resultsA[0].Name)
|
||||
}
|
||||
@@ -1,502 +0,0 @@
|
||||
package unit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestUserValidation 测试用户模型验证
|
||||
func TestUserValidation(t *testing.T) {
|
||||
validate := validator.New()
|
||||
|
||||
t.Run("有效的创建用户请求", func(t *testing.T) {
|
||||
req := &model.CreateUserRequest{
|
||||
Username: "validuser",
|
||||
Email: "valid@example.com",
|
||||
Password: "password123",
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("用户名太短", func(t *testing.T) {
|
||||
req := &model.CreateUserRequest{
|
||||
Username: "ab", // 少于 3 个字符
|
||||
Email: "valid@example.com",
|
||||
Password: "password123",
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("用户名太长", func(t *testing.T) {
|
||||
req := &model.CreateUserRequest{
|
||||
Username: "a123456789012345678901234567890123456789012345678901", // 超过 50 个字符
|
||||
Email: "valid@example.com",
|
||||
Password: "password123",
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("无效的邮箱格式", func(t *testing.T) {
|
||||
req := &model.CreateUserRequest{
|
||||
Username: "validuser",
|
||||
Email: "invalid-email",
|
||||
Password: "password123",
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("密码太短", func(t *testing.T) {
|
||||
req := &model.CreateUserRequest{
|
||||
Username: "validuser",
|
||||
Email: "valid@example.com",
|
||||
Password: "short", // 少于 8 个字符
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("缺少必填字段", func(t *testing.T) {
|
||||
req := &model.CreateUserRequest{
|
||||
Username: "validuser",
|
||||
// 缺少 Email 和 Password
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestUserUpdateValidation 测试用户更新验证
|
||||
func TestUserUpdateValidation(t *testing.T) {
|
||||
validate := validator.New()
|
||||
|
||||
t.Run("有效的更新请求", func(t *testing.T) {
|
||||
email := "newemail@example.com"
|
||||
status := constants.UserStatusActive
|
||||
req := &model.UpdateUserRequest{
|
||||
Email: &email,
|
||||
Status: &status,
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("无效的邮箱格式", func(t *testing.T) {
|
||||
email := "invalid-email"
|
||||
req := &model.UpdateUserRequest{
|
||||
Email: &email,
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("无效的状态值", func(t *testing.T) {
|
||||
status := "invalid_status"
|
||||
req := &model.UpdateUserRequest{
|
||||
Status: &status,
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("空更新请求", func(t *testing.T) {
|
||||
req := &model.UpdateUserRequest{}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.NoError(t, err) // 空更新请求应该是有效的
|
||||
})
|
||||
}
|
||||
|
||||
// TestOrderValidation 测试订单模型验证
|
||||
func TestOrderValidation(t *testing.T) {
|
||||
validate := validator.New()
|
||||
|
||||
t.Run("有效的创建订单请求", func(t *testing.T) {
|
||||
req := &model.CreateOrderRequest{
|
||||
OrderID: "ORD-2025-001",
|
||||
UserID: 1,
|
||||
Amount: 10000,
|
||||
Remark: "测试订单",
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("订单号太短", func(t *testing.T) {
|
||||
req := &model.CreateOrderRequest{
|
||||
OrderID: "ORD-123", // 少于 10 个字符
|
||||
UserID: 1,
|
||||
Amount: 10000,
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("订单号太长", func(t *testing.T) {
|
||||
req := &model.CreateOrderRequest{
|
||||
OrderID: "ORD-12345678901234567890123456789012345678901234567890", // 超过 50 个字符
|
||||
UserID: 1,
|
||||
Amount: 10000,
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("用户ID无效", func(t *testing.T) {
|
||||
req := &model.CreateOrderRequest{
|
||||
OrderID: "ORD-2025-001",
|
||||
UserID: 0, // 用户ID必须大于0
|
||||
Amount: 10000,
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("金额为负数", func(t *testing.T) {
|
||||
req := &model.CreateOrderRequest{
|
||||
OrderID: "ORD-2025-001",
|
||||
UserID: 1,
|
||||
Amount: -1000, // 金额不能为负数
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("缺少必填字段", func(t *testing.T) {
|
||||
req := &model.CreateOrderRequest{
|
||||
OrderID: "ORD-2025-001",
|
||||
// 缺少 UserID 和 Amount
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestOrderUpdateValidation 测试订单更新验证
|
||||
func TestOrderUpdateValidation(t *testing.T) {
|
||||
validate := validator.New()
|
||||
|
||||
t.Run("有效的更新请求", func(t *testing.T) {
|
||||
status := constants.OrderStatusPaid
|
||||
remark := "已支付"
|
||||
req := &model.UpdateOrderRequest{
|
||||
Status: &status,
|
||||
Remark: &remark,
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("无效的状态值", func(t *testing.T) {
|
||||
status := "invalid_status"
|
||||
req := &model.UpdateOrderRequest{
|
||||
Status: &status,
|
||||
}
|
||||
|
||||
err := validate.Struct(req)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestUserModel 测试用户模型
|
||||
func TestUserModel(t *testing.T) {
|
||||
t.Run("创建用户模型", func(t *testing.T) {
|
||||
user := &model.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
|
||||
assert.Equal(t, "testuser", user.Username)
|
||||
assert.Equal(t, "test@example.com", user.Email)
|
||||
assert.Equal(t, constants.UserStatusActive, user.Status)
|
||||
})
|
||||
|
||||
t.Run("用户表名", func(t *testing.T) {
|
||||
user := &model.User{}
|
||||
assert.Equal(t, "tb_user", user.TableName())
|
||||
})
|
||||
|
||||
t.Run("软删除字段", func(t *testing.T) {
|
||||
user := &model.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
|
||||
// DeletedAt 应该是 nil (未删除)
|
||||
assert.True(t, user.DeletedAt.Time.IsZero())
|
||||
})
|
||||
|
||||
t.Run("LastLoginAt 可选字段", func(t *testing.T) {
|
||||
user := &model.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
|
||||
assert.Nil(t, user.LastLoginAt)
|
||||
|
||||
// 设置登录时间
|
||||
now := time.Now()
|
||||
user.LastLoginAt = &now
|
||||
assert.NotNil(t, user.LastLoginAt)
|
||||
assert.Equal(t, now, *user.LastLoginAt)
|
||||
})
|
||||
}
|
||||
|
||||
// TestOrderModel 测试订单模型
|
||||
func TestOrderModel(t *testing.T) {
|
||||
t.Run("创建订单模型", func(t *testing.T) {
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-2025-001",
|
||||
UserID: 1,
|
||||
Amount: 10000,
|
||||
Status: constants.OrderStatusPending,
|
||||
Remark: "测试订单",
|
||||
}
|
||||
|
||||
assert.Equal(t, "ORD-2025-001", order.OrderID)
|
||||
assert.Equal(t, uint(1), order.UserID)
|
||||
assert.Equal(t, int64(10000), order.Amount)
|
||||
assert.Equal(t, constants.OrderStatusPending, order.Status)
|
||||
})
|
||||
|
||||
t.Run("订单表名", func(t *testing.T) {
|
||||
order := &model.Order{}
|
||||
assert.Equal(t, "tb_order", order.TableName())
|
||||
})
|
||||
|
||||
t.Run("可选时间字段", func(t *testing.T) {
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-2025-001",
|
||||
UserID: 1,
|
||||
Amount: 10000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
|
||||
assert.Nil(t, order.PaidAt)
|
||||
assert.Nil(t, order.CompletedAt)
|
||||
|
||||
// 设置支付时间
|
||||
now := time.Now()
|
||||
order.PaidAt = &now
|
||||
assert.NotNil(t, order.PaidAt)
|
||||
assert.Equal(t, now, *order.PaidAt)
|
||||
})
|
||||
}
|
||||
|
||||
// TestBaseModel 测试基础模型
|
||||
func TestBaseModel(t *testing.T) {
|
||||
t.Run("BaseModel 字段", func(t *testing.T) {
|
||||
user := &model.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
|
||||
// ID 应该是 0 (未保存)
|
||||
assert.Zero(t, user.ID)
|
||||
|
||||
// 时间戳应该是零值
|
||||
assert.True(t, user.CreatedAt.IsZero())
|
||||
assert.True(t, user.UpdatedAt.IsZero())
|
||||
})
|
||||
}
|
||||
|
||||
// TestUserStatusConstants 测试用户状态常量
|
||||
func TestUserStatusConstants(t *testing.T) {
|
||||
t.Run("用户状态常量定义", func(t *testing.T) {
|
||||
assert.Equal(t, "active", constants.UserStatusActive)
|
||||
assert.Equal(t, "inactive", constants.UserStatusInactive)
|
||||
assert.Equal(t, "suspended", constants.UserStatusSuspended)
|
||||
})
|
||||
|
||||
t.Run("用户状态验证", func(t *testing.T) {
|
||||
validStatuses := []string{
|
||||
constants.UserStatusActive,
|
||||
constants.UserStatusInactive,
|
||||
constants.UserStatusSuspended,
|
||||
}
|
||||
|
||||
for _, status := range validStatuses {
|
||||
user := &model.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: status,
|
||||
}
|
||||
assert.Contains(t, validStatuses, user.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestOrderStatusConstants 测试订单状态常量
|
||||
func TestOrderStatusConstants(t *testing.T) {
|
||||
t.Run("订单状态常量定义", func(t *testing.T) {
|
||||
assert.Equal(t, "pending", constants.OrderStatusPending)
|
||||
assert.Equal(t, "paid", constants.OrderStatusPaid)
|
||||
assert.Equal(t, "processing", constants.OrderStatusProcessing)
|
||||
assert.Equal(t, "completed", constants.OrderStatusCompleted)
|
||||
assert.Equal(t, "cancelled", constants.OrderStatusCancelled)
|
||||
})
|
||||
|
||||
t.Run("订单状态流转", func(t *testing.T) {
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-2025-001",
|
||||
UserID: 1,
|
||||
Amount: 10000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
|
||||
// 订单状态流转:pending -> paid -> processing -> completed
|
||||
assert.Equal(t, constants.OrderStatusPending, order.Status)
|
||||
|
||||
order.Status = constants.OrderStatusPaid
|
||||
assert.Equal(t, constants.OrderStatusPaid, order.Status)
|
||||
|
||||
order.Status = constants.OrderStatusProcessing
|
||||
assert.Equal(t, constants.OrderStatusProcessing, order.Status)
|
||||
|
||||
order.Status = constants.OrderStatusCompleted
|
||||
assert.Equal(t, constants.OrderStatusCompleted, order.Status)
|
||||
})
|
||||
|
||||
t.Run("订单取消", func(t *testing.T) {
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-2025-002",
|
||||
UserID: 1,
|
||||
Amount: 10000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
|
||||
// 从任何状态都可以取消
|
||||
order.Status = constants.OrderStatusCancelled
|
||||
assert.Equal(t, constants.OrderStatusCancelled, order.Status)
|
||||
})
|
||||
}
|
||||
|
||||
// TestUserResponse 测试用户响应模型
|
||||
func TestUserResponse(t *testing.T) {
|
||||
t.Run("创建用户响应", func(t *testing.T) {
|
||||
now := time.Now()
|
||||
resp := &model.UserResponse{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Status: constants.UserStatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
assert.Equal(t, uint(1), resp.ID)
|
||||
assert.Equal(t, "testuser", resp.Username)
|
||||
assert.Equal(t, "test@example.com", resp.Email)
|
||||
assert.Equal(t, constants.UserStatusActive, resp.Status)
|
||||
})
|
||||
|
||||
t.Run("用户响应不包含密码", func(t *testing.T) {
|
||||
// UserResponse 结构体不应该包含 Password 字段
|
||||
resp := &model.UserResponse{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
|
||||
// 验证结构体大小合理 (不包含密码字段)
|
||||
assert.NotNil(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
// TestListResponse 测试列表响应模型
|
||||
func TestListResponse(t *testing.T) {
|
||||
t.Run("用户列表响应", func(t *testing.T) {
|
||||
users := []model.UserResponse{
|
||||
{ID: 1, Username: "user1", Email: "user1@example.com", Status: constants.UserStatusActive},
|
||||
{ID: 2, Username: "user2", Email: "user2@example.com", Status: constants.UserStatusActive},
|
||||
}
|
||||
|
||||
resp := &model.ListUsersResponse{
|
||||
Users: users,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Total: 100,
|
||||
TotalPages: 5,
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, len(resp.Users))
|
||||
assert.Equal(t, 1, resp.Page)
|
||||
assert.Equal(t, 20, resp.PageSize)
|
||||
assert.Equal(t, int64(100), resp.Total)
|
||||
assert.Equal(t, 5, resp.TotalPages)
|
||||
})
|
||||
|
||||
t.Run("订单列表响应", func(t *testing.T) {
|
||||
orders := []model.OrderResponse{
|
||||
{ID: 1, OrderID: "ORD-001", UserID: 1, Amount: 10000, Status: constants.OrderStatusPending},
|
||||
{ID: 2, OrderID: "ORD-002", UserID: 1, Amount: 20000, Status: constants.OrderStatusPaid},
|
||||
}
|
||||
|
||||
resp := &model.ListOrdersResponse{
|
||||
Orders: orders,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Total: 50,
|
||||
TotalPages: 3,
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, len(resp.Orders))
|
||||
assert.Equal(t, 1, resp.Page)
|
||||
assert.Equal(t, 20, resp.PageSize)
|
||||
assert.Equal(t, int64(50), resp.Total)
|
||||
assert.Equal(t, 3, resp.TotalPages)
|
||||
})
|
||||
}
|
||||
|
||||
// TestFieldTags 测试字段标签
|
||||
func TestFieldTags(t *testing.T) {
|
||||
t.Run("User GORM 标签", func(t *testing.T) {
|
||||
user := &model.User{}
|
||||
|
||||
// 验证 TableName 方法存在
|
||||
tableName := user.TableName()
|
||||
assert.Equal(t, "tb_user", tableName)
|
||||
})
|
||||
|
||||
t.Run("Order GORM 标签", func(t *testing.T) {
|
||||
order := &model.Order{}
|
||||
|
||||
// 验证 TableName 方法存在
|
||||
tableName := order.TableName()
|
||||
assert.Equal(t, "tb_order", tableName)
|
||||
})
|
||||
}
|
||||
@@ -1,550 +0,0 @@
|
||||
package unit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/model"
|
||||
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// setupTestStore 创建内存数据库用于单元测试
|
||||
func setupTestStore(t *testing.T) (*postgres.Store, func()) {
|
||||
// 使用 SQLite 内存数据库进行单元测试
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
require.NoError(t, err, "创建内存数据库失败")
|
||||
|
||||
// 自动迁移
|
||||
err = db.AutoMigrate(&model.User{}, &model.Order{})
|
||||
require.NoError(t, err, "数据库迁移失败")
|
||||
|
||||
// 创建测试 logger
|
||||
testLogger := zap.NewNop()
|
||||
store := postgres.NewStore(db, testLogger)
|
||||
|
||||
cleanup := func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return store, cleanup
|
||||
}
|
||||
|
||||
// TestUserStore 测试用户 Store 层
|
||||
func TestUserStore(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("创建用户成功", func(t *testing.T) {
|
||||
user := &model.User{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
|
||||
err := store.User.Create(ctx, user)
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, user.ID)
|
||||
assert.False(t, user.CreatedAt.IsZero())
|
||||
assert.False(t, user.UpdatedAt.IsZero())
|
||||
})
|
||||
|
||||
t.Run("创建重复用户名失败", func(t *testing.T) {
|
||||
user1 := &model.User{
|
||||
Username: "duplicate",
|
||||
Email: "user1@example.com",
|
||||
Password: "password1",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 尝试创建相同用户名
|
||||
user2 := &model.User{
|
||||
Username: "duplicate",
|
||||
Email: "user2@example.com",
|
||||
Password: "password2",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err = store.User.Create(ctx, user2)
|
||||
assert.Error(t, err, "应该返回唯一约束错误")
|
||||
})
|
||||
|
||||
t.Run("根据ID查询用户", func(t *testing.T) {
|
||||
user := &model.User{
|
||||
Username: "findbyid",
|
||||
Email: "findbyid@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := store.User.GetByID(ctx, user.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, user.Username, found.Username)
|
||||
assert.Equal(t, user.Email, found.Email)
|
||||
})
|
||||
|
||||
t.Run("查询不存在的用户", func(t *testing.T) {
|
||||
_, err := store.User.GetByID(ctx, 99999)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, gorm.ErrRecordNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("根据用户名查询用户", func(t *testing.T) {
|
||||
user := &model.User{
|
||||
Username: "findbyname",
|
||||
Email: "findbyname@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := store.User.GetByUsername(ctx, "findbyname")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.ID)
|
||||
})
|
||||
|
||||
t.Run("更新用户", func(t *testing.T) {
|
||||
user := &model.User{
|
||||
Username: "updatetest",
|
||||
Email: "update@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 更新用户
|
||||
user.Email = "newemail@example.com"
|
||||
user.Status = constants.UserStatusInactive
|
||||
err = store.User.Update(ctx, user)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证更新
|
||||
found, err := store.User.GetByID(ctx, user.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "newemail@example.com", found.Email)
|
||||
assert.Equal(t, constants.UserStatusInactive, found.Status)
|
||||
})
|
||||
|
||||
t.Run("软删除用户", func(t *testing.T) {
|
||||
user := &model.User{
|
||||
Username: "deletetest",
|
||||
Email: "delete@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 软删除
|
||||
err = store.User.Delete(ctx, user.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证已删除
|
||||
_, err = store.User.GetByID(ctx, user.ID)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, gorm.ErrRecordNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("分页列表查询", func(t *testing.T) {
|
||||
// 创建10个用户
|
||||
for i := 1; i <= 10; i++ {
|
||||
user := &model.User{
|
||||
Username: "listuser" + string(rune('0'+i)),
|
||||
Email: "list" + string(rune('0'+i)) + "@example.com",
|
||||
Password: "password",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// 第一页
|
||||
users, total, err := store.User.List(ctx, 1, 5)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(users), 5)
|
||||
assert.GreaterOrEqual(t, total, int64(10))
|
||||
|
||||
// 第二页
|
||||
users2, total2, err := store.User.List(ctx, 2, 5)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(users2), 5)
|
||||
assert.Equal(t, total, total2)
|
||||
|
||||
// 验证不同页的数据不同
|
||||
if len(users) > 0 && len(users2) > 0 {
|
||||
assert.NotEqual(t, users[0].ID, users2[0].ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestOrderStore 测试订单 Store 层
|
||||
func TestOrderStore(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建测试用户
|
||||
user := &model.User{
|
||||
Username: "orderuser",
|
||||
Email: "orderuser@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("创建订单成功", func(t *testing.T) {
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-TEST-001",
|
||||
UserID: user.ID,
|
||||
Amount: 10000,
|
||||
Status: constants.OrderStatusPending,
|
||||
Remark: "测试订单",
|
||||
}
|
||||
|
||||
err := store.Order.Create(ctx, order)
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, order.ID)
|
||||
assert.False(t, order.CreatedAt.IsZero())
|
||||
})
|
||||
|
||||
t.Run("创建重复订单号失败", func(t *testing.T) {
|
||||
order1 := &model.Order{
|
||||
OrderID: "ORD-DUP-001",
|
||||
UserID: user.ID,
|
||||
Amount: 10000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err := store.Order.Create(ctx, order1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 尝试创建相同订单号
|
||||
order2 := &model.Order{
|
||||
OrderID: "ORD-DUP-001",
|
||||
UserID: user.ID,
|
||||
Amount: 20000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err = store.Order.Create(ctx, order2)
|
||||
assert.Error(t, err, "应该返回唯一约束错误")
|
||||
})
|
||||
|
||||
t.Run("根据ID查询订单", func(t *testing.T) {
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-FIND-001",
|
||||
UserID: user.ID,
|
||||
Amount: 20000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err := store.Order.Create(ctx, order)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := store.Order.GetByID(ctx, order.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, order.OrderID, found.OrderID)
|
||||
assert.Equal(t, order.Amount, found.Amount)
|
||||
})
|
||||
|
||||
t.Run("根据订单号查询", func(t *testing.T) {
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-FIND-002",
|
||||
UserID: user.ID,
|
||||
Amount: 30000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err := store.Order.Create(ctx, order)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := store.Order.GetByOrderID(ctx, "ORD-FIND-002")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, order.ID, found.ID)
|
||||
})
|
||||
|
||||
t.Run("根据用户ID列表查询", func(t *testing.T) {
|
||||
// 创建多个订单
|
||||
for i := 1; i <= 5; i++ {
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-LIST-" + string(rune('0'+i)),
|
||||
UserID: user.ID,
|
||||
Amount: int64(i * 10000),
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err := store.Order.Create(ctx, order)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
orders, total, err := store.Order.ListByUserID(ctx, user.ID, 1, 10)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(orders), 5)
|
||||
assert.GreaterOrEqual(t, total, int64(5))
|
||||
})
|
||||
|
||||
t.Run("更新订单状态", func(t *testing.T) {
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-UPDATE-001",
|
||||
UserID: user.ID,
|
||||
Amount: 50000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err := store.Order.Create(ctx, order)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 更新状态
|
||||
now := time.Now()
|
||||
order.Status = constants.OrderStatusPaid
|
||||
order.PaidAt = &now
|
||||
err = store.Order.Update(ctx, order)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证更新
|
||||
found, err := store.Order.GetByID(ctx, order.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, constants.OrderStatusPaid, found.Status)
|
||||
assert.NotNil(t, found.PaidAt)
|
||||
})
|
||||
|
||||
t.Run("软删除订单", func(t *testing.T) {
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-DELETE-001",
|
||||
UserID: user.ID,
|
||||
Amount: 60000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
err := store.Order.Create(ctx, order)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 软删除
|
||||
err = store.Order.Delete(ctx, order.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证已删除
|
||||
_, err = store.Order.GetByID(ctx, order.ID)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, gorm.ErrRecordNotFound, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestStoreTransaction 测试事务功能
|
||||
func TestStoreTransaction(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("事务提交成功", func(t *testing.T) {
|
||||
var userID uint
|
||||
var orderID uint
|
||||
|
||||
err := store.Transaction(ctx, func(tx *postgres.Store) error {
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
Username: "txuser1",
|
||||
Email: "txuser1@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
if err := tx.User.Create(ctx, user); err != nil {
|
||||
return err
|
||||
}
|
||||
userID = user.ID
|
||||
|
||||
// 创建订单
|
||||
order := &model.Order{
|
||||
OrderID: "ORD-TX-001",
|
||||
UserID: user.ID,
|
||||
Amount: 10000,
|
||||
Status: constants.OrderStatusPending,
|
||||
}
|
||||
if err := tx.Order.Create(ctx, order); err != nil {
|
||||
return err
|
||||
}
|
||||
orderID = order.ID
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证用户和订单都已创建
|
||||
user, err := store.User.GetByID(ctx, userID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "txuser1", user.Username)
|
||||
|
||||
order, err := store.Order.GetByID(ctx, orderID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ORD-TX-001", order.OrderID)
|
||||
})
|
||||
|
||||
t.Run("事务回滚", func(t *testing.T) {
|
||||
var userID uint
|
||||
|
||||
err := store.Transaction(ctx, func(tx *postgres.Store) error {
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
Username: "rollbackuser",
|
||||
Email: "rollback@example.com",
|
||||
Password: "hashedpassword",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
if err := tx.User.Create(ctx, user); err != nil {
|
||||
return err
|
||||
}
|
||||
userID = user.ID
|
||||
|
||||
// 模拟错误,触发回滚
|
||||
return errors.New("模拟错误")
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "模拟错误", err.Error())
|
||||
|
||||
// 验证用户未创建(已回滚)
|
||||
_, err = store.User.GetByID(ctx, userID)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, gorm.ErrRecordNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("嵌套事务回滚", func(t *testing.T) {
|
||||
var user1ID, user2ID uint
|
||||
|
||||
err := store.Transaction(ctx, func(tx1 *postgres.Store) error {
|
||||
// 外层事务:创建第一个用户
|
||||
user1 := &model.User{
|
||||
Username: "nested1",
|
||||
Email: "nested1@example.com",
|
||||
Password: "password",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
if err := tx1.User.Create(ctx, user1); err != nil {
|
||||
return err
|
||||
}
|
||||
user1ID = user1.ID
|
||||
|
||||
// 内层事务:创建第二个用户并失败
|
||||
err := tx1.Transaction(ctx, func(tx2 *postgres.Store) error {
|
||||
user2 := &model.User{
|
||||
Username: "nested2",
|
||||
Email: "nested2@example.com",
|
||||
Password: "password",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
if err := tx2.User.Create(ctx, user2); err != nil {
|
||||
return err
|
||||
}
|
||||
user2ID = user2.ID
|
||||
|
||||
// 内层事务失败
|
||||
return errors.New("内层事务失败")
|
||||
})
|
||||
|
||||
// 内层事务失败导致外层事务也失败
|
||||
return err
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
|
||||
// 验证两个用户都未创建
|
||||
_, err = store.User.GetByID(ctx, user1ID)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = store.User.GetByID(ctx, user2ID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestConcurrentAccess 测试并发访问
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
store, cleanup := setupTestStore(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("并发创建用户", func(t *testing.T) {
|
||||
concurrency := 20
|
||||
errChan := make(chan error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(index int) {
|
||||
user := &model.User{
|
||||
Username: "concurrent" + string(rune('A'+index)),
|
||||
Email: "concurrent" + string(rune('A'+index)) + "@example.com",
|
||||
Password: "password",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
errChan <- store.User.Create(ctx, user)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 收集结果
|
||||
successCount := 0
|
||||
for i := 0; i < concurrency; i++ {
|
||||
err := <-errChan
|
||||
if err == nil {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, concurrency, successCount, "所有并发创建应该成功")
|
||||
})
|
||||
|
||||
t.Run("并发读写同一用户", func(t *testing.T) {
|
||||
// 创建测试用户
|
||||
user := &model.User{
|
||||
Username: "rwuser",
|
||||
Email: "rwuser@example.com",
|
||||
Password: "password",
|
||||
Status: constants.UserStatusActive,
|
||||
}
|
||||
err := store.User.Create(ctx, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
concurrency := 10
|
||||
done := make(chan bool, concurrency*2)
|
||||
|
||||
// 并发读
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
_, err := store.User.GetByID(ctx, user.ID)
|
||||
assert.NoError(t, err)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// 并发写
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(index int) {
|
||||
user.Status = constants.UserStatusActive
|
||||
err := store.User.Update(ctx, user)
|
||||
assert.NoError(t, err)
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 等待所有操作完成
|
||||
for i := 0; i < concurrency*2; i++ {
|
||||
<-done
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user