package gorm import ( "context" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/logger" "github.com/break/junhong_cmp_fiber/pkg/middleware" "go.uber.org/zap" "gorm.io/gorm" "gorm.io/gorm/schema" ) // contextKey 用于 context value 的 key 类型 type contextKey string // SkipDataPermissionKey 跳过数据权限过滤的 context key const SkipDataPermissionKey contextKey = "skip_data_permission" // SkipDataPermission 返回跳过数据权限过滤的 Context // 用于需要查询所有数据的场景(如管理后台统计、系统任务等) // // 使用示例: // // ctx = gorm.SkipDataPermission(ctx) // db.WithContext(ctx).Find(&accounts) func SkipDataPermission(ctx context.Context) context.Context { return context.WithValue(ctx, SkipDataPermissionKey, true) } // AccountStoreInterface 账号 Store 接口 // 用于 Callback 获取下级 ID,避免循环依赖 type AccountStoreInterface interface { GetSubordinateIDs(ctx context.Context, accountID uint) ([]uint, error) } // RegisterDataPermissionCallback 注册 GORM 数据权限过滤 Callback // // 自动化数据权限过滤规则: // 1. root 用户跳过过滤,可以查看所有数据 // 2. 普通用户只能查看自己和下级的数据(通过递归查询下级 ID) // 3. 同时限制 shop_id 相同(如果配置了 shop_id) // 4. 通过 SkipDataPermission(ctx) 可以绕过权限过滤 // // 注意: // - Callback 只对包含 creator 字段的表生效 // - 必须在初始化 Store 之前注册 // // 参数: // - db: GORM DB 实例 // - accountStore: 账号 Store,用于查询下级 ID // // 返回: // - error: 注册错误 func RegisterDataPermissionCallback(db *gorm.DB, accountStore AccountStoreInterface) error { // 注册查询前的 Callback err := db.Callback().Query().Before("gorm:query").Register("data_permission:query", func(tx *gorm.DB) { ctx := tx.Statement.Context if ctx == nil { return } // 1. 检查是否跳过数据权限过滤 if skip, ok := ctx.Value(SkipDataPermissionKey).(bool); ok && skip { return } // 2. 检查是否为 root 用户,root 用户跳过过滤 if middleware.IsRootUser(ctx) { return } // 3. 检查表是否有 creator 字段(只对有 creator 字段的表生效) if !hasCreatorField(tx.Statement.Schema) { return } // 4. 获取当前用户 ID userID := middleware.GetUserIDFromContext(ctx) if userID == 0 { // 未登录用户返回空结果 logger.GetAppLogger().Warn("数据权限过滤:未获取到用户 ID") tx.Where("1 = 0") return } // 5. 获取当前用户及所有下级的 ID subordinateIDs, err := accountStore.GetSubordinateIDs(ctx, userID) if err != nil { // 查询失败时,降级为只能看自己的数据 logger.GetAppLogger().Error("数据权限过滤:获取下级 ID 失败", zap.Uint("user_id", userID), zap.Error(err)) subordinateIDs = []uint{userID} } if len(subordinateIDs) == 0 { subordinateIDs = []uint{userID} } // 6. 获取当前用户的 shop_id shopID := middleware.GetShopIDFromContext(ctx) // 7. 应用数据权限过滤条件 // creator IN (用户自己及所有下级) AND shop_id = 当前用户 shop_id if shopID != 0 && hasShopIDField(tx.Statement.Schema) { // 同时过滤 creator 和 shop_id tx.Where("creator IN ? AND shop_id = ?", subordinateIDs, shopID) } else { // 只根据 creator 过滤 tx.Where("creator IN ?", subordinateIDs) } }) return err } // RegisterSetCreatorUpdaterCallback 注册 GORM 创建数据时创建人更新人 Callback func RegisterSetCreatorUpdaterCallback(db *gorm.DB) error { err := db.Callback().Create().Before("gorm:create").Register("set_creator_updater", func(tx *gorm.DB) { ctx := tx.Statement.Context if userID, ok := tx.Statement.Context.Value(constants.ContextKeyUserID).(uint); ok { if f := tx.Statement.Schema; f != nil { if c, ok := f.FieldsByName["Creator"]; ok { _ = c.Set(ctx, tx.Statement.ReflectValue, userID) } if u, ok := f.FieldsByName["Updater"]; ok { _ = u.Set(ctx, tx.Statement.ReflectValue, userID) } } } }) return err } // hasCreatorField 检查 Schema 是否包含 creator 字段 func hasCreatorField(s *schema.Schema) bool { if s == nil { return false } _, ok := s.FieldsByDBName["creator"] return ok } // hasShopIDField 检查 Schema 是否包含 shop_id 字段 func hasShopIDField(s *schema.Schema) bool { if s == nil { return false } _, ok := s.FieldsByDBName["shop_id"] return ok }