实现角色权限体系重构

本次提交完成了角色权限体系的重构,主要包括:

1. 数据库迁移
   - 添加 tb_permission.platform 字段(all/web/h5)
   - 更新 tb_role.role_type 注释(1=平台角色,2=客户角色)

2. GORM 模型更新
   - Permission 模型添加 Platform 字段
   - Role 模型更新 RoleType 注释

3. 常量定义
   - 新增角色类型常量(RoleTypePlatform, RoleTypeCustomer)
   - 新增权限端口常量(PlatformAll, PlatformWeb, PlatformH5)
   - 添加角色类型与用户类型匹配规则函数

4. Store 层实现
   - Permission Store 支持按 platform 过滤
   - Account Role Store 添加 CountByAccountID 方法

5. Service 层实现
   - 角色分配支持类型匹配校验
   - 角色分配支持数量限制(超级管理员0个,平台用户无限制,代理/企业1个)
   - Permission Service 支持 platform 过滤

6. 权限校验中间件
   - 实现 RequirePermission、RequireAnyPermission、RequireAllPermissions
   - 支持 platform 字段过滤
   - 支持跳过超级管理员检查

7. 测试用例
   - 角色类型匹配规则单元测试
   - 角色分配数量限制单元测试
   - 权限 platform 过滤单元测试
   - 权限校验中间件集成测试(占位)

8. 代码清理
   - 删除过时的 subordinate 测试文件
   - 移除 Account.ParentID 相关引用
   - 更新 DTO 验证规则

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-10 09:51:52 +08:00
parent a36e4a79c0
commit 1b9080e3ab
31 changed files with 1767 additions and 607 deletions

View File

@@ -97,7 +97,7 @@ func TestAccountRoleAssociation_AssignRoles(t *testing.T) {
// 创建测试角色
role := &model.Role{
RoleName: "单角色测试",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -127,7 +127,7 @@ func TestAccountRoleAssociation_AssignRoles(t *testing.T) {
for i := 0; i < 3; i++ {
roles[i] = &model.Role{
RoleName: "多角色测试_" + string(rune('A'+i)),
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(roles[i])
@@ -154,7 +154,7 @@ func TestAccountRoleAssociation_AssignRoles(t *testing.T) {
// 创建并分配角色
role := &model.Role{
RoleName: "获取角色列表测试",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -183,7 +183,7 @@ func TestAccountRoleAssociation_AssignRoles(t *testing.T) {
// 创建并分配角色
role := &model.Role{
RoleName: "移除角色测试",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -216,7 +216,7 @@ func TestAccountRoleAssociation_AssignRoles(t *testing.T) {
// 创建测试角色
role := &model.Role{
RoleName: "重复分配测试",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -238,7 +238,7 @@ func TestAccountRoleAssociation_AssignRoles(t *testing.T) {
t.Run("账号不存在时分配角色失败", func(t *testing.T) {
role := &model.Role{
RoleName: "账号不存在测试",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -322,7 +322,7 @@ func TestAccountRoleAssociation_SoftDelete(t *testing.T) {
role := &model.Role{
RoleName: "恢复角色测试",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)

View File

@@ -187,7 +187,6 @@ func TestAccountAPI_Create(t *testing.T) {
Phone: "13800000001",
Password: "Password123",
UserType: constants.UserTypePlatform,
ParentID: &rootAccount.ID,
}
jsonBody, _ := json.Marshal(reqBody)
@@ -216,7 +215,6 @@ func TestAccountAPI_Create(t *testing.T) {
Phone: "13800000002",
Password: "hashedpassword",
UserType: constants.UserTypePlatform,
ParentID: &rootAccount.ID,
Status: constants.StatusEnabled,
}
createTestAccount(t, env.db, existingAccount)
@@ -227,7 +225,6 @@ func TestAccountAPI_Create(t *testing.T) {
Phone: "13800000003",
Password: "Password123",
UserType: constants.UserTypePlatform,
ParentID: &rootAccount.ID,
}
jsonBody, _ := json.Marshal(reqBody)
@@ -476,7 +473,7 @@ func TestAccountAPI_AssignRoles(t *testing.T) {
// 创建测试角色
testRole := &model.Role{
RoleName: "测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
env.db.Create(testRole)
@@ -527,7 +524,7 @@ func TestAccountAPI_GetRoles(t *testing.T) {
// 创建并分配角色
testRole := &model.Role{
RoleName: "获取角色测试",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
env.db.Create(testRole)
@@ -580,7 +577,7 @@ func TestAccountAPI_RemoveRole(t *testing.T) {
// 创建并分配角色
testRole := &model.Role{
RoleName: "移除角色测试",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
env.db.Create(testRole)

View File

@@ -230,7 +230,7 @@ func TestAPIRegression_RouteModularization(t *testing.T) {
// 创建测试数据
role := &model.Role{
RoleName: "回归测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
env.db.Create(role)

View File

@@ -0,0 +1,130 @@
package integration
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/break/junhong_cmp_fiber/pkg/constants"
)
// MockPermissionChecker 模拟权限检查器
type MockPermissionChecker struct {
permissions map[uint]map[string]bool // userID -> permCode -> hasPermission
}
func NewMockPermissionChecker() *MockPermissionChecker {
return &MockPermissionChecker{
permissions: make(map[uint]map[string]bool),
}
}
func (m *MockPermissionChecker) GrantPermission(userID uint, permCode string) {
if m.permissions[userID] == nil {
m.permissions[userID] = make(map[string]bool)
}
m.permissions[userID][permCode] = true
}
func (m *MockPermissionChecker) CheckPermission(ctx context.Context, userID uint, permCode string, platform string) (bool, error) {
if m.permissions[userID] == nil {
return false, nil
}
return m.permissions[userID][permCode], nil
}
// TestPermissionMiddleware_RequirePermission 测试权限校验中间件(单个权限)
// TODO: 完整实现需要启动 Fiber 应用并模拟 HTTP 请求
func TestPermissionMiddleware_RequirePermission(t *testing.T) {
t.Skip("TODO: 需要完整的 Fiber 集成测试环境")
// 占位测试:验证 PermissionChecker 接口可以被 mock
checker := NewMockPermissionChecker()
checker.GrantPermission(1, "user:read")
ctx := context.Background()
hasPermission, err := checker.CheckPermission(ctx, 1, "user:read", constants.PlatformAll)
assert.NoError(t, err)
assert.True(t, hasPermission)
hasPermission, err = checker.CheckPermission(ctx, 1, "user:write", constants.PlatformAll)
assert.NoError(t, err)
assert.False(t, hasPermission)
}
// TestPermissionMiddleware_RequireAnyPermission 测试权限校验中间件(多个权限任一)
func TestPermissionMiddleware_RequireAnyPermission(t *testing.T) {
t.Skip("TODO: 需要完整的 Fiber 集成测试环境")
}
// TestPermissionMiddleware_RequireAllPermissions 测试权限校验中间件(多个权限全部)
func TestPermissionMiddleware_RequireAllPermissions(t *testing.T) {
t.Skip("TODO: 需要完整的 Fiber 集成测试环境")
}
// TestPermissionMiddleware_SkipSuperAdmin 测试超级管理员跳过权限检查
func TestPermissionMiddleware_SkipSuperAdmin(t *testing.T) {
t.Skip("TODO: 需要完整的 Fiber 集成测试环境")
}
// TestPermissionMiddleware_PlatformFiltering 测试按 platform 过滤权限
func TestPermissionMiddleware_PlatformFiltering(t *testing.T) {
t.Skip("TODO: 需要完整的 Fiber 集成测试环境")
// 测试场景:
// 1. Web 端请求需要 Web 权限
// 2. H5 端请求需要 H5 权限
// 3. all 权限在所有端口都有效
}
// TestPermissionMiddleware_Unauthorized 测试未认证用户访问受保护路由
func TestPermissionMiddleware_Unauthorized(t *testing.T) {
t.Skip("TODO: 需要完整的 Fiber 集成测试环境")
}
// 集成测试实现指南:
//
// 完整的集成测试应该:
// 1. 启动 Fiber 应用
// 2. 注册受权限保护的路由:
// - 使用 middleware.RequirePermission("user:read", config)
// - 使用 middleware.RequireAnyPermission([]string{"user:read", "user:write"}, config)
// - 使用 middleware.RequireAllPermissions([]string{"user:read", "user:write"}, config)
// 3. 模拟不同用户的 HTTP 请求
// 4. 验证权限检查结果200 OK 或 403 Forbidden
//
// 示例代码结构:
//
// func TestPermissionMiddleware_Integration(t *testing.T) {
// // 1. 初始化数据库和 Redis
// db, redisClient := testutils.SetupTestDB(t)
// defer testutils.TeardownTestDB(t, db, redisClient)
//
// // 2. 创建测试数据(用户、角色、权限)
// // ...
//
// // 3. 初始化 Service 和 Middleware
// permissionService := permission.New(permissionStore)
// config := middleware.PermissionConfig{
// PermissionChecker: permissionService,
// Platform: constants.PlatformWeb,
// SkipSuperAdmin: true,
// }
//
// // 4. 创建 Fiber 应用并注册路由
// app := fiber.New()
// app.Get("/protected",
// middleware.RequirePermission("user:read", config),
// func(c *fiber.Ctx) error {
// return c.JSON(fiber.Map{"message": "success"})
// },
// )
//
// // 5. 模拟请求并验证响应
// req := httptest.NewRequest("GET", "/protected", nil)
// // 设置认证信息...
// resp, err := app.Test(req)
// require.NoError(t, err)
// assert.Equal(t, fiber.StatusOK, resp.StatusCode)
// }

View File

@@ -70,7 +70,7 @@ func TestRolePermissionAssociation_AssignPermissions(t *testing.T) {
// 创建测试角色
role := &model.Role{
RoleName: "单权限测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -96,7 +96,7 @@ func TestRolePermissionAssociation_AssignPermissions(t *testing.T) {
// 创建测试角色
role := &model.Role{
RoleName: "多权限测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -124,7 +124,7 @@ func TestRolePermissionAssociation_AssignPermissions(t *testing.T) {
// 创建测试角色
role := &model.Role{
RoleName: "获取权限列表测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -152,7 +152,7 @@ func TestRolePermissionAssociation_AssignPermissions(t *testing.T) {
// 创建测试角色
role := &model.Role{
RoleName: "移除权限测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -184,7 +184,7 @@ func TestRolePermissionAssociation_AssignPermissions(t *testing.T) {
// 创建测试角色
role := &model.Role{
RoleName: "重复权限测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -228,7 +228,7 @@ func TestRolePermissionAssociation_AssignPermissions(t *testing.T) {
t.Run("权限不存在时分配失败", func(t *testing.T) {
role := &model.Role{
RoleName: "权限不存在测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -276,7 +276,7 @@ func TestRolePermissionAssociation_SoftDelete(t *testing.T) {
// 创建测试数据
role := &model.Role{
RoleName: "恢复权限测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -312,7 +312,7 @@ func TestRolePermissionAssociation_SoftDelete(t *testing.T) {
// 创建测试角色
role := &model.Role{
RoleName: "批量权限测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)
@@ -383,7 +383,7 @@ func TestRolePermissionAssociation_Cascade(t *testing.T) {
// 创建角色和权限
role := &model.Role{
RoleName: "级联测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
db.Create(role)

View File

@@ -167,7 +167,7 @@ func TestRoleAPI_Create(t *testing.T) {
reqBody := model.CreateRoleRequest{
RoleName: "测试角色",
RoleDesc: "这是一个测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
}
jsonBody, _ := json.Marshal(reqBody)
@@ -224,7 +224,7 @@ func TestRoleAPI_Get(t *testing.T) {
// 创建测试角色
testRole := &model.Role{
RoleName: "获取测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
env.db.Create(testRole)
@@ -269,7 +269,7 @@ func TestRoleAPI_Update(t *testing.T) {
// 创建测试角色
testRole := &model.Role{
RoleName: "更新测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
env.db.Create(testRole)
@@ -312,7 +312,7 @@ func TestRoleAPI_Delete(t *testing.T) {
// 创建测试角色
testRole := &model.Role{
RoleName: "删除测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
env.db.Create(testRole)
@@ -347,7 +347,7 @@ func TestRoleAPI_List(t *testing.T) {
for i := 1; i <= 5; i++ {
role := &model.Role{
RoleName: fmt.Sprintf("列表测试角色_%d", i),
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
env.db.Create(role)
@@ -382,7 +382,7 @@ func TestRoleAPI_AssignPermissions(t *testing.T) {
// 创建测试角色
testRole := &model.Role{
RoleName: "权限分配测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
env.db.Create(testRole)
@@ -432,7 +432,7 @@ func TestRoleAPI_GetPermissions(t *testing.T) {
// 创建测试角色
testRole := &model.Role{
RoleName: "获取权限测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
env.db.Create(testRole)
@@ -482,7 +482,7 @@ func TestRoleAPI_RemovePermission(t *testing.T) {
// 创建测试角色
testRole := &model.Role{
RoleName: "移除权限测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
env.db.Create(testRole)

View File

@@ -37,32 +37,7 @@ func TestAccountModel_Create(t *testing.T) {
assert.NotZero(t, account.UpdatedAt)
})
t.Run("创建带 parent_id 的账号", func(t *testing.T) {
// 先创建父账号
parent := &model.Account{
Username: "parent_user",
Phone: "13800000002",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
Status: constants.StatusEnabled,
}
err := store.Create(ctx, parent)
require.NoError(t, err)
// 创建子账号
child := &model.Account{
Username: "child_user",
Phone: "13800000003",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &parent.ID,
Status: constants.StatusEnabled,
}
err = store.Create(ctx, child)
require.NoError(t, err)
assert.NotZero(t, child.ID)
assert.Equal(t, parent.ID, *child.ParentID)
})
// 注意parent_id 字段已被移除,层级关系通过 shop_id 和 enterprise_id 维护
t.Run("创建带 shop_id 的账号", func(t *testing.T) {
shopID := uint(100)

View File

@@ -0,0 +1,209 @@
package unit
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/service/permission"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/break/junhong_cmp_fiber/tests/testutils"
)
// TestPermissionPlatformFilter_List 测试权限列表按 platform 过滤
func TestPermissionPlatformFilter_List(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
permissionStore := postgres.NewPermissionStore(db)
service := permission.New(permissionStore)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0)
// 创建不同 platform 的权限
permissions := []*model.Permission{
{PermName: "全端菜单", PermCode: "menu:all", PermType: constants.PermissionTypeMenu, Platform: constants.PlatformAll, Status: constants.StatusEnabled},
{PermName: "Web菜单", PermCode: "menu:web", PermType: constants.PermissionTypeMenu, Platform: constants.PlatformWeb, Status: constants.StatusEnabled},
{PermName: "H5菜单", PermCode: "menu:h5", PermType: constants.PermissionTypeMenu, Platform: constants.PlatformH5, Status: constants.StatusEnabled},
{PermName: "Web按钮", PermCode: "button:web", PermType: constants.PermissionTypeButton, Platform: constants.PlatformWeb, Status: constants.StatusEnabled},
{PermName: "H5按钮", PermCode: "button:h5", PermType: constants.PermissionTypeButton, Platform: constants.PlatformH5, Status: constants.StatusEnabled},
}
for _, perm := range permissions {
require.NoError(t, db.Create(perm).Error)
}
// 测试查询全部权限(不过滤)
t.Run("查询全部权限", func(t *testing.T) {
req := &model.PermissionListRequest{
Page: 1,
PageSize: 10,
}
perms, total, err := service.List(ctx, req)
require.NoError(t, err)
assert.Equal(t, int64(5), total)
assert.Len(t, perms, 5)
})
// 测试只查询 all 权限
t.Run("只查询all端口权限", func(t *testing.T) {
req := &model.PermissionListRequest{
Page: 1,
PageSize: 10,
Platform: constants.PlatformAll,
}
perms, total, err := service.List(ctx, req)
require.NoError(t, err)
assert.Equal(t, int64(1), total)
assert.Len(t, perms, 1)
assert.Equal(t, "全端菜单", perms[0].PermName)
})
// 测试只查询 web 权限
t.Run("只查询web端口权限", func(t *testing.T) {
req := &model.PermissionListRequest{
Page: 1,
PageSize: 10,
Platform: constants.PlatformWeb,
}
perms, total, err := service.List(ctx, req)
require.NoError(t, err)
assert.Equal(t, int64(2), total)
assert.Len(t, perms, 2)
// 验证都是 web 端口的权限
for _, perm := range perms {
assert.Equal(t, constants.PlatformWeb, perm.Platform)
}
})
// 测试只查询 h5 权限
t.Run("只查询h5端口权限", func(t *testing.T) {
req := &model.PermissionListRequest{
Page: 1,
PageSize: 10,
Platform: constants.PlatformH5,
}
perms, total, err := service.List(ctx, req)
require.NoError(t, err)
assert.Equal(t, int64(2), total)
assert.Len(t, perms, 2)
// 验证都是 h5 端口的权限
for _, perm := range perms {
assert.Equal(t, constants.PlatformH5, perm.Platform)
}
})
}
// TestPermissionPlatformFilter_CreateWithDefaultPlatform 测试创建权限时默认 platform 为 all
func TestPermissionPlatformFilter_CreateWithDefaultPlatform(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
permissionStore := postgres.NewPermissionStore(db)
service := permission.New(permissionStore)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0)
// 创建权限时不指定 platform
req := &model.CreatePermissionRequest{
PermName: "测试权限",
PermCode: "test:permission",
PermType: constants.PermissionTypeMenu,
// Platform 字段为空
}
perm, err := service.Create(ctx, req)
require.NoError(t, err)
assert.Equal(t, constants.PlatformAll, perm.Platform, "未指定 platform 时应默认为 all")
}
// TestPermissionPlatformFilter_CreateWithSpecificPlatform 测试创建权限时指定 platform
func TestPermissionPlatformFilter_CreateWithSpecificPlatform(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
permissionStore := postgres.NewPermissionStore(db)
service := permission.New(permissionStore)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0)
tests := []struct {
name string
platform string
expected string
}{
{name: "指定为all", platform: constants.PlatformAll, expected: constants.PlatformAll},
{name: "指定为web", platform: constants.PlatformWeb, expected: constants.PlatformWeb},
{name: "指定为h5", platform: constants.PlatformH5, expected: constants.PlatformH5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &model.CreatePermissionRequest{
PermName: "测试权限_" + tt.platform,
PermCode: "test:" + tt.platform,
PermType: constants.PermissionTypeMenu,
Platform: tt.platform,
}
perm, err := service.Create(ctx, req)
require.NoError(t, err)
assert.Equal(t, tt.expected, perm.Platform)
})
}
}
// TestPermissionPlatformFilter_Tree 测试权限树包含 platform 字段
func TestPermissionPlatformFilter_Tree(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
permissionStore := postgres.NewPermissionStore(db)
service := permission.New(permissionStore)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0)
// 创建层级权限
parent := &model.Permission{
PermName: "系统管理",
PermCode: "system:manage",
PermType: constants.PermissionTypeMenu,
Platform: constants.PlatformWeb,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(parent).Error)
child := &model.Permission{
PermName: "用户管理",
PermCode: "user:manage",
PermType: constants.PermissionTypeMenu,
Platform: constants.PlatformWeb,
ParentID: &parent.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(child).Error)
// 获取权限树
tree, err := service.GetTree(ctx)
require.NoError(t, err)
require.Len(t, tree, 1)
// 验证父节点
root := tree[0]
assert.Equal(t, "系统管理", root.PermName)
assert.Equal(t, constants.PlatformWeb, root.Platform)
// 验证子节点
require.Len(t, root.Children, 1)
childNode := root.Children[0]
assert.Equal(t, "用户管理", childNode.PermName)
assert.Equal(t, constants.PlatformWeb, childNode.Platform)
}

View File

@@ -0,0 +1,179 @@
package unit
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/service/account"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/break/junhong_cmp_fiber/tests/testutils"
)
// TestRoleAssignmentLimit_PlatformUser 测试平台用户可以分配多个角色(无限制)
func TestRoleAssignmentLimit_PlatformUser(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
accountStore := postgres.NewAccountStore(db, redisClient)
roleStore := postgres.NewRoleStore(db)
accountRoleStore := postgres.NewAccountRoleStore(db)
service := account.New(accountStore, roleStore, accountRoleStore)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0)
// 创建平台用户
platformUser := &model.Account{
Username: "platform_user",
Phone: "13800000001",
Password: "hashedpassword",
UserType: constants.UserTypePlatform,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(platformUser).Error)
// 创建 3 个平台角色
roles := []*model.Role{
{RoleName: "运营", RoleType: constants.RoleTypePlatform, Status: constants.StatusEnabled},
{RoleName: "客服", RoleType: constants.RoleTypePlatform, Status: constants.StatusEnabled},
{RoleName: "财务", RoleType: constants.RoleTypePlatform, Status: constants.StatusEnabled},
}
for _, role := range roles {
require.NoError(t, db.Create(role).Error)
}
// 为平台用户分配 3 个角色(应该成功,因为平台用户无限制)
roleIDs := []uint{roles[0].ID, roles[1].ID, roles[2].ID}
ars, err := service.AssignRoles(ctx, platformUser.ID, roleIDs)
require.NoError(t, err)
assert.Len(t, ars, 3)
}
// TestRoleAssignmentLimit_AgentUser 测试代理账号只能分配一个角色
func TestRoleAssignmentLimit_AgentUser(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
accountStore := postgres.NewAccountStore(db, redisClient)
roleStore := postgres.NewRoleStore(db)
accountRoleStore := postgres.NewAccountRoleStore(db)
service := account.New(accountStore, roleStore, accountRoleStore)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0)
// 创建代理账号
agentAccount := &model.Account{
Username: "agent_user",
Phone: "13800000002",
Password: "hashedpassword",
UserType: constants.UserTypeAgent,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(agentAccount).Error)
// 创建 2 个客户角色
roles := []*model.Role{
{RoleName: "一级代理", RoleType: constants.RoleTypeCustomer, Status: constants.StatusEnabled},
{RoleName: "二级代理", RoleType: constants.RoleTypeCustomer, Status: constants.StatusEnabled},
}
for _, role := range roles {
require.NoError(t, db.Create(role).Error)
}
// 先分配第一个角色(应该成功)
ars, err := service.AssignRoles(ctx, agentAccount.ID, []uint{roles[0].ID})
require.NoError(t, err)
assert.Len(t, ars, 1)
// 尝试分配第二个角色(应该失败,超过数量限制)
_, err = service.AssignRoles(ctx, agentAccount.ID, []uint{roles[1].ID})
require.Error(t, err)
assert.Contains(t, err.Error(), "最多只能分配 1 个角色")
}
// TestRoleAssignmentLimit_EnterpriseUser 测试企业账号只能分配一个角色
func TestRoleAssignmentLimit_EnterpriseUser(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
accountStore := postgres.NewAccountStore(db, redisClient)
roleStore := postgres.NewRoleStore(db)
accountRoleStore := postgres.NewAccountRoleStore(db)
service := account.New(accountStore, roleStore, accountRoleStore)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0)
// 创建企业账号
enterpriseAccount := &model.Account{
Username: "enterprise_user",
Phone: "13800000003",
Password: "hashedpassword",
UserType: constants.UserTypeEnterprise,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(enterpriseAccount).Error)
// 创建 2 个客户角色
roles := []*model.Role{
{RoleName: "企业普通", RoleType: constants.RoleTypeCustomer, Status: constants.StatusEnabled},
{RoleName: "企业高级", RoleType: constants.RoleTypeCustomer, Status: constants.StatusEnabled},
}
for _, role := range roles {
require.NoError(t, db.Create(role).Error)
}
// 先分配第一个角色(应该成功)
ars, err := service.AssignRoles(ctx, enterpriseAccount.ID, []uint{roles[0].ID})
require.NoError(t, err)
assert.Len(t, ars, 1)
// 尝试分配第二个角色(应该失败,超过数量限制)
_, err = service.AssignRoles(ctx, enterpriseAccount.ID, []uint{roles[1].ID})
require.Error(t, err)
assert.Contains(t, err.Error(), "最多只能分配 1 个角色")
}
// TestRoleAssignmentLimit_SuperAdmin 测试超级管理员不允许分配角色
func TestRoleAssignmentLimit_SuperAdmin(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
accountStore := postgres.NewAccountStore(db, redisClient)
roleStore := postgres.NewRoleStore(db)
accountRoleStore := postgres.NewAccountRoleStore(db)
service := account.New(accountStore, roleStore, accountRoleStore)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0)
// 创建超级管理员
superAdmin := &model.Account{
Username: "superadmin",
Phone: "13800000004",
Password: "hashedpassword",
UserType: constants.UserTypeSuperAdmin,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(superAdmin).Error)
// 创建一个平台角色
role := &model.Role{
RoleName: "测试角色",
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(role).Error)
// 尝试为超级管理员分配角色(应该失败)
_, err := service.AssignRoles(ctx, superAdmin.ID, []uint{role.ID})
require.Error(t, err)
assert.Contains(t, err.Error(), "不需要分配角色")
}

View File

@@ -0,0 +1,111 @@
package unit
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/break/junhong_cmp_fiber/pkg/constants"
)
// TestIsRoleTypeMatchUserType 测试角色类型与用户类型匹配规则
func TestIsRoleTypeMatchUserType(t *testing.T) {
tests := []struct {
name string
roleType int
userType int
expected bool
}{
{
name: "超级管理员不需要角色",
roleType: constants.RoleTypePlatform,
userType: constants.UserTypeSuperAdmin,
expected: false,
},
{
name: "平台用户匹配平台角色",
roleType: constants.RoleTypePlatform,
userType: constants.UserTypePlatform,
expected: true,
},
{
name: "平台用户不匹配客户角色",
roleType: constants.RoleTypeCustomer,
userType: constants.UserTypePlatform,
expected: false,
},
{
name: "代理账号匹配客户角色",
roleType: constants.RoleTypeCustomer,
userType: constants.UserTypeAgent,
expected: true,
},
{
name: "代理账号不匹配平台角色",
roleType: constants.RoleTypePlatform,
userType: constants.UserTypeAgent,
expected: false,
},
{
name: "企业账号匹配客户角色",
roleType: constants.RoleTypeCustomer,
userType: constants.UserTypeEnterprise,
expected: true,
},
{
name: "企业账号不匹配平台角色",
roleType: constants.RoleTypePlatform,
userType: constants.UserTypeEnterprise,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := constants.IsRoleTypeMatchUserType(tt.roleType, tt.userType)
assert.Equal(t, tt.expected, result)
})
}
}
// TestGetMaxRolesForUserType 测试用户类型的最大角色数量限制
func TestGetMaxRolesForUserType(t *testing.T) {
tests := []struct {
name string
userType int
expected int
}{
{
name: "超级管理员不需要角色",
userType: constants.UserTypeSuperAdmin,
expected: 0,
},
{
name: "平台用户无角色数量限制",
userType: constants.UserTypePlatform,
expected: -1, // -1 表示无限制
},
{
name: "代理账号最多一个角色",
userType: constants.UserTypeAgent,
expected: 1,
},
{
name: "企业账号最多一个角色",
userType: constants.UserTypeEnterprise,
expected: 1,
},
{
name: "未知用户类型不允许角色",
userType: 999,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := constants.GetMaxRolesForUserType(tt.userType)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -78,7 +78,7 @@ func TestRoleSoftDelete(t *testing.T) {
role := &model.Role{
RoleName: "test_role",
RoleDesc: "测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
err := roleStore.Create(ctx, role)
@@ -169,7 +169,7 @@ func TestAccountRoleSoftDelete(t *testing.T) {
role := &model.Role{
RoleName: "ar_role",
RoleDesc: "测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
err = roleStore.Create(ctx, role)
@@ -228,7 +228,7 @@ func TestRolePermissionSoftDelete(t *testing.T) {
role := &model.Role{
RoleName: "rp_role",
RoleDesc: "测试角色",
RoleType: constants.RoleTypeSuper,
RoleType: constants.RoleTypePlatform,
Status: constants.StatusEnabled,
}
err := roleStore.Create(ctx, role)

View File

@@ -1,276 +0,0 @@
package unit
import (
"context"
"testing"
"time"
"github.com/bytedance/sonic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/tests/testutils"
)
// TestGetSubordinateIDs_CacheHit 测试 Redis 缓存命中
func TestGetSubordinateIDs_CacheHit(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
store := postgres.NewAccountStore(db, redisClient)
ctx := context.Background()
// 创建测试账号
accountA := &model.Account{
Username: "user_a",
Phone: "13800000001",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountA).Error)
accountB := &model.Account{
Username: "user_b",
Phone: "13800000002",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &accountA.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountB).Error)
// 第一次查询(缓存未命中,会写入缓存)
ids1, err := store.GetSubordinateIDs(ctx, accountA.ID)
require.NoError(t, err)
assert.Len(t, ids1, 2)
// 验证缓存已写入
cacheKey := constants.RedisAccountSubordinatesKey(accountA.ID)
cached, err := redisClient.Get(ctx, cacheKey).Result()
require.NoError(t, err)
var cachedIDs []uint
require.NoError(t, sonic.Unmarshal([]byte(cached), &cachedIDs))
assert.Equal(t, ids1, cachedIDs)
// 第二次查询(缓存命中,不查询数据库)
ids2, err := store.GetSubordinateIDs(ctx, accountA.ID)
require.NoError(t, err)
assert.Equal(t, ids1, ids2)
}
// TestGetSubordinateIDs_CacheExpiry 测试缓存过期
func TestGetSubordinateIDs_CacheExpiry(t *testing.T) {
if testing.Short() {
t.Skip("跳过缓存过期测试")
}
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
store := postgres.NewAccountStore(db, redisClient)
ctx := context.Background()
// 创建测试账号
accountA := &model.Account{
Username: "user_a",
Phone: "13800000001",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountA).Error)
// 第一次查询(写入缓存)
ids1, err := store.GetSubordinateIDs(ctx, accountA.ID)
require.NoError(t, err)
// 验证缓存 TTL应该是 30 分钟)
cacheKey := constants.RedisAccountSubordinatesKey(accountA.ID)
ttl, err := redisClient.TTL(ctx, cacheKey).Result()
require.NoError(t, err)
assert.Greater(t, ttl, 29*time.Minute)
assert.LessOrEqual(t, ttl, 30*time.Minute)
// 模拟缓存过期(手动删除)
require.NoError(t, redisClient.Del(ctx, cacheKey).Err())
// 再次查询(缓存未命中,重新查询数据库)
ids2, err := store.GetSubordinateIDs(ctx, accountA.ID)
require.NoError(t, err)
assert.Equal(t, ids1, ids2)
// 验证缓存已重新写入
cached, err := redisClient.Get(ctx, cacheKey).Result()
require.NoError(t, err)
var cachedIDs []uint
require.NoError(t, sonic.Unmarshal([]byte(cached), &cachedIDs))
assert.Equal(t, ids2, cachedIDs)
}
// TestClearSubordinatesCache 测试清除指定账号的缓存
func TestClearSubordinatesCache(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
store := postgres.NewAccountStore(db, redisClient)
ctx := context.Background()
// 创建测试账号
accountA := &model.Account{
Username: "user_a",
Phone: "13800000001",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountA).Error)
// 查询以写入缓存
_, err := store.GetSubordinateIDs(ctx, accountA.ID)
require.NoError(t, err)
// 验证缓存存在
cacheKey := constants.RedisAccountSubordinatesKey(accountA.ID)
exists, err := redisClient.Exists(ctx, cacheKey).Result()
require.NoError(t, err)
assert.Equal(t, int64(1), exists)
// 清除缓存
err = store.ClearSubordinatesCache(ctx, accountA.ID)
require.NoError(t, err)
// 验证缓存已删除
exists, err = redisClient.Exists(ctx, cacheKey).Result()
require.NoError(t, err)
assert.Equal(t, int64(0), exists)
}
// TestClearSubordinatesCacheForParents 测试递归清除上级缓存
func TestClearSubordinatesCacheForParents(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
store := postgres.NewAccountStore(db, redisClient)
ctx := context.Background()
// 创建层级结构: A -> B -> C
accountA := &model.Account{
Username: "user_a",
Phone: "13800000001",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountA).Error)
accountB := &model.Account{
Username: "user_b",
Phone: "13800000002",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &accountA.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountB).Error)
accountC := &model.Account{
Username: "user_c",
Phone: "13800000003",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &accountB.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountC).Error)
// 查询所有账号以写入缓存
_, err := store.GetSubordinateIDs(ctx, accountA.ID)
require.NoError(t, err)
_, err = store.GetSubordinateIDs(ctx, accountB.ID)
require.NoError(t, err)
_, err = store.GetSubordinateIDs(ctx, accountC.ID)
require.NoError(t, err)
// 验证所有缓存存在
cacheKeyA := constants.RedisAccountSubordinatesKey(accountA.ID)
cacheKeyB := constants.RedisAccountSubordinatesKey(accountB.ID)
cacheKeyC := constants.RedisAccountSubordinatesKey(accountC.ID)
exists, _ := redisClient.Exists(ctx, cacheKeyA).Result()
assert.Equal(t, int64(1), exists)
exists, _ = redisClient.Exists(ctx, cacheKeyB).Result()
assert.Equal(t, int64(1), exists)
exists, _ = redisClient.Exists(ctx, cacheKeyC).Result()
assert.Equal(t, int64(1), exists)
// 清除 C 的缓存(应该递归清除 B 和 A 的缓存)
err = store.ClearSubordinatesCacheForParents(ctx, accountC.ID)
require.NoError(t, err)
// 验证所有上级缓存已删除
exists, _ = redisClient.Exists(ctx, cacheKeyA).Result()
assert.Equal(t, int64(0), exists, "A 的缓存应该被清除")
exists, _ = redisClient.Exists(ctx, cacheKeyB).Result()
assert.Equal(t, int64(0), exists, "B 的缓存应该被清除")
exists, _ = redisClient.Exists(ctx, cacheKeyC).Result()
assert.Equal(t, int64(0), exists, "C 的缓存应该被清除")
}
// TestCacheInvalidationOnCreate 测试创建账号时清除父账号缓存
func TestCacheInvalidationOnCreate(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
store := postgres.NewAccountStore(db, redisClient)
ctx := context.Background()
// 创建父账号
accountA := &model.Account{
Username: "user_a",
Phone: "13800000001",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountA).Error)
// 查询 A 的下级(只有自己),写入缓存
ids1, err := store.GetSubordinateIDs(ctx, accountA.ID)
require.NoError(t, err)
assert.Len(t, ids1, 1)
// 验证缓存存在
cacheKey := constants.RedisAccountSubordinatesKey(accountA.ID)
exists, _ := redisClient.Exists(ctx, cacheKey).Result()
assert.Equal(t, int64(1), exists)
// 创建子账号 B应该清除 A 的缓存)
accountB := &model.Account{
Username: "user_b",
Phone: "13800000002",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &accountA.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountB).Error)
// 注意:缓存清除逻辑在 Service 层,这里模拟清除
err = store.ClearSubordinatesCacheForParents(ctx, accountA.ID)
require.NoError(t, err)
// 验证缓存已清除
exists, _ = redisClient.Exists(ctx, cacheKey).Result()
assert.Equal(t, int64(0), exists, "创建子账号后,父账号的缓存应该被清除")
// 再次查询(缓存未命中,重新查询数据库,应该包含 B
ids2, err := store.GetSubordinateIDs(ctx, accountA.ID)
require.NoError(t, err)
assert.Len(t, ids2, 2, "应该包含 A 和 B")
assert.Contains(t, ids2, accountA.ID)
assert.Contains(t, ids2, accountB.ID)
}

View File

@@ -1,252 +0,0 @@
package unit
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/tests/testutils"
)
// TestGetSubordinateIDs_SingleLevel 测试单层下级查询
func TestGetSubordinateIDs_SingleLevel(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
store := postgres.NewAccountStore(db, redisClient)
ctx := context.Background()
// 创建层级结构: A -> B, C
accountA := &model.Account{
Username: "user_a",
Phone: "13800000001",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountA).Error)
accountB := &model.Account{
Username: "user_b",
Phone: "13800000002",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &accountA.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountB).Error)
accountC := &model.Account{
Username: "user_c",
Phone: "13800000003",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &accountA.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountC).Error)
// 查询 A 的所有下级(应该包含 A, B, C
ids, err := store.GetSubordinateIDs(ctx, accountA.ID)
require.NoError(t, err)
assert.Len(t, ids, 3)
assert.Contains(t, ids, accountA.ID)
assert.Contains(t, ids, accountB.ID)
assert.Contains(t, ids, accountC.ID)
}
// TestGetSubordinateIDs_MultiLevel 测试多层递归查询
func TestGetSubordinateIDs_MultiLevel(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
store := postgres.NewAccountStore(db, redisClient)
ctx := context.Background()
// 创建层级结构: A -> B -> C -> D -> E (5层)
accountA := &model.Account{
Username: "user_a",
Phone: "13800000001",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountA).Error)
accountB := &model.Account{
Username: "user_b",
Phone: "13800000002",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &accountA.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountB).Error)
accountC := &model.Account{
Username: "user_c",
Phone: "13800000003",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &accountB.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountC).Error)
accountD := &model.Account{
Username: "user_d",
Phone: "13800000004",
Password: "hashed_password",
UserType: constants.UserTypeEnterprise,
ParentID: &accountC.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountD).Error)
accountE := &model.Account{
Username: "user_e",
Phone: "13800000005",
Password: "hashed_password",
UserType: constants.UserTypeEnterprise,
ParentID: &accountD.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountE).Error)
// 查询 A 的所有下级(应该包含所有 5 个账号)
ids, err := store.GetSubordinateIDs(ctx, accountA.ID)
require.NoError(t, err)
assert.Len(t, ids, 5)
// 查询 B 的所有下级(应该包含 B, C, D, E
ids, err = store.GetSubordinateIDs(ctx, accountB.ID)
require.NoError(t, err)
assert.Len(t, ids, 4)
// 查询 E 的所有下级(只有自己)
ids, err = store.GetSubordinateIDs(ctx, accountE.ID)
require.NoError(t, err)
assert.Len(t, ids, 1)
assert.Equal(t, accountE.ID, ids[0])
}
// TestGetSubordinateIDs_WithSoftDeleted 测试包含软删除账号的递归查询
func TestGetSubordinateIDs_WithSoftDeleted(t *testing.T) {
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
store := postgres.NewAccountStore(db, redisClient)
ctx := context.Background()
// 创建层级结构: A -> B -> C
accountA := &model.Account{
Username: "user_a",
Phone: "13800000001",
Password: "hashed_password",
UserType: constants.UserTypePlatform,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountA).Error)
accountB := &model.Account{
Username: "user_b",
Phone: "13800000002",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &accountA.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountB).Error)
accountC := &model.Account{
Username: "user_c",
Phone: "13800000003",
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &accountB.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountC).Error)
// 软删除 B
require.NoError(t, db.Delete(accountB).Error)
// 查询 A 的所有下级(应该仍然包含 B 和 C因为递归查询包含软删除账号
ids, err := store.GetSubordinateIDs(ctx, accountA.ID)
require.NoError(t, err)
assert.Len(t, ids, 3)
assert.Contains(t, ids, accountB.ID)
assert.Contains(t, ids, accountC.ID)
}
// TestGetSubordinateIDs_Performance 测试递归查询性能
func TestGetSubordinateIDs_Performance(t *testing.T) {
if testing.Short() {
t.Skip("跳过性能测试")
}
db, redisClient := testutils.SetupTestDB(t)
defer testutils.TeardownTestDB(t, db, redisClient)
store := postgres.NewAccountStore(db, redisClient)
ctx := context.Background()
// 创建 5 层层级,每层 3 个分支(共 121 个账号)
// 层级 1: 1 个账号
accountA := &model.Account{
Username: "user_root",
Phone: "13800000000",
Password: "hashed_password",
UserType: constants.UserTypeSuperAdmin,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(accountA).Error)
// 层级 2: 3 个账号
var level2IDs []uint
for i := 1; i <= 3; i++ {
acc := &model.Account{
Username: testutils.GenerateUsername("level2", i),
Phone: testutils.GeneratePhone("138", i),
Password: "hashed_password",
UserType: constants.UserTypePlatform,
ParentID: &accountA.ID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(acc).Error)
level2IDs = append(level2IDs, acc.ID)
}
// 层级 3: 9 个账号
var level3IDs []uint
for _, parentID := range level2IDs {
for i := 1; i <= 3; i++ {
acc := &model.Account{
Username: testutils.GenerateUsername("level3", int(parentID)*10+i),
Phone: testutils.GeneratePhone("139", int(parentID)*10+i),
Password: "hashed_password",
UserType: constants.UserTypeAgent,
ParentID: &parentID,
Status: constants.StatusEnabled,
}
require.NoError(t, db.Create(acc).Error)
level3IDs = append(level3IDs, acc.ID)
}
}
// 测试查询性能(应该 < 50ms
start := testutils.Now()
ids, err := store.GetSubordinateIDs(ctx, accountA.ID)
duration := testutils.Since(start)
require.NoError(t, err)
assert.GreaterOrEqual(t, len(ids), 13) // 至少包含 1 + 3 + 9 个账号
// 验证性能要求
assert.Less(t, duration.Milliseconds(), int64(50), "递归查询应在 50ms 内完成")
}