refactor: 数据权限过滤从 GORM Callback 改为 Store 层显式调用
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 7m2s

- 移除 RegisterDataPermissionCallback 和 SkipDataPermission 机制
- 在 Auth 中间件预计算 SubordinateShopIDs 并注入 Context
- 新增 ApplyShopFilter/ApplyEnterpriseFilter/ApplyOwnerShopFilter 等 Helper 函数
- 所有 Store 层查询方法显式调用数据权限过滤函数
- 权限检查函数 CanManageShop/CanManageEnterprise 改为从 Context 获取数据

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-26 16:38:52 +08:00
parent 4ba1f5b99d
commit 03a0960c4d
46 changed files with 1573 additions and 705 deletions

View File

@@ -6,7 +6,6 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/pkg/config" "github.com/break/junhong_cmp_fiber/pkg/config"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
pkgGorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -15,7 +14,6 @@ func initDefaultAdmin(deps *Dependencies, services *services) error {
cfg := config.Get() cfg := config.Get()
ctx := context.Background() ctx := context.Background()
ctx = pkgGorm.SkipDataPermission(ctx)
var count int64 var count int64
if err := deps.DB.WithContext(ctx).Model(&model.Account{}).Where("user_type = ?", constants.UserTypeSuperAdmin).Count(&count).Error; err != nil { if err := deps.DB.WithContext(ctx).Model(&model.Account{}).Where("user_type = ?", constants.UserTypeSuperAdmin).Count(&count).Error; err != nil {

View File

@@ -45,8 +45,8 @@ func Bootstrap(deps *Dependencies) (*BootstrapResult, error) {
deps.Logger.Error("初始化默认超级管理员失败", zap.Error(err)) deps.Logger.Error("初始化默认超级管理员失败", zap.Error(err))
} }
// 5. 初始化 Middleware 层 // 5. 初始化 Middleware 层(传入 ShopStore 以支持预计算下级店铺 ID
middlewares := initMiddlewares(deps) middlewares := initMiddlewares(deps, stores)
// 6. 初始化 Handler 层 // 6. 初始化 Handler 层
handlers := initHandlers(services, deps) handlers := initHandlers(services, deps)
@@ -59,17 +59,12 @@ func Bootstrap(deps *Dependencies) (*BootstrapResult, error) {
// registerGORMCallbacks 注册 GORM Callbacks // registerGORMCallbacks 注册 GORM Callbacks
func registerGORMCallbacks(deps *Dependencies, stores *stores) error { func registerGORMCallbacks(deps *Dependencies, stores *stores) error {
// 注册数据权限过滤 Callback使用 ShopStore 来查询下级店铺 ID
if err := pkgGorm.RegisterDataPermissionCallback(deps.DB, stores.Shop); err != nil {
return err
}
// 注册自动添加创建&更新人 Callback // 注册自动添加创建&更新人 Callback
if err := pkgGorm.RegisterSetCreatorUpdaterCallback(deps.DB); err != nil { if err := pkgGorm.RegisterSetCreatorUpdaterCallback(deps.DB); err != nil {
return err return err
} }
// TODO: 在此添加其他 GORM Callbacks // 数据权限过滤已移至 Store 层显式调用 ApplyXxxFilter 函数
return nil return nil
} }

View File

@@ -14,7 +14,7 @@ import (
) )
// initMiddlewares 初始化所有中间件 // initMiddlewares 初始化所有中间件
func initMiddlewares(deps *Dependencies) *Middlewares { func initMiddlewares(deps *Dependencies, stores *stores) *Middlewares {
// 获取全局配置 // 获取全局配置
cfg := config.Get() cfg := config.Get()
@@ -29,11 +29,11 @@ func initMiddlewares(deps *Dependencies) *Middlewares {
refreshTTL := time.Duration(cfg.JWT.RefreshTokenTTL) * time.Second refreshTTL := time.Duration(cfg.JWT.RefreshTokenTTL) * time.Second
tokenManager := pkgauth.NewTokenManager(deps.Redis, accessTTL, refreshTTL) tokenManager := pkgauth.NewTokenManager(deps.Redis, accessTTL, refreshTTL)
// 创建后台认证中间件 // 创建后台认证中间件(传入 ShopStore 以支持预计算下级店铺 ID
adminAuthMiddleware := createAdminAuthMiddleware(tokenManager) adminAuthMiddleware := createAdminAuthMiddleware(tokenManager, stores.Shop)
// 创建H5认证中间件 // 创建H5认证中间件(传入 ShopStore 以支持预计算下级店铺 ID
h5AuthMiddleware := createH5AuthMiddleware(tokenManager) h5AuthMiddleware := createH5AuthMiddleware(tokenManager, stores.Shop)
return &Middlewares{ return &Middlewares{
PersonalAuth: personalAuthMiddleware, PersonalAuth: personalAuthMiddleware,
@@ -42,7 +42,7 @@ func initMiddlewares(deps *Dependencies) *Middlewares {
} }
} }
func createAdminAuthMiddleware(tokenManager *pkgauth.TokenManager) fiber.Handler { func createAdminAuthMiddleware(tokenManager *pkgauth.TokenManager, shopStore pkgmiddleware.AuthShopStoreInterface) fiber.Handler {
return pkgmiddleware.Auth(pkgmiddleware.AuthConfig{ return pkgmiddleware.Auth(pkgmiddleware.AuthConfig{
TokenValidator: func(token string) (*pkgmiddleware.UserContextInfo, error) { TokenValidator: func(token string) (*pkgmiddleware.UserContextInfo, error) {
tokenInfo, err := tokenManager.ValidateAccessToken(context.Background(), token) tokenInfo, err := tokenManager.ValidateAccessToken(context.Background(), token)
@@ -65,10 +65,11 @@ func createAdminAuthMiddleware(tokenManager *pkgauth.TokenManager) fiber.Handler
}, nil }, nil
}, },
SkipPaths: []string{"/api/admin/login", "/api/admin/refresh-token"}, SkipPaths: []string{"/api/admin/login", "/api/admin/refresh-token"},
ShopStore: shopStore,
}) })
} }
func createH5AuthMiddleware(tokenManager *pkgauth.TokenManager) fiber.Handler { func createH5AuthMiddleware(tokenManager *pkgauth.TokenManager, shopStore pkgmiddleware.AuthShopStoreInterface) fiber.Handler {
return pkgmiddleware.Auth(pkgmiddleware.AuthConfig{ return pkgmiddleware.Auth(pkgmiddleware.AuthConfig{
TokenValidator: func(token string) (*pkgmiddleware.UserContextInfo, error) { TokenValidator: func(token string) (*pkgmiddleware.UserContextInfo, error) {
tokenInfo, err := tokenManager.ValidateAccessToken(context.Background(), token) tokenInfo, err := tokenManager.ValidateAccessToken(context.Background(), token)
@@ -90,5 +91,6 @@ func createH5AuthMiddleware(tokenManager *pkgauth.TokenManager) fiber.Handler {
}, nil }, nil
}, },
SkipPaths: []string{"/api/h5/login", "/api/h5/refresh-token"}, SkipPaths: []string{"/api/h5/login", "/api/h5/refresh-token"},
ShopStore: shopStore,
}) })
} }

View File

@@ -147,6 +147,6 @@ func initServices(s *stores, deps *Dependencies) *services {
PollingMonitoring: pollingSvc.NewMonitoringService(deps.Redis), PollingMonitoring: pollingSvc.NewMonitoringService(deps.Redis),
PollingAlert: pollingSvc.NewAlertService(s.PollingAlertRule, s.PollingAlertHistory, deps.Redis, deps.Logger), PollingAlert: pollingSvc.NewAlertService(s.PollingAlertRule, s.PollingAlertHistory, deps.Redis, deps.Logger),
PollingCleanup: pollingSvc.NewCleanupService(s.DataCleanupConfig, s.DataCleanupLog, deps.Logger), PollingCleanup: pollingSvc.NewCleanupService(s.DataCleanupConfig, s.DataCleanupLog, deps.Logger),
PollingManualTrigger: pollingSvc.NewManualTriggerService(s.PollingManualTriggerLog, s.IotCard, s.Shop, deps.Redis, deps.Logger), PollingManualTrigger: pollingSvc.NewManualTriggerService(s.PollingManualTriggerLog, s.IotCard, deps.Redis, deps.Logger),
} }
} }

View File

@@ -17,13 +17,18 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
// ShopStoreInterface 店铺存储接口(仅用于获取店铺信息)
type ShopStoreInterface interface {
GetByIDs(ctx context.Context, ids []uint) ([]*model.Shop, error)
}
// Service 账号业务服务 // Service 账号业务服务
type Service struct { type Service struct {
accountStore *postgres.AccountStore accountStore *postgres.AccountStore
roleStore *postgres.RoleStore roleStore *postgres.RoleStore
accountRoleStore *postgres.AccountRoleStore accountRoleStore *postgres.AccountRoleStore
shopRoleStore *postgres.ShopRoleStore shopRoleStore *postgres.ShopRoleStore
shopStore middleware.ShopStoreInterface shopStore ShopStoreInterface
enterpriseStore middleware.EnterpriseStoreInterface enterpriseStore middleware.EnterpriseStoreInterface
auditService AuditServiceInterface auditService AuditServiceInterface
} }
@@ -38,7 +43,7 @@ func New(
roleStore *postgres.RoleStore, roleStore *postgres.RoleStore,
accountRoleStore *postgres.AccountRoleStore, accountRoleStore *postgres.AccountRoleStore,
shopRoleStore *postgres.ShopRoleStore, shopRoleStore *postgres.ShopRoleStore,
shopStore middleware.ShopStoreInterface, shopStore ShopStoreInterface,
enterpriseStore middleware.EnterpriseStoreInterface, enterpriseStore middleware.EnterpriseStoreInterface,
auditService AuditServiceInterface, auditService AuditServiceInterface,
) *Service { ) *Service {
@@ -79,13 +84,13 @@ func (s *Service) Create(ctx context.Context, req *dto.CreateAccountRequest) (*m
} }
if req.UserType == constants.UserTypeAgent && req.ShopID != nil { if req.UserType == constants.UserTypeAgent && req.ShopID != nil {
if err := middleware.CanManageShop(ctx, *req.ShopID, s.shopStore); err != nil { if err := middleware.CanManageShop(ctx, *req.ShopID); err != nil {
return nil, err return nil, err
} }
} }
if req.UserType == constants.UserTypeEnterprise && req.EnterpriseID != nil { if req.UserType == constants.UserTypeEnterprise && req.EnterpriseID != nil {
if err := middleware.CanManageEnterprise(ctx, *req.EnterpriseID, s.enterpriseStore, s.shopStore); err != nil { if err := middleware.CanManageEnterprise(ctx, *req.EnterpriseID, s.enterpriseStore); err != nil {
return nil, err return nil, err
} }
} }
@@ -190,7 +195,7 @@ func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdateAccountReq
if account.ShopID == nil { if account.ShopID == nil {
return nil, errors.New(errors.CodeForbidden, "无权限操作该账号") return nil, errors.New(errors.CodeForbidden, "无权限操作该账号")
} }
if err := middleware.CanManageShop(ctx, *account.ShopID, s.shopStore); err != nil { if err := middleware.CanManageShop(ctx, *account.ShopID); err != nil {
return nil, errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在") return nil, errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在")
} }
} }
@@ -291,7 +296,7 @@ func (s *Service) Delete(ctx context.Context, id uint) error {
if account.ShopID == nil { if account.ShopID == nil {
return errors.New(errors.CodeForbidden, "无权限操作该账号") return errors.New(errors.CodeForbidden, "无权限操作该账号")
} }
if err := middleware.CanManageShop(ctx, *account.ShopID, s.shopStore); err != nil { if err := middleware.CanManageShop(ctx, *account.ShopID); err != nil {
return errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在") return errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在")
} }
} }
@@ -407,7 +412,7 @@ func (s *Service) AssignRoles(ctx context.Context, accountID uint, roleIDs []uin
if account.ShopID == nil { if account.ShopID == nil {
return nil, errors.New(errors.CodeForbidden, "无权限操作该账号") return nil, errors.New(errors.CodeForbidden, "无权限操作该账号")
} }
if err := middleware.CanManageShop(ctx, *account.ShopID, s.shopStore); err != nil { if err := middleware.CanManageShop(ctx, *account.ShopID); err != nil {
return nil, errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在") return nil, errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在")
} }
} }
@@ -558,7 +563,7 @@ func (s *Service) RemoveRole(ctx context.Context, accountID, roleID uint) error
if account.ShopID == nil { if account.ShopID == nil {
return errors.New(errors.CodeForbidden, "无权限操作该账号") return errors.New(errors.CodeForbidden, "无权限操作该账号")
} }
if err := middleware.CanManageShop(ctx, *account.ShopID, s.shopStore); err != nil { if err := middleware.CanManageShop(ctx, *account.ShopID); err != nil {
return errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在") return errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在")
} }
} }

View File

@@ -10,7 +10,6 @@ import (
"github.com/break/junhong_cmp_fiber/pkg/auth" "github.com/break/junhong_cmp_fiber/pkg/auth"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/break/junhong_cmp_fiber/pkg/errors"
pkgGorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm" "gorm.io/gorm"
@@ -47,8 +46,6 @@ func New(
} }
func (s *Service) Login(ctx context.Context, req *dto.LoginRequest, clientIP string) (*dto.LoginResponse, error) { func (s *Service) Login(ctx context.Context, req *dto.LoginRequest, clientIP string) (*dto.LoginResponse, error) {
ctx = pkgGorm.SkipDataPermission(ctx)
account, err := s.accountStore.GetByUsernameOrPhone(ctx, req.Username) account, err := s.accountStore.GetByUsernameOrPhone(ctx, req.Username)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {

View File

@@ -9,7 +9,6 @@ import (
"github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/break/junhong_cmp_fiber/pkg/errors"
pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
"github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/break/junhong_cmp_fiber/pkg/middleware"
"go.uber.org/zap" "go.uber.org/zap"
"gorm.io/gorm" "gorm.io/gorm"
@@ -426,10 +425,8 @@ func (s *Service) ListDevicesForEnterprise(ctx context.Context, req *dto.Enterpr
authMap[auth.DeviceID] = auth authMap[auth.DeviceID] = auth
} }
skipCtx := pkggorm.SkipDataPermission(ctx)
var devices []model.Device var devices []model.Device
query := s.db.WithContext(skipCtx).Where("id IN ?", deviceIDs) query := s.db.WithContext(ctx).Where("id IN ?", deviceIDs)
if req.DeviceNo != "" { if req.DeviceNo != "" {
query = query.Where("device_no LIKE ?", "%"+req.DeviceNo+"%") query = query.Where("device_no LIKE ?", "%"+req.DeviceNo+"%")
} }
@@ -438,7 +435,7 @@ func (s *Service) ListDevicesForEnterprise(ctx context.Context, req *dto.Enterpr
} }
var bindings []model.DeviceSimBinding var bindings []model.DeviceSimBinding
if err := s.db.WithContext(skipCtx). if err := s.db.WithContext(ctx).
Where("device_id IN ? AND bind_status = 1", deviceIDs). Where("device_id IN ? AND bind_status = 1", deviceIDs).
Find(&bindings).Error; err != nil { Find(&bindings).Error; err != nil {
return nil, errors.Wrap(errors.CodeInternalError, err, "查询设备绑定卡失败") return nil, errors.Wrap(errors.CodeInternalError, err, "查询设备绑定卡失败")
@@ -480,15 +477,14 @@ func (s *Service) GetDeviceDetail(ctx context.Context, deviceID uint) (*dto.Ente
return nil, errors.New(errors.CodeDeviceNotAuthorized, "设备未授权给此企业") return nil, errors.New(errors.CodeDeviceNotAuthorized, "设备未授权给此企业")
} }
skipCtx := pkggorm.SkipDataPermission(ctx)
var device model.Device var device model.Device
if err := s.db.WithContext(skipCtx).Where("id = ?", deviceID).First(&device).Error; err != nil { if err := s.db.WithContext(ctx).Where("id = ?", deviceID).First(&device).Error; err != nil {
return nil, errors.Wrap(errors.CodeInternalError, err, "查询设备信息失败") return nil, errors.Wrap(errors.CodeInternalError, err, "查询设备信息失败")
} }
var bindings []model.DeviceSimBinding var bindings []model.DeviceSimBinding
if err := s.db.WithContext(skipCtx). if err := s.db.WithContext(ctx).
Where("device_id = ? AND bind_status = 1", deviceID). Where("device_id = ? AND bind_status = 1", deviceID).
Find(&bindings).Error; err != nil { Find(&bindings).Error; err != nil {
return nil, errors.Wrap(errors.CodeInternalError, err, "查询设备绑定卡失败") return nil, errors.Wrap(errors.CodeInternalError, err, "查询设备绑定卡失败")
@@ -502,7 +498,7 @@ func (s *Service) GetDeviceDetail(ctx context.Context, deviceID uint) (*dto.Ente
var cards []model.IotCard var cards []model.IotCard
cardInfos := make([]dto.DeviceCardInfo, 0) cardInfos := make([]dto.DeviceCardInfo, 0)
if len(cardIDs) > 0 { if len(cardIDs) > 0 {
if err := s.db.WithContext(skipCtx).Where("id IN ?", cardIDs).Find(&cards).Error; err != nil { if err := s.db.WithContext(ctx).Where("id IN ?", cardIDs).Find(&cards).Error; err != nil {
return nil, errors.Wrap(errors.CodeInternalError, err, "查询卡信息失败") return nil, errors.Wrap(errors.CodeInternalError, err, "查询卡信息失败")
} }
@@ -514,7 +510,7 @@ func (s *Service) GetDeviceDetail(ctx context.Context, deviceID uint) (*dto.Ente
var carriers []model.Carrier var carriers []model.Carrier
carrierMap := make(map[uint]string) carrierMap := make(map[uint]string)
if len(carrierIDs) > 0 { if len(carrierIDs) > 0 {
if err := s.db.WithContext(skipCtx).Where("id IN ?", carrierIDs).Find(&carriers).Error; err == nil { if err := s.db.WithContext(ctx).Where("id IN ?", carrierIDs).Find(&carriers).Error; err == nil {
for _, carrier := range carriers { for _, carrier := range carriers {
carrierMap[carrier.ID] = carrier.CarrierName carrierMap[carrier.ID] = carrier.CarrierName
} }
@@ -551,8 +547,7 @@ func (s *Service) SuspendCard(ctx context.Context, deviceID, cardID uint, req *d
return nil, err return nil, err
} }
skipCtx := pkggorm.SkipDataPermission(ctx) if err := s.db.WithContext(ctx).Model(&model.IotCard{}).
if err := s.db.WithContext(skipCtx).Model(&model.IotCard{}).
Where("id = ?", cardID). Where("id = ?", cardID).
Update("network_status", 0).Error; err != nil { Update("network_status", 0).Error; err != nil {
return nil, errors.Wrap(errors.CodeInternalError, err, "停机操作失败") return nil, errors.Wrap(errors.CodeInternalError, err, "停机操作失败")
@@ -569,8 +564,7 @@ func (s *Service) ResumeCard(ctx context.Context, deviceID, cardID uint, req *dt
return nil, err return nil, err
} }
skipCtx := pkggorm.SkipDataPermission(ctx) if err := s.db.WithContext(ctx).Model(&model.IotCard{}).
if err := s.db.WithContext(skipCtx).Model(&model.IotCard{}).
Where("id = ?", cardID). Where("id = ?", cardID).
Update("network_status", 1).Error; err != nil { Update("network_status", 1).Error; err != nil {
return nil, errors.Wrap(errors.CodeInternalError, err, "复机操作失败") return nil, errors.Wrap(errors.CodeInternalError, err, "复机操作失败")
@@ -593,17 +587,16 @@ func (s *Service) validateCardOperation(ctx context.Context, deviceID, cardID ui
return errors.New(errors.CodeDeviceNotAuthorized, "设备未授权给此企业") return errors.New(errors.CodeDeviceNotAuthorized, "设备未授权给此企业")
} }
skipCtx := pkggorm.SkipDataPermission(ctx)
var binding model.DeviceSimBinding var binding model.DeviceSimBinding
if err := s.db.WithContext(skipCtx). if err := s.db.WithContext(ctx).
Where("device_id = ? AND iot_card_id = ? AND bind_status = 1", deviceID, cardID). Where("device_id = ? AND iot_card_id = ? AND bind_status = 1", deviceID, cardID).
First(&binding).Error; err != nil { First(&binding).Error; err != nil {
return errors.New(errors.CodeForbidden, "卡不属于该设备") return errors.New(errors.CodeForbidden, "卡不属于该设备")
} }
var cardAuth model.EnterpriseCardAuthorization var cardAuth model.EnterpriseCardAuthorization
if err := s.db.WithContext(skipCtx). if err := s.db.WithContext(ctx).
Where("enterprise_id = ? AND card_id = ? AND device_auth_id IS NOT NULL AND revoked_at IS NULL", enterpriseID, cardID). Where("enterprise_id = ? AND card_id = ? AND device_auth_id IS NOT NULL AND revoked_at IS NULL", enterpriseID, cardID).
First(&cardAuth).Error; err != nil { First(&cardAuth).Error; err != nil {
return errors.New(errors.CodeForbidden, "无权操作此卡") return errors.New(errors.CodeForbidden, "无权操作此卡")

View File

@@ -19,7 +19,6 @@ import (
type ManualTriggerService struct { type ManualTriggerService struct {
logStore *postgres.PollingManualTriggerLogStore logStore *postgres.PollingManualTriggerLogStore
iotCardStore *postgres.IotCardStore iotCardStore *postgres.IotCardStore
shopStore middleware.ShopStoreInterface
redis *redis.Client redis *redis.Client
logger *zap.Logger logger *zap.Logger
} }
@@ -28,14 +27,12 @@ type ManualTriggerService struct {
func NewManualTriggerService( func NewManualTriggerService(
logStore *postgres.PollingManualTriggerLogStore, logStore *postgres.PollingManualTriggerLogStore,
iotCardStore *postgres.IotCardStore, iotCardStore *postgres.IotCardStore,
shopStore middleware.ShopStoreInterface,
redis *redis.Client, redis *redis.Client,
logger *zap.Logger, logger *zap.Logger,
) *ManualTriggerService { ) *ManualTriggerService {
return &ManualTriggerService{ return &ManualTriggerService{
logStore: logStore, logStore: logStore,
iotCardStore: iotCardStore, iotCardStore: iotCardStore,
shopStore: shopStore,
redis: redis, redis: redis,
logger: logger, logger: logger,
} }
@@ -386,7 +383,7 @@ func (s *ManualTriggerService) canManageCard(ctx context.Context, cardID uint) e
} }
// 检查代理是否有权管理该店铺 // 检查代理是否有权管理该店铺
return middleware.CanManageShop(ctx, *card.ShopID, s.shopStore) return middleware.CanManageShop(ctx, *card.ShopID)
} }
// canManageCards 检查用户是否有权管理多张卡 // canManageCards 检查用户是否有权管理多张卡
@@ -403,18 +400,13 @@ func (s *ManualTriggerService) canManageCards(ctx context.Context, cardIDs []uin
return errors.New(errors.CodeForbidden, "企业账号无权限手动触发轮询") return errors.New(errors.CodeForbidden, "企业账号无权限手动触发轮询")
} }
// 代理账号只能管理自己店铺及下级店铺的卡 // 从 Context 获取预计算的下级店铺 ID 列表
currentShopID := middleware.GetShopIDFromContext(ctx) subordinateIDs := middleware.GetSubordinateShopIDs(ctx)
if currentShopID == 0 { if subordinateIDs == nil {
// 平台用户/超管不受限制,但这里不应该进入(前面已经检查过用户类型)
return errors.New(errors.CodeForbidden, "无权限操作") return errors.New(errors.CodeForbidden, "无权限操作")
} }
// 获取下级店铺ID列表
subordinateIDs, err := s.shopStore.GetSubordinateShopIDs(ctx, currentShopID)
if err != nil {
return errors.Wrap(errors.CodeInternalError, err, "查询下级店铺失败")
}
// 构建可管理的店铺ID集合 // 构建可管理的店铺ID集合
allowedShopIDs := make(map[uint]bool) allowedShopIDs := make(map[uint]bool)
for _, id := range subordinateIDs { for _, id := range subordinateIDs {
@@ -462,7 +454,7 @@ func (s *ManualTriggerService) applyShopPermissionFilter(ctx context.Context, fi
// 如果用户指定了 ShopID验证是否在可管理范围内 // 如果用户指定了 ShopID验证是否在可管理范围内
if filter.ShopID != nil { if filter.ShopID != nil {
if err := middleware.CanManageShop(ctx, *filter.ShopID, s.shopStore); err != nil { if err := middleware.CanManageShop(ctx, *filter.ShopID); err != nil {
return err return err
} }
// 已指定有效的 ShopID无需修改 // 已指定有效的 ShopID无需修改

View File

@@ -11,7 +11,7 @@ import (
) )
func (s *Service) AssignRolesToShop(ctx context.Context, shopID uint, roleIDs []uint) ([]*model.ShopRole, error) { func (s *Service) AssignRolesToShop(ctx context.Context, shopID uint, roleIDs []uint) ([]*model.ShopRole, error) {
if err := middleware.CanManageShop(ctx, shopID, s.shopStore); err != nil { if err := middleware.CanManageShop(ctx, shopID); err != nil {
return nil, err return nil, err
} }
@@ -70,7 +70,7 @@ func (s *Service) AssignRolesToShop(ctx context.Context, shopID uint, roleIDs []
} }
func (s *Service) GetShopRoles(ctx context.Context, shopID uint) (*dto.ShopRolesResponse, error) { func (s *Service) GetShopRoles(ctx context.Context, shopID uint) (*dto.ShopRolesResponse, error) {
if err := middleware.CanManageShop(ctx, shopID, s.shopStore); err != nil { if err := middleware.CanManageShop(ctx, shopID); err != nil {
return nil, err return nil, err
} }
@@ -128,7 +128,7 @@ func (s *Service) GetShopRoles(ctx context.Context, shopID uint) (*dto.ShopRoles
} }
func (s *Service) DeleteShopRole(ctx context.Context, shopID, roleID uint) error { func (s *Service) DeleteShopRole(ctx context.Context, shopID, roleID uint) error {
if err := middleware.CanManageShop(ctx, shopID, s.shopStore); err != nil { if err := middleware.CanManageShop(ctx, shopID); err != nil {
return err return err
} }

View File

@@ -10,7 +10,6 @@ import (
"github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/break/junhong_cmp_fiber/pkg/errors"
pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
"github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/break/junhong_cmp_fiber/pkg/middleware"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -71,9 +70,8 @@ func (s *Service) Create(ctx context.Context, req *dto.CreateShopSeriesAllocatio
return nil, errors.Wrap(errors.CodeInternalError, err, "获取套餐系列失败") return nil, errors.Wrap(errors.CodeInternalError, err, "获取套餐系列失败")
} }
// 检查是否已存在分配(跳过数据权限过滤,避免误判) // 检查是否已存在分配
skipCtx := pkggorm.SkipDataPermission(ctx) exists, err := s.seriesAllocationStore.ExistsByShopAndSeries(ctx, req.ShopID, req.SeriesID)
exists, err := s.seriesAllocationStore.ExistsByShopAndSeries(skipCtx, req.ShopID, req.SeriesID)
if err != nil { if err != nil {
return nil, errors.Wrap(errors.CodeInternalError, err, "检查分配记录失败") return nil, errors.Wrap(errors.CodeInternalError, err, "检查分配记录失败")
} }
@@ -84,7 +82,7 @@ func (s *Service) Create(ctx context.Context, req *dto.CreateShopSeriesAllocatio
// 代理用户:检查自己是否有该系列的分配权限,且金额不能超过上级给的上限 // 代理用户:检查自己是否有该系列的分配权限,且金额不能超过上级给的上限
// 平台用户:无上限限制,可自由设定金额 // 平台用户:无上限限制,可自由设定金额
if userType == constants.UserTypeAgent { if userType == constants.UserTypeAgent {
allocatorAllocation, err := s.seriesAllocationStore.GetByShopAndSeries(skipCtx, allocatorShopID, req.SeriesID) allocatorAllocation, err := s.seriesAllocationStore.GetByShopAndSeries(ctx, allocatorShopID, req.SeriesID)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
return nil, errors.New(errors.CodeForbidden, "您没有该套餐系列的分配权限") return nil, errors.New(errors.CodeForbidden, "您没有该套餐系列的分配权限")
@@ -239,8 +237,7 @@ func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdateShopSeries
} }
func (s *Service) Delete(ctx context.Context, id uint) error { func (s *Service) Delete(ctx context.Context, id uint) error {
skipCtx := pkggorm.SkipDataPermission(ctx) _, err := s.seriesAllocationStore.GetByID(ctx, id)
_, err := s.seriesAllocationStore.GetByID(skipCtx, id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
return errors.New(errors.CodeNotFound, "分配记录不存在") return errors.New(errors.CodeNotFound, "分配记录不存在")
@@ -248,7 +245,7 @@ func (s *Service) Delete(ctx context.Context, id uint) error {
return errors.Wrap(errors.CodeInternalError, err, "获取分配记录失败") return errors.Wrap(errors.CodeInternalError, err, "获取分配记录失败")
} }
count, err := s.packageAllocationStore.CountBySeriesAllocationID(skipCtx, id) count, err := s.packageAllocationStore.CountBySeriesAllocationID(ctx, id)
if err != nil { if err != nil {
return errors.Wrap(errors.CodeInternalError, err, "检查关联套餐分配失败") return errors.Wrap(errors.CodeInternalError, err, "检查关联套餐分配失败")
} }
@@ -256,7 +253,7 @@ func (s *Service) Delete(ctx context.Context, id uint) error {
return errors.New(errors.CodeInvalidParam, "存在关联的套餐分配,无法删除") return errors.New(errors.CodeInvalidParam, "存在关联的套餐分配,无法删除")
} }
if err := s.seriesAllocationStore.Delete(skipCtx, id); err != nil { if err := s.seriesAllocationStore.Delete(ctx, id); err != nil {
return errors.Wrap(errors.CodeInternalError, err, "删除分配失败") return errors.Wrap(errors.CodeInternalError, err, "删除分配失败")
} }

View File

@@ -3,9 +3,9 @@ package postgres
import ( import (
"context" "context"
"github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -32,7 +32,12 @@ func (s *AccountStore) Create(ctx context.Context, account *model.Account) error
// GetByID 根据 ID 获取账号 // GetByID 根据 ID 获取账号
func (s *AccountStore) GetByID(ctx context.Context, id uint) (*model.Account, error) { func (s *AccountStore) GetByID(ctx context.Context, id uint) (*model.Account, error) {
var account model.Account var account model.Account
if err := s.db.WithContext(ctx).First(&account, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 根据当前用户类型应用数据权限过滤
// 代理用户:过滤 shop_id企业用户过滤 enterprise_id
query = middleware.ApplyShopFilter(ctx, query)
query = middleware.ApplyEnterpriseFilter(ctx, query)
if err := query.First(&account).Error; err != nil {
return nil, err return nil, err
} }
return &account, nil return &account, nil
@@ -68,7 +73,10 @@ func (s *AccountStore) GetByUsernameOrPhone(ctx context.Context, identifier stri
// GetByShopID 根据店铺 ID 查询账号列表 // GetByShopID 根据店铺 ID 查询账号列表
func (s *AccountStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.Account, error) { func (s *AccountStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.Account, error) {
var accounts []*model.Account var accounts []*model.Account
if err := s.db.WithContext(ctx).Where("shop_id = ?", shopID).Find(&accounts).Error; err != nil { query := s.db.WithContext(ctx).Where("shop_id = ?", shopID)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&accounts).Error; err != nil {
return nil, err return nil, err
} }
return accounts, nil return accounts, nil
@@ -77,7 +85,10 @@ func (s *AccountStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.A
// GetByEnterpriseID 根据企业 ID 查询账号列表 // GetByEnterpriseID 根据企业 ID 查询账号列表
func (s *AccountStore) GetByEnterpriseID(ctx context.Context, enterpriseID uint) ([]*model.Account, error) { func (s *AccountStore) GetByEnterpriseID(ctx context.Context, enterpriseID uint) ([]*model.Account, error) {
var accounts []*model.Account var accounts []*model.Account
if err := s.db.WithContext(ctx).Where("enterprise_id = ?", enterpriseID).Find(&accounts).Error; err != nil { query := s.db.WithContext(ctx).Where("enterprise_id = ?", enterpriseID)
// 应用企业数据权限过滤
query = middleware.ApplyEnterpriseFilter(ctx, query)
if err := query.Find(&accounts).Error; err != nil {
return nil, err return nil, err
} }
return accounts, nil return accounts, nil
@@ -99,6 +110,10 @@ func (s *AccountStore) List(ctx context.Context, opts *store.QueryOptions, filte
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.Account{}) query := s.db.WithContext(ctx).Model(&model.Account{})
// 根据当前用户类型应用数据权限过滤
// 代理用户:过滤 shop_id企业用户过滤 enterprise_id
query = middleware.ApplyShopFilter(ctx, query)
query = middleware.ApplyEnterpriseFilter(ctx, query)
// 应用过滤条件 // 应用过滤条件
if username, ok := filters["username"].(string); ok && username != "" { if username, ok := filters["username"].(string); ok && username != "" {
@@ -229,7 +244,11 @@ func (s *AccountStore) GetByIDs(ctx context.Context, ids []uint) ([]*model.Accou
return []*model.Account{}, nil return []*model.Account{}, nil
} }
var accounts []*model.Account var accounts []*model.Account
if err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&accounts).Error; err != nil { query := s.db.WithContext(ctx).Where("id IN ?", ids)
// 根据当前用户类型应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
query = middleware.ApplyEnterpriseFilter(ctx, query)
if err := query.Find(&accounts).Error; err != nil {
return nil, err return nil, err
} }
return accounts, nil return accounts, nil
@@ -240,9 +259,11 @@ func (s *AccountStore) GetPrimaryAccountsByShopIDs(ctx context.Context, shopIDs
return []*model.Account{}, nil return []*model.Account{}, nil
} }
var accounts []*model.Account var accounts []*model.Account
if err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("shop_id IN ? AND is_primary = ?", shopIDs, true). Where("shop_id IN ? AND is_primary = ?", shopIDs, true)
Find(&accounts).Error; err != nil { // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&accounts).Error; err != nil {
return nil, err return nil, err
} }
return accounts, nil return accounts, nil
@@ -254,6 +275,8 @@ func (s *AccountStore) ListByShopID(ctx context.Context, shopID uint, opts *stor
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.Account{}).Where("shop_id = ?", shopID) query := s.db.WithContext(ctx).Model(&model.Account{}).Where("shop_id = ?", shopID)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if username, ok := filters["username"].(string); ok && username != "" { if username, ok := filters["username"].(string); ok && username != "" {
query = query.Where("username LIKE ?", "%"+username+"%") query = query.Where("username LIKE ?", "%"+username+"%")

View File

@@ -6,6 +6,7 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -41,9 +42,11 @@ func (s *AgentWalletStore) GetByShopIDAndType(ctx context.Context, shopID uint,
// 注意:这里简化处理,实际项目中可以缓存完整的钱包信息 // 注意:这里简化处理,实际项目中可以缓存完整的钱包信息
var wallet model.AgentWallet var wallet model.AgentWallet
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("shop_id = ? AND wallet_type = ?", shopID, walletType). Where("shop_id = ? AND wallet_type = ?", shopID, walletType)
First(&wallet).Error // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
err := query.First(&wallet).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,7 +61,10 @@ func (s *AgentWalletStore) GetByShopIDAndType(ctx context.Context, shopID uint,
// GetByID 根据钱包 ID 查询 // GetByID 根据钱包 ID 查询
func (s *AgentWalletStore) GetByID(ctx context.Context, id uint) (*model.AgentWallet, error) { func (s *AgentWalletStore) GetByID(ctx context.Context, id uint) (*model.AgentWallet, error) {
var wallet model.AgentWallet var wallet model.AgentWallet
if err := s.db.WithContext(ctx).First(&wallet, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&wallet).Error; err != nil {
return nil, err return nil, err
} }
return &wallet, nil return &wallet, nil
@@ -209,9 +215,11 @@ func (s *AgentWalletStore) GetShopCommissionSummaryBatch(ctx context.Context, sh
} }
var wallets []model.AgentWallet var wallets []model.AgentWallet
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("shop_id IN ? AND wallet_type = ?", shopIDs, constants.AgentWalletTypeCommission). Where("shop_id IN ? AND wallet_type = ?", shopIDs, constants.AgentWalletTypeCommission)
Find(&wallets).Error // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
err := query.Find(&wallets).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -30,9 +31,11 @@ func (s *AgentWalletTransactionStore) CreateWithTx(ctx context.Context, tx *gorm
// ListByShopID 按店铺查询交易记录(支持分页) // ListByShopID 按店铺查询交易记录(支持分页)
func (s *AgentWalletTransactionStore) ListByShopID(ctx context.Context, shopID uint, offset, limit int) ([]*model.AgentWalletTransaction, error) { func (s *AgentWalletTransactionStore) ListByShopID(ctx context.Context, shopID uint, offset, limit int) ([]*model.AgentWalletTransaction, error) {
var transactions []*model.AgentWalletTransaction var transactions []*model.AgentWalletTransaction
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("shop_id = ?", shopID). Where("shop_id = ?", shopID)
Order("created_at DESC"). // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
err := query.Order("created_at DESC").
Offset(offset). Offset(offset).
Limit(limit). Limit(limit).
Find(&transactions).Error Find(&transactions).Error
@@ -45,19 +48,23 @@ func (s *AgentWalletTransactionStore) ListByShopID(ctx context.Context, shopID u
// CountByShopID 统计店铺的交易记录数量 // CountByShopID 统计店铺的交易记录数量
func (s *AgentWalletTransactionStore) CountByShopID(ctx context.Context, shopID uint) (int64, error) { func (s *AgentWalletTransactionStore) CountByShopID(ctx context.Context, shopID uint) (int64, error) {
var count int64 var count int64
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Model(&model.AgentWalletTransaction{}). Model(&model.AgentWalletTransaction{}).
Where("shop_id = ?", shopID). Where("shop_id = ?", shopID)
Count(&count).Error // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
err := query.Count(&count).Error
return count, err return count, err
} }
// ListByWalletID 按钱包查询交易记录(支持分页) // ListByWalletID 按钱包查询交易记录(支持分页)
func (s *AgentWalletTransactionStore) ListByWalletID(ctx context.Context, walletID uint, offset, limit int) ([]*model.AgentWalletTransaction, error) { func (s *AgentWalletTransactionStore) ListByWalletID(ctx context.Context, walletID uint, offset, limit int) ([]*model.AgentWalletTransaction, error) {
var transactions []*model.AgentWalletTransaction var transactions []*model.AgentWalletTransaction
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("agent_wallet_id = ?", walletID). Where("agent_wallet_id = ?", walletID)
Order("created_at DESC"). // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
err := query.Order("created_at DESC").
Offset(offset). Offset(offset).
Limit(limit). Limit(limit).
Find(&transactions).Error Find(&transactions).Error
@@ -70,9 +77,11 @@ func (s *AgentWalletTransactionStore) ListByWalletID(ctx context.Context, wallet
// GetByReference 根据关联业务查询交易记录 // GetByReference 根据关联业务查询交易记录
func (s *AgentWalletTransactionStore) GetByReference(ctx context.Context, referenceType string, referenceID uint) (*model.AgentWalletTransaction, error) { func (s *AgentWalletTransactionStore) GetByReference(ctx context.Context, referenceType string, referenceID uint) (*model.AgentWalletTransaction, error) {
var transaction model.AgentWalletTransaction var transaction model.AgentWalletTransaction
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("reference_type = ? AND reference_id = ?", referenceType, referenceID). Where("reference_type = ? AND reference_id = ?", referenceType, referenceID)
First(&transaction).Error // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
err := query.First(&transaction).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -6,6 +6,7 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -27,9 +28,11 @@ func NewCardWalletStore(db *gorm.DB, redis *redis.Client) *CardWalletStore {
// GetByResourceTypeAndID 根据资源类型和 ID 查询钱包 // GetByResourceTypeAndID 根据资源类型和 ID 查询钱包
func (s *CardWalletStore) GetByResourceTypeAndID(ctx context.Context, resourceType string, resourceID uint) (*model.CardWallet, error) { func (s *CardWalletStore) GetByResourceTypeAndID(ctx context.Context, resourceType string, resourceID uint) (*model.CardWallet, error) {
var wallet model.CardWallet var wallet model.CardWallet
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("resource_type = ? AND resource_id = ?", resourceType, resourceID). Where("resource_type = ? AND resource_id = ?", resourceType, resourceID)
First(&wallet).Error // 应用数据权限过滤(使用 shop_id_tag 字段)
query = middleware.ApplyShopTagFilter(ctx, query)
err := query.First(&wallet).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -39,7 +42,10 @@ func (s *CardWalletStore) GetByResourceTypeAndID(ctx context.Context, resourceTy
// GetByID 根据钱包 ID 查询 // GetByID 根据钱包 ID 查询
func (s *CardWalletStore) GetByID(ctx context.Context, id uint) (*model.CardWallet, error) { func (s *CardWalletStore) GetByID(ctx context.Context, id uint) (*model.CardWallet, error) {
var wallet model.CardWallet var wallet model.CardWallet
if err := s.db.WithContext(ctx).First(&wallet, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤(使用 shop_id_tag 字段)
query = middleware.ApplyShopTagFilter(ctx, query)
if err := query.First(&wallet).Error; err != nil {
return nil, err return nil, err
} }
return &wallet, nil return &wallet, nil

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -30,9 +31,11 @@ func (s *CardWalletTransactionStore) CreateWithTx(ctx context.Context, tx *gorm.
// ListByResourceID 按资源查询交易记录(支持分页) // ListByResourceID 按资源查询交易记录(支持分页)
func (s *CardWalletTransactionStore) ListByResourceID(ctx context.Context, resourceType string, resourceID uint, offset, limit int) ([]*model.CardWalletTransaction, error) { func (s *CardWalletTransactionStore) ListByResourceID(ctx context.Context, resourceType string, resourceID uint, offset, limit int) ([]*model.CardWalletTransaction, error) {
var transactions []*model.CardWalletTransaction var transactions []*model.CardWalletTransaction
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("resource_type = ? AND resource_id = ?", resourceType, resourceID). Where("resource_type = ? AND resource_id = ?", resourceType, resourceID)
Order("created_at DESC"). // 应用数据权限过滤(使用 shop_id_tag 字段)
query = middleware.ApplyShopTagFilter(ctx, query)
err := query.Order("created_at DESC").
Offset(offset). Offset(offset).
Limit(limit). Limit(limit).
Find(&transactions).Error Find(&transactions).Error
@@ -45,19 +48,23 @@ func (s *CardWalletTransactionStore) ListByResourceID(ctx context.Context, resou
// CountByResourceID 统计资源的交易记录数量 // CountByResourceID 统计资源的交易记录数量
func (s *CardWalletTransactionStore) CountByResourceID(ctx context.Context, resourceType string, resourceID uint) (int64, error) { func (s *CardWalletTransactionStore) CountByResourceID(ctx context.Context, resourceType string, resourceID uint) (int64, error) {
var count int64 var count int64
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Model(&model.CardWalletTransaction{}). Model(&model.CardWalletTransaction{}).
Where("resource_type = ? AND resource_id = ?", resourceType, resourceID). Where("resource_type = ? AND resource_id = ?", resourceType, resourceID)
Count(&count).Error // 应用数据权限过滤(使用 shop_id_tag 字段)
query = middleware.ApplyShopTagFilter(ctx, query)
err := query.Count(&count).Error
return count, err return count, err
} }
// ListByWalletID 按钱包查询交易记录(支持分页) // ListByWalletID 按钱包查询交易记录(支持分页)
func (s *CardWalletTransactionStore) ListByWalletID(ctx context.Context, walletID uint, offset, limit int) ([]*model.CardWalletTransaction, error) { func (s *CardWalletTransactionStore) ListByWalletID(ctx context.Context, walletID uint, offset, limit int) ([]*model.CardWalletTransaction, error) {
var transactions []*model.CardWalletTransaction var transactions []*model.CardWalletTransaction
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("card_wallet_id = ?", walletID). Where("card_wallet_id = ?", walletID)
Order("created_at DESC"). // 应用数据权限过滤(使用 shop_id_tag 字段)
query = middleware.ApplyShopTagFilter(ctx, query)
err := query.Order("created_at DESC").
Offset(offset). Offset(offset).
Limit(limit). Limit(limit).
Find(&transactions).Error Find(&transactions).Error
@@ -70,9 +77,11 @@ func (s *CardWalletTransactionStore) ListByWalletID(ctx context.Context, walletI
// GetByReference 根据关联业务查询交易记录 // GetByReference 根据关联业务查询交易记录
func (s *CardWalletTransactionStore) GetByReference(ctx context.Context, referenceType string, referenceID uint) (*model.CardWalletTransaction, error) { func (s *CardWalletTransactionStore) GetByReference(ctx context.Context, referenceType string, referenceID uint) (*model.CardWalletTransaction, error) {
var transaction model.CardWalletTransaction var transaction model.CardWalletTransaction
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("reference_type = ? AND reference_id = ?", referenceType, referenceID). Where("reference_type = ? AND reference_id = ?", referenceType, referenceID)
First(&transaction).Error // 应用数据权限过滤(使用 shop_id_tag 字段)
query = middleware.ApplyShopTagFilter(ctx, query)
err := query.First(&transaction).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -6,6 +6,7 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -28,7 +29,10 @@ func (s *CommissionRecordStore) Create(ctx context.Context, record *model.Commis
func (s *CommissionRecordStore) GetByID(ctx context.Context, id uint) (*model.CommissionRecord, error) { func (s *CommissionRecordStore) GetByID(ctx context.Context, id uint) (*model.CommissionRecord, error) {
var record model.CommissionRecord var record model.CommissionRecord
if err := s.db.WithContext(ctx).First(&record, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&record).Error; err != nil {
return nil, err return nil, err
} }
return &record, nil return &record, nil
@@ -50,6 +54,8 @@ func (s *CommissionRecordStore) ListByShopID(ctx context.Context, opts *store.Qu
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.CommissionRecord{}) query := s.db.WithContext(ctx).Model(&model.CommissionRecord{})
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if filters != nil { if filters != nil {
if filters.ShopID > 0 { if filters.ShopID > 0 {
@@ -107,6 +113,8 @@ type CommissionStats struct {
func (s *CommissionRecordStore) GetStats(ctx context.Context, filters *CommissionRecordListFilters) (*CommissionStats, error) { func (s *CommissionRecordStore) GetStats(ctx context.Context, filters *CommissionRecordListFilters) (*CommissionStats, error) {
query := s.db.WithContext(ctx).Model(&model.CommissionRecord{}). query := s.db.WithContext(ctx).Model(&model.CommissionRecord{}).
Where("status = ?", model.CommissionStatusReleased) Where("status = ?", model.CommissionStatusReleased)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if filters != nil { if filters != nil {
if filters.ShopID > 0 { if filters.ShopID > 0 {
@@ -151,6 +159,8 @@ func (s *CommissionRecordStore) GetDailyStats(ctx context.Context, filters *Comm
query := s.db.WithContext(ctx).Model(&model.CommissionRecord{}). query := s.db.WithContext(ctx).Model(&model.CommissionRecord{}).
Where("status = ?", model.CommissionStatusReleased) Where("status = ?", model.CommissionStatusReleased)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if filters != nil { if filters != nil {
if filters.ShopID > 0 { if filters.ShopID > 0 {

View File

@@ -7,6 +7,7 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -29,7 +30,10 @@ func (s *CommissionWithdrawalRequestStore) Create(ctx context.Context, req *mode
func (s *CommissionWithdrawalRequestStore) GetByID(ctx context.Context, id uint) (*model.CommissionWithdrawalRequest, error) { func (s *CommissionWithdrawalRequestStore) GetByID(ctx context.Context, id uint) (*model.CommissionWithdrawalRequest, error) {
var req model.CommissionWithdrawalRequest var req model.CommissionWithdrawalRequest
if err := s.db.WithContext(ctx).First(&req, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&req).Error; err != nil {
return nil, err return nil, err
} }
return &req, nil return &req, nil
@@ -52,6 +56,8 @@ func (s *CommissionWithdrawalRequestStore) ListByShopID(ctx context.Context, opt
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.CommissionWithdrawalRequest{}) query := s.db.WithContext(ctx).Model(&model.CommissionWithdrawalRequest{})
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if filters != nil { if filters != nil {
if filters.ShopID > 0 { if filters.ShopID > 0 {
@@ -146,6 +152,8 @@ func (s *CommissionWithdrawalRequestStore) List(ctx context.Context, opts *store
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.CommissionWithdrawalRequest{}) query := s.db.WithContext(ctx).Model(&model.CommissionWithdrawalRequest{})
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if filters != nil { if filters != nil {
if filters.WithdrawalNo != "" { if filters.WithdrawalNo != "" {

View File

@@ -7,6 +7,7 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -36,7 +37,10 @@ func (s *DeviceStore) CreateBatch(ctx context.Context, devices []*model.Device)
func (s *DeviceStore) GetByID(ctx context.Context, id uint) (*model.Device, error) { func (s *DeviceStore) GetByID(ctx context.Context, id uint) (*model.Device, error) {
var device model.Device var device model.Device
if err := s.db.WithContext(ctx).First(&device, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤NULL shop_id 对代理用户不可见)
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&device).Error; err != nil {
return nil, err return nil, err
} }
return &device, nil return &device, nil
@@ -44,7 +48,10 @@ func (s *DeviceStore) GetByID(ctx context.Context, id uint) (*model.Device, erro
func (s *DeviceStore) GetByDeviceNo(ctx context.Context, deviceNo string) (*model.Device, error) { func (s *DeviceStore) GetByDeviceNo(ctx context.Context, deviceNo string) (*model.Device, error) {
var device model.Device var device model.Device
if err := s.db.WithContext(ctx).Where("device_no = ?", deviceNo).First(&device).Error; err != nil { query := s.db.WithContext(ctx).Where("device_no = ?", deviceNo)
// 应用数据权限过滤NULL shop_id 对代理用户不可见)
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&device).Error; err != nil {
return nil, err return nil, err
} }
return &device, nil return &device, nil
@@ -55,7 +62,10 @@ func (s *DeviceStore) GetByIDs(ctx context.Context, ids []uint) ([]*model.Device
if len(ids) == 0 { if len(ids) == 0 {
return devices, nil return devices, nil
} }
if err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&devices).Error; err != nil { query := s.db.WithContext(ctx).Where("id IN ?", ids)
// 应用数据权限过滤NULL shop_id 对代理用户不可见)
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&devices).Error; err != nil {
return nil, err return nil, err
} }
return devices, nil return devices, nil
@@ -74,6 +84,8 @@ func (s *DeviceStore) List(ctx context.Context, opts *store.QueryOptions, filter
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.Device{}) query := s.db.WithContext(ctx).Model(&model.Device{})
// 应用数据权限过滤NULL shop_id 对代理用户不可见)
query = middleware.ApplyShopFilter(ctx, query)
if deviceNo, ok := filters["device_no"].(string); ok && deviceNo != "" { if deviceNo, ok := filters["device_no"].(string); ok && deviceNo != "" {
query = query.Where("device_no LIKE ?", "%"+deviceNo+"%") query = query.Where("device_no LIKE ?", "%"+deviceNo+"%")
@@ -179,7 +191,10 @@ func (s *DeviceStore) GetByDeviceNos(ctx context.Context, deviceNos []string) ([
if len(deviceNos) == 0 { if len(deviceNos) == 0 {
return devices, nil return devices, nil
} }
if err := s.db.WithContext(ctx).Where("device_no IN ?", deviceNos).Find(&devices).Error; err != nil { query := s.db.WithContext(ctx).Where("device_no IN ?", deviceNos)
// 应用数据权限过滤NULL shop_id 对代理用户不可见)
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&devices).Error; err != nil {
return nil, err return nil, err
} }
return devices, nil return devices, nil
@@ -198,7 +213,10 @@ func (s *DeviceStore) BatchUpdateSeriesID(ctx context.Context, deviceIDs []uint,
// ListBySeriesID 根据套餐系列ID查询设备列表 // ListBySeriesID 根据套餐系列ID查询设备列表
func (s *DeviceStore) ListBySeriesID(ctx context.Context, seriesID uint) ([]*model.Device, error) { func (s *DeviceStore) ListBySeriesID(ctx context.Context, seriesID uint) ([]*model.Device, error) {
var devices []*model.Device var devices []*model.Device
if err := s.db.WithContext(ctx).Where("series_id = ?", seriesID).Find(&devices).Error; err != nil { query := s.db.WithContext(ctx).Where("series_id = ?", seriesID)
// 应用数据权限过滤NULL shop_id 对代理用户不可见)
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&devices).Error; err != nil {
return nil, err return nil, err
} }
return devices, nil return devices, nil

View File

@@ -6,7 +6,6 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
pkgGorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
"github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
@@ -50,9 +49,11 @@ func (s *EnterpriseCardAuthorizationStore) RevokeAuthorizations(ctx context.Cont
func (s *EnterpriseCardAuthorizationStore) GetByEnterpriseAndCard(ctx context.Context, enterpriseID, cardID uint) (*model.EnterpriseCardAuthorization, error) { func (s *EnterpriseCardAuthorizationStore) GetByEnterpriseAndCard(ctx context.Context, enterpriseID, cardID uint) (*model.EnterpriseCardAuthorization, error) {
var auth model.EnterpriseCardAuthorization var auth model.EnterpriseCardAuthorization
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("enterprise_id = ? AND card_id = ?", enterpriseID, cardID). Where("enterprise_id = ? AND card_id = ?", enterpriseID, cardID)
First(&auth).Error // 应用数据权限过滤
query = s.applyEnterpriseAuthFilter(ctx, query)
err := query.First(&auth).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -62,6 +63,8 @@ func (s *EnterpriseCardAuthorizationStore) GetByEnterpriseAndCard(ctx context.Co
func (s *EnterpriseCardAuthorizationStore) ListByEnterprise(ctx context.Context, enterpriseID uint, includeRevoked bool) ([]*model.EnterpriseCardAuthorization, error) { func (s *EnterpriseCardAuthorizationStore) ListByEnterprise(ctx context.Context, enterpriseID uint, includeRevoked bool) ([]*model.EnterpriseCardAuthorization, error) {
var auths []*model.EnterpriseCardAuthorization var auths []*model.EnterpriseCardAuthorization
query := s.db.WithContext(ctx).Where("enterprise_id = ?", enterpriseID) query := s.db.WithContext(ctx).Where("enterprise_id = ?", enterpriseID)
// 应用数据权限过滤
query = s.applyEnterpriseAuthFilter(ctx, query)
if !includeRevoked { if !includeRevoked {
query = query.Where("revoked_at IS NULL") query = query.Where("revoked_at IS NULL")
} }
@@ -77,6 +80,8 @@ func (s *EnterpriseCardAuthorizationStore) ListByCards(ctx context.Context, card
} }
var auths []*model.EnterpriseCardAuthorization var auths []*model.EnterpriseCardAuthorization
query := s.db.WithContext(ctx).Where("card_id IN ?", cardIDs) query := s.db.WithContext(ctx).Where("card_id IN ?", cardIDs)
// 应用数据权限过滤
query = s.applyEnterpriseAuthFilter(ctx, query)
if !includeRevoked { if !includeRevoked {
query = query.Where("revoked_at IS NULL") query = query.Where("revoked_at IS NULL")
} }
@@ -88,17 +93,21 @@ func (s *EnterpriseCardAuthorizationStore) ListByCards(ctx context.Context, card
func (s *EnterpriseCardAuthorizationStore) GetActiveAuthorizedCardIDs(ctx context.Context, enterpriseID uint) ([]uint, error) { func (s *EnterpriseCardAuthorizationStore) GetActiveAuthorizedCardIDs(ctx context.Context, enterpriseID uint) ([]uint, error) {
var cardIDs []uint var cardIDs []uint
err := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). query := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}).
Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID). Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID)
Pluck("card_id", &cardIDs).Error // 应用数据权限过滤
query = s.applyEnterpriseAuthFilter(ctx, query)
err := query.Pluck("card_id", &cardIDs).Error
return cardIDs, err return cardIDs, err
} }
func (s *EnterpriseCardAuthorizationStore) CheckAuthorizationExists(ctx context.Context, enterpriseID, cardID uint) (bool, error) { func (s *EnterpriseCardAuthorizationStore) CheckAuthorizationExists(ctx context.Context, enterpriseID, cardID uint) (bool, error) {
var count int64 var count int64
err := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). query := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}).
Where("enterprise_id = ? AND card_id = ? AND revoked_at IS NULL", enterpriseID, cardID). Where("enterprise_id = ? AND card_id = ? AND revoked_at IS NULL", enterpriseID, cardID)
Count(&count).Error // 应用数据权限过滤
query = s.applyEnterpriseAuthFilter(ctx, query)
err := query.Count(&count).Error
return count > 0, err return count > 0, err
} }
@@ -115,6 +124,8 @@ type AuthorizationListOptions struct {
func (s *EnterpriseCardAuthorizationStore) ListWithOptions(ctx context.Context, opts AuthorizationListOptions) ([]*model.EnterpriseCardAuthorization, int64, error) { func (s *EnterpriseCardAuthorizationStore) ListWithOptions(ctx context.Context, opts AuthorizationListOptions) ([]*model.EnterpriseCardAuthorization, int64, error) {
var auths []*model.EnterpriseCardAuthorization var auths []*model.EnterpriseCardAuthorization
query := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}) query := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{})
// 应用数据权限过滤
query = s.applyEnterpriseAuthFilter(ctx, query)
if opts.EnterpriseID != nil { if opts.EnterpriseID != nil {
query = query.Where("enterprise_id = ?", *opts.EnterpriseID) query = query.Where("enterprise_id = ?", *opts.EnterpriseID)
@@ -154,9 +165,11 @@ func (s *EnterpriseCardAuthorizationStore) GetActiveAuthsByCardIDs(ctx context.C
return make(map[uint]bool), nil return make(map[uint]bool), nil
} }
var authCardIDs []uint var authCardIDs []uint
err := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). query := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}).
Where("enterprise_id = ? AND card_id IN ? AND revoked_at IS NULL", enterpriseID, cardIDs). Where("enterprise_id = ? AND card_id IN ? AND revoked_at IS NULL", enterpriseID, cardIDs)
Pluck("card_id", &authCardIDs).Error // 应用数据权限过滤
query = s.applyEnterpriseAuthFilter(ctx, query)
err := query.Pluck("card_id", &authCardIDs).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -186,9 +199,11 @@ func (s *EnterpriseCardAuthorizationStore) BatchUpdateStatus(ctx context.Context
// ListCardIDsByEnterprise 获取企业的有效授权卡ID列表 // ListCardIDsByEnterprise 获取企业的有效授权卡ID列表
func (s *EnterpriseCardAuthorizationStore) ListCardIDsByEnterprise(ctx context.Context, enterpriseID uint) ([]uint, error) { func (s *EnterpriseCardAuthorizationStore) ListCardIDsByEnterprise(ctx context.Context, enterpriseID uint) ([]uint, error) {
var cardIDs []uint var cardIDs []uint
err := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). query := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}).
Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID). Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID)
Pluck("card_id", &cardIDs).Error // 应用数据权限过滤
query = s.applyEnterpriseAuthFilter(ctx, query)
err := query.Pluck("card_id", &cardIDs).Error
return cardIDs, err return cardIDs, err
} }
@@ -233,31 +248,28 @@ func (s *EnterpriseCardAuthorizationStore) ListWithJoin(ctx context.Context, opt
args := []interface{}{} args := []interface{}{}
// 数据权限过滤(原生 SQL 需要手动处理) // 数据权限过滤(原生 SQL 需要手动处理)
// 检查是否跳过数据权限过滤 userType := middleware.GetUserTypeFromContext(ctx)
if skip, ok := ctx.Value(pkgGorm.SkipDataPermissionKey).(bool); !ok || !skip { // 超级管理员和平台用户跳过过滤
userType := middleware.GetUserTypeFromContext(ctx) if userType != constants.UserTypeSuperAdmin && userType != constants.UserTypePlatform {
// 超级管理员和平台用户跳过过滤 if userType == constants.UserTypeAgent {
if userType != constants.UserTypeSuperAdmin && userType != constants.UserTypePlatform { // 代理用户:只能看到自己及下级店铺所拥有企业的授权记录
if userType == constants.UserTypeAgent { shopIDs := middleware.GetSubordinateShopIDs(ctx)
shopID := middleware.GetShopIDFromContext(ctx) if len(shopIDs) == 0 {
if shopID == 0 { // 代理用户没有下级店铺信息,返回空结果
// 代理用户没有 shop_id返回空结果
return []AuthorizationWithJoin{}, 0, nil
}
// 只能看到自己店铺下企业的授权记录(不包含下级店铺)
baseQuery += " AND a.enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id = ? AND deleted_at IS NULL)"
args = append(args, shopID)
} else if userType == constants.UserTypeEnterprise {
enterpriseID := middleware.GetEnterpriseIDFromContext(ctx)
if enterpriseID == 0 {
return []AuthorizationWithJoin{}, 0, nil
}
baseQuery += " AND a.enterprise_id = ?"
args = append(args, enterpriseID)
} else {
// 其他用户类型(个人客户等)不应访问授权记录
return []AuthorizationWithJoin{}, 0, nil return []AuthorizationWithJoin{}, 0, nil
} }
baseQuery += " AND a.enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id IN (?) AND deleted_at IS NULL)"
args = append(args, shopIDs)
} else if userType == constants.UserTypeEnterprise {
enterpriseID := middleware.GetEnterpriseIDFromContext(ctx)
if enterpriseID == 0 {
return []AuthorizationWithJoin{}, 0, nil
}
baseQuery += " AND a.enterprise_id = ?"
args = append(args, enterpriseID)
} else {
// 其他用户类型(个人客户等)不应访问授权记录
return []AuthorizationWithJoin{}, 0, nil
} }
} }
@@ -338,26 +350,25 @@ func (s *EnterpriseCardAuthorizationStore) GetByIDWithJoin(ctx context.Context,
args := []interface{}{id} args := []interface{}{id}
// 数据权限过滤(原生 SQL 需要手动处理) // 数据权限过滤(原生 SQL 需要手动处理)
if skip, ok := ctx.Value(pkgGorm.SkipDataPermissionKey).(bool); !ok || !skip { userType := middleware.GetUserTypeFromContext(ctx)
userType := middleware.GetUserTypeFromContext(ctx) if userType != constants.UserTypeSuperAdmin && userType != constants.UserTypePlatform {
if userType != constants.UserTypeSuperAdmin && userType != constants.UserTypePlatform { if userType == constants.UserTypeAgent {
if userType == constants.UserTypeAgent { // 代理用户:只能看到自己及下级店铺所拥有企业的授权记录
shopID := middleware.GetShopIDFromContext(ctx) shopIDs := middleware.GetSubordinateShopIDs(ctx)
if shopID == 0 { if len(shopIDs) == 0 {
return nil, gorm.ErrRecordNotFound
}
baseSQL += " AND a.enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id = ? AND deleted_at IS NULL)"
args = append(args, shopID)
} else if userType == constants.UserTypeEnterprise {
enterpriseID := middleware.GetEnterpriseIDFromContext(ctx)
if enterpriseID == 0 {
return nil, gorm.ErrRecordNotFound
}
baseSQL += " AND a.enterprise_id = ?"
args = append(args, enterpriseID)
} else {
return nil, gorm.ErrRecordNotFound return nil, gorm.ErrRecordNotFound
} }
baseSQL += " AND a.enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id IN (?) AND deleted_at IS NULL)"
args = append(args, shopIDs)
} else if userType == constants.UserTypeEnterprise {
enterpriseID := middleware.GetEnterpriseIDFromContext(ctx)
if enterpriseID == 0 {
return nil, gorm.ErrRecordNotFound
}
baseSQL += " AND a.enterprise_id = ?"
args = append(args, enterpriseID)
} else {
return nil, gorm.ErrRecordNotFound
} }
} }
@@ -401,7 +412,10 @@ func (s *EnterpriseCardAuthorizationStore) UpdateRemarkWithConstraint(ctx contex
func (s *EnterpriseCardAuthorizationStore) GetByID(ctx context.Context, id uint) (*model.EnterpriseCardAuthorization, error) { func (s *EnterpriseCardAuthorizationStore) GetByID(ctx context.Context, id uint) (*model.EnterpriseCardAuthorization, error) {
var auth model.EnterpriseCardAuthorization var auth model.EnterpriseCardAuthorization
err := s.db.WithContext(ctx).Where("id = ?", id).First(&auth).Error query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤
query = s.applyEnterpriseAuthFilter(ctx, query)
err := query.First(&auth).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -417,3 +431,23 @@ func (s *EnterpriseCardAuthorizationStore) RevokeByDeviceAuthID(ctx context.Cont
"revoked_at": now, "revoked_at": now,
}).Error }).Error
} }
// applyEnterpriseAuthFilter 应用企业卡授权表的数据权限过滤
// 企业用户:只能看到自己企业的授权记录
// 代理用户:只能看到自己及下级店铺所拥有企业的授权记录
// 平台/超管:不过滤
func (s *EnterpriseCardAuthorizationStore) applyEnterpriseAuthFilter(ctx context.Context, query *gorm.DB) *gorm.DB {
// 企业用户过滤
query = middleware.ApplyEnterpriseFilter(ctx, query)
// 代理用户:通过企业的 owner_shop_id 过滤
userType := middleware.GetUserTypeFromContext(ctx)
if userType == constants.UserTypeAgent {
shopIDs := middleware.GetSubordinateShopIDs(ctx)
if shopIDs != nil {
query = query.Where("enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id IN ? AND deleted_at IS NULL)", shopIDs)
}
}
return query
}

View File

@@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -45,7 +46,10 @@ func (s *EnterpriseDeviceAuthorizationStore) BatchCreate(ctx context.Context, au
func (s *EnterpriseDeviceAuthorizationStore) GetByID(ctx context.Context, id uint) (*model.EnterpriseDeviceAuthorization, error) { func (s *EnterpriseDeviceAuthorizationStore) GetByID(ctx context.Context, id uint) (*model.EnterpriseDeviceAuthorization, error) {
var auth model.EnterpriseDeviceAuthorization var auth model.EnterpriseDeviceAuthorization
err := s.db.WithContext(ctx).Where("id = ?", id).First(&auth).Error query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用企业数据权限过滤
query = middleware.ApplyEnterpriseFilter(ctx, query)
err := query.First(&auth).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -54,9 +58,11 @@ func (s *EnterpriseDeviceAuthorizationStore) GetByID(ctx context.Context, id uin
func (s *EnterpriseDeviceAuthorizationStore) GetByDeviceID(ctx context.Context, deviceID uint) (*model.EnterpriseDeviceAuthorization, error) { func (s *EnterpriseDeviceAuthorizationStore) GetByDeviceID(ctx context.Context, deviceID uint) (*model.EnterpriseDeviceAuthorization, error) {
var auth model.EnterpriseDeviceAuthorization var auth model.EnterpriseDeviceAuthorization
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("device_id = ? AND revoked_at IS NULL", deviceID). Where("device_id = ? AND revoked_at IS NULL", deviceID)
First(&auth).Error // 应用企业数据权限过滤
query = middleware.ApplyEnterpriseFilter(ctx, query)
err := query.First(&auth).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -66,6 +72,8 @@ func (s *EnterpriseDeviceAuthorizationStore) GetByDeviceID(ctx context.Context,
func (s *EnterpriseDeviceAuthorizationStore) GetByEnterpriseID(ctx context.Context, enterpriseID uint, includeRevoked bool) ([]*model.EnterpriseDeviceAuthorization, error) { func (s *EnterpriseDeviceAuthorizationStore) GetByEnterpriseID(ctx context.Context, enterpriseID uint, includeRevoked bool) ([]*model.EnterpriseDeviceAuthorization, error) {
var auths []*model.EnterpriseDeviceAuthorization var auths []*model.EnterpriseDeviceAuthorization
query := s.db.WithContext(ctx).Where("enterprise_id = ?", enterpriseID) query := s.db.WithContext(ctx).Where("enterprise_id = ?", enterpriseID)
// 应用企业数据权限过滤
query = middleware.ApplyEnterpriseFilter(ctx, query)
if !includeRevoked { if !includeRevoked {
query = query.Where("revoked_at IS NULL") query = query.Where("revoked_at IS NULL")
} }
@@ -87,6 +95,8 @@ func (s *EnterpriseDeviceAuthorizationStore) ListByEnterprise(ctx context.Contex
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.EnterpriseDeviceAuthorization{}) query := s.db.WithContext(ctx).Model(&model.EnterpriseDeviceAuthorization{})
// 应用企业数据权限过滤
query = middleware.ApplyEnterpriseFilter(ctx, query)
if opts.EnterpriseID != nil { if opts.EnterpriseID != nil {
query = query.Where("enterprise_id = ?", *opts.EnterpriseID) query = query.Where("enterprise_id = ?", *opts.EnterpriseID)
@@ -134,10 +144,12 @@ func (s *EnterpriseDeviceAuthorizationStore) GetActiveAuthsByDeviceIDs(ctx conte
} }
var auths []model.EnterpriseDeviceAuthorization var auths []model.EnterpriseDeviceAuthorization
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Select("device_id"). Select("device_id").
Where("enterprise_id = ? AND device_id IN ? AND revoked_at IS NULL", enterpriseID, deviceIDs). Where("enterprise_id = ? AND device_id IN ? AND revoked_at IS NULL", enterpriseID, deviceIDs)
Find(&auths).Error // 应用企业数据权限过滤
query = middleware.ApplyEnterpriseFilter(ctx, query)
err := query.Find(&auths).Error
if err != nil { if err != nil {
return nil, err return nil, err
@@ -152,9 +164,11 @@ func (s *EnterpriseDeviceAuthorizationStore) GetActiveAuthsByDeviceIDs(ctx conte
func (s *EnterpriseDeviceAuthorizationStore) ListDeviceIDsByEnterprise(ctx context.Context, enterpriseID uint) ([]uint, error) { func (s *EnterpriseDeviceAuthorizationStore) ListDeviceIDsByEnterprise(ctx context.Context, enterpriseID uint) ([]uint, error) {
var deviceIDs []uint var deviceIDs []uint
err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Model(&model.EnterpriseDeviceAuthorization{}). Model(&model.EnterpriseDeviceAuthorization{}).
Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID). Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID)
Pluck("device_id", &deviceIDs).Error // 应用企业数据权限过滤
query = middleware.ApplyEnterpriseFilter(ctx, query)
err := query.Pluck("device_id", &deviceIDs).Error
return deviceIDs, err return deviceIDs, err
} }

View File

@@ -6,6 +6,7 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -32,7 +33,10 @@ func (s *EnterpriseStore) Create(ctx context.Context, enterprise *model.Enterpri
// GetByID 根据 ID 获取企业 // GetByID 根据 ID 获取企业
func (s *EnterpriseStore) GetByID(ctx context.Context, id uint) (*model.Enterprise, error) { func (s *EnterpriseStore) GetByID(ctx context.Context, id uint) (*model.Enterprise, error) {
var enterprise model.Enterprise var enterprise model.Enterprise
if err := s.db.WithContext(ctx).First(&enterprise, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用归属店铺数据权限过滤
query = middleware.ApplyOwnerShopFilter(ctx, query)
if err := query.First(&enterprise).Error; err != nil {
return nil, err return nil, err
} }
return &enterprise, nil return &enterprise, nil
@@ -41,7 +45,10 @@ func (s *EnterpriseStore) GetByID(ctx context.Context, id uint) (*model.Enterpri
// GetByCode 根据企业编号获取企业 // GetByCode 根据企业编号获取企业
func (s *EnterpriseStore) GetByCode(ctx context.Context, code string) (*model.Enterprise, error) { func (s *EnterpriseStore) GetByCode(ctx context.Context, code string) (*model.Enterprise, error) {
var enterprise model.Enterprise var enterprise model.Enterprise
if err := s.db.WithContext(ctx).Where("enterprise_code = ?", code).First(&enterprise).Error; err != nil { query := s.db.WithContext(ctx).Where("enterprise_code = ?", code)
// 应用归属店铺数据权限过滤
query = middleware.ApplyOwnerShopFilter(ctx, query)
if err := query.First(&enterprise).Error; err != nil {
return nil, err return nil, err
} }
return &enterprise, nil return &enterprise, nil
@@ -63,6 +70,8 @@ func (s *EnterpriseStore) List(ctx context.Context, opts *store.QueryOptions, fi
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.Enterprise{}) query := s.db.WithContext(ctx).Model(&model.Enterprise{})
// 应用归属店铺数据权限过滤
query = middleware.ApplyOwnerShopFilter(ctx, query)
// 应用过滤条件 // 应用过滤条件
if enterpriseName, ok := filters["enterprise_name"].(string); ok && enterpriseName != "" { if enterpriseName, ok := filters["enterprise_name"].(string); ok && enterpriseName != "" {
@@ -111,7 +120,10 @@ func (s *EnterpriseStore) List(ctx context.Context, opts *store.QueryOptions, fi
// GetByOwnerShopID 根据归属店铺 ID 查询企业列表 // GetByOwnerShopID 根据归属店铺 ID 查询企业列表
func (s *EnterpriseStore) GetByOwnerShopID(ctx context.Context, ownerShopID uint) ([]*model.Enterprise, error) { func (s *EnterpriseStore) GetByOwnerShopID(ctx context.Context, ownerShopID uint) ([]*model.Enterprise, error) {
var enterprises []*model.Enterprise var enterprises []*model.Enterprise
if err := s.db.WithContext(ctx).Where("owner_shop_id = ?", ownerShopID).Find(&enterprises).Error; err != nil { query := s.db.WithContext(ctx).Where("owner_shop_id = ?", ownerShopID)
// 应用归属店铺数据权限过滤
query = middleware.ApplyOwnerShopFilter(ctx, query)
if err := query.Find(&enterprises).Error; err != nil {
return nil, err return nil, err
} }
return enterprises, nil return enterprises, nil
@@ -120,7 +132,10 @@ func (s *EnterpriseStore) GetByOwnerShopID(ctx context.Context, ownerShopID uint
// GetPlatformEnterprises 获取平台直属企业列表owner_shop_id 为 NULL // GetPlatformEnterprises 获取平台直属企业列表owner_shop_id 为 NULL
func (s *EnterpriseStore) GetPlatformEnterprises(ctx context.Context) ([]*model.Enterprise, error) { func (s *EnterpriseStore) GetPlatformEnterprises(ctx context.Context) ([]*model.Enterprise, error) {
var enterprises []*model.Enterprise var enterprises []*model.Enterprise
if err := s.db.WithContext(ctx).Where("owner_shop_id IS NULL").Find(&enterprises).Error; err != nil { query := s.db.WithContext(ctx).Where("owner_shop_id IS NULL")
// 应用归属店铺数据权限过滤(代理用户无法看到平台直属企业)
query = middleware.ApplyOwnerShopFilter(ctx, query)
if err := query.Find(&enterprises).Error; err != nil {
return nil, err return nil, err
} }
return enterprises, nil return enterprises, nil
@@ -132,7 +147,10 @@ func (s *EnterpriseStore) GetByIDs(ctx context.Context, ids []uint) ([]*model.En
return []*model.Enterprise{}, nil return []*model.Enterprise{}, nil
} }
var enterprises []*model.Enterprise var enterprises []*model.Enterprise
if err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&enterprises).Error; err != nil { query := s.db.WithContext(ctx).Where("id IN ?", ids)
// 应用归属店铺数据权限过滤
query = middleware.ApplyOwnerShopFilter(ctx, query)
if err := query.Find(&enterprises).Error; err != nil {
return nil, err return nil, err
} }
return enterprises, nil return enterprises, nil

View File

@@ -8,6 +8,7 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -30,7 +31,10 @@ func (s *IotCardImportTaskStore) Create(ctx context.Context, task *model.IotCard
func (s *IotCardImportTaskStore) GetByID(ctx context.Context, id uint) (*model.IotCardImportTask, error) { func (s *IotCardImportTaskStore) GetByID(ctx context.Context, id uint) (*model.IotCardImportTask, error) {
var task model.IotCardImportTask var task model.IotCardImportTask
if err := s.db.WithContext(ctx).First(&task, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&task).Error; err != nil {
return nil, err return nil, err
} }
return &task, nil return &task, nil
@@ -38,7 +42,10 @@ func (s *IotCardImportTaskStore) GetByID(ctx context.Context, id uint) (*model.I
func (s *IotCardImportTaskStore) GetByTaskNo(ctx context.Context, taskNo string) (*model.IotCardImportTask, error) { func (s *IotCardImportTaskStore) GetByTaskNo(ctx context.Context, taskNo string) (*model.IotCardImportTask, error) {
var task model.IotCardImportTask var task model.IotCardImportTask
if err := s.db.WithContext(ctx).Where("task_no = ?", taskNo).First(&task).Error; err != nil { query := s.db.WithContext(ctx).Where("task_no = ?", taskNo)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&task).Error; err != nil {
return nil, err return nil, err
} }
return &task, nil return &task, nil
@@ -82,6 +89,8 @@ func (s *IotCardImportTaskStore) List(ctx context.Context, opts *store.QueryOpti
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.IotCardImportTask{}) query := s.db.WithContext(ctx).Model(&model.IotCardImportTask{})
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if status, ok := filters["status"].(int); ok && status > 0 { if status, ok := filters["status"].(int); ok && status > 0 {
query = query.Where("status = ?", status) query = query.Where("status = ?", status)

View File

@@ -11,7 +11,6 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
"github.com/break/junhong_cmp_fiber/pkg/logger" "github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
@@ -46,7 +45,10 @@ func (s *IotCardStore) CreateBatch(ctx context.Context, cards []*model.IotCard)
func (s *IotCardStore) GetByID(ctx context.Context, id uint) (*model.IotCard, error) { func (s *IotCardStore) GetByID(ctx context.Context, id uint) (*model.IotCard, error) {
var card model.IotCard var card model.IotCard
if err := s.db.WithContext(ctx).First(&card, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤NULL shop_id 对代理用户不可见)
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&card).Error; err != nil {
return nil, err return nil, err
} }
return &card, nil return &card, nil
@@ -54,7 +56,10 @@ func (s *IotCardStore) GetByID(ctx context.Context, id uint) (*model.IotCard, er
func (s *IotCardStore) GetByICCID(ctx context.Context, iccid string) (*model.IotCard, error) { func (s *IotCardStore) GetByICCID(ctx context.Context, iccid string) (*model.IotCard, error) {
var card model.IotCard var card model.IotCard
if err := s.db.WithContext(ctx).Where("iccid = ?", iccid).First(&card).Error; err != nil { query := s.db.WithContext(ctx).Where("iccid = ?", iccid)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&card).Error; err != nil {
return nil, err return nil, err
} }
return &card, nil return &card, nil
@@ -65,7 +70,10 @@ func (s *IotCardStore) GetByIDs(ctx context.Context, ids []uint) ([]*model.IotCa
return []*model.IotCard{}, nil return []*model.IotCard{}, nil
} }
var cards []*model.IotCard var cards []*model.IotCard
if err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&cards).Error; err != nil { query := s.db.WithContext(ctx).Where("id IN ?", ids)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&cards).Error; err != nil {
return nil, err return nil, err
} }
return cards, nil return cards, nil
@@ -111,13 +119,15 @@ func (s *IotCardStore) List(ctx context.Context, opts *store.QueryOptions, filte
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.IotCard{}) query := s.db.WithContext(ctx).Model(&model.IotCard{})
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
// 企业用户特殊处理:只能看到授权给自己的卡 // 企业用户特殊处理:只能看到授权给自己的卡
// 子查询跳过数据权限过滤,权限已由外层查询的 GORM callback 保证 // 子查询跳过数据权限过滤,权限已由外层查询的 GORM callback 保证
skipCtx := pkggorm.SkipDataPermission(ctx) // 子查询无需数据权限过滤(在不同表上执行)
if enterpriseID, ok := filters["authorized_enterprise_id"].(uint); ok && enterpriseID > 0 { if enterpriseID, ok := filters["authorized_enterprise_id"].(uint); ok && enterpriseID > 0 {
query = query.Where("id IN (?)", query = query.Where("id IN (?)",
s.db.WithContext(skipCtx).Table("tb_enterprise_card_authorization"). s.db.WithContext(ctx).Table("tb_enterprise_card_authorization").
Select("card_id"). Select("card_id").
Where("enterprise_id = ? AND revoked_at IS NULL AND deleted_at IS NULL", enterpriseID)) Where("enterprise_id = ? AND revoked_at IS NULL AND deleted_at IS NULL", enterpriseID))
} }
@@ -143,7 +153,7 @@ func (s *IotCardStore) List(ctx context.Context, opts *store.QueryOptions, filte
} }
if packageID, ok := filters["package_id"].(uint); ok && packageID > 0 { if packageID, ok := filters["package_id"].(uint); ok && packageID > 0 {
query = query.Where("id IN (?)", query = query.Where("id IN (?)",
s.db.WithContext(skipCtx).Table("tb_package_usage"). s.db.WithContext(ctx).Table("tb_package_usage").
Select("iot_card_id"). Select("iot_card_id").
Where("package_id = ? AND deleted_at IS NULL", packageID)) Where("package_id = ? AND deleted_at IS NULL", packageID))
} }
@@ -249,6 +259,8 @@ func (s *IotCardStore) listStandaloneTwoPhase(ctx context.Context, opts *store.Q
query := s.db.WithContext(ctx).Model(&model.IotCard{}). query := s.db.WithContext(ctx).Model(&model.IotCard{}).
Where("is_standalone = true") Where("is_standalone = true")
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
query = s.applyStandaloneFilters(ctx, query, filters) query = s.applyStandaloneFilters(ctx, query, filters)
if cachedTotal, ok := s.getCachedCount(ctx, "standalone", filters); ok { if cachedTotal, ok := s.getCachedCount(ctx, "standalone", filters); ok {
@@ -309,6 +321,8 @@ func (s *IotCardStore) listStandaloneDefault(ctx context.Context, opts *store.Qu
query := s.db.WithContext(ctx).Model(&model.IotCard{}). query := s.db.WithContext(ctx).Model(&model.IotCard{}).
Where("is_standalone = true") Where("is_standalone = true")
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
query = s.applyStandaloneFilters(ctx, query, filters) query = s.applyStandaloneFilters(ctx, query, filters)
if cachedTotal, ok := s.getCachedCount(ctx, "standalone", filters); ok { if cachedTotal, ok := s.getCachedCount(ctx, "standalone", filters); ok {
@@ -339,7 +353,7 @@ func (s *IotCardStore) listStandaloneDefault(ctx context.Context, opts *store.Qu
// 将 shop_id IN (...) 拆分为 per-shop 独立查询,每个查询走 Index Scan // 将 shop_id IN (...) 拆分为 per-shop 独立查询,每个查询走 Index Scan
// 然后在应用层归并排序,避免 PG 对多值 IN + ORDER BY 选择全表扫描 // 然后在应用层归并排序,避免 PG 对多值 IN + ORDER BY 选择全表扫描
func (s *IotCardStore) listStandaloneParallel(ctx context.Context, opts *store.QueryOptions, filters map[string]any, shopIDs []uint) ([]*model.IotCard, int64, error) { func (s *IotCardStore) listStandaloneParallel(ctx context.Context, opts *store.QueryOptions, filters map[string]any, shopIDs []uint) ([]*model.IotCard, int64, error) {
skipCtx := pkggorm.SkipDataPermission(ctx) // 子查询无需数据权限过滤(在不同表上执行)
fetchLimit := (opts.Page-1)*opts.PageSize + opts.PageSize fetchLimit := (opts.Page-1)*opts.PageSize + opts.PageSize
@@ -366,9 +380,9 @@ func (s *IotCardStore) listStandaloneParallel(ctx context.Context, opts *store.Q
go func(idx int, sid uint) { go func(idx int, sid uint) {
defer wg.Done() defer wg.Done()
q := s.db.WithContext(skipCtx).Model(&model.IotCard{}). q := s.db.WithContext(ctx).Model(&model.IotCard{}).
Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid) Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid)
q = s.applyStandaloneFilters(skipCtx, q, filters) q = s.applyStandaloneFilters(ctx, q, filters)
var cards []*model.IotCard var cards []*model.IotCard
if err := q.Select(standaloneListColumns). if err := q.Select(standaloneListColumns).
@@ -381,9 +395,9 @@ func (s *IotCardStore) listStandaloneParallel(ctx context.Context, opts *store.Q
var count int64 var count int64
if !hasCachedTotal { if !hasCachedTotal {
countQ := s.db.WithContext(skipCtx).Model(&model.IotCard{}). countQ := s.db.WithContext(ctx).Model(&model.IotCard{}).
Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid) Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid)
countQ = s.applyStandaloneFilters(skipCtx, countQ, filters) countQ = s.applyStandaloneFilters(ctx, countQ, filters)
if err := countQ.Count(&count).Error; err != nil { if err := countQ.Count(&count).Error; err != nil {
results[idx] = shopResult{err: err} results[idx] = shopResult{err: err}
return return
@@ -455,7 +469,7 @@ type cardIDWithTime struct {
// 归并排序后取目标页的 20 个 ID // 归并排序后取目标页的 20 个 ID
// Phase 2: SELECT 完整列 WHERE id IN (20 IDs)PK 精确回表) // Phase 2: SELECT 完整列 WHERE id IN (20 IDs)PK 精确回表)
func (s *IotCardStore) listStandaloneParallelTwoPhase(ctx context.Context, opts *store.QueryOptions, filters map[string]any, shopIDs []uint) ([]*model.IotCard, int64, error) { func (s *IotCardStore) listStandaloneParallelTwoPhase(ctx context.Context, opts *store.QueryOptions, filters map[string]any, shopIDs []uint) ([]*model.IotCard, int64, error) {
skipCtx := pkggorm.SkipDataPermission(ctx) // 子查询无需数据权限过滤(在不同表上执行)
fetchLimit := (opts.Page-1)*opts.PageSize + opts.PageSize fetchLimit := (opts.Page-1)*opts.PageSize + opts.PageSize
@@ -476,9 +490,9 @@ func (s *IotCardStore) listStandaloneParallelTwoPhase(ctx context.Context, opts
go func(idx int, sid uint) { go func(idx int, sid uint) {
defer wg.Done() defer wg.Done()
q := s.db.WithContext(skipCtx).Model(&model.IotCard{}). q := s.db.WithContext(ctx).Model(&model.IotCard{}).
Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid) Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid)
q = s.applyStandaloneFilters(skipCtx, q, filters) q = s.applyStandaloneFilters(ctx, q, filters)
var ids []cardIDWithTime var ids []cardIDWithTime
if err := q.Select("id, created_at"). if err := q.Select("id, created_at").
@@ -491,9 +505,9 @@ func (s *IotCardStore) listStandaloneParallelTwoPhase(ctx context.Context, opts
var count int64 var count int64
if !hasCachedTotal { if !hasCachedTotal {
countQ := s.db.WithContext(skipCtx).Model(&model.IotCard{}). countQ := s.db.WithContext(ctx).Model(&model.IotCard{}).
Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid) Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid)
countQ = s.applyStandaloneFilters(skipCtx, countQ, filters) countQ = s.applyStandaloneFilters(ctx, countQ, filters)
if err := countQ.Count(&count).Error; err != nil { if err := countQ.Count(&count).Error; err != nil {
results[idx] = shopResult{err: err} results[idx] = shopResult{err: err}
return return
@@ -553,7 +567,7 @@ func (s *IotCardStore) listStandaloneParallelTwoPhase(ctx context.Context, opts
// Phase 2: 用 ID 精确回表获取完整数据PK Index Scan仅 20 行) // Phase 2: 用 ID 精确回表获取完整数据PK Index Scan仅 20 行)
var cards []*model.IotCard var cards []*model.IotCard
if err := s.db.WithContext(skipCtx).Model(&model.IotCard{}). if err := s.db.WithContext(ctx).Model(&model.IotCard{}).
Select(standaloneListColumns). Select(standaloneListColumns).
Where("id IN ?", pageIDs). Where("id IN ?", pageIDs).
Find(&cards).Error; err != nil { Find(&cards).Error; err != nil {
@@ -584,7 +598,7 @@ func (s *IotCardStore) listStandaloneParallelTwoPhase(ctx context.Context, opts
// 注意:不包含 is_standalone、shop_id、deleted_at 条件(由调用方控制) // 注意:不包含 is_standalone、shop_id、deleted_at 条件(由调用方控制)
// 也不包含 subordinate_shop_ids仅用于路由选择不作为查询条件 // 也不包含 subordinate_shop_ids仅用于路由选择不作为查询条件
func (s *IotCardStore) applyStandaloneFilters(ctx context.Context, query *gorm.DB, filters map[string]any) *gorm.DB { func (s *IotCardStore) applyStandaloneFilters(ctx context.Context, query *gorm.DB, filters map[string]any) *gorm.DB {
skipCtx := pkggorm.SkipDataPermission(ctx) // 子查询无需数据权限过滤(在不同表上执行)
if status, ok := filters["status"].(int); ok && status > 0 { if status, ok := filters["status"].(int); ok && status > 0 {
query = query.Where("status = ?", status) query = query.Where("status = ?", status)
@@ -607,7 +621,7 @@ func (s *IotCardStore) applyStandaloneFilters(ctx context.Context, query *gorm.D
} }
if packageID, ok := filters["package_id"].(uint); ok && packageID > 0 { if packageID, ok := filters["package_id"].(uint); ok && packageID > 0 {
query = query.Where("id IN (?)", query = query.Where("id IN (?)",
s.db.WithContext(skipCtx).Table("tb_package_usage"). s.db.WithContext(ctx).Table("tb_package_usage").
Select("iot_card_id"). Select("iot_card_id").
Where("package_id = ? AND deleted_at IS NULL", packageID)) Where("package_id = ? AND deleted_at IS NULL", packageID))
} }
@@ -627,12 +641,12 @@ func (s *IotCardStore) applyStandaloneFilters(ctx context.Context, query *gorm.D
if isReplaced, ok := filters["is_replaced"].(bool); ok { if isReplaced, ok := filters["is_replaced"].(bool); ok {
if isReplaced { if isReplaced {
query = query.Where("id IN (?)", query = query.Where("id IN (?)",
s.db.WithContext(skipCtx).Table("tb_card_replacement_record"). s.db.WithContext(ctx).Table("tb_card_replacement_record").
Select("old_iot_card_id"). Select("old_iot_card_id").
Where("deleted_at IS NULL")) Where("deleted_at IS NULL"))
} else { } else {
query = query.Where("id NOT IN (?)", query = query.Where("id NOT IN (?)",
s.db.WithContext(skipCtx).Table("tb_card_replacement_record"). s.db.WithContext(ctx).Table("tb_card_replacement_record").
Select("old_iot_card_id"). Select("old_iot_card_id").
Where("deleted_at IS NULL")) Where("deleted_at IS NULL"))
} }
@@ -649,7 +663,10 @@ func (s *IotCardStore) GetByICCIDs(ctx context.Context, iccids []string) ([]*mod
return nil, nil return nil, nil
} }
var cards []*model.IotCard var cards []*model.IotCard
if err := s.db.WithContext(ctx).Where("iccid IN ?", iccids).Find(&cards).Error; err != nil { query := s.db.WithContext(ctx).Where("iccid IN ?", iccids)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&cards).Error; err != nil {
return nil, err return nil, err
} }
return cards, nil return cards, nil
@@ -659,6 +676,8 @@ func (s *IotCardStore) GetStandaloneByICCIDRange(ctx context.Context, iccidStart
query := s.db.WithContext(ctx).Model(&model.IotCard{}). query := s.db.WithContext(ctx).Model(&model.IotCard{}).
Where("is_standalone = true"). Where("is_standalone = true").
Where("iccid >= ? AND iccid <= ?", iccidStart, iccidEnd) Where("iccid >= ? AND iccid <= ?", iccidStart, iccidEnd)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if shopID == nil { if shopID == nil {
query = query.Where("shop_id IS NULL") query = query.Where("shop_id IS NULL")
@@ -676,11 +695,13 @@ func (s *IotCardStore) GetStandaloneByICCIDRange(ctx context.Context, iccidStart
// GetDistributedStandaloneByICCIDRange 根据号段范围查询已分配给店铺的单卡(用于回收) // GetDistributedStandaloneByICCIDRange 根据号段范围查询已分配给店铺的单卡(用于回收)
func (s *IotCardStore) GetDistributedStandaloneByICCIDRange(ctx context.Context, iccidStart, iccidEnd string) ([]*model.IotCard, error) { func (s *IotCardStore) GetDistributedStandaloneByICCIDRange(ctx context.Context, iccidStart, iccidEnd string) ([]*model.IotCard, error) {
var cards []*model.IotCard var cards []*model.IotCard
if err := s.db.WithContext(ctx).Model(&model.IotCard{}). query := s.db.WithContext(ctx).Model(&model.IotCard{}).
Where("is_standalone = true"). Where("is_standalone = true").
Where("shop_id IS NOT NULL"). Where("shop_id IS NOT NULL").
Where("iccid >= ? AND iccid <= ?", iccidStart, iccidEnd). Where("iccid >= ? AND iccid <= ?", iccidStart, iccidEnd)
Find(&cards).Error; err != nil { // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&cards).Error; err != nil {
return nil, err return nil, err
} }
return cards, nil return cards, nil
@@ -689,6 +710,8 @@ func (s *IotCardStore) GetDistributedStandaloneByICCIDRange(ctx context.Context,
func (s *IotCardStore) GetStandaloneByFilters(ctx context.Context, filters map[string]any, shopID *uint) ([]*model.IotCard, error) { func (s *IotCardStore) GetStandaloneByFilters(ctx context.Context, filters map[string]any, shopID *uint) ([]*model.IotCard, error) {
query := s.db.WithContext(ctx).Model(&model.IotCard{}). query := s.db.WithContext(ctx).Model(&model.IotCard{}).
Where("is_standalone = true") Where("is_standalone = true")
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if shopID == nil { if shopID == nil {
query = query.Where("shop_id IS NULL") query = query.Where("shop_id IS NULL")
@@ -718,6 +741,8 @@ func (s *IotCardStore) GetDistributedStandaloneByFilters(ctx context.Context, fi
query := s.db.WithContext(ctx).Model(&model.IotCard{}). query := s.db.WithContext(ctx).Model(&model.IotCard{}).
Where("is_standalone = true"). Where("is_standalone = true").
Where("shop_id IS NOT NULL") Where("shop_id IS NOT NULL")
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if carrierID, ok := filters["carrier_id"].(uint); ok && carrierID > 0 { if carrierID, ok := filters["carrier_id"].(uint); ok && carrierID > 0 {
query = query.Where("carrier_id = ?", carrierID) query = query.Where("carrier_id = ?", carrierID)
@@ -764,10 +789,10 @@ func (s *IotCardStore) GetByIDsWithEnterpriseFilter(ctx context.Context, cardIDs
query := s.db.WithContext(ctx).Model(&model.IotCard{}) query := s.db.WithContext(ctx).Model(&model.IotCard{})
if enterpriseID != nil && *enterpriseID > 0 { if enterpriseID != nil && *enterpriseID > 0 {
skipCtx := pkggorm.SkipDataPermission(ctx) // 子查询无需数据权限过滤(在不同表上执行)
query = query.Where("id IN (?) AND id IN (?)", query = query.Where("id IN (?) AND id IN (?)",
cardIDs, cardIDs,
s.db.WithContext(skipCtx).Table("tb_enterprise_card_authorization"). s.db.WithContext(ctx).Table("tb_enterprise_card_authorization").
Select("card_id"). Select("card_id").
Where("enterprise_id = ? AND revoked_at IS NULL AND deleted_at IS NULL", *enterpriseID)) Where("enterprise_id = ? AND revoked_at IS NULL AND deleted_at IS NULL", *enterpriseID))
} else { } else {
@@ -796,7 +821,10 @@ func (s *IotCardStore) BatchUpdateSeriesID(ctx context.Context, cardIDs []uint,
// 用于查询某个套餐系列下的所有卡 // 用于查询某个套餐系列下的所有卡
func (s *IotCardStore) ListBySeriesID(ctx context.Context, seriesID uint) ([]*model.IotCard, error) { func (s *IotCardStore) ListBySeriesID(ctx context.Context, seriesID uint) ([]*model.IotCard, error) {
var cards []*model.IotCard var cards []*model.IotCard
if err := s.db.WithContext(ctx).Where("series_id = ?", seriesID).Find(&cards).Error; err != nil { query := s.db.WithContext(ctx).Where("series_id = ?", seriesID)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&cards).Error; err != nil {
return nil, err return nil, err
} }
return cards, nil return cards, nil

View File

@@ -8,6 +8,7 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -43,7 +44,10 @@ func (s *OrderStore) Create(ctx context.Context, order *model.Order, items []*mo
func (s *OrderStore) GetByID(ctx context.Context, id uint) (*model.Order, error) { func (s *OrderStore) GetByID(ctx context.Context, id uint) (*model.Order, error) {
var order model.Order var order model.Order
if err := s.db.WithContext(ctx).First(&order, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤(使用 seller_shop_id 字段)
query = middleware.ApplySellerShopFilter(ctx, query)
if err := query.First(&order).Error; err != nil {
return nil, err return nil, err
} }
return &order, nil return &order, nil
@@ -51,7 +55,10 @@ func (s *OrderStore) GetByID(ctx context.Context, id uint) (*model.Order, error)
func (s *OrderStore) GetByIDWithItems(ctx context.Context, id uint) (*model.Order, []*model.OrderItem, error) { func (s *OrderStore) GetByIDWithItems(ctx context.Context, id uint) (*model.Order, []*model.OrderItem, error) {
var order model.Order var order model.Order
if err := s.db.WithContext(ctx).First(&order, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤(使用 seller_shop_id 字段)
query = middleware.ApplySellerShopFilter(ctx, query)
if err := query.First(&order).Error; err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -65,7 +72,10 @@ func (s *OrderStore) GetByIDWithItems(ctx context.Context, id uint) (*model.Orde
func (s *OrderStore) GetByOrderNo(ctx context.Context, orderNo string) (*model.Order, error) { func (s *OrderStore) GetByOrderNo(ctx context.Context, orderNo string) (*model.Order, error) {
var order model.Order var order model.Order
if err := s.db.WithContext(ctx).Where("order_no = ?", orderNo).First(&order).Error; err != nil { query := s.db.WithContext(ctx).Where("order_no = ?", orderNo)
// 应用数据权限过滤(使用 seller_shop_id 字段)
query = middleware.ApplySellerShopFilter(ctx, query)
if err := query.First(&order).Error; err != nil {
return nil, err return nil, err
} }
return &order, nil return &order, nil
@@ -80,6 +90,8 @@ func (s *OrderStore) List(ctx context.Context, opts *store.QueryOptions, filters
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.Order{}) query := s.db.WithContext(ctx).Model(&model.Order{})
// 应用数据权限过滤(使用 seller_shop_id 字段)
query = middleware.ApplySellerShopFilter(ctx, query)
if v, ok := filters["payment_status"]; ok { if v, ok := filters["payment_status"]; ok {
query = query.Where("payment_status = ?", v) query = query.Where("payment_status = ?", v)

View File

@@ -5,6 +5,7 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -22,7 +23,10 @@ func (s *ShopPackageAllocationStore) Create(ctx context.Context, allocation *mod
func (s *ShopPackageAllocationStore) GetByID(ctx context.Context, id uint) (*model.ShopPackageAllocation, error) { func (s *ShopPackageAllocationStore) GetByID(ctx context.Context, id uint) (*model.ShopPackageAllocation, error) {
var allocation model.ShopPackageAllocation var allocation model.ShopPackageAllocation
if err := s.db.WithContext(ctx).First(&allocation, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&allocation).Error; err != nil {
return nil, err return nil, err
} }
return &allocation, nil return &allocation, nil
@@ -30,7 +34,10 @@ func (s *ShopPackageAllocationStore) GetByID(ctx context.Context, id uint) (*mod
func (s *ShopPackageAllocationStore) GetByShopAndPackage(ctx context.Context, shopID, packageID uint) (*model.ShopPackageAllocation, error) { func (s *ShopPackageAllocationStore) GetByShopAndPackage(ctx context.Context, shopID, packageID uint) (*model.ShopPackageAllocation, error) {
var allocation model.ShopPackageAllocation var allocation model.ShopPackageAllocation
if err := s.db.WithContext(ctx).Where("shop_id = ? AND package_id = ?", shopID, packageID).First(&allocation).Error; err != nil { query := s.db.WithContext(ctx).Where("shop_id = ? AND package_id = ?", shopID, packageID)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&allocation).Error; err != nil {
return nil, err return nil, err
} }
return &allocation, nil return &allocation, nil
@@ -49,6 +56,8 @@ func (s *ShopPackageAllocationStore) List(ctx context.Context, opts *store.Query
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.ShopPackageAllocation{}) query := s.db.WithContext(ctx).Model(&model.ShopPackageAllocation{})
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if shopID, ok := filters["shop_id"].(uint); ok && shopID > 0 { if shopID, ok := filters["shop_id"].(uint); ok && shopID > 0 {
query = query.Where("shop_id = ?", shopID) query = query.Where("shop_id = ?", shopID)
@@ -99,7 +108,10 @@ func (s *ShopPackageAllocationStore) UpdateStatus(ctx context.Context, id uint,
func (s *ShopPackageAllocationStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.ShopPackageAllocation, error) { func (s *ShopPackageAllocationStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.ShopPackageAllocation, error) {
var allocations []*model.ShopPackageAllocation var allocations []*model.ShopPackageAllocation
if err := s.db.WithContext(ctx).Where("shop_id = ? AND status = 1", shopID).Find(&allocations).Error; err != nil { query := s.db.WithContext(ctx).Where("shop_id = ? AND status = 1", shopID)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&allocations).Error; err != nil {
return nil, err return nil, err
} }
return allocations, nil return allocations, nil
@@ -107,9 +119,11 @@ func (s *ShopPackageAllocationStore) GetByShopID(ctx context.Context, shopID uin
func (s *ShopPackageAllocationStore) GetByShopAndPackages(ctx context.Context, shopID uint, packageIDs []uint) ([]*model.ShopPackageAllocation, error) { func (s *ShopPackageAllocationStore) GetByShopAndPackages(ctx context.Context, shopID uint, packageIDs []uint) ([]*model.ShopPackageAllocation, error) {
var allocations []*model.ShopPackageAllocation var allocations []*model.ShopPackageAllocation
if err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("shop_id = ? AND package_id IN ? AND status = 1", shopID, packageIDs). Where("shop_id = ? AND package_id IN ? AND status = 1", shopID, packageIDs)
Find(&allocations).Error; err != nil { // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&allocations).Error; err != nil {
return nil, err return nil, err
} }
return allocations, nil return allocations, nil
@@ -117,9 +131,11 @@ func (s *ShopPackageAllocationStore) GetByShopAndPackages(ctx context.Context, s
func (s *ShopPackageAllocationStore) GetBySeriesAllocationID(ctx context.Context, seriesAllocationID uint) ([]*model.ShopPackageAllocation, error) { func (s *ShopPackageAllocationStore) GetBySeriesAllocationID(ctx context.Context, seriesAllocationID uint) ([]*model.ShopPackageAllocation, error) {
var allocations []*model.ShopPackageAllocation var allocations []*model.ShopPackageAllocation
if err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("series_allocation_id = ? AND status = 1", seriesAllocationID). Where("series_allocation_id = ? AND status = 1", seriesAllocationID)
Find(&allocations).Error; err != nil { // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&allocations).Error; err != nil {
return nil, err return nil, err
} }
return allocations, nil return allocations, nil

View File

@@ -5,6 +5,7 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -62,9 +63,10 @@ func (s *ShopRoleStore) DeleteByShopID(ctx context.Context, shopID uint) error {
func (s *ShopRoleStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.ShopRole, error) { func (s *ShopRoleStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.ShopRole, error) {
var srs []*model.ShopRole var srs []*model.ShopRole
if err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).Where("shop_id = ?", shopID)
Where("shop_id = ?", shopID). // 应用数据权限过滤
Find(&srs).Error; err != nil { query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&srs).Error; err != nil {
return nil, err return nil, err
} }
return srs, nil return srs, nil
@@ -72,10 +74,12 @@ func (s *ShopRoleStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.
func (s *ShopRoleStore) GetRoleIDsByShopID(ctx context.Context, shopID uint) ([]uint, error) { func (s *ShopRoleStore) GetRoleIDsByShopID(ctx context.Context, shopID uint) ([]uint, error) {
var roleIDs []uint var roleIDs []uint
if err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Model(&model.ShopRole{}). Model(&model.ShopRole{}).
Where("shop_id = ?", shopID). Where("shop_id = ?", shopID)
Pluck("role_id", &roleIDs).Error; err != nil { // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Pluck("role_id", &roleIDs).Error; err != nil {
return nil, err return nil, err
} }
return roleIDs, nil return roleIDs, nil

View File

@@ -5,6 +5,7 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/store"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -22,7 +23,10 @@ func (s *ShopSeriesAllocationStore) Create(ctx context.Context, allocation *mode
func (s *ShopSeriesAllocationStore) GetByID(ctx context.Context, id uint) (*model.ShopSeriesAllocation, error) { func (s *ShopSeriesAllocationStore) GetByID(ctx context.Context, id uint) (*model.ShopSeriesAllocation, error) {
var allocation model.ShopSeriesAllocation var allocation model.ShopSeriesAllocation
if err := s.db.WithContext(ctx).First(&allocation, id).Error; err != nil { query := s.db.WithContext(ctx).Where("id = ?", id)
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&allocation).Error; err != nil {
return nil, err return nil, err
} }
return &allocation, nil return &allocation, nil
@@ -30,9 +34,11 @@ func (s *ShopSeriesAllocationStore) GetByID(ctx context.Context, id uint) (*mode
func (s *ShopSeriesAllocationStore) GetByShopAndSeries(ctx context.Context, shopID, seriesID uint) (*model.ShopSeriesAllocation, error) { func (s *ShopSeriesAllocationStore) GetByShopAndSeries(ctx context.Context, shopID, seriesID uint) (*model.ShopSeriesAllocation, error) {
var allocation model.ShopSeriesAllocation var allocation model.ShopSeriesAllocation
if err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("shop_id = ? AND series_id = ?", shopID, seriesID). Where("shop_id = ? AND series_id = ?", shopID, seriesID)
First(&allocation).Error; err != nil { // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.First(&allocation).Error; err != nil {
return nil, err return nil, err
} }
return &allocation, nil return &allocation, nil
@@ -51,6 +57,8 @@ func (s *ShopSeriesAllocationStore) List(ctx context.Context, opts *store.QueryO
var total int64 var total int64
query := s.db.WithContext(ctx).Model(&model.ShopSeriesAllocation{}) query := s.db.WithContext(ctx).Model(&model.ShopSeriesAllocation{})
// 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if shopID, ok := filters["shop_id"].(uint); ok && shopID > 0 { if shopID, ok := filters["shop_id"].(uint); ok && shopID > 0 {
query = query.Where("shop_id = ?", shopID) query = query.Where("shop_id = ?", shopID)
@@ -100,9 +108,11 @@ func (s *ShopSeriesAllocationStore) UpdateStatus(ctx context.Context, id uint, s
func (s *ShopSeriesAllocationStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.ShopSeriesAllocation, error) { func (s *ShopSeriesAllocationStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.ShopSeriesAllocation, error) {
var allocations []*model.ShopSeriesAllocation var allocations []*model.ShopSeriesAllocation
if err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("shop_id = ? AND status = 1", shopID). Where("shop_id = ? AND status = 1", shopID)
Find(&allocations).Error; err != nil { // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&allocations).Error; err != nil {
return nil, err return nil, err
} }
return allocations, nil return allocations, nil
@@ -132,9 +142,11 @@ func (s *ShopSeriesAllocationStore) ExistsByShopAndSeries(ctx context.Context, s
func (s *ShopSeriesAllocationStore) GetByAllocatorShopID(ctx context.Context, allocatorShopID uint) ([]*model.ShopSeriesAllocation, error) { func (s *ShopSeriesAllocationStore) GetByAllocatorShopID(ctx context.Context, allocatorShopID uint) ([]*model.ShopSeriesAllocation, error) {
var allocations []*model.ShopSeriesAllocation var allocations []*model.ShopSeriesAllocation
if err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Where("allocator_shop_id = ? AND status = 1", allocatorShopID). Where("allocator_shop_id = ? AND status = 1", allocatorShopID)
Find(&allocations).Error; err != nil { // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Find(&allocations).Error; err != nil {
return nil, err return nil, err
} }
return allocations, nil return allocations, nil
@@ -145,10 +157,12 @@ func (s *ShopSeriesAllocationStore) GetIDsByShopIDsAndSeries(ctx context.Context
return nil, nil return nil, nil
} }
var ids []uint var ids []uint
if err := s.db.WithContext(ctx). query := s.db.WithContext(ctx).
Model(&model.ShopSeriesAllocation{}). Model(&model.ShopSeriesAllocation{}).
Where("shop_id IN ? AND series_id = ? AND status = 1", shopIDs, seriesID). Where("shop_id IN ? AND series_id = ? AND status = 1", shopIDs, seriesID)
Pluck("id", &ids).Error; err != nil { // 应用数据权限过滤
query = middleware.ApplyShopFilter(ctx, query)
if err := query.Pluck("id", &ids).Error; err != nil {
return nil, err return nil, err
} }
return ids, nil return ids, nil

View File

@@ -9,7 +9,6 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"github.com/break/junhong_cmp_fiber/internal/service/commission_calculation" "github.com/break/junhong_cmp_fiber/internal/service/commission_calculation"
pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
) )
const ( const (
@@ -39,8 +38,6 @@ func NewCommissionCalculationHandler(
} }
func (h *CommissionCalculationHandler) HandleCommissionCalculation(ctx context.Context, task *asynq.Task) error { func (h *CommissionCalculationHandler) HandleCommissionCalculation(ctx context.Context, task *asynq.Task) error {
ctx = pkggorm.SkipDataPermission(ctx)
var payload CommissionCalculationPayload var payload CommissionCalculationPayload
if err := sonic.Unmarshal(task.Payload(), &payload); err != nil { if err := sonic.Unmarshal(task.Payload(), &payload); err != nil {
h.logger.Error("解析佣金计算任务载荷失败", h.logger.Error("解析佣金计算任务载荷失败",

View File

@@ -12,7 +12,6 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
) )
type CommissionStatsArchiveHandler struct { type CommissionStatsArchiveHandler struct {
@@ -37,8 +36,6 @@ func NewCommissionStatsArchiveHandler(
} }
func (h *CommissionStatsArchiveHandler) HandleCommissionStatsArchive(ctx context.Context, task *asynq.Task) error { func (h *CommissionStatsArchiveHandler) HandleCommissionStatsArchive(ctx context.Context, task *asynq.Task) error {
ctx = pkggorm.SkipDataPermission(ctx)
now := time.Now() now := time.Now()
lastMonthStart := now.AddDate(0, -1, 0) lastMonthStart := now.AddDate(0, -1, 0)
lastMonthStart = time.Date(lastMonthStart.Year(), lastMonthStart.Month(), 1, 0, 0, 0, 0, time.UTC) lastMonthStart = time.Date(lastMonthStart.Year(), lastMonthStart.Month(), 1, 0, 0, 0, 0, time.UTC)

View File

@@ -14,7 +14,6 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
) )
type CommissionStatsSyncHandler struct { type CommissionStatsSyncHandler struct {
@@ -39,8 +38,6 @@ func NewCommissionStatsSyncHandler(
} }
func (h *CommissionStatsSyncHandler) HandleCommissionStatsSync(ctx context.Context, task *asynq.Task) error { func (h *CommissionStatsSyncHandler) HandleCommissionStatsSync(ctx context.Context, task *asynq.Task) error {
ctx = pkggorm.SkipDataPermission(ctx)
lockKey := constants.RedisCommissionStatsLockKey() lockKey := constants.RedisCommissionStatsLockKey()
locked, err := h.redis.SetNX(ctx, lockKey, "1", 5*time.Minute).Result() locked, err := h.redis.SetNX(ctx, lockKey, "1", 5*time.Minute).Result()
if err != nil { if err != nil {

View File

@@ -11,7 +11,6 @@ import (
"github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
) )
type CommissionStatsUpdatePayload struct { type CommissionStatsUpdatePayload struct {
@@ -42,8 +41,6 @@ func NewCommissionStatsUpdateHandler(
} }
func (h *CommissionStatsUpdateHandler) HandleCommissionStatsUpdate(ctx context.Context, task *asynq.Task) error { func (h *CommissionStatsUpdateHandler) HandleCommissionStatsUpdate(ctx context.Context, task *asynq.Task) error {
ctx = pkggorm.SkipDataPermission(ctx)
var payload CommissionStatsUpdatePayload var payload CommissionStatsUpdatePayload
if err := sonic.Unmarshal(task.Payload(), &payload); err != nil { if err := sonic.Unmarshal(task.Payload(), &payload); err != nil {
h.logger.Error("解析统计更新任务载荷失败", h.logger.Error("解析统计更新任务载荷失败",

View File

@@ -17,7 +17,6 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
"github.com/break/junhong_cmp_fiber/pkg/storage" "github.com/break/junhong_cmp_fiber/pkg/storage"
"github.com/break/junhong_cmp_fiber/pkg/utils" "github.com/break/junhong_cmp_fiber/pkg/utils"
) )
@@ -62,8 +61,6 @@ func NewDeviceImportHandler(
} }
func (h *DeviceImportHandler) HandleDeviceImport(ctx context.Context, task *asynq.Task) error { func (h *DeviceImportHandler) HandleDeviceImport(ctx context.Context, task *asynq.Task) error {
ctx = pkggorm.SkipDataPermission(ctx)
var payload DeviceImportPayload var payload DeviceImportPayload
if err := sonic.Unmarshal(task.Payload(), &payload); err != nil { if err := sonic.Unmarshal(task.Payload(), &payload); err != nil {
h.logger.Error("解析设备导入任务载荷失败", h.logger.Error("解析设备导入任务载荷失败",

View File

@@ -17,7 +17,6 @@ import (
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm"
"github.com/break/junhong_cmp_fiber/pkg/storage" "github.com/break/junhong_cmp_fiber/pkg/storage"
"github.com/break/junhong_cmp_fiber/pkg/utils" "github.com/break/junhong_cmp_fiber/pkg/utils"
"github.com/break/junhong_cmp_fiber/pkg/validator" "github.com/break/junhong_cmp_fiber/pkg/validator"
@@ -72,8 +71,6 @@ func NewIotCardImportHandler(
} }
func (h *IotCardImportHandler) HandleIotCardImport(ctx context.Context, task *asynq.Task) error { func (h *IotCardImportHandler) HandleIotCardImport(ctx context.Context, task *asynq.Task) error {
ctx = pkggorm.SkipDataPermission(ctx)
var payload IotCardImportPayload var payload IotCardImportPayload
if err := sonic.Unmarshal(task.Payload(), &payload); err != nil { if err := sonic.Unmarshal(task.Payload(), &payload); err != nil {
h.logger.Error("解析 IoT 卡导入任务载荷失败", h.logger.Error("解析 IoT 卡导入任务载荷失败",

View File

@@ -0,0 +1,2 @@
schema: spec-driven
created: 2026-02-26

View File

@@ -0,0 +1,330 @@
# Design: refactor-data-permission-filter
## Context
当前系统使用 GORM Callback 在查询前自动注入数据权限过滤条件。该机制存在以下问题:
1. **隐式行为**开发者不知道查询被加了什么条件SQL 调试困难
2. **跳过率高**20+ 处使用 `SkipDataPermission`,说明自动过滤不适用于大量场景
3. **特殊处理多**`tb_shop``tb_tag``tb_enterprise_card_authorization` 等需要特殊逻辑
4. **重复查询**:同一请求中 Service 层 `CanManageShop` 和 Callback 都调用 `GetSubordinateShopIDs`
5. **原生 SQL 失效**Callback 无法处理原生 SQL需要在 Store 里手动重复写过滤逻辑
**涉及的 Store 统计:**
- 需要改动16 个 Store、85+ 个查询方法
- 无需改动32 个 Store系统全局表、关联表等
## Goals / Non-Goals
**Goals:**
- 数据权限过滤行为显式可控,便于调试
- 单次请求内下级店铺 ID 只查询一次(中间件预计算)
- 保持现有的数据隔离行为不变(代理看自己+下级,企业看自己,平台看全部)
- NULL `shop_id` 记录对代理用户不可见(保持现有行为)
**Non-Goals:**
- 不改变数据权限的业务规则
- 不修改 Redis 缓存策略(仍保持 30 分钟过期)
- 不处理个人客户的数据权限(保持 `customer_id` / `creator` 字段的现有逻辑)
## Decisions
### Decision 1: 中间件预计算 vs 懒加载
**选择**:中间件预计算
**方案对比:**
| 方案 | 优点 | 缺点 |
|------|------|------|
| 中间件预计算 | 单次请求只查一次;代码简单 | 所有请求都计算,即使不需要 |
| 懒加载(首次使用时计算) | 按需计算 | 需要加锁防止并发重复计算;代码复杂 |
**理由**
1. 绝大多数 API 都需要数据权限过滤,"不需要"是少数情况
2. `GetSubordinateShopIDs` 有 Redis 缓存,命中率高,预计算开销小
3. 代码更简单,不需要处理并发问题
### Decision 2: Helper 函数设计
**选择**:多个专用函数而非通用函数
```go
// 专用函数(选择)
func ApplyShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB
func ApplyEnterpriseFilter(ctx context.Context, query *gorm.DB) *gorm.DB
func ApplyOwnerShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB
// 通用函数(否决)
func ApplyDataPermission(ctx context.Context, query *gorm.DB, field string) *gorm.DB
```
**理由**
1. 字段名固定(`shop_id``enterprise_id``owner_shop_id`),无需参数化
2. 专用函数调用更清晰IDE 自动补全友好
3. 不同字段的过滤逻辑可能有细微差异,专用函数更灵活
### Decision 3: UserContextInfo 扩展 vs 新增 DataScope 结构体
**选择**:扩展 UserContextInfo
```go
// 扩展现有结构体(选择)
type UserContextInfo struct {
UserID uint
UserType int
ShopID uint
EnterpriseID uint
CustomerID uint
SubordinateShopIDs []uint // 新增
}
// 新增结构体(否决)
type DataScope struct {
SubordinateShopIDs []uint
}
```
**理由**
1. 减少 Context 中的 key 数量
2. 用户信息和权限范围本来就是强关联的
3. 现有代码已经大量使用 `UserContextInfo`,扩展更自然
### Decision 4: AuthConfig 传入 ShopStore vs 全局注册
**选择**AuthConfig 传入
```go
// AuthConfig 传入(选择)
type AuthConfig struct {
TokenExtractor func(c *fiber.Ctx) string
TokenValidator func(token string) (*UserContextInfo, error)
SkipPaths []string
ShopStore ShopStoreInterface // 新增
}
// 全局注册(否决)
var globalShopStore ShopStoreInterface
func RegisterShopStore(s ShopStoreInterface) { ... }
```
**理由**
1. 显式依赖注入,便于测试
2. 避免全局变量,符合 Go 最佳实践
3. 与现有 `TokenValidator` 传入方式一致
### Decision 5: nil vs 空切片表示"不限制"
**选择**nil 表示不限制
```go
// nil 表示不限制(选择)
if shopIDs := GetSubordinateShopIDs(ctx); shopIDs != nil {
query = query.Where("shop_id IN ?", shopIDs)
}
// 空切片表示不限制(否决)
if len(shopIDs) > 0 {
query = query.Where("shop_id IN ?", shopIDs)
}
```
**理由**
1. 语义更清晰nil = 未设置/不限制,`[]uint{}` = 设置了但列表为空
2. 空切片 `WHERE shop_id IN ()` 在 SQL 中是无效语法
3. 便于区分"平台用户不限制"和"代理用户但店铺列表为空(异常情况)"
## Implementation
### 文件结构
```
pkg/middleware/
├── auth.go # 扩展 UserContextInfo修改 Auth 中间件
├── data_scope.go # 新增Helper 函数
└── permission_helper.go # 修改CanManageShop 等函数签名
pkg/gorm/
└── callback.go # 移除 RegisterDataPermissionCallback
```
### 核心代码结构
**1. UserContextInfo 扩展auth.go**
```go
type UserContextInfo struct {
UserID uint
UserType int
ShopID uint
EnterpriseID uint
CustomerID uint
SubordinateShopIDs []uint // 新增代理用户的下级店铺ID列表nil表示不限制
}
```
**2. Auth 中间件改造auth.go**
```go
func Auth(config AuthConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// ... 现有 token 验证逻辑 ...
// 新增:预计算 SubordinateShopIDs
if config.ShopStore != nil &&
userInfo.UserType == constants.UserTypeAgent &&
userInfo.ShopID > 0 {
shopIDs, err := config.ShopStore.GetSubordinateShopIDs(c.UserContext(), userInfo.ShopID)
if err != nil {
// 降级处理
shopIDs = []uint{userInfo.ShopID}
logger.Warn("获取下级店铺失败,降级为只包含自己", zap.Error(err))
}
userInfo.SubordinateShopIDs = shopIDs
}
SetUserToFiberContext(c, userInfo)
return c.Next()
}
}
```
**3. Helper 函数data_scope.go**
```go
// GetSubordinateShopIDs 获取当前用户可管理的店铺ID列表
// 返回 nil 表示不受限制(平台用户/超管)
func GetSubordinateShopIDs(ctx context.Context) []uint {
if ctx == nil {
return nil
}
if ids, ok := ctx.Value(constants.ContextKeySubordinateShopIDs).([]uint); ok {
return ids
}
return nil
}
// ApplyShopFilter 应用店铺数据权限过滤
// 平台用户/超管:不添加条件
// 代理用户WHERE shop_id IN (subordinateShopIDs)
func ApplyShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB {
shopIDs := GetSubordinateShopIDs(ctx)
if shopIDs == nil {
return query
}
return query.Where("shop_id IN ?", shopIDs)
}
// ApplyEnterpriseFilter 应用企业数据权限过滤
// 非企业用户:不添加条件
// 企业用户WHERE enterprise_id = ?
func ApplyEnterpriseFilter(ctx context.Context, query *gorm.DB) *gorm.DB {
userType := GetUserTypeFromContext(ctx)
if userType != constants.UserTypeEnterprise {
return query
}
enterpriseID := GetEnterpriseIDFromContext(ctx)
if enterpriseID == 0 {
return query.Where("1 = 0") // 企业用户但无企业ID返回空
}
return query.Where("enterprise_id = ?", enterpriseID)
}
// ApplyOwnerShopFilter 应用归属店铺数据权限过滤
// 用于 Enterprise 等使用 owner_shop_id 的表
func ApplyOwnerShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB {
shopIDs := GetSubordinateShopIDs(ctx)
if shopIDs == nil {
return query
}
return query.Where("owner_shop_id IN ?", shopIDs)
}
```
**4. Store 层调用示例**
```go
// 改造前
func (s *DeviceStore) List(ctx context.Context, opts *QueryOptions) ([]*model.Device, error) {
query := s.db.WithContext(ctx).Model(&model.Device{})
// ... GORM Callback 自动添加 WHERE shop_id IN (...) ...
return ...
}
// 改造后
func (s *DeviceStore) List(ctx context.Context, opts *QueryOptions) ([]*model.Device, error) {
query := s.db.WithContext(ctx).Model(&model.Device{})
query = middleware.ApplyShopFilter(ctx, query) // 显式调用
// ...
return ...
}
```
## Risks / Trade-offs
### Risk 1: Store 层遗漏过滤调用
**风险**:改造过程中可能遗漏某些查询方法,导致数据泄露
**缓解措施**
1. 按照 proposal 中的 Store 清单逐一检查
2. 代码审查重点关注权限过滤
3. 可考虑在开发环境添加检测中间件,对未过滤的敏感表查询打印告警
### Risk 2: 预计算开销
**风险**:每个请求都预计算 `SubordinateShopIDs`,即使某些请求不需要
**缓解措施**
1. `GetSubordinateShopIDs` 有 Redis 缓存30分钟命中率高
2. Redis 查询通常 < 1ms
3. 实际影响可忽略
### Risk 3: 改造期间的并行开发冲突
**风险**:改造涉及 16 个 Store 文件,可能与其他开发任务冲突
**缓解措施**
1. 分阶段改造:先基础设施,再 Store 层
2. 优先改造高复杂度 Store降低后期风险
3. 改造期间及时 rebase 和解决冲突
## Migration Plan
### Phase 1: 基础设施(不影响现有功能)
1. 扩展 `UserContextInfo`,添加 `SubordinateShopIDs` 字段
2. 新增 `pkg/middleware/data_scope.go`,实现 Helper 函数
3. 修改 Auth 中间件,预计算 `SubordinateShopIDs`
4. 此阶段 GORM Callback 仍然生效,两套机制并存
### Phase 2: 改造权限检查函数
1. 修改 `CanManageShop`,从 Context 获取数据,移除 `shopStore` 参数
2. 修改 `CanManageEnterprise`,从 Context 获取数据
3. 更新所有调用点
### Phase 3: Store 层改造(按复杂度分批)
1. **低复杂度 Store9 个)**agent_wallet、commission_record 等
2. **中复杂度 Store4 个)**device、order、shop_package_allocation 等
3. **高复杂度 Store3 个)**iot_card、account、enterprise_card_authorization
### Phase 4: 清理
1. 移除 `RegisterDataPermissionCallback` 及其调用
2. 移除 `SkipDataPermission` 函数及所有调用点20+ 处)
3. 移除 `pkg/gorm/callback.go` 中的相关代码
### Rollback Strategy
如果发现问题,可以:
1. Phase 1-2 期间直接回滚代码GORM Callback 仍在工作
2. Phase 3 期间:保留 Callback 代码,只回滚 Store 改动
3. Phase 4 后:需要重新启用 Callback 代码
## Open Questions
1. ~~NULL shop_id 的记录对代理用户是否可见?~~ **已确认:不可见(保持现有行为)**
2. ~~是否需要为开发环境添加"未过滤敏感表查询"的告警机制?~~ **暂不需要,通过代码审查保证**

View File

@@ -0,0 +1,93 @@
# Proposal: refactor-data-permission-filter
## Why
当前 GORM Callback 自动数据权限过滤机制存在以下问题:
1. **跳过率高** - 20+ 处使用 `SkipDataPermission`(异步任务、登录、复杂查询等场景都需要跳过)
2. **特殊处理多** - `tb_shop``tb_tag``tb_enterprise_card_authorization` 等表需要特殊逻辑
3. **重复查询** - 同一请求中 Service 层的 `CanManageShop` 和 Callback 都调用 `GetSubordinateShopIDs`
4. **隐式行为** - 开发者不知道 GORM 查询被加了什么条件,调试困难
5. **原生 SQL 失效** - 原生 SQL 无法自动应用,需要手动在 Store 里重复写权限过滤逻辑
需要重构为显式调用模式,让数据权限过滤行为可预测、可控。
## What Changes
### 新增
- **中间件预加载**:扩展 Auth 中间件,对代理用户预计算 `SubordinateShopIDs` 并放入 Context
- **UserContextInfo 扩展**:新增 `SubordinateShopIDs []uint` 字段
- **Helper 函数**
- `GetSubordinateShopIDs(ctx) []uint` - 获取下级店铺 ID 列表
- `ApplyShopFilter(ctx, query) *gorm.DB` - 应用 `WHERE shop_id IN ?` 过滤
- `ApplyEnterpriseFilter(ctx, query) *gorm.DB` - 应用 `WHERE enterprise_id = ?` 过滤
- `ApplyOwnerShopFilter(ctx, query) *gorm.DB` - 应用 `WHERE owner_shop_id IN ?` 过滤
### 移除
- **BREAKING**:移除 `RegisterDataPermissionCallback` 函数
- **BREAKING**:移除 `SkipDataPermission` 函数及所有调用点20+ 处)
- **BREAKING**`CanManageShop` / `CanManageEnterprise` 函数签名变更,不再需要 Store 参数
### 修改
- **Store 层显式过滤**16 个 Store、85+ 个查询方法需要显式调用 Helper 函数添加权限过滤
## Capabilities
### New Capabilities
- `data-scope-middleware`: 数据权限范围中间件,负责预计算用户的数据访问范围并注入 Context
### Modified Capabilities
- `data-permission`: 数据权限过滤机制从 GORM Callback 自动过滤改为业务层显式调用
## Impact
### 代码影响
| 类型 | 文件/模块 | 改动说明 |
|------|-----------|----------|
| 新增 | `pkg/middleware/data_scope.go` | Helper 函数和 Context 操作 |
| 修改 | `pkg/middleware/auth.go` | 扩展 `UserContextInfo`Auth 中间件预计算 |
| 修改 | `pkg/middleware/permission_helper.go` | `CanManageShop` 等函数改为从 Context 取数据 |
| 移除 | `pkg/gorm/callback.go` | 移除 `RegisterDataPermissionCallback` |
| 修改 | 16 个 Store 文件 | 显式调用过滤 Helper 函数 |
| 移除 | 20+ 处 `SkipDataPermission` 调用 | 不再需要跳过机制 |
### 需要改动的 Store按复杂度分级
**高复杂度3 个)**
- `iot_card_store.go` - 9+ 方法,已有部分手动实现
- `account_store.go` - 7+ 方法双字段过滤shop_id / enterprise_id
- `enterprise_card_authorization_store.go` - 8+ 方法,已有参考实现
**中复杂度4 个)**
- `device_store.go` - 4+ 方法NULL shop_id 表示平台库存
- `order_store.go` - 3 方法,使用 seller_shop_id
- `shop_package_allocation_store.go` - 6 方法
- `shop_series_allocation_store.go` - 6 方法
**低复杂度9 个)**
- `agent_wallet_store.go` - 5 方法
- `agent_wallet_transaction_store.go` - 4 方法
- `commission_record_store.go` - 4 方法
- `enterprise_device_authorization_store.go` - 5 方法
- `enterprise_store.go` - 3 方法
- `shop_role_store.go` - 2 方法
- `iot_card_import_task_store.go` - 2 方法
- `commission_withdrawal_request_store.go` - 3 方法
- `card_wallet_store.go` - 5+ 方法
### API 影响
- 无外部 API 变更
- 内部函数签名变更:`CanManageShop(ctx, targetShopID, shopStore)``CanManageShop(ctx, targetShopID)`
### 行为变更
- NULL `shop_id` 的记录对代理用户不可见(保持现有行为)
- 平台用户/超管不受数据权限限制(保持现有行为)
- 数据权限过滤从隐式变为显式,需要业务层主动调用

View File

@@ -0,0 +1,90 @@
# data-permission Delta Specification
## Purpose
数据权限过滤机制从 GORM Callback 自动过滤改为业务层显式调用。
## REMOVED Requirements
### Requirement: GORM Callback Data Permission
**Reason**: GORM Callback 自动过滤存在以下问题跳过率高20+ 处 SkipDataPermission、特殊处理多、重复查询、隐式行为难调试、原生 SQL 失效。改为业务层显式调用模式。
**Migration**:
1. Store 层查询方法显式调用 `ApplyShopFilter``ApplyEnterpriseFilter` 等 Helper 函数
2. 参考 `pkg/middleware/data_scope.go` 中的 Helper 函数用法
### Requirement: Skip Data Permission
**Reason**: 移除 GORM Callback 后,不再需要跳过机制。数据权限过滤由业务层显式控制。
**Migration**:
1. 删除所有 `SkipDataPermission(ctx)` 调用
2. 删除所有 `pkggorm.SkipDataPermission` 引用
3. 业务逻辑直接控制是否调用过滤函数
### Requirement: Callback Registration
**Reason**: 不再使用 GORM Callback 机制。数据权限范围在 Auth 中间件预计算,过滤在业务层显式调用。
**Migration**:
1. 移除 `RegisterDataPermissionCallback` 函数
2. 移除应用启动时的 Callback 注册调用
3. Auth 中间件配置 `ShopStore` 以支持预计算
## MODIFIED Requirements
### Requirement: Subordinate IDs Caching
系统 SHALL 缓存用户的下级店铺 ID 列表以提高查询性能。
#### Scenario: 缓存命中
- **WHEN** 获取用户下级店铺 ID 列表
- **AND** Redis 缓存存在
- **THEN** 直接返回缓存数据
#### Scenario: 缓存未命中
- **WHEN** 获取用户下级店铺 ID 列表
- **AND** Redis 缓存不存在
- **THEN** 执行递归查询获取下级店铺 ID
- **AND** 将结果缓存到 Redis30 分钟过期)
#### Scenario: 请求级别复用
- **WHEN** 同一请求内多次需要下级店铺 ID 列表
- **THEN** 从 Context 中获取预计算的值
- **AND** 不重复查询 Redis 或数据库
## ADDED Requirements
### Requirement: Store 层显式数据权限过滤
系统 SHALL 在 Store 层查询方法中显式调用数据权限过滤函数。
#### Scenario: 有 shop_id 字段的表
- **WHEN** Store 执行列表查询
- **AND** 表包含 `shop_id` 字段
- **THEN** 显式调用 `ApplyShopFilter(ctx, query)`
- **AND** 代理用户只能查询 `shop_id IN (subordinateShopIDs)` 的数据
#### Scenario: 有 enterprise_id 字段的表
- **WHEN** Store 执行列表查询
- **AND** 表包含 `enterprise_id` 字段
- **AND** 当前用户为企业用户
- **THEN** 显式调用 `ApplyEnterpriseFilter(ctx, query)`
- **AND** 企业用户只能查询 `enterprise_id = ?` 的数据
#### Scenario: 有 owner_shop_id 字段的表
- **WHEN** Store 执行列表查询
- **AND** 表包含 `owner_shop_id` 字段(如 Enterprise 表)
- **THEN** 显式调用 `ApplyOwnerShopFilter(ctx, query)`
- **AND** 代理用户只能查询 `owner_shop_id IN (subordinateShopIDs)` 的数据
#### Scenario: NULL shop_id 不可见
- **WHEN** 代理用户查询有 `shop_id` 字段的表
- **AND** 记录的 `shop_id` 为 NULL平台库存
- **THEN** 该记录对代理用户不可见
#### Scenario: 平台用户/超管不过滤
- **WHEN** 平台用户或超级管理员执行查询
- **THEN** Helper 函数不添加任何过滤条件
- **AND** 可查询所有数据

View File

@@ -0,0 +1,130 @@
# data-scope-middleware Specification
## Purpose
数据权限范围中间件,负责在请求入口预计算用户的数据访问范围并注入 Context供业务层显式使用。
## ADDED Requirements
### Requirement: UserContextInfo 扩展
系统 SHALL 扩展 `UserContextInfo` 结构体以包含预计算的数据权限范围。
#### Scenario: 代理用户包含下级店铺 ID 列表
- **WHEN** 代理用户登录成功
- **AND** 用户有关联的店铺 ID
- **THEN** `UserContextInfo.SubordinateShopIDs` 包含自己店铺及所有下级店铺的 ID 列表
#### Scenario: 平台用户/超管不限制
- **WHEN** 平台用户或超级管理员登录成功
- **THEN** `UserContextInfo.SubordinateShopIDs` 为 nil
- **AND** nil 表示不受数据权限限制
#### Scenario: 企业用户使用 EnterpriseID
- **WHEN** 企业用户登录成功
- **THEN** `UserContextInfo.EnterpriseID` 包含用户所属企业 ID
- **AND** `UserContextInfo.SubordinateShopIDs` 为 nil
### Requirement: Auth 中间件预计算
系统 SHALL 在 Auth 中间件中预计算用户的数据访问范围。
#### Scenario: 代理用户预计算下级店铺
- **WHEN** Auth 中间件验证 token 成功
- **AND** 用户类型为代理用户
- **AND** 用户有关联的店铺 ID
- **THEN** 调用 `GetSubordinateShopIDs` 获取下级店铺 ID 列表
- **AND** 将结果设置到 `UserContextInfo.SubordinateShopIDs`
#### Scenario: 获取下级店铺失败降级处理
- **WHEN** 调用 `GetSubordinateShopIDs` 失败
- **THEN** `SubordinateShopIDs` 降级为只包含用户自己的店铺 ID
- **AND** 记录 Error 日志
#### Scenario: 非代理用户跳过预计算
- **WHEN** Auth 中间件验证 token 成功
- **AND** 用户类型不是代理用户
- **THEN** 不调用 `GetSubordinateShopIDs`
- **AND** `SubordinateShopIDs` 保持为 nil
### Requirement: Context 数据获取函数
系统 SHALL 提供从 Context 获取数据权限范围的函数。
#### Scenario: 获取下级店铺 ID 列表
- **WHEN** 调用 `GetSubordinateShopIDs(ctx)`
- **AND** Context 包含 `SubordinateShopIDs`
- **THEN** 返回下级店铺 ID 列表
#### Scenario: 获取空列表表示不限制
- **WHEN** 调用 `GetSubordinateShopIDs(ctx)`
- **AND** Context 中 `SubordinateShopIDs` 为 nil
- **THEN** 返回 nil
- **AND** 调用方应理解 nil 表示不受数据权限限制
### Requirement: 查询过滤 Helper 函数
系统 SHALL 提供查询过滤 Helper 函数,供 Store 层显式调用。
#### Scenario: ApplyShopFilter 过滤店铺数据
- **WHEN** 调用 `ApplyShopFilter(ctx, query)`
- **AND** `SubordinateShopIDs` 不为 nil
- **THEN** 返回添加了 `WHERE shop_id IN (?)` 条件的查询
- **AND** 参数为 `SubordinateShopIDs`
#### Scenario: ApplyShopFilter 不限制时不添加条件
- **WHEN** 调用 `ApplyShopFilter(ctx, query)`
- **AND** `SubordinateShopIDs` 为 nil
- **THEN** 返回原查询,不添加任何条件
#### Scenario: ApplyEnterpriseFilter 过滤企业数据
- **WHEN** 调用 `ApplyEnterpriseFilter(ctx, query)`
- **AND** 用户类型为企业用户
- **AND** `EnterpriseID` 大于 0
- **THEN** 返回添加了 `WHERE enterprise_id = ?` 条件的查询
#### Scenario: ApplyEnterpriseFilter 非企业用户不添加条件
- **WHEN** 调用 `ApplyEnterpriseFilter(ctx, query)`
- **AND** 用户类型不是企业用户
- **THEN** 返回原查询,不添加任何条件
#### Scenario: ApplyOwnerShopFilter 过滤归属店铺数据
- **WHEN** 调用 `ApplyOwnerShopFilter(ctx, query)`
- **AND** `SubordinateShopIDs` 不为 nil
- **THEN** 返回添加了 `WHERE owner_shop_id IN (?)` 条件的查询
### Requirement: 权限检查函数改造
系统 SHALL 改造权限检查函数,从 Context 获取数据而非传入 Store。
#### Scenario: CanManageShop 从 Context 获取数据
- **WHEN** 调用 `CanManageShop(ctx, targetShopID)`
- **AND** 用户类型为代理用户
- **THEN** 从 Context 获取 `SubordinateShopIDs`
- **AND** 检查 `targetShopID` 是否在列表中
#### Scenario: CanManageShop 平台用户自动通过
- **WHEN** 调用 `CanManageShop(ctx, targetShopID)`
- **AND** `SubordinateShopIDs` 为 nil
- **THEN** 返回成功(不受限制)
#### Scenario: CanManageEnterprise 从 Context 获取数据
- **WHEN** 调用 `CanManageEnterprise(ctx, targetEnterpriseID)`
- **AND** 用户类型为代理用户
- **THEN** 从 Context 获取 `SubordinateShopIDs`
- **AND** 查询目标企业的 `owner_shop_id`
- **AND** 检查 `owner_shop_id` 是否在列表中
### Requirement: AuthConfig 扩展
系统 SHALL 扩展 `AuthConfig` 以支持传入 ShopStore。
#### Scenario: AuthConfig 包含 ShopStore
- **WHEN** 初始化 Auth 中间件
- **THEN** `AuthConfig` 可选包含 `ShopStore ShopStoreInterface`
- **AND** 用于调用 `GetSubordinateShopIDs`
#### Scenario: ShopStore 未配置时跳过预计算
- **WHEN** `AuthConfig.ShopStore` 为 nil
- **THEN** 不预计算 `SubordinateShopIDs`
- **AND** 所有用户的 `SubordinateShopIDs` 为 nil

View File

@@ -0,0 +1,83 @@
# Tasks: refactor-data-permission-filter
## 1. 基础设施 - 数据结构和 Helper 函数
- [x] 1.1 扩展 `UserContextInfo` 结构体,添加 `SubordinateShopIDs []uint` 字段(`pkg/middleware/auth.go`
- [x] 1.2 新增 Context key 常量 `ContextKeySubordinateShopIDs``pkg/constants/constants.go`
- [x] 1.3 新增 `SetUserContext``SetUserToFiberContext``SubordinateShopIDs` 的处理
- [x] 1.4 新增 `pkg/middleware/data_scope.go` 文件,实现 `GetSubordinateShopIDs` 函数
- [x] 1.5 实现 `ApplyShopFilter` Helper 函数
- [x] 1.6 实现 `ApplyEnterpriseFilter` Helper 函数
- [x] 1.7 实现 `ApplyOwnerShopFilter` Helper 函数
- [x] 1.8 验证:编译通过,`go build ./...`
## 2. Auth 中间件改造
- [x] 2.1 扩展 `AuthConfig` 结构体,添加 `ShopStore ShopStoreInterface` 字段
- [x] 2.2 定义 `AuthShopStoreInterface` 接口(包含 `GetSubordinateShopIDs` 方法)
- [x] 2.3 修改 `Auth` 中间件,在 token 验证成功后预计算 `SubordinateShopIDs`
- [x] 2.4 实现降级逻辑:获取下级店铺失败时降级为只包含自己的店铺 ID
- [x] 2.5 更新 Admin API 和 H5 API 的 Auth 中间件配置,传入 ShopStore
- [x] 2.6 验证:编译通过(运行时验证将在最终验证阶段进行)
## 3. 权限检查函数改造
- [x] 3.1 修改 `CanManageShop` 函数签名,移除 `shopStore` 参数,改为从 Context 获取数据
- [x] 3.2 修改 `CanManageEnterprise` 函数签名,移除 `shopStore` 参数(保留 enterpriseStore
- [x] 3.3 更新所有 `CanManageShop` 调用点Service 层)
- [x] 3.4 更新所有 `CanManageEnterprise` 调用点Service 层)
- [x] 3.5 验证:编译通过
## 4. Store 层改造 - 低复杂度9 个)
- [x] 4.1 改造 `agent_wallet_store.go`List、GetByShopID 等方法添加 `ApplyShopFilter`
- [x] 4.2 改造 `agent_wallet_transaction_store.go`List 等方法添加 `ApplyShopFilter`
- [x] 4.3 改造 `commission_record_store.go`List 等方法添加 `ApplyShopFilter`
- [x] 4.4 改造 `enterprise_device_authorization_store.go`List 等方法添加 `ApplyEnterpriseFilter`
- [x] 4.5 改造 `enterprise_store.go`List 等方法添加 `ApplyOwnerShopFilter`
- [x] 4.6 改造 `shop_role_store.go`List 等方法添加 `ApplyShopFilter`
- [x] 4.7 改造 `iot_card_import_task_store.go`List 等方法添加 `ApplyShopFilter`
- [x] 4.8 改造 `commission_withdrawal_request_store.go`List 等方法添加 `ApplyShopFilter`
- [x] 4.9 改造 `card_wallet_store.go``card_wallet_transaction_store.go`
- [x] 4.10 验证:编译通过,低复杂度 Store 的列表接口数据过滤正常
## 5. Store 层改造 - 中复杂度4 个)
- [x] 5.1 改造 `device_store.go`List、Count 等方法添加 `ApplyShopFilter`,注意 NULL shop_id 处理
- [x] 5.2 改造 `order_store.go`List 等方法添加店铺过滤(使用 seller_shop_id 字段)
- [x] 5.3 改造 `shop_package_allocation_store.go`List 等方法添加 `ApplyShopFilter`
- [x] 5.4 改造 `shop_series_allocation_store.go`List 等方法添加 `ApplyShopFilter`
- [x] 5.5 验证:编译通过,中复杂度 Store 的列表接口数据过滤正常
## 6. Store 层改造 - 高复杂度3 个)
- [x] 6.1 改造 `account_store.go`List、GetByID 等方法,根据用户类型选择 `ApplyShopFilter``ApplyEnterpriseFilter`
- [x] 6.2 改造 `iot_card_store.go`List、ListStandalone 等方法添加 `ApplyShopFilter`,移除已有的手动过滤逻辑
- [x] 6.3 改造 `enterprise_card_authorization_store.go`List、ListWithJoin 等方法,整合现有的手动权限过滤
- [x] 6.4 验证:编译通过,高复杂度 Store 的列表接口数据过滤正常
## 7. 清理 - 移除 GORM Callback
- [x] 7.1 移除 `pkg/gorm/callback.go` 中的 `RegisterDataPermissionCallback` 函数
- [x] 7.2 移除 `SkipDataPermission` 函数和 `SkipDataPermissionKey` 常量
- [x] 7.3 移除应用启动时的 `RegisterDataPermissionCallback` 调用
- [x] 7.4 验证:编译通过
## 8. 清理 - 移除 SkipDataPermission 调用
- [x] 8.1 移除 `internal/task/*.go` 中的 `SkipDataPermission` 调用6 处)
- [x] 8.2 移除 `internal/service/auth/service.go` 中的 `SkipDataPermission` 调用
- [x] 8.3 移除 `internal/service/shop_series_allocation/service.go` 中的 `SkipDataPermission` 调用
- [x] 8.4 移除 `internal/service/enterprise_device/service.go` 中的 `SkipDataPermission` 调用5 处)
- [x] 8.5 移除 `internal/store/postgres/iot_card_store.go` 中的 `SkipDataPermission` 调用5 处)
- [x] 8.6 移除 `internal/store/postgres/enterprise_card_authorization_store.go` 中的 `SkipDataPermission` 调用
- [x] 8.7 移除 `internal/bootstrap/admin.go` 中的 `SkipDataPermission` 调用
- [x] 8.8 验证:编译通过,全局搜索 `SkipDataPermission` 无结果
## 9. 最终验证
- [ ] 9.1 启动服务,验证平台管理员可以查看所有数据
- [ ] 9.2 验证代理用户只能看到自己店铺及下级店铺的数据
- [ ] 9.3 验证企业用户只能看到自己企业的数据
- [ ] 9.4 验证 NULL shop_id 的记录对代理用户不可见
- [x] 9.5 运行 `go build ./...``go vet ./...` 确认无编译警告

View File

@@ -1,77 +1,60 @@
# data-permission Specification # data-permission Specification
## Purpose ## Purpose
TBD - created by archiving change refactor-framework-cleanup. Update Purpose after archive.
数据权限过滤机制,通过业务层显式调用实现数据隔离。
## Requirements ## Requirements
### Requirement: GORM Callback Data Permission
系统 SHALL 使用 GORM Callback 机制自动为所有查询添加数据权限过滤。
#### Scenario: 自动应用权限过滤
- **WHEN** 执行 GORM 查询
- **AND** Context 包含用户信息
- **AND** 表包含 owner_id 字段
- **THEN** 自动添加 WHERE owner_id IN (subordinateIDs) 条件
#### Scenario: Root 用户跳过过滤
- **WHEN** 当前用户是 Root 用户
- **THEN** 不添加任何数据权限过滤条件
- **AND** 可查询所有数据
#### Scenario: 无 owner_id 字段的表
- **WHEN** 表不包含 owner_id 字段
- **THEN** 不添加数据权限过滤条件
#### Scenario: 授权记录表特殊处理
- **WHEN** 查询 `tb_enterprise_card_authorization`
- **AND** 当前用户是代理用户
- **THEN** 自动添加 WHERE enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id = 当前店铺ID) 条件
- **AND** 不包含下级店铺的数据
#### Scenario: 平台用户查询授权记录
- **WHEN** 查询 `tb_enterprise_card_authorization`
- **AND** 当前用户是平台用户或超级管理员
- **THEN** 不添加数据权限过滤条件
- **AND** 可查询所有授权记录
### Requirement: Skip Data Permission
系统 SHALL 支持通过 Context 绕过数据权限过滤。
#### Scenario: 显式跳过权限过滤
- **WHEN** 调用 SkipDataPermission(ctx) 获取新 Context
- **AND** 使用该 Context 执行 GORM 查询
- **THEN** 不添加任何数据权限过滤条件
#### Scenario: 内部操作跳过过滤
- **WHEN** 执行内部同步、批量操作或管理员操作
- **THEN** 应使用 SkipDataPermission 绕过过滤
### Requirement: Subordinate IDs Caching ### Requirement: Subordinate IDs Caching
系统 SHALL 缓存用户的下级 ID 列表以提高查询性能。 系统 SHALL 缓存用户的下级店铺 ID 列表以提高查询性能。
#### Scenario: 缓存命中 #### Scenario: 缓存命中
- **WHEN** 获取用户下级 ID 列表 - **WHEN** 获取用户下级店铺 ID 列表
- **AND** Redis 缓存存在 - **AND** Redis 缓存存在
- **THEN** 直接返回缓存数据 - **THEN** 直接返回缓存数据
#### Scenario: 缓存未命中 #### Scenario: 缓存未命中
- **WHEN** 获取用户下级 ID 列表 - **WHEN** 获取用户下级店铺 ID 列表
- **AND** Redis 缓存不存在 - **AND** Redis 缓存不存在
- **THEN** 执行递归 CTE 查询获取下级 ID - **THEN** 执行递归查询获取下级店铺 ID
- **AND** 将结果缓存到 Redis30 分钟过期) - **AND** 将结果缓存到 Redis30 分钟过期)
### Requirement: Callback Registration #### Scenario: 请求级别复用
- **WHEN** 同一请求内多次需要下级店铺 ID 列表
- **THEN** 从 Context 中获取预计算的值
- **AND** 不重复查询 Redis 或数据库
系统 SHALL 在应用启动时注册 GORM 数据权限 Callback。 ### Requirement: Store 层显式数据权限过滤
#### Scenario: 注册 Callback 系统 SHALL 在 Store 层查询方法中显式调用数据权限过滤函数。
- **WHEN** 调用 RegisterDataPermissionCallback(db, accountStore)
- **THEN** 注册 Query Before Callback
- **AND** Callback 名称为 "data_permission"
#### Scenario: AccountStore 依赖 #### Scenario: 有 shop_id 字段的表
- **WHEN** 注册 Callback 时 - **WHEN** Store 执行列表查询
- **THEN** 需要传入 AccountStore 实例用于获取下级 ID - **AND** 表包含 `shop_id` 字段
- **THEN** 显式调用 `ApplyShopFilter(ctx, query)`
- **AND** 代理用户只能查询 `shop_id IN (subordinateShopIDs)` 的数据
#### Scenario: 有 enterprise_id 字段的表
- **WHEN** Store 执行列表查询
- **AND** 表包含 `enterprise_id` 字段
- **AND** 当前用户为企业用户
- **THEN** 显式调用 `ApplyEnterpriseFilter(ctx, query)`
- **AND** 企业用户只能查询 `enterprise_id = ?` 的数据
#### Scenario: 有 owner_shop_id 字段的表
- **WHEN** Store 执行列表查询
- **AND** 表包含 `owner_shop_id` 字段(如 Enterprise 表)
- **THEN** 显式调用 `ApplyOwnerShopFilter(ctx, query)`
- **AND** 代理用户只能查询 `owner_shop_id IN (subordinateShopIDs)` 的数据
#### Scenario: NULL shop_id 不可见
- **WHEN** 代理用户查询有 `shop_id` 字段的表
- **AND** 记录的 `shop_id` 为 NULL平台库存
- **THEN** 该记录对代理用户不可见
#### Scenario: 平台用户/超管不过滤
- **WHEN** 平台用户或超级管理员执行查询
- **THEN** Helper 函数不添加任何过滤条件
- **AND** 可查询所有数据

View File

@@ -0,0 +1,130 @@
# data-scope-middleware Specification
## Purpose
数据权限范围中间件,负责在请求入口预计算用户的数据访问范围并注入 Context供业务层显式使用。
## Requirements
### Requirement: UserContextInfo 扩展
系统 SHALL 扩展 `UserContextInfo` 结构体以包含预计算的数据权限范围。
#### Scenario: 代理用户包含下级店铺 ID 列表
- **WHEN** 代理用户登录成功
- **AND** 用户有关联的店铺 ID
- **THEN** `UserContextInfo.SubordinateShopIDs` 包含自己店铺及所有下级店铺的 ID 列表
#### Scenario: 平台用户/超管不限制
- **WHEN** 平台用户或超级管理员登录成功
- **THEN** `UserContextInfo.SubordinateShopIDs` 为 nil
- **AND** nil 表示不受数据权限限制
#### Scenario: 企业用户使用 EnterpriseID
- **WHEN** 企业用户登录成功
- **THEN** `UserContextInfo.EnterpriseID` 包含用户所属企业 ID
- **AND** `UserContextInfo.SubordinateShopIDs` 为 nil
### Requirement: Auth 中间件预计算
系统 SHALL 在 Auth 中间件中预计算用户的数据访问范围。
#### Scenario: 代理用户预计算下级店铺
- **WHEN** Auth 中间件验证 token 成功
- **AND** 用户类型为代理用户
- **AND** 用户有关联的店铺 ID
- **THEN** 调用 `GetSubordinateShopIDs` 获取下级店铺 ID 列表
- **AND** 将结果设置到 `UserContextInfo.SubordinateShopIDs`
#### Scenario: 获取下级店铺失败降级处理
- **WHEN** 调用 `GetSubordinateShopIDs` 失败
- **THEN** `SubordinateShopIDs` 降级为只包含用户自己的店铺 ID
- **AND** 记录 Error 日志
#### Scenario: 非代理用户跳过预计算
- **WHEN** Auth 中间件验证 token 成功
- **AND** 用户类型不是代理用户
- **THEN** 不调用 `GetSubordinateShopIDs`
- **AND** `SubordinateShopIDs` 保持为 nil
### Requirement: Context 数据获取函数
系统 SHALL 提供从 Context 获取数据权限范围的函数。
#### Scenario: 获取下级店铺 ID 列表
- **WHEN** 调用 `GetSubordinateShopIDs(ctx)`
- **AND** Context 包含 `SubordinateShopIDs`
- **THEN** 返回下级店铺 ID 列表
#### Scenario: 获取空列表表示不限制
- **WHEN** 调用 `GetSubordinateShopIDs(ctx)`
- **AND** Context 中 `SubordinateShopIDs` 为 nil
- **THEN** 返回 nil
- **AND** 调用方应理解 nil 表示不受数据权限限制
### Requirement: 查询过滤 Helper 函数
系统 SHALL 提供查询过滤 Helper 函数,供 Store 层显式调用。
#### Scenario: ApplyShopFilter 过滤店铺数据
- **WHEN** 调用 `ApplyShopFilter(ctx, query)`
- **AND** `SubordinateShopIDs` 不为 nil
- **THEN** 返回添加了 `WHERE shop_id IN (?)` 条件的查询
- **AND** 参数为 `SubordinateShopIDs`
#### Scenario: ApplyShopFilter 不限制时不添加条件
- **WHEN** 调用 `ApplyShopFilter(ctx, query)`
- **AND** `SubordinateShopIDs` 为 nil
- **THEN** 返回原查询,不添加任何条件
#### Scenario: ApplyEnterpriseFilter 过滤企业数据
- **WHEN** 调用 `ApplyEnterpriseFilter(ctx, query)`
- **AND** 用户类型为企业用户
- **AND** `EnterpriseID` 大于 0
- **THEN** 返回添加了 `WHERE enterprise_id = ?` 条件的查询
#### Scenario: ApplyEnterpriseFilter 非企业用户不添加条件
- **WHEN** 调用 `ApplyEnterpriseFilter(ctx, query)`
- **AND** 用户类型不是企业用户
- **THEN** 返回原查询,不添加任何条件
#### Scenario: ApplyOwnerShopFilter 过滤归属店铺数据
- **WHEN** 调用 `ApplyOwnerShopFilter(ctx, query)`
- **AND** `SubordinateShopIDs` 不为 nil
- **THEN** 返回添加了 `WHERE owner_shop_id IN (?)` 条件的查询
### Requirement: 权限检查函数改造
系统 SHALL 改造权限检查函数,从 Context 获取数据而非传入 Store。
#### Scenario: CanManageShop 从 Context 获取数据
- **WHEN** 调用 `CanManageShop(ctx, targetShopID)`
- **AND** 用户类型为代理用户
- **THEN** 从 Context 获取 `SubordinateShopIDs`
- **AND** 检查 `targetShopID` 是否在列表中
#### Scenario: CanManageShop 平台用户自动通过
- **WHEN** 调用 `CanManageShop(ctx, targetShopID)`
- **AND** `SubordinateShopIDs` 为 nil
- **THEN** 返回成功(不受限制)
#### Scenario: CanManageEnterprise 从 Context 获取数据
- **WHEN** 调用 `CanManageEnterprise(ctx, targetEnterpriseID)`
- **AND** 用户类型为代理用户
- **THEN** 从 Context 获取 `SubordinateShopIDs`
- **AND** 查询目标企业的 `owner_shop_id`
- **AND** 检查 `owner_shop_id` 是否在列表中
### Requirement: AuthConfig 扩展
系统 SHALL 扩展 `AuthConfig` 以支持传入 ShopStore。
#### Scenario: AuthConfig 包含 ShopStore
- **WHEN** 初始化 Auth 中间件
- **THEN** `AuthConfig` 可选包含 `ShopStore ShopStoreInterface`
- **AND** 用于调用 `GetSubordinateShopIDs`
#### Scenario: ShopStore 未配置时跳过预计算
- **WHEN** `AuthConfig.ShopStore` 为 nil
- **THEN** 不预计算 `SubordinateShopIDs`
- **AND** 所有用户的 `SubordinateShopIDs` 为 nil

View File

@@ -4,16 +4,17 @@ import "time"
// Fiber Locals 的上下文键 // Fiber Locals 的上下文键
const ( const (
ContextKeyRequestID = "requestid" // 请求记录ID ContextKeyRequestID = "requestid" // 请求记录ID
ContextKeyStartTime = "start_time" // 请求开始时间 ContextKeyStartTime = "start_time" // 请求开始时间
ContextKeyUserID = "user_id" // 用户ID ContextKeyUserID = "user_id" // 用户ID
ContextKeyUserType = "user_type" // 用户类型 ContextKeyUserType = "user_type" // 用户类型
ContextKeyShopID = "shop_id" // 店铺ID ContextKeyShopID = "shop_id" // 店铺ID
ContextKeyEnterpriseID = "enterprise_id" // 企业ID ContextKeyEnterpriseID = "enterprise_id" // 企业ID
ContextKeyCustomerID = "customer_id" // 个人客户ID ContextKeyCustomerID = "customer_id" // 个人客户ID
ContextKeyUserInfo = "user_info" // 完整的用户信息 ContextKeyUserInfo = "user_info" // 完整的用户信息
ContextKeyIP = "ip_address" // IP地址 ContextKeyIP = "ip_address" // IP地址
ContextKeyUserAgent = "user_agent" // User-Agent ContextKeyUserAgent = "user_agent" // User-Agent
ContextKeySubordinateShopIDs = "subordinate_shop_ids" // 下级店铺ID列表代理用户预计算
) )
// 配置环境变量 // 配置环境变量

View File

@@ -1,250 +1,12 @@
package gorm package gorm
import ( import (
"context"
"reflect" "reflect"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"go.uber.org/zap"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/schema"
) )
// contextKey 用于 context value 的 key 类型
type contextKey string
// SkipDataPermissionKey 跳过数据权限过滤的 context key
const SkipDataPermissionKey contextKey = "skip_data_permission"
// SkipDataPermission 返回跳过数据权限过滤的 Context
// 用于需要查询所有数据的场景(如管理后台统计、系统任务等)
//
// 使用示例:
//
// ctx = gorm.SkipDataPermission(ctx)
// db.WithContext(ctx).Find(&accounts)
func SkipDataPermission(ctx context.Context) context.Context {
return context.WithValue(ctx, SkipDataPermissionKey, true)
}
// ShopStoreInterface 店铺 Store 接口
// 用于 Callback 获取下级店铺 ID避免循环依赖
type ShopStoreInterface interface {
GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error)
}
// RegisterDataPermissionCallback 注册 GORM 数据权限过滤 Callback
//
// 自动化数据权限过滤规则:
// 1. 超级管理员跳过过滤,可以查看所有数据
// 2. 平台用户跳过过滤,可以查看所有数据
// 3. 代理用户只能查看自己店铺及下级店铺的数据(基于 shop_id 字段)
// 4. 企业用户只能查看自己企业的数据(基于 enterprise_id 字段)
// 5. 个人客户只能查看自己的数据(基于 creator 字段或 customer_id 字段)
// 6. 通过 SkipDataPermission(ctx) 可以绕过权限过滤
//
// 软删除过滤规则:
// 1. 所有查询自动排除 deleted_at IS NOT NULL 的记录
// 2. 使用 db.Unscoped() 可以查询包含已删除的记录
//
// 注意:
// - Callback 根据表的字段自动选择过滤策略
// - 必须在初始化 Store 之前注册
//
// 参数:
// - db: GORM DB 实例
// - shopStore: 店铺 Store用于查询下级店铺 ID
//
// 返回:
// - error: 注册错误
func RegisterDataPermissionCallback(db *gorm.DB, shopStore ShopStoreInterface) error {
// 注册查询前的 Callback
err := db.Callback().Query().Before("gorm:query").Register("data_permission:query", func(tx *gorm.DB) {
ctx := tx.Statement.Context
if ctx == nil {
return
}
// 1. 检查是否跳过数据权限过滤
if skip, ok := ctx.Value(SkipDataPermissionKey).(bool); ok && skip {
return
}
// 2. 获取用户类型
userType := middleware.GetUserTypeFromContext(ctx)
// 3. 超级管理员和平台用户跳过过滤,可以查看所有数据
if userType == constants.UserTypeSuperAdmin || userType == constants.UserTypePlatform {
return
}
// 4. 获取当前用户信息
userID := middleware.GetUserIDFromContext(ctx)
if userID == 0 {
// 未登录用户返回空结果
logger.GetAppLogger().Warn("数据权限过滤:未获取到用户 ID")
tx.Where("1 = 0")
return
}
shopID := middleware.GetShopIDFromContext(ctx)
// 5. 根据用户类型和表结构应用不同的过滤规则
schema := tx.Statement.Schema
if schema == nil {
return
}
// 5.1 代理用户:基于店铺层级过滤
if userType == constants.UserTypeAgent {
tableName := schema.Table
// 特殊处理:授权记录表(通过企业归属过滤,不含下级店铺)
if tableName == "tb_enterprise_card_authorization" {
if shopID == 0 {
// 代理用户没有 shop_id返回空结果
tx.Where("1 = 0")
return
}
// 只能看到自己店铺下企业的授权记录(不包含下级店铺)
tx.Where("enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id = ? AND deleted_at IS NULL)", shopID)
return
}
// 特殊处理:标签表和资源标签表(包含全局标签)
if tableName == "tb_tag" || tableName == "tb_resource_tag" {
if shopID == 0 {
// 没有 shop_id只能看全局标签
tx.Where("enterprise_id IS NULL AND shop_id IS NULL")
return
}
// 查询该店铺及下级店铺的 ID
subordinateShopIDs, err := shopStore.GetSubordinateShopIDs(ctx, shopID)
if err != nil {
logger.GetAppLogger().Error("数据权限过滤:获取下级店铺 ID 失败",
zap.Uint("shop_id", shopID),
zap.Error(err))
subordinateShopIDs = []uint{shopID}
}
// 过滤:店铺标签(自己店铺及下级店铺)或全局标签
tx.Where("shop_id IN ? OR (enterprise_id IS NULL AND shop_id IS NULL)", subordinateShopIDs)
return
}
if !hasShopIDField(schema) {
// 表没有 shop_id 字段,无法过滤
return
}
if shopID == 0 {
// 代理用户没有 shop_id只能看自己创建的数据
if hasCreatorField(schema) {
tx.Where("creator = ?", userID)
} else {
tx.Where("1 = 0")
}
return
}
// 查询该店铺及下级店铺的 ID
subordinateShopIDs, err := shopStore.GetSubordinateShopIDs(ctx, shopID)
if err != nil {
logger.GetAppLogger().Error("数据权限过滤:获取下级店铺 ID 失败",
zap.Uint("shop_id", shopID),
zap.Error(err))
// 降级为只能看自己店铺的数据
subordinateShopIDs = []uint{shopID}
}
// 过滤shop_id IN (自己店铺及下级店铺)
tx.Where("shop_id IN ?", subordinateShopIDs)
return
}
// 5.2 企业用户:基于 enterprise_id 过滤
if userType == constants.UserTypeEnterprise {
enterpriseID := middleware.GetEnterpriseIDFromContext(ctx)
tableName := schema.Table
// 特殊处理:标签表和资源标签表(包含全局标签)
if tableName == "tb_tag" || tableName == "tb_resource_tag" {
if enterpriseID != 0 {
// 过滤:企业标签或全局标签
tx.Where("enterprise_id = ? OR (enterprise_id IS NULL AND shop_id IS NULL)", enterpriseID)
} else {
// 没有 enterprise_id只能看全局标签
tx.Where("enterprise_id IS NULL AND shop_id IS NULL")
}
return
}
if hasEnterpriseIDField(schema) {
if enterpriseID != 0 {
tx.Where("enterprise_id = ?", enterpriseID)
} else {
// 企业用户没有 enterprise_id返回空结果
tx.Where("1 = 0")
}
return
}
// 如果表没有 enterprise_id 字段,但有 creator 字段,则只能看自己创建的数据
if hasCreatorField(schema) {
tx.Where("creator = ?", userID)
return
}
// 无法过滤,返回空结果
tx.Where("1 = 0")
return
}
// 5.3 个人客户:只能看自己的数据
if userType == constants.UserTypePersonalCustomer {
customerID := middleware.GetCustomerIDFromContext(ctx)
tableName := schema.Table
// 特殊处理:标签表和资源标签表(只能看全局标签)
if tableName == "tb_tag" || tableName == "tb_resource_tag" {
tx.Where("enterprise_id IS NULL AND shop_id IS NULL")
return
}
// 优先使用 customer_id 字段
if hasCustomerIDField(schema) {
if customerID != 0 {
tx.Where("customer_id = ?", customerID)
} else {
// 个人客户没有 customer_id返回空结果
tx.Where("1 = 0")
}
return
}
// 降级为使用 creator 字段
if hasCreatorField(schema) {
tx.Where("creator = ?", userID)
return
}
// 无法过滤,返回空结果
tx.Where("1 = 0")
return
}
// 6. 默认:未知用户类型,返回空结果
logger.GetAppLogger().Warn("数据权限过滤:未知用户类型",
zap.Uint("user_id", userID),
zap.Int("user_type", userType))
tx.Where("1 = 0")
})
return err
}
// RegisterSetCreatorUpdaterCallback 注册 GORM 创建数据时创建人更新人 Callback // RegisterSetCreatorUpdaterCallback 注册 GORM 创建数据时创建人更新人 Callback
func RegisterSetCreatorUpdaterCallback(db *gorm.DB) error { func RegisterSetCreatorUpdaterCallback(db *gorm.DB) error {
err := db.Callback().Create().Before("gorm:create").Register("set_creator_updater", func(tx *gorm.DB) { err := db.Callback().Create().Before("gorm:create").Register("set_creator_updater", func(tx *gorm.DB) {
@@ -296,48 +58,3 @@ func RegisterSetCreatorUpdaterCallback(db *gorm.DB) error {
}) })
return err return err
} }
// hasCreatorField 检查 Schema 是否包含 creator 字段
func hasCreatorField(s *schema.Schema) bool {
if s == nil {
return false
}
_, ok := s.FieldsByDBName["creator"]
return ok
}
// hasShopIDField 检查 Schema 是否包含 shop_id 字段
func hasShopIDField(s *schema.Schema) bool {
if s == nil {
return false
}
_, ok := s.FieldsByDBName["shop_id"]
return ok
}
// hasEnterpriseIDField 检查 Schema 是否包含 enterprise_id 字段
func hasEnterpriseIDField(s *schema.Schema) bool {
if s == nil {
return false
}
_, ok := s.FieldsByDBName["enterprise_id"]
return ok
}
// hasCustomerIDField 检查 Schema 是否包含 customer_id 字段
func hasCustomerIDField(s *schema.Schema) bool {
if s == nil {
return false
}
_, ok := s.FieldsByDBName["customer_id"]
return ok
}
// hasDeletedAtField 检查 Schema 是否包含 deleted_at 字段
func hasDeletedAtField(s *schema.Schema) bool {
if s == nil {
return false
}
_, ok := s.FieldsByDBName["deleted_at"]
return ok
}

View File

@@ -5,16 +5,19 @@ import (
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"go.uber.org/zap"
) )
// UserContextInfo 用户上下文信息 // UserContextInfo 用户上下文信息
type UserContextInfo struct { type UserContextInfo struct {
UserID uint UserID uint
UserType int UserType int
ShopID uint ShopID uint
EnterpriseID uint EnterpriseID uint
CustomerID uint CustomerID uint
SubordinateShopIDs []uint // 代理用户的下级店铺ID列表nil 表示不受数据权限限制
} }
// SetUserContext 将用户信息设置到 context 中 // SetUserContext 将用户信息设置到 context 中
@@ -25,6 +28,10 @@ func SetUserContext(ctx context.Context, info *UserContextInfo) context.Context
ctx = context.WithValue(ctx, constants.ContextKeyShopID, info.ShopID) ctx = context.WithValue(ctx, constants.ContextKeyShopID, info.ShopID)
ctx = context.WithValue(ctx, constants.ContextKeyEnterpriseID, info.EnterpriseID) ctx = context.WithValue(ctx, constants.ContextKeyEnterpriseID, info.EnterpriseID)
ctx = context.WithValue(ctx, constants.ContextKeyCustomerID, info.CustomerID) ctx = context.WithValue(ctx, constants.ContextKeyCustomerID, info.CustomerID)
// SubordinateShopIDs: nil 表示不限制,空切片表示无权限
if info.SubordinateShopIDs != nil {
ctx = context.WithValue(ctx, constants.ContextKeySubordinateShopIDs, info.SubordinateShopIDs)
}
return ctx return ctx
} }
@@ -134,12 +141,21 @@ func SetUserToFiberContext(c *fiber.Ctx, info *UserContextInfo) {
c.Locals(constants.ContextKeyShopID, info.ShopID) c.Locals(constants.ContextKeyShopID, info.ShopID)
c.Locals(constants.ContextKeyEnterpriseID, info.EnterpriseID) c.Locals(constants.ContextKeyEnterpriseID, info.EnterpriseID)
c.Locals(constants.ContextKeyCustomerID, info.CustomerID) c.Locals(constants.ContextKeyCustomerID, info.CustomerID)
if info.SubordinateShopIDs != nil {
c.Locals(constants.ContextKeySubordinateShopIDs, info.SubordinateShopIDs)
}
// 设置到标准 context用于 GORM 数据权限过滤) // 设置到标准 context用于数据权限过滤
ctx := SetUserContext(c.UserContext(), info) ctx := SetUserContext(c.UserContext(), info)
c.SetUserContext(ctx) c.SetUserContext(ctx)
} }
// AuthShopStoreInterface 店铺存储接口
// 用于 Auth 中间件获取下级店铺 ID避免循环依赖
type AuthShopStoreInterface interface {
GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error)
}
// AuthConfig Auth 中间件配置 // AuthConfig Auth 中间件配置
type AuthConfig struct { type AuthConfig struct {
// TokenExtractor 自定义 token 提取函数 // TokenExtractor 自定义 token 提取函数
@@ -153,6 +169,10 @@ type AuthConfig struct {
// SkipPaths 跳过认证的路径列表 // SkipPaths 跳过认证的路径列表
SkipPaths []string SkipPaths []string
// ShopStore 店铺存储,用于预计算代理用户的下级店铺 ID
// 可选,不传则不预计算 SubordinateShopIDs
ShopStore AuthShopStoreInterface
} }
// Auth 认证中间件 // Auth 认证中间件
@@ -196,6 +216,21 @@ func Auth(config AuthConfig) fiber.Handler {
return errors.Wrap(errors.CodeInvalidToken, err, "认证令牌无效") return errors.Wrap(errors.CodeInvalidToken, err, "认证令牌无效")
} }
// 预计算代理用户的下级店铺 ID
if config.ShopStore != nil &&
userInfo.UserType == constants.UserTypeAgent &&
userInfo.ShopID > 0 {
shopIDs, err := config.ShopStore.GetSubordinateShopIDs(c.UserContext(), userInfo.ShopID)
if err != nil {
// 降级处理:只包含自己的店铺 ID
shopIDs = []uint{userInfo.ShopID}
logger.GetAppLogger().Warn("预计算下级店铺失败,降级为只包含自己",
zap.Uint("shop_id", userInfo.ShopID),
zap.Error(err))
}
userInfo.SubordinateShopIDs = shopIDs
}
// 将用户信息设置到 context // 将用户信息设置到 context
SetUserToFiberContext(c, userInfo) SetUserToFiberContext(c, userInfo)

View File

@@ -0,0 +1,91 @@
package middleware
import (
"context"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"gorm.io/gorm"
)
// GetSubordinateShopIDs 获取当前用户可管理的店铺ID列表
// 返回 nil 表示不受数据权限限制(平台用户/超管)
// 返回 []uint 表示限制在这些店铺范围内(代理用户)
func GetSubordinateShopIDs(ctx context.Context) []uint {
if ctx == nil {
return nil
}
if ids, ok := ctx.Value(constants.ContextKeySubordinateShopIDs).([]uint); ok {
return ids
}
return nil
}
// ApplyShopFilter 应用店铺数据权限过滤
// 平台用户/超管不添加条件SubordinateShopIDs 为 nil
// 代理用户WHERE shop_id IN (subordinateShopIDs)
// 注意NULL shop_id 的记录对代理用户不可见
func ApplyShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB {
shopIDs := GetSubordinateShopIDs(ctx)
if shopIDs == nil {
return query
}
return query.Where("shop_id IN ?", shopIDs)
}
// ApplyEnterpriseFilter 应用企业数据权限过滤
// 非企业用户:不添加条件
// 企业用户WHERE enterprise_id = ?
func ApplyEnterpriseFilter(ctx context.Context, query *gorm.DB) *gorm.DB {
userType := GetUserTypeFromContext(ctx)
if userType != constants.UserTypeEnterprise {
return query
}
enterpriseID := GetEnterpriseIDFromContext(ctx)
if enterpriseID == 0 {
// 企业用户但无企业ID返回空结果
return query.Where("1 = 0")
}
return query.Where("enterprise_id = ?", enterpriseID)
}
// ApplyOwnerShopFilter 应用归属店铺数据权限过滤
// 用于 Enterprise 等使用 owner_shop_id 字段的表
// 平台用户/超管:不添加条件
// 代理用户WHERE owner_shop_id IN (subordinateShopIDs)
func ApplyOwnerShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB {
shopIDs := GetSubordinateShopIDs(ctx)
if shopIDs == nil {
return query
}
return query.Where("owner_shop_id IN ?", shopIDs)
}
// IsUnrestricted 检查当前用户是否不受数据权限限制
// 平台用户/超管返回 true代理/企业用户返回 false
func IsUnrestricted(ctx context.Context) bool {
return GetSubordinateShopIDs(ctx) == nil
}
// ApplySellerShopFilter 应用销售店铺数据权限过滤
// 用于 Order 等使用 seller_shop_id 字段的表
// 平台用户/超管:不添加条件
// 代理用户WHERE seller_shop_id IN (subordinateShopIDs)
func ApplySellerShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB {
shopIDs := GetSubordinateShopIDs(ctx)
if shopIDs == nil {
return query
}
return query.Where("seller_shop_id IN ?", shopIDs)
}
// ApplyShopTagFilter 应用店铺标签数据权限过滤
// 用于 CardWallet 等使用 shop_id_tag 字段的表
// 平台用户/超管:不添加条件
// 代理用户WHERE shop_id_tag IN (subordinateShopIDs)
func ApplyShopTagFilter(ctx context.Context, query *gorm.DB) *gorm.DB {
shopIDs := GetSubordinateShopIDs(ctx)
if shopIDs == nil {
return query
}
return query.Where("shop_id_tag IN ?", shopIDs)
}

View File

@@ -2,20 +2,13 @@ package middleware
import ( import (
"context" "context"
"slices"
"github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/model"
"github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/break/junhong_cmp_fiber/pkg/errors"
) )
// ShopStoreInterface 店铺存储接口
// 用于权限检查时查询店铺信息和下级店铺ID
type ShopStoreInterface interface {
GetByID(ctx context.Context, id uint) (*model.Shop, error)
GetByIDs(ctx context.Context, ids []uint) ([]*model.Shop, error)
GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error)
}
// EnterpriseStoreInterface 企业存储接口 // EnterpriseStoreInterface 企业存储接口
// 用于权限检查时查询企业信息 // 用于权限检查时查询企业信息
type EnterpriseStoreInterface interface { type EnterpriseStoreInterface interface {
@@ -23,91 +16,80 @@ type EnterpriseStoreInterface interface {
GetByIDs(ctx context.Context, ids []uint) ([]*model.Enterprise, error) GetByIDs(ctx context.Context, ids []uint) ([]*model.Enterprise, error)
} }
// CanManageShop 检查当前用户是否有权管理目标店铺的账号 // CanManageShop 检查当前用户是否有权管理目标店铺
// 超级管理员和平台用户自动通过 // 超级管理员和平台用户自动通过SubordinateShopIDs 为 nil
// 代理账号只能管理自己店铺及下级店铺的账号 // 代理账号只能管理自己店铺及下级店铺
// 企业账号禁止管理店铺账号 // 企业账号禁止管理店铺
func CanManageShop(ctx context.Context, targetShopID uint, shopStore ShopStoreInterface) error { func CanManageShop(ctx context.Context, targetShopID uint) error {
userType := GetUserTypeFromContext(ctx) userType := GetUserTypeFromContext(ctx)
// 超级管理员和平台用户跳过权限检查 // 企业账号禁止管理店铺
if userType == constants.UserTypeSuperAdmin || userType == constants.UserTypePlatform { if userType == constants.UserTypeEnterprise {
return errors.New(errors.CodeForbidden, "无权限管理店铺")
}
// 从 Context 获取预计算的下级店铺 ID 列表
subordinateIDs := GetSubordinateShopIDs(ctx)
// nil 表示不受限制(超级管理员/平台用户)
if subordinateIDs == nil {
return nil return nil
} }
// 企业账号禁止管理店铺账号
if userType != constants.UserTypeAgent {
return errors.New(errors.CodeForbidden, "无权限管理店铺账号")
}
// 获取当前代理账号的店铺ID
currentShopID := GetShopIDFromContext(ctx)
if currentShopID == 0 {
return errors.New(errors.CodeForbidden, "无权限管理店铺账号")
}
// 递归查询下级店铺ID包含自己
subordinateIDs, err := shopStore.GetSubordinateShopIDs(ctx, currentShopID)
if err != nil {
return errors.Wrap(errors.CodeInternalError, err, "查询下级店铺失败")
}
// 检查目标店铺是否在下级列表中 // 检查目标店铺是否在下级列表中
for _, id := range subordinateIDs { if slices.Contains(subordinateIDs, targetShopID) {
if id == targetShopID {
return nil
}
}
return errors.New(errors.CodeForbidden, "无权限管理该店铺的账号")
}
// CanManageEnterprise 检查当前用户是否有权管理目标企业的账号
// 超级管理员和平台用户自动通过
// 代理账号只能管理归属于自己店铺或下级店铺的企业账号
// 企业账号禁止管理其他企业账号
func CanManageEnterprise(ctx context.Context, targetEnterpriseID uint, enterpriseStore EnterpriseStoreInterface, shopStore ShopStoreInterface) error {
userType := GetUserTypeFromContext(ctx)
// 超级管理员和平台用户跳过权限检查
if userType == constants.UserTypeSuperAdmin || userType == constants.UserTypePlatform {
return nil return nil
} }
// 企业账号禁止管理其他企业账号 return errors.New(errors.CodeForbidden, "无权限管理该店铺")
if userType != constants.UserTypeAgent { }
return errors.New(errors.CodeForbidden, "无权限管理企业账号")
// CanManageEnterprise 检查当前用户是否有权管理目标企业
// 超级管理员和平台用户自动通过SubordinateShopIDs 为 nil
// 代理账号只能管理归属于自己店铺或下级店铺的企业
// 企业账号禁止管理其他企业
func CanManageEnterprise(ctx context.Context, targetEnterpriseID uint, enterpriseStore EnterpriseStoreInterface) error {
userType := GetUserTypeFromContext(ctx)
// 企业账号禁止管理其他企业
if userType == constants.UserTypeEnterprise {
return errors.New(errors.CodeForbidden, "无权限管理企业")
}
// 从 Context 获取预计算的下级店铺 ID 列表
subordinateIDs := GetSubordinateShopIDs(ctx)
// nil 表示不受限制(超级管理员/平台用户)
if subordinateIDs == nil {
return nil
} }
// 获取目标企业信息 // 获取目标企业信息
enterprise, err := enterpriseStore.GetByID(ctx, targetEnterpriseID) enterprise, err := enterpriseStore.GetByID(ctx, targetEnterpriseID)
if err != nil { if err != nil {
return errors.Wrap(errors.CodeForbidden, err, "无权限操作该资源或资源不存在") return errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在")
} }
// 代理账号不能管理平台级企业owner_shop_idNULL // 代理账号不能管理平台级企业owner_shop_idNULL
if enterprise.OwnerShopID == nil { if enterprise.OwnerShopID == nil {
return errors.New(errors.CodeForbidden, "无权限管理平台级企业账号") return errors.New(errors.CodeForbidden, "无权限管理平台级企业")
}
// 获取当前代理账号的店铺ID
currentShopID := GetShopIDFromContext(ctx)
if currentShopID == 0 {
return errors.New(errors.CodeForbidden, "无权限管理企业账号")
}
// 递归查询下级店铺ID包含自己
subordinateIDs, err := shopStore.GetSubordinateShopIDs(ctx, currentShopID)
if err != nil {
return errors.Wrap(errors.CodeInternalError, err, "查询下级店铺失败")
} }
// 检查企业归属的店铺是否在下级列表中 // 检查企业归属的店铺是否在下级列表中
for _, id := range subordinateIDs { if slices.Contains(subordinateIDs, *enterprise.OwnerShopID) {
if id == *enterprise.OwnerShopID { return nil
return nil
}
} }
return errors.New(errors.CodeForbidden, "无权限管理该企业的账号") return errors.New(errors.CodeForbidden, "无权限管理该企业")
}
// ContainsShopID 检查目标店铺 ID 是否在当前用户可管理的店铺列表中
// 平台用户/超管返回 true不受限制
// 代理用户检查是否在 SubordinateShopIDs 中
func ContainsShopID(ctx context.Context, targetShopID uint) bool {
subordinateIDs := GetSubordinateShopIDs(ctx)
if subordinateIDs == nil {
return true // 不受限制
}
return slices.Contains(subordinateIDs, targetShopID)
} }