Files
junhong_cmp_fiber/pkg/gorm/callback.go
huang 2570269c8d feat(wallet,tag): 钱包和标签系统多租户改造
核心变更:
- 钱包表:删除 user_id,添加 resource_type/resource_id(绑定资源而非用户)
- 标签表:添加 enterprise_id/shop_id(实现三级隔离:全局/企业/店铺)
- GORM Callback:自动数据权限过滤
- 迁移脚本:可重复执行,已验证回滚功能

钱包归属重构原因:
- 旧设计:钱包绑定用户账号,个人客户卡/设备转手后新用户无法使用余额
- 新设计:钱包绑定资源(卡/设备/店铺),余额随资源流转

标签三级隔离:
- 平台全局标签:所有用户可见
- 企业标签:仅该企业可见(企业内唯一)
- 店铺标签:该店铺及下级可见(店铺内唯一)

测试覆盖:
- 9 个单元测试验证标签多租户过滤(全部通过)
- 迁移和回滚功能测试通过(测试环境)
- OpenSpec 验证通过

变更 ID: fix-wallet-tag-multi-tenant
迁移版本: 000008
参考: openspec/changes/archive/2026-01-13-fix-wallet-tag-multi-tenant/
2026-01-13 16:52:37 +08:00

284 lines
8.2 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package gorm
import (
"context"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"go.uber.org/zap"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)
// contextKey 用于 context value 的 key 类型
type contextKey string
// SkipDataPermissionKey 跳过数据权限过滤的 context key
const SkipDataPermissionKey contextKey = "skip_data_permission"
// SkipDataPermission 返回跳过数据权限过滤的 Context
// 用于需要查询所有数据的场景(如管理后台统计、系统任务等)
//
// 使用示例:
//
// ctx = gorm.SkipDataPermission(ctx)
// db.WithContext(ctx).Find(&accounts)
func SkipDataPermission(ctx context.Context) context.Context {
return context.WithValue(ctx, SkipDataPermissionKey, true)
}
// ShopStoreInterface 店铺 Store 接口
// 用于 Callback 获取下级店铺 ID避免循环依赖
type ShopStoreInterface interface {
GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error)
}
// RegisterDataPermissionCallback 注册 GORM 数据权限过滤 Callback
//
// 自动化数据权限过滤规则:
// 1. 超级管理员跳过过滤,可以查看所有数据
// 2. 平台用户跳过过滤,可以查看所有数据
// 3. 代理用户只能查看自己店铺及下级店铺的数据(基于 shop_id 字段)
// 4. 企业用户只能查看自己企业的数据(基于 enterprise_id 字段)
// 5. 个人客户只能查看自己的数据(基于 creator 字段或 customer_id 字段)
// 6. 通过 SkipDataPermission(ctx) 可以绕过权限过滤
//
// 注意:
// - Callback 根据表的字段自动选择过滤策略
// - 必须在初始化 Store 之前注册
//
// 参数:
// - db: GORM DB 实例
// - shopStore: 店铺 Store用于查询下级店铺 ID
//
// 返回:
// - error: 注册错误
func RegisterDataPermissionCallback(db *gorm.DB, shopStore ShopStoreInterface) error {
// 注册查询前的 Callback
err := db.Callback().Query().Before("gorm:query").Register("data_permission:query", func(tx *gorm.DB) {
ctx := tx.Statement.Context
if ctx == nil {
return
}
// 1. 检查是否跳过数据权限过滤
if skip, ok := ctx.Value(SkipDataPermissionKey).(bool); ok && skip {
return
}
// 2. 获取用户类型
userType := middleware.GetUserTypeFromContext(ctx)
// 3. 超级管理员和平台用户跳过过滤,可以查看所有数据
if userType == constants.UserTypeSuperAdmin || userType == constants.UserTypePlatform {
return
}
// 4. 获取当前用户信息
userID := middleware.GetUserIDFromContext(ctx)
if userID == 0 {
// 未登录用户返回空结果
logger.GetAppLogger().Warn("数据权限过滤:未获取到用户 ID")
tx.Where("1 = 0")
return
}
shopID := middleware.GetShopIDFromContext(ctx)
// 5. 根据用户类型和表结构应用不同的过滤规则
schema := tx.Statement.Schema
if schema == nil {
return
}
// 5.1 代理用户:基于店铺层级过滤
if userType == constants.UserTypeAgent {
tableName := schema.Table
// 特殊处理:标签表和资源标签表(包含全局标签)
if tableName == "tb_tag" || tableName == "tb_resource_tag" {
if shopID == 0 {
// 没有 shop_id只能看全局标签
tx.Where("enterprise_id IS NULL AND shop_id IS NULL")
return
}
// 查询该店铺及下级店铺的 ID
subordinateShopIDs, err := shopStore.GetSubordinateShopIDs(ctx, shopID)
if err != nil {
logger.GetAppLogger().Error("数据权限过滤:获取下级店铺 ID 失败",
zap.Uint("shop_id", shopID),
zap.Error(err))
subordinateShopIDs = []uint{shopID}
}
// 过滤:店铺标签(自己店铺及下级店铺)或全局标签
tx.Where("shop_id IN ? OR (enterprise_id IS NULL AND shop_id IS NULL)", subordinateShopIDs)
return
}
if !hasShopIDField(schema) {
// 表没有 shop_id 字段,无法过滤
return
}
if shopID == 0 {
// 代理用户没有 shop_id只能看自己创建的数据
if hasCreatorField(schema) {
tx.Where("creator = ?", userID)
} else {
tx.Where("1 = 0")
}
return
}
// 查询该店铺及下级店铺的 ID
subordinateShopIDs, err := shopStore.GetSubordinateShopIDs(ctx, shopID)
if err != nil {
logger.GetAppLogger().Error("数据权限过滤:获取下级店铺 ID 失败",
zap.Uint("shop_id", shopID),
zap.Error(err))
// 降级为只能看自己店铺的数据
subordinateShopIDs = []uint{shopID}
}
// 过滤shop_id IN (自己店铺及下级店铺)
tx.Where("shop_id IN ?", subordinateShopIDs)
return
}
// 5.2 企业用户:基于 enterprise_id 过滤
if userType == constants.UserTypeEnterprise {
enterpriseID := middleware.GetEnterpriseIDFromContext(ctx)
tableName := schema.Table
// 特殊处理:标签表和资源标签表(包含全局标签)
if tableName == "tb_tag" || tableName == "tb_resource_tag" {
if enterpriseID != 0 {
// 过滤:企业标签或全局标签
tx.Where("enterprise_id = ? OR (enterprise_id IS NULL AND shop_id IS NULL)", enterpriseID)
} else {
// 没有 enterprise_id只能看全局标签
tx.Where("enterprise_id IS NULL AND shop_id IS NULL")
}
return
}
if hasEnterpriseIDField(schema) {
if enterpriseID != 0 {
tx.Where("enterprise_id = ?", enterpriseID)
} else {
// 企业用户没有 enterprise_id返回空结果
tx.Where("1 = 0")
}
return
}
// 如果表没有 enterprise_id 字段,但有 creator 字段,则只能看自己创建的数据
if hasCreatorField(schema) {
tx.Where("creator = ?", userID)
return
}
// 无法过滤,返回空结果
tx.Where("1 = 0")
return
}
// 5.3 个人客户:只能看自己的数据
if userType == constants.UserTypePersonalCustomer {
customerID := middleware.GetCustomerIDFromContext(ctx)
tableName := schema.Table
// 特殊处理:标签表和资源标签表(只能看全局标签)
if tableName == "tb_tag" || tableName == "tb_resource_tag" {
tx.Where("enterprise_id IS NULL AND shop_id IS NULL")
return
}
// 优先使用 customer_id 字段
if hasCustomerIDField(schema) {
if customerID != 0 {
tx.Where("customer_id = ?", customerID)
} else {
// 个人客户没有 customer_id返回空结果
tx.Where("1 = 0")
}
return
}
// 降级为使用 creator 字段
if hasCreatorField(schema) {
tx.Where("creator = ?", userID)
return
}
// 无法过滤,返回空结果
tx.Where("1 = 0")
return
}
// 6. 默认:未知用户类型,返回空结果
logger.GetAppLogger().Warn("数据权限过滤:未知用户类型",
zap.Uint("user_id", userID),
zap.Int("user_type", userType))
tx.Where("1 = 0")
})
return err
}
// RegisterSetCreatorUpdaterCallback 注册 GORM 创建数据时创建人更新人 Callback
func RegisterSetCreatorUpdaterCallback(db *gorm.DB) error {
err := db.Callback().Create().Before("gorm:create").Register("set_creator_updater", func(tx *gorm.DB) {
ctx := tx.Statement.Context
if userID, ok := tx.Statement.Context.Value(constants.ContextKeyUserID).(uint); ok {
if f := tx.Statement.Schema; f != nil {
if c, ok := f.FieldsByName["Creator"]; ok {
_ = c.Set(ctx, tx.Statement.ReflectValue, userID)
}
if u, ok := f.FieldsByName["Updater"]; ok {
_ = u.Set(ctx, tx.Statement.ReflectValue, userID)
}
}
}
})
return err
}
// hasCreatorField 检查 Schema 是否包含 creator 字段
func hasCreatorField(s *schema.Schema) bool {
if s == nil {
return false
}
_, ok := s.FieldsByDBName["creator"]
return ok
}
// hasShopIDField 检查 Schema 是否包含 shop_id 字段
func hasShopIDField(s *schema.Schema) bool {
if s == nil {
return false
}
_, ok := s.FieldsByDBName["shop_id"]
return ok
}
// hasEnterpriseIDField 检查 Schema 是否包含 enterprise_id 字段
func hasEnterpriseIDField(s *schema.Schema) bool {
if s == nil {
return false
}
_, ok := s.FieldsByDBName["enterprise_id"]
return ok
}
// hasCustomerIDField 检查 Schema 是否包含 customer_id 字段
func hasCustomerIDField(s *schema.Schema) bool {
if s == nil {
return false
}
_, ok := s.FieldsByDBName["customer_id"]
return ok
}