重构数据权限模型并清理旧RBAC代码
核心变更: - 数据权限过滤从基于账号层级改为基于用户类型的多策略过滤 - 移除 AccountStore 中的 GetSubordinateIDs 等旧方法 - 重构认证中间件,支持 enterprise_id 和 customer_id - 更新 GORM Callback,根据用户类型自动选择过滤策略(代理/企业/个人客户) - 更新所有集成测试以适配新的 API 签名 - 添加功能总结文档和 OpenSpec 归档 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -4,12 +4,14 @@ import "time"
|
||||
|
||||
// Fiber Locals 的上下文键
|
||||
const (
|
||||
ContextKeyRequestID = "requestid" // 请求记录ID
|
||||
ContextKeyStartTime = "start_time" //请求开始时间
|
||||
ContextKeyUserID = "user_id" // 用户ID
|
||||
ContextKeyUserType = "user_type" //用户类型
|
||||
ContextKeyShopID = "shop_id" //店铺ID
|
||||
ContextKeyUserInfo = "user_info" //完整的用户信息
|
||||
ContextKeyRequestID = "requestid" // 请求记录ID
|
||||
ContextKeyStartTime = "start_time" // 请求开始时间
|
||||
ContextKeyUserID = "user_id" // 用户ID
|
||||
ContextKeyUserType = "user_type" // 用户类型
|
||||
ContextKeyShopID = "shop_id" // 店铺ID
|
||||
ContextKeyEnterpriseID = "enterprise_id" // 企业ID
|
||||
ContextKeyCustomerID = "customer_id" // 个人客户ID
|
||||
ContextKeyUserInfo = "user_info" // 完整的用户信息
|
||||
)
|
||||
|
||||
// 配置环境变量
|
||||
@@ -52,10 +54,11 @@ const (
|
||||
|
||||
// RBAC 用户类型常量
|
||||
const (
|
||||
UserTypeSuperAdmin = 1 // 超级管理员(跳过数据权限过滤)
|
||||
UserTypePlatform = 2 // 平台用户
|
||||
UserTypeAgent = 3 // 代理账号
|
||||
UserTypeEnterprise = 4 // 企业账号
|
||||
UserTypeSuperAdmin = 1 // 超级管理员(跳过数据权限过滤)
|
||||
UserTypePlatform = 2 // 平台用户
|
||||
UserTypeAgent = 3 // 代理账号
|
||||
UserTypeEnterprise = 4 // 企业账号
|
||||
UserTypePersonalCustomer = 5 // 个人客户(C端用户)
|
||||
)
|
||||
|
||||
// RBAC 角色类型常量
|
||||
|
||||
@@ -26,13 +26,6 @@ func RedisTaskStatusKey(taskID string) string {
|
||||
return fmt.Sprintf("task:status:%s", taskID)
|
||||
}
|
||||
|
||||
// RedisAccountSubordinatesKey 生成账号下级 ID 列表的 Redis 键
|
||||
// 用途:缓存递归查询的下级账号 ID 列表
|
||||
// 过期时间:30 分钟
|
||||
func RedisAccountSubordinatesKey(accountID uint) string {
|
||||
return fmt.Sprintf("account:subordinates:%d", accountID)
|
||||
}
|
||||
|
||||
// RedisShopSubordinatesKey 生成店铺下级 ID 列表的 Redis 键
|
||||
// 用途:缓存递归查询的下级店铺 ID 列表
|
||||
// 过期时间:30 分钟
|
||||
|
||||
@@ -28,31 +28,33 @@ func SkipDataPermission(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, SkipDataPermissionKey, true)
|
||||
}
|
||||
|
||||
// AccountStoreInterface 账号 Store 接口
|
||||
// 用于 Callback 获取下级 ID,避免循环依赖
|
||||
type AccountStoreInterface interface {
|
||||
GetSubordinateIDs(ctx context.Context, accountID uint) ([]uint, error)
|
||||
// ShopStoreInterface 店铺 Store 接口
|
||||
// 用于 Callback 获取下级店铺 ID,避免循环依赖
|
||||
type ShopStoreInterface interface {
|
||||
GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error)
|
||||
}
|
||||
|
||||
// RegisterDataPermissionCallback 注册 GORM 数据权限过滤 Callback
|
||||
//
|
||||
// 自动化数据权限过滤规则:
|
||||
// 1. root 用户跳过过滤,可以查看所有数据
|
||||
// 2. 普通用户只能查看自己和下级的数据(通过递归查询下级 ID)
|
||||
// 3. 同时限制 shop_id 相同(如果配置了 shop_id)
|
||||
// 4. 通过 SkipDataPermission(ctx) 可以绕过权限过滤
|
||||
// 自动化数据权限过滤规则:
|
||||
// 1. 超级管理员跳过过滤,可以查看所有数据
|
||||
// 2. 平台用户跳过过滤,可以查看所有数据
|
||||
// 3. 代理用户只能查看自己店铺及下级店铺的数据(基于 shop_id 字段)
|
||||
// 4. 企业用户只能查看自己企业的数据(基于 enterprise_id 字段)
|
||||
// 5. 个人客户只能查看自己的数据(基于 creator 字段或 customer_id 字段)
|
||||
// 6. 通过 SkipDataPermission(ctx) 可以绕过权限过滤
|
||||
//
|
||||
// 注意:
|
||||
// - Callback 只对包含 creator 字段的表生效
|
||||
// 注意:
|
||||
// - Callback 根据表的字段自动选择过滤策略
|
||||
// - 必须在初始化 Store 之前注册
|
||||
//
|
||||
// 参数:
|
||||
// 参数:
|
||||
// - db: GORM DB 实例
|
||||
// - accountStore: 账号 Store,用于查询下级 ID
|
||||
// - shopStore: 店铺 Store,用于查询下级店铺 ID
|
||||
//
|
||||
// 返回:
|
||||
// 返回:
|
||||
// - error: 注册错误
|
||||
func RegisterDataPermissionCallback(db *gorm.DB, accountStore AccountStoreInterface) error {
|
||||
func RegisterDataPermissionCallback(db *gorm.DB, shopStore ShopStoreInterface) error {
|
||||
// 注册查询前的 Callback
|
||||
err := db.Callback().Query().Before("gorm:query").Register("data_permission:query", func(tx *gorm.DB) {
|
||||
ctx := tx.Statement.Context
|
||||
@@ -65,17 +67,15 @@ func RegisterDataPermissionCallback(db *gorm.DB, accountStore AccountStoreInterf
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 检查是否为 root 用户,root 用户跳过过滤
|
||||
if middleware.IsRootUser(ctx) {
|
||||
// 2. 获取用户类型
|
||||
userType := middleware.GetUserTypeFromContext(ctx)
|
||||
|
||||
// 3. 超级管理员和平台用户跳过过滤,可以查看所有数据
|
||||
if userType == constants.UserTypeSuperAdmin || userType == constants.UserTypePlatform {
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 检查表是否有 creator 字段(只对有 creator 字段的表生效)
|
||||
if !hasCreatorField(tx.Statement.Schema) {
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 获取当前用户 ID
|
||||
// 4. 获取当前用户信息
|
||||
userID := middleware.GetUserIDFromContext(ctx)
|
||||
if userID == 0 {
|
||||
// 未登录用户返回空结果
|
||||
@@ -84,32 +84,102 @@ func RegisterDataPermissionCallback(db *gorm.DB, accountStore AccountStoreInterf
|
||||
return
|
||||
}
|
||||
|
||||
// 5. 获取当前用户及所有下级的 ID
|
||||
subordinateIDs, err := accountStore.GetSubordinateIDs(ctx, userID)
|
||||
if err != nil {
|
||||
// 查询失败时,降级为只能看自己的数据
|
||||
logger.GetAppLogger().Error("数据权限过滤:获取下级 ID 失败",
|
||||
zap.Uint("user_id", userID),
|
||||
zap.Error(err))
|
||||
subordinateIDs = []uint{userID}
|
||||
}
|
||||
|
||||
if len(subordinateIDs) == 0 {
|
||||
subordinateIDs = []uint{userID}
|
||||
}
|
||||
|
||||
// 6. 获取当前用户的 shop_id
|
||||
shopID := middleware.GetShopIDFromContext(ctx)
|
||||
|
||||
// 7. 应用数据权限过滤条件
|
||||
// creator IN (用户自己及所有下级) AND shop_id = 当前用户 shop_id
|
||||
if shopID != 0 && hasShopIDField(tx.Statement.Schema) {
|
||||
// 同时过滤 creator 和 shop_id
|
||||
tx.Where("creator IN ? AND shop_id = ?", subordinateIDs, shopID)
|
||||
} else {
|
||||
// 只根据 creator 过滤
|
||||
tx.Where("creator IN ?", subordinateIDs)
|
||||
// 5. 根据用户类型和表结构应用不同的过滤规则
|
||||
schema := tx.Statement.Schema
|
||||
if schema == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 5.1 代理用户:基于店铺层级过滤
|
||||
if userType == constants.UserTypeAgent {
|
||||
if !hasShopIDField(schema) {
|
||||
// 表没有 shop_id 字段,无法过滤
|
||||
return
|
||||
}
|
||||
|
||||
if shopID == 0 {
|
||||
// 代理用户没有 shop_id,只能看自己创建的数据
|
||||
if hasCreatorField(schema) {
|
||||
tx.Where("creator = ?", userID)
|
||||
} else {
|
||||
tx.Where("1 = 0")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 查询该店铺及下级店铺的 ID
|
||||
subordinateShopIDs, err := shopStore.GetSubordinateShopIDs(ctx, shopID)
|
||||
if err != nil {
|
||||
logger.GetAppLogger().Error("数据权限过滤:获取下级店铺 ID 失败",
|
||||
zap.Uint("shop_id", shopID),
|
||||
zap.Error(err))
|
||||
// 降级为只能看自己店铺的数据
|
||||
subordinateShopIDs = []uint{shopID}
|
||||
}
|
||||
|
||||
// 过滤:shop_id IN (自己店铺及下级店铺)
|
||||
tx.Where("shop_id IN ?", subordinateShopIDs)
|
||||
return
|
||||
}
|
||||
|
||||
// 5.2 企业用户:基于 enterprise_id 过滤
|
||||
if userType == constants.UserTypeEnterprise {
|
||||
enterpriseID := middleware.GetEnterpriseIDFromContext(ctx)
|
||||
|
||||
if hasEnterpriseIDField(schema) {
|
||||
if enterpriseID != 0 {
|
||||
tx.Where("enterprise_id = ?", enterpriseID)
|
||||
} else {
|
||||
// 企业用户没有 enterprise_id,返回空结果
|
||||
tx.Where("1 = 0")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 如果表没有 enterprise_id 字段,但有 creator 字段,则只能看自己创建的数据
|
||||
if hasCreatorField(schema) {
|
||||
tx.Where("creator = ?", userID)
|
||||
return
|
||||
}
|
||||
|
||||
// 无法过滤,返回空结果
|
||||
tx.Where("1 = 0")
|
||||
return
|
||||
}
|
||||
|
||||
// 5.3 个人客户:只能看自己的数据
|
||||
if userType == constants.UserTypePersonalCustomer {
|
||||
customerID := middleware.GetCustomerIDFromContext(ctx)
|
||||
|
||||
// 优先使用 customer_id 字段
|
||||
if hasCustomerIDField(schema) {
|
||||
if customerID != 0 {
|
||||
tx.Where("customer_id = ?", customerID)
|
||||
} else {
|
||||
// 个人客户没有 customer_id,返回空结果
|
||||
tx.Where("1 = 0")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 降级为使用 creator 字段
|
||||
if hasCreatorField(schema) {
|
||||
tx.Where("creator = ?", userID)
|
||||
return
|
||||
}
|
||||
|
||||
// 无法过滤,返回空结果
|
||||
tx.Where("1 = 0")
|
||||
return
|
||||
}
|
||||
|
||||
// 6. 默认:未知用户类型,返回空结果
|
||||
logger.GetAppLogger().Warn("数据权限过滤:未知用户类型",
|
||||
zap.Uint("user_id", userID),
|
||||
zap.Int("user_type", userType))
|
||||
tx.Where("1 = 0")
|
||||
})
|
||||
return err
|
||||
}
|
||||
@@ -149,3 +219,21 @@ func hasShopIDField(s *schema.Schema) bool {
|
||||
_, ok := s.FieldsByDBName["shop_id"]
|
||||
return ok
|
||||
}
|
||||
|
||||
// hasEnterpriseIDField 检查 Schema 是否包含 enterprise_id 字段
|
||||
func hasEnterpriseIDField(s *schema.Schema) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := s.FieldsByDBName["enterprise_id"]
|
||||
return ok
|
||||
}
|
||||
|
||||
// hasCustomerIDField 检查 Schema 是否包含 customer_id 字段
|
||||
func hasCustomerIDField(s *schema.Schema) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := s.FieldsByDBName["customer_id"]
|
||||
return ok
|
||||
}
|
||||
|
||||
@@ -9,20 +9,19 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// mockAccountStore 模拟账号 Store
|
||||
type mockAccountStore struct {
|
||||
subordinateIDs []uint
|
||||
err error
|
||||
// mockShopStore 模拟店铺 Store
|
||||
type mockShopStore struct {
|
||||
subordinateShopIDs []uint
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockAccountStore) GetSubordinateIDs(ctx context.Context, accountID uint) ([]uint, error) {
|
||||
func (m *mockShopStore) GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.subordinateIDs, nil
|
||||
return m.subordinateShopIDs, nil
|
||||
}
|
||||
|
||||
// TestSkipDataPermission 测试跳过数据权限过滤
|
||||
@@ -38,95 +37,15 @@ func TestSkipDataPermission(t *testing.T) {
|
||||
assert.True(t, skip)
|
||||
}
|
||||
|
||||
// TestHasCreatorField 测试检查 creator 字段
|
||||
func TestHasCreatorField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema *schema.Schema
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil schema",
|
||||
schema: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "schema with creator field",
|
||||
schema: &schema.Schema{
|
||||
FieldsByDBName: map[string]*schema.Field{
|
||||
"creator": {},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "schema without creator field",
|
||||
schema: &schema.Schema{
|
||||
FieldsByDBName: map[string]*schema.Field{
|
||||
"id": {},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasCreatorField(tt.schema)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasShopIDField 测试检查 shop_id 字段
|
||||
func TestHasShopIDField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema *schema.Schema
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil schema",
|
||||
schema: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "schema with shop_id field",
|
||||
schema: &schema.Schema{
|
||||
FieldsByDBName: map[string]*schema.Field{
|
||||
"shop_id": {},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "schema without shop_id field",
|
||||
schema: &schema.Schema{
|
||||
FieldsByDBName: map[string]*schema.Field{
|
||||
"id": {},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasShopIDField(tt.schema)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegisterDataPermissionCallback 测试注册数据权限 Callback
|
||||
func TestRegisterDataPermissionCallback(t *testing.T) {
|
||||
// 创建内存数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建 mock AccountStore
|
||||
mockStore := &mockAccountStore{
|
||||
subordinateIDs: []uint{1, 2, 3},
|
||||
// 创建 mock ShopStore
|
||||
mockStore := &mockShopStore{
|
||||
subordinateShopIDs: []uint{1, 2, 3},
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
@@ -134,8 +53,8 @@ func TestRegisterDataPermissionCallback(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestDataPermissionCallback_SkipForRootUser 测试 root 用户跳过过滤
|
||||
func TestDataPermissionCallback_SkipForRootUser(t *testing.T) {
|
||||
// TestDataPermissionCallback_SkipForSuperAdmin 测试超级管理员跳过过滤
|
||||
func TestDataPermissionCallback_SkipForSuperAdmin(t *testing.T) {
|
||||
// 创建内存数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
assert.NoError(t, err)
|
||||
@@ -143,6 +62,7 @@ func TestDataPermissionCallback_SkipForRootUser(t *testing.T) {
|
||||
// 创建测试表
|
||||
type TestModel struct {
|
||||
ID uint
|
||||
ShopID uint
|
||||
Creator uint
|
||||
Name string
|
||||
}
|
||||
@@ -151,33 +71,39 @@ func TestDataPermissionCallback_SkipForRootUser(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 插入测试数据
|
||||
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
|
||||
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
|
||||
|
||||
// 创建 mock AccountStore
|
||||
mockStore := &mockAccountStore{
|
||||
subordinateIDs: []uint{1}, // 只有 ID 1
|
||||
// 创建 mock ShopStore
|
||||
mockStore := &mockShopStore{
|
||||
subordinateShopIDs: []uint{100}, // 只有店铺 100
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 设置 root 用户 context
|
||||
// 设置超级管理员 context
|
||||
ctx := context.Background()
|
||||
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0)
|
||||
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypeSuperAdmin,
|
||||
ShopID: 0,
|
||||
EnterpriseID: 0,
|
||||
CustomerID: 0,
|
||||
})
|
||||
|
||||
// 查询数据
|
||||
var results []TestModel
|
||||
err = db.WithContext(ctx).Find(&results).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// root 用户应该看到所有数据
|
||||
// 超级管理员应该看到所有数据
|
||||
assert.Equal(t, 2, len(results))
|
||||
}
|
||||
|
||||
// TestDataPermissionCallback_FilterForNormalUser 测试普通用户过滤
|
||||
func TestDataPermissionCallback_FilterForNormalUser(t *testing.T) {
|
||||
// TestDataPermissionCallback_SkipForPlatform 测试平台用户跳过过滤
|
||||
func TestDataPermissionCallback_SkipForPlatform(t *testing.T) {
|
||||
// 创建内存数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
assert.NoError(t, err)
|
||||
@@ -185,6 +111,7 @@ func TestDataPermissionCallback_FilterForNormalUser(t *testing.T) {
|
||||
// 创建测试表
|
||||
type TestModel struct {
|
||||
ID uint
|
||||
ShopID uint
|
||||
Creator uint
|
||||
Name string
|
||||
}
|
||||
@@ -193,32 +120,86 @@ func TestDataPermissionCallback_FilterForNormalUser(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 插入测试数据
|
||||
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
|
||||
db.Create(&TestModel{ID: 3, Creator: 3, Name: "test3"})
|
||||
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
|
||||
|
||||
// 创建 mock AccountStore
|
||||
mockStore := &mockAccountStore{
|
||||
subordinateIDs: []uint{1, 2}, // 只能看到 1 和 2
|
||||
// 创建 mock ShopStore
|
||||
mockStore := &mockShopStore{
|
||||
subordinateShopIDs: []uint{100}, // 只有店铺 100
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 设置普通用户 context (非 root)
|
||||
// 设置平台用户 context
|
||||
ctx := context.Background()
|
||||
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 0)
|
||||
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypePlatform,
|
||||
ShopID: 0,
|
||||
EnterpriseID: 0,
|
||||
CustomerID: 0,
|
||||
})
|
||||
|
||||
// 查询数据
|
||||
var results []TestModel
|
||||
err = db.WithContext(ctx).Find(&results).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 普通用户只能看到自己和下级的数据
|
||||
// 平台用户应该看到所有数据
|
||||
assert.Equal(t, 2, len(results))
|
||||
assert.Equal(t, uint(1), results[0].Creator)
|
||||
assert.Equal(t, uint(2), results[1].Creator)
|
||||
}
|
||||
|
||||
// TestDataPermissionCallback_FilterForAgent 测试代理用户过滤
|
||||
func TestDataPermissionCallback_FilterForAgent(t *testing.T) {
|
||||
// 创建内存数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建测试表(包含 shop_id 字段以触发店铺层级过滤)
|
||||
type TestModel struct {
|
||||
ID uint
|
||||
ShopID uint
|
||||
Name string
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&TestModel{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 插入测试数据
|
||||
db.Create(&TestModel{ID: 1, ShopID: 100, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, ShopID: 200, Name: "test2"})
|
||||
db.Create(&TestModel{ID: 3, ShopID: 300, Name: "test3"})
|
||||
|
||||
// 创建 mock ShopStore
|
||||
mockStore := &mockShopStore{
|
||||
subordinateShopIDs: []uint{100, 200}, // 只能看到店铺 100 和 200
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 设置代理用户 context (shop_id = 100)
|
||||
ctx := context.Background()
|
||||
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypeAgent,
|
||||
ShopID: 100,
|
||||
EnterpriseID: 0,
|
||||
CustomerID: 0,
|
||||
})
|
||||
|
||||
// 查询数据
|
||||
var results []TestModel
|
||||
err = db.WithContext(ctx).Find(&results).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 代理用户只能看到自己店铺和下级店铺的数据
|
||||
assert.Equal(t, 2, len(results))
|
||||
assert.Equal(t, uint(100), results[0].ShopID)
|
||||
assert.Equal(t, uint(200), results[1].ShopID)
|
||||
}
|
||||
|
||||
// TestDataPermissionCallback_SkipWithContext 测试通过 Context 跳过过滤
|
||||
@@ -230,6 +211,7 @@ func TestDataPermissionCallback_SkipWithContext(t *testing.T) {
|
||||
// 创建测试表
|
||||
type TestModel struct {
|
||||
ID uint
|
||||
ShopID uint
|
||||
Creator uint
|
||||
Name string
|
||||
}
|
||||
@@ -238,21 +220,27 @@ func TestDataPermissionCallback_SkipWithContext(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 插入测试数据
|
||||
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
|
||||
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
|
||||
|
||||
// 创建 mock AccountStore
|
||||
mockStore := &mockAccountStore{
|
||||
subordinateIDs: []uint{1}, // 只有 ID 1
|
||||
// 创建 mock ShopStore
|
||||
mockStore := &mockShopStore{
|
||||
subordinateShopIDs: []uint{100}, // 只有店铺 100
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 设置普通用户 context 并跳过过滤
|
||||
// 设置代理用户 context 并跳过过滤
|
||||
ctx := context.Background()
|
||||
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 0)
|
||||
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypeAgent,
|
||||
ShopID: 100,
|
||||
EnterpriseID: 0,
|
||||
CustomerID: 0,
|
||||
})
|
||||
ctx = SkipDataPermission(ctx)
|
||||
|
||||
// 查询数据
|
||||
@@ -286,27 +274,134 @@ func TestDataPermissionCallback_WithShopID(t *testing.T) {
|
||||
db.Create(&TestModel{ID: 2, Creator: 2, ShopID: 100, Name: "test2"})
|
||||
db.Create(&TestModel{ID: 3, Creator: 2, ShopID: 200, Name: "test3"}) // 不同 shop_id
|
||||
|
||||
// 创建 mock AccountStore
|
||||
mockStore := &mockAccountStore{
|
||||
subordinateIDs: []uint{1, 2}, // 可以看到 1 和 2
|
||||
// 创建 mock ShopStore
|
||||
mockStore := &mockShopStore{
|
||||
subordinateShopIDs: []uint{100, 200}, // 可以看到店铺 100 和 200
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 设置普通用户 context (shop_id = 100)
|
||||
// 设置代理用户 context (shop_id = 100)
|
||||
ctx := context.Background()
|
||||
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 100)
|
||||
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypeAgent,
|
||||
ShopID: 100,
|
||||
EnterpriseID: 0,
|
||||
CustomerID: 0,
|
||||
})
|
||||
|
||||
// 查询数据
|
||||
var results []TestModel
|
||||
err = db.WithContext(ctx).Find(&results).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 只能看到 shop_id = 100 的数据
|
||||
// 应该看到 shop_id = 100 和 200 的所有数据(因为 mockStore 返回了这两个店铺 ID)
|
||||
assert.Equal(t, 3, len(results))
|
||||
}
|
||||
|
||||
// TestDataPermissionCallback_FilterForEnterprise 测试企业用户过滤
|
||||
func TestDataPermissionCallback_FilterForEnterprise(t *testing.T) {
|
||||
// 创建内存数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建测试表(包含 enterprise_id 字段)
|
||||
type TestModel struct {
|
||||
ID uint
|
||||
EnterpriseID uint
|
||||
Name string
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&TestModel{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 插入测试数据
|
||||
db.Create(&TestModel{ID: 1, EnterpriseID: 1001, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, EnterpriseID: 1001, Name: "test2"})
|
||||
db.Create(&TestModel{ID: 3, EnterpriseID: 1002, Name: "test3"})
|
||||
|
||||
// 创建 mock ShopStore(企业用户不需要,但注册时需要)
|
||||
mockStore := &mockShopStore{
|
||||
subordinateShopIDs: []uint{},
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 设置企业用户 context
|
||||
ctx := context.Background()
|
||||
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypeEnterprise,
|
||||
ShopID: 0,
|
||||
EnterpriseID: 1001,
|
||||
CustomerID: 0,
|
||||
})
|
||||
|
||||
// 查询数据
|
||||
var results []TestModel
|
||||
err = db.WithContext(ctx).Find(&results).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 企业用户只能看到自己企业的数据
|
||||
assert.Equal(t, 2, len(results))
|
||||
for _, r := range results {
|
||||
assert.Equal(t, uint(100), r.ShopID)
|
||||
assert.Equal(t, uint(1001), r.EnterpriseID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDataPermissionCallback_FilterForPersonalCustomer 测试个人客户过滤
|
||||
func TestDataPermissionCallback_FilterForPersonalCustomer(t *testing.T) {
|
||||
// 创建内存数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建测试表(包含 creator 字段)
|
||||
type TestModel struct {
|
||||
ID uint
|
||||
Creator uint
|
||||
Name string
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&TestModel{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 插入测试数据
|
||||
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
|
||||
db.Create(&TestModel{ID: 3, Creator: 1, Name: "test3"})
|
||||
|
||||
// 创建 mock ShopStore(个人客户不需要,但注册时需要)
|
||||
mockStore := &mockShopStore{
|
||||
subordinateShopIDs: []uint{},
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 设置个人客户 context
|
||||
ctx := context.Background()
|
||||
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
|
||||
UserID: 1,
|
||||
UserType: constants.UserTypePersonalCustomer,
|
||||
ShopID: 0,
|
||||
EnterpriseID: 0,
|
||||
CustomerID: 1,
|
||||
})
|
||||
|
||||
// 查询数据
|
||||
var results []TestModel
|
||||
err = db.WithContext(ctx).Find(&results).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 个人客户只能看到自己创建的数据
|
||||
assert.Equal(t, 2, len(results))
|
||||
for _, r := range results {
|
||||
assert.Equal(t, uint(1), r.Creator)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,12 +8,23 @@ import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// UserContextInfo 用户上下文信息
|
||||
type UserContextInfo struct {
|
||||
UserID uint
|
||||
UserType int
|
||||
ShopID uint
|
||||
EnterpriseID uint
|
||||
CustomerID uint
|
||||
}
|
||||
|
||||
// SetUserContext 将用户信息设置到 context 中
|
||||
// 在 Auth 中间件认证成功后调用
|
||||
func SetUserContext(ctx context.Context, userID uint, userType int, shopID uint) context.Context {
|
||||
ctx = context.WithValue(ctx, constants.ContextKeyUserID, userID)
|
||||
ctx = context.WithValue(ctx, constants.ContextKeyUserType, userType)
|
||||
ctx = context.WithValue(ctx, constants.ContextKeyShopID, shopID)
|
||||
func SetUserContext(ctx context.Context, info *UserContextInfo) context.Context {
|
||||
ctx = context.WithValue(ctx, constants.ContextKeyUserID, info.UserID)
|
||||
ctx = context.WithValue(ctx, constants.ContextKeyUserType, info.UserType)
|
||||
ctx = context.WithValue(ctx, constants.ContextKeyShopID, info.ShopID)
|
||||
ctx = context.WithValue(ctx, constants.ContextKeyEnterpriseID, info.EnterpriseID)
|
||||
ctx = context.WithValue(ctx, constants.ContextKeyCustomerID, info.CustomerID)
|
||||
return ctx
|
||||
}
|
||||
|
||||
@@ -53,6 +64,30 @@ func GetShopIDFromContext(ctx context.Context) uint {
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetEnterpriseIDFromContext 从 context 中提取企业 ID
|
||||
// 如果未设置,返回 0
|
||||
func GetEnterpriseIDFromContext(ctx context.Context) uint {
|
||||
if ctx == nil {
|
||||
return 0
|
||||
}
|
||||
if enterpriseID, ok := ctx.Value(constants.ContextKeyEnterpriseID).(uint); ok {
|
||||
return enterpriseID
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetCustomerIDFromContext 从 context 中提取个人客户 ID
|
||||
// 如果未设置,返回 0
|
||||
func GetCustomerIDFromContext(ctx context.Context) uint {
|
||||
if ctx == nil {
|
||||
return 0
|
||||
}
|
||||
if customerID, ok := ctx.Value(constants.ContextKeyCustomerID).(uint); ok {
|
||||
return customerID
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsRootUser 检查当前用户是否为 root 用户
|
||||
// root 用户跳过数据权限过滤
|
||||
func IsRootUser(ctx context.Context) bool {
|
||||
@@ -62,14 +97,16 @@ func IsRootUser(ctx context.Context) bool {
|
||||
|
||||
// SetUserToFiberContext 将用户信息设置到 Fiber context 的 Locals 中
|
||||
// 同时也设置到标准 context 中,便于 GORM 查询使用
|
||||
func SetUserToFiberContext(c *fiber.Ctx, userID uint, userType int, shopID uint) {
|
||||
func SetUserToFiberContext(c *fiber.Ctx, info *UserContextInfo) {
|
||||
// 设置到 Fiber Locals
|
||||
c.Locals(constants.ContextKeyUserID, userID)
|
||||
c.Locals(constants.ContextKeyUserType, userType)
|
||||
c.Locals(constants.ContextKeyShopID, shopID)
|
||||
c.Locals(constants.ContextKeyUserID, info.UserID)
|
||||
c.Locals(constants.ContextKeyUserType, info.UserType)
|
||||
c.Locals(constants.ContextKeyShopID, info.ShopID)
|
||||
c.Locals(constants.ContextKeyEnterpriseID, info.EnterpriseID)
|
||||
c.Locals(constants.ContextKeyCustomerID, info.CustomerID)
|
||||
|
||||
// 设置到标准 context(用于 GORM 数据权限过滤)
|
||||
ctx := SetUserContext(c.UserContext(), userID, userType, shopID)
|
||||
ctx := SetUserContext(c.UserContext(), info)
|
||||
c.SetUserContext(ctx)
|
||||
}
|
||||
|
||||
@@ -80,9 +117,9 @@ type AuthConfig struct {
|
||||
TokenExtractor func(c *fiber.Ctx) string
|
||||
|
||||
// TokenValidator token 验证函数
|
||||
// 验证成功返回用户 ID、用户类型、店铺 ID
|
||||
// 验证成功返回用户上下文信息
|
||||
// 验证失败返回 error
|
||||
TokenValidator func(token string) (userID uint, userType int, shopID uint, err error)
|
||||
TokenValidator func(token string) (*UserContextInfo, error)
|
||||
|
||||
// SkipPaths 跳过认证的路径列表
|
||||
SkipPaths []string
|
||||
@@ -119,7 +156,7 @@ func Auth(config AuthConfig) fiber.Handler {
|
||||
return errors.New(errors.CodeInternalError, "认证验证器未配置")
|
||||
}
|
||||
|
||||
userID, userType, shopID, err := config.TokenValidator(token)
|
||||
userInfo, err := config.TokenValidator(token)
|
||||
if err != nil {
|
||||
// 如果验证器返回的是 AppError,直接返回
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
@@ -130,7 +167,7 @@ func Auth(config AuthConfig) fiber.Handler {
|
||||
}
|
||||
|
||||
// 将用户信息设置到 context
|
||||
SetUserToFiberContext(c, userID, userType, shopID)
|
||||
SetUserToFiberContext(c, userInfo)
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
@@ -144,3 +181,16 @@ func extractBearerToken(c *fiber.Ctx) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// NewSimpleUserContext 创建简单的用户上下文信息(仅包含基本字段)
|
||||
// 这是一个兼容性辅助函数,用于快速创建只包含 userID, userType, shopID 的上下文
|
||||
// 适用于测试代码和不需要完整上下文信息的场景
|
||||
func NewSimpleUserContext(userID uint, userType int, shopID uint) *UserContextInfo {
|
||||
return &UserContextInfo{
|
||||
UserID: userID,
|
||||
UserType: userType,
|
||||
ShopID: shopID,
|
||||
EnterpriseID: 0,
|
||||
CustomerID: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,9 +71,10 @@ func TestSuccess(t *testing.T) {
|
||||
t.Errorf("Expected status code 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 验证响应头
|
||||
if resp.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
|
||||
// 验证响应头(Fiber 会自动添加 charset=utf-8)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != "application/json" && contentType != "application/json; charset=utf-8" {
|
||||
t.Errorf("Expected Content-Type application/json or application/json; charset=utf-8, got %s", contentType)
|
||||
}
|
||||
|
||||
// 解析响应体
|
||||
|
||||
Reference in New Issue
Block a user