package gorm import ( "context" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/logger" "github.com/break/junhong_cmp_fiber/pkg/middleware" "go.uber.org/zap" "gorm.io/gorm" "gorm.io/gorm/schema" ) // contextKey 用于 context value 的 key 类型 type contextKey string // SkipDataPermissionKey 跳过数据权限过滤的 context key const SkipDataPermissionKey contextKey = "skip_data_permission" // SkipDataPermission 返回跳过数据权限过滤的 Context // 用于需要查询所有数据的场景(如管理后台统计、系统任务等) // // 使用示例: // // ctx = gorm.SkipDataPermission(ctx) // db.WithContext(ctx).Find(&accounts) func SkipDataPermission(ctx context.Context) context.Context { return context.WithValue(ctx, SkipDataPermissionKey, true) } // ShopStoreInterface 店铺 Store 接口 // 用于 Callback 获取下级店铺 ID,避免循环依赖 type ShopStoreInterface interface { GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) } // RegisterDataPermissionCallback 注册 GORM 数据权限过滤 Callback // // 自动化数据权限过滤规则: // 1. 超级管理员跳过过滤,可以查看所有数据 // 2. 平台用户跳过过滤,可以查看所有数据 // 3. 代理用户只能查看自己店铺及下级店铺的数据(基于 shop_id 字段) // 4. 企业用户只能查看自己企业的数据(基于 enterprise_id 字段) // 5. 个人客户只能查看自己的数据(基于 creator 字段或 customer_id 字段) // 6. 通过 SkipDataPermission(ctx) 可以绕过权限过滤 // // 注意: // - Callback 根据表的字段自动选择过滤策略 // - 必须在初始化 Store 之前注册 // // 参数: // - db: GORM DB 实例 // - shopStore: 店铺 Store,用于查询下级店铺 ID // // 返回: // - error: 注册错误 func RegisterDataPermissionCallback(db *gorm.DB, shopStore ShopStoreInterface) error { // 注册查询前的 Callback err := db.Callback().Query().Before("gorm:query").Register("data_permission:query", func(tx *gorm.DB) { ctx := tx.Statement.Context if ctx == nil { return } // 1. 检查是否跳过数据权限过滤 if skip, ok := ctx.Value(SkipDataPermissionKey).(bool); ok && skip { return } // 2. 获取用户类型 userType := middleware.GetUserTypeFromContext(ctx) // 3. 超级管理员和平台用户跳过过滤,可以查看所有数据 if userType == constants.UserTypeSuperAdmin || userType == constants.UserTypePlatform { return } // 4. 获取当前用户信息 userID := middleware.GetUserIDFromContext(ctx) if userID == 0 { // 未登录用户返回空结果 logger.GetAppLogger().Warn("数据权限过滤:未获取到用户 ID") tx.Where("1 = 0") return } shopID := middleware.GetShopIDFromContext(ctx) // 5. 根据用户类型和表结构应用不同的过滤规则 schema := tx.Statement.Schema if schema == nil { return } // 5.1 代理用户:基于店铺层级过滤 if userType == constants.UserTypeAgent { tableName := schema.Table // 特殊处理:标签表和资源标签表(包含全局标签) if tableName == "tb_tag" || tableName == "tb_resource_tag" { if shopID == 0 { // 没有 shop_id,只能看全局标签 tx.Where("enterprise_id IS NULL AND shop_id IS NULL") return } // 查询该店铺及下级店铺的 ID subordinateShopIDs, err := shopStore.GetSubordinateShopIDs(ctx, shopID) if err != nil { logger.GetAppLogger().Error("数据权限过滤:获取下级店铺 ID 失败", zap.Uint("shop_id", shopID), zap.Error(err)) subordinateShopIDs = []uint{shopID} } // 过滤:店铺标签(自己店铺及下级店铺)或全局标签 tx.Where("shop_id IN ? OR (enterprise_id IS NULL AND shop_id IS NULL)", subordinateShopIDs) return } if !hasShopIDField(schema) { // 表没有 shop_id 字段,无法过滤 return } if shopID == 0 { // 代理用户没有 shop_id,只能看自己创建的数据 if hasCreatorField(schema) { tx.Where("creator = ?", userID) } else { tx.Where("1 = 0") } return } // 查询该店铺及下级店铺的 ID subordinateShopIDs, err := shopStore.GetSubordinateShopIDs(ctx, shopID) if err != nil { logger.GetAppLogger().Error("数据权限过滤:获取下级店铺 ID 失败", zap.Uint("shop_id", shopID), zap.Error(err)) // 降级为只能看自己店铺的数据 subordinateShopIDs = []uint{shopID} } // 过滤:shop_id IN (自己店铺及下级店铺) tx.Where("shop_id IN ?", subordinateShopIDs) return } // 5.2 企业用户:基于 enterprise_id 过滤 if userType == constants.UserTypeEnterprise { enterpriseID := middleware.GetEnterpriseIDFromContext(ctx) tableName := schema.Table // 特殊处理:标签表和资源标签表(包含全局标签) if tableName == "tb_tag" || tableName == "tb_resource_tag" { if enterpriseID != 0 { // 过滤:企业标签或全局标签 tx.Where("enterprise_id = ? OR (enterprise_id IS NULL AND shop_id IS NULL)", enterpriseID) } else { // 没有 enterprise_id,只能看全局标签 tx.Where("enterprise_id IS NULL AND shop_id IS NULL") } return } if hasEnterpriseIDField(schema) { if enterpriseID != 0 { tx.Where("enterprise_id = ?", enterpriseID) } else { // 企业用户没有 enterprise_id,返回空结果 tx.Where("1 = 0") } return } // 如果表没有 enterprise_id 字段,但有 creator 字段,则只能看自己创建的数据 if hasCreatorField(schema) { tx.Where("creator = ?", userID) return } // 无法过滤,返回空结果 tx.Where("1 = 0") return } // 5.3 个人客户:只能看自己的数据 if userType == constants.UserTypePersonalCustomer { customerID := middleware.GetCustomerIDFromContext(ctx) tableName := schema.Table // 特殊处理:标签表和资源标签表(只能看全局标签) if tableName == "tb_tag" || tableName == "tb_resource_tag" { tx.Where("enterprise_id IS NULL AND shop_id IS NULL") return } // 优先使用 customer_id 字段 if hasCustomerIDField(schema) { if customerID != 0 { tx.Where("customer_id = ?", customerID) } else { // 个人客户没有 customer_id,返回空结果 tx.Where("1 = 0") } return } // 降级为使用 creator 字段 if hasCreatorField(schema) { tx.Where("creator = ?", userID) return } // 无法过滤,返回空结果 tx.Where("1 = 0") return } // 6. 默认:未知用户类型,返回空结果 logger.GetAppLogger().Warn("数据权限过滤:未知用户类型", zap.Uint("user_id", userID), zap.Int("user_type", userType)) tx.Where("1 = 0") }) return err } // RegisterSetCreatorUpdaterCallback 注册 GORM 创建数据时创建人更新人 Callback func RegisterSetCreatorUpdaterCallback(db *gorm.DB) error { err := db.Callback().Create().Before("gorm:create").Register("set_creator_updater", func(tx *gorm.DB) { ctx := tx.Statement.Context if userID, ok := tx.Statement.Context.Value(constants.ContextKeyUserID).(uint); ok { if f := tx.Statement.Schema; f != nil { if c, ok := f.FieldsByName["Creator"]; ok { _ = c.Set(ctx, tx.Statement.ReflectValue, userID) } if u, ok := f.FieldsByName["Updater"]; ok { _ = u.Set(ctx, tx.Statement.ReflectValue, userID) } } } }) return err } // hasCreatorField 检查 Schema 是否包含 creator 字段 func hasCreatorField(s *schema.Schema) bool { if s == nil { return false } _, ok := s.FieldsByDBName["creator"] return ok } // hasShopIDField 检查 Schema 是否包含 shop_id 字段 func hasShopIDField(s *schema.Schema) bool { if s == nil { return false } _, ok := s.FieldsByDBName["shop_id"] return ok } // hasEnterpriseIDField 检查 Schema 是否包含 enterprise_id 字段 func hasEnterpriseIDField(s *schema.Schema) bool { if s == nil { return false } _, ok := s.FieldsByDBName["enterprise_id"] return ok } // hasCustomerIDField 检查 Schema 是否包含 customer_id 字段 func hasCustomerIDField(s *schema.Schema) bool { if s == nil { return false } _, ok := s.FieldsByDBName["customer_id"] return ok }