152 lines
4.4 KiB
Go
152 lines
4.4 KiB
Go
package gorm
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
|
"github.com/break/junhong_cmp_fiber/pkg/logger"
|
|
"github.com/break/junhong_cmp_fiber/pkg/middleware"
|
|
"go.uber.org/zap"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/schema"
|
|
)
|
|
|
|
// contextKey 用于 context value 的 key 类型
|
|
type contextKey string
|
|
|
|
// SkipDataPermissionKey 跳过数据权限过滤的 context key
|
|
const SkipDataPermissionKey contextKey = "skip_data_permission"
|
|
|
|
// SkipDataPermission 返回跳过数据权限过滤的 Context
|
|
// 用于需要查询所有数据的场景(如管理后台统计、系统任务等)
|
|
//
|
|
// 使用示例:
|
|
//
|
|
// ctx = gorm.SkipDataPermission(ctx)
|
|
// db.WithContext(ctx).Find(&accounts)
|
|
func SkipDataPermission(ctx context.Context) context.Context {
|
|
return context.WithValue(ctx, SkipDataPermissionKey, true)
|
|
}
|
|
|
|
// AccountStoreInterface 账号 Store 接口
|
|
// 用于 Callback 获取下级 ID,避免循环依赖
|
|
type AccountStoreInterface interface {
|
|
GetSubordinateIDs(ctx context.Context, accountID uint) ([]uint, error)
|
|
}
|
|
|
|
// RegisterDataPermissionCallback 注册 GORM 数据权限过滤 Callback
|
|
//
|
|
// 自动化数据权限过滤规则:
|
|
// 1. root 用户跳过过滤,可以查看所有数据
|
|
// 2. 普通用户只能查看自己和下级的数据(通过递归查询下级 ID)
|
|
// 3. 同时限制 shop_id 相同(如果配置了 shop_id)
|
|
// 4. 通过 SkipDataPermission(ctx) 可以绕过权限过滤
|
|
//
|
|
// 注意:
|
|
// - Callback 只对包含 creator 字段的表生效
|
|
// - 必须在初始化 Store 之前注册
|
|
//
|
|
// 参数:
|
|
// - db: GORM DB 实例
|
|
// - accountStore: 账号 Store,用于查询下级 ID
|
|
//
|
|
// 返回:
|
|
// - error: 注册错误
|
|
func RegisterDataPermissionCallback(db *gorm.DB, accountStore AccountStoreInterface) error {
|
|
// 注册查询前的 Callback
|
|
err := db.Callback().Query().Before("gorm:query").Register("data_permission:query", func(tx *gorm.DB) {
|
|
ctx := tx.Statement.Context
|
|
if ctx == nil {
|
|
return
|
|
}
|
|
|
|
// 1. 检查是否跳过数据权限过滤
|
|
if skip, ok := ctx.Value(SkipDataPermissionKey).(bool); ok && skip {
|
|
return
|
|
}
|
|
|
|
// 2. 检查是否为 root 用户,root 用户跳过过滤
|
|
if middleware.IsRootUser(ctx) {
|
|
return
|
|
}
|
|
|
|
// 3. 检查表是否有 creator 字段(只对有 creator 字段的表生效)
|
|
if !hasCreatorField(tx.Statement.Schema) {
|
|
return
|
|
}
|
|
|
|
// 4. 获取当前用户 ID
|
|
userID := middleware.GetUserIDFromContext(ctx)
|
|
if userID == 0 {
|
|
// 未登录用户返回空结果
|
|
logger.GetAppLogger().Warn("数据权限过滤:未获取到用户 ID")
|
|
tx.Where("1 = 0")
|
|
return
|
|
}
|
|
|
|
// 5. 获取当前用户及所有下级的 ID
|
|
subordinateIDs, err := accountStore.GetSubordinateIDs(ctx, userID)
|
|
if err != nil {
|
|
// 查询失败时,降级为只能看自己的数据
|
|
logger.GetAppLogger().Error("数据权限过滤:获取下级 ID 失败",
|
|
zap.Uint("user_id", userID),
|
|
zap.Error(err))
|
|
subordinateIDs = []uint{userID}
|
|
}
|
|
|
|
if len(subordinateIDs) == 0 {
|
|
subordinateIDs = []uint{userID}
|
|
}
|
|
|
|
// 6. 获取当前用户的 shop_id
|
|
shopID := middleware.GetShopIDFromContext(ctx)
|
|
|
|
// 7. 应用数据权限过滤条件
|
|
// creator IN (用户自己及所有下级) AND shop_id = 当前用户 shop_id
|
|
if shopID != 0 && hasShopIDField(tx.Statement.Schema) {
|
|
// 同时过滤 creator 和 shop_id
|
|
tx.Where("creator IN ? AND shop_id = ?", subordinateIDs, shopID)
|
|
} else {
|
|
// 只根据 creator 过滤
|
|
tx.Where("creator IN ?", subordinateIDs)
|
|
}
|
|
})
|
|
return err
|
|
}
|
|
|
|
// RegisterSetCreatorUpdaterCallback 注册 GORM 创建数据时创建人更新人 Callback
|
|
func RegisterSetCreatorUpdaterCallback(db *gorm.DB) error {
|
|
err := db.Callback().Create().Before("gorm:create").Register("set_creator_updater", func(tx *gorm.DB) {
|
|
ctx := tx.Statement.Context
|
|
if userID, ok := tx.Statement.Context.Value(constants.ContextKeyUserID).(uint); ok {
|
|
if f := tx.Statement.Schema; f != nil {
|
|
if c, ok := f.FieldsByName["Creator"]; ok {
|
|
_ = c.Set(ctx, tx.Statement.ReflectValue, userID)
|
|
}
|
|
if u, ok := f.FieldsByName["Updater"]; ok {
|
|
_ = u.Set(ctx, tx.Statement.ReflectValue, userID)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
return err
|
|
}
|
|
|
|
// hasCreatorField 检查 Schema 是否包含 creator 字段
|
|
func hasCreatorField(s *schema.Schema) bool {
|
|
if s == nil {
|
|
return false
|
|
}
|
|
_, ok := s.FieldsByDBName["creator"]
|
|
return ok
|
|
}
|
|
|
|
// hasShopIDField 检查 Schema 是否包含 shop_id 字段
|
|
func hasShopIDField(s *schema.Schema) bool {
|
|
if s == nil {
|
|
return false
|
|
}
|
|
_, ok := s.FieldsByDBName["shop_id"]
|
|
return ok
|
|
}
|