diff --git a/internal/bootstrap/admin.go b/internal/bootstrap/admin.go index fca57d5..fc120ac 100644 --- a/internal/bootstrap/admin.go +++ b/internal/bootstrap/admin.go @@ -6,7 +6,6 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/pkg/config" "github.com/break/junhong_cmp_fiber/pkg/constants" - pkgGorm "github.com/break/junhong_cmp_fiber/pkg/gorm" "go.uber.org/zap" ) @@ -15,7 +14,6 @@ func initDefaultAdmin(deps *Dependencies, services *services) error { cfg := config.Get() ctx := context.Background() - ctx = pkgGorm.SkipDataPermission(ctx) var count int64 if err := deps.DB.WithContext(ctx).Model(&model.Account{}).Where("user_type = ?", constants.UserTypeSuperAdmin).Count(&count).Error; err != nil { diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go index 18f60a6..9117184 100644 --- a/internal/bootstrap/bootstrap.go +++ b/internal/bootstrap/bootstrap.go @@ -45,8 +45,8 @@ func Bootstrap(deps *Dependencies) (*BootstrapResult, error) { deps.Logger.Error("初始化默认超级管理员失败", zap.Error(err)) } - // 5. 初始化 Middleware 层 - middlewares := initMiddlewares(deps) + // 5. 初始化 Middleware 层(传入 ShopStore 以支持预计算下级店铺 ID) + middlewares := initMiddlewares(deps, stores) // 6. 初始化 Handler 层 handlers := initHandlers(services, deps) @@ -59,17 +59,12 @@ func Bootstrap(deps *Dependencies) (*BootstrapResult, error) { // registerGORMCallbacks 注册 GORM Callbacks func registerGORMCallbacks(deps *Dependencies, stores *stores) error { - // 注册数据权限过滤 Callback(使用 ShopStore 来查询下级店铺 ID) - if err := pkgGorm.RegisterDataPermissionCallback(deps.DB, stores.Shop); err != nil { - return err - } - // 注册自动添加创建&更新人 Callback if err := pkgGorm.RegisterSetCreatorUpdaterCallback(deps.DB); err != nil { return err } - // TODO: 在此添加其他 GORM Callbacks + // 数据权限过滤已移至 Store 层显式调用 ApplyXxxFilter 函数 return nil } diff --git a/internal/bootstrap/middlewares.go b/internal/bootstrap/middlewares.go index 52ef250..03357de 100644 --- a/internal/bootstrap/middlewares.go +++ b/internal/bootstrap/middlewares.go @@ -14,7 +14,7 @@ import ( ) // initMiddlewares 初始化所有中间件 -func initMiddlewares(deps *Dependencies) *Middlewares { +func initMiddlewares(deps *Dependencies, stores *stores) *Middlewares { // 获取全局配置 cfg := config.Get() @@ -29,11 +29,11 @@ func initMiddlewares(deps *Dependencies) *Middlewares { refreshTTL := time.Duration(cfg.JWT.RefreshTokenTTL) * time.Second tokenManager := pkgauth.NewTokenManager(deps.Redis, accessTTL, refreshTTL) - // 创建后台认证中间件 - adminAuthMiddleware := createAdminAuthMiddleware(tokenManager) + // 创建后台认证中间件(传入 ShopStore 以支持预计算下级店铺 ID) + adminAuthMiddleware := createAdminAuthMiddleware(tokenManager, stores.Shop) - // 创建H5认证中间件 - h5AuthMiddleware := createH5AuthMiddleware(tokenManager) + // 创建H5认证中间件(传入 ShopStore 以支持预计算下级店铺 ID) + h5AuthMiddleware := createH5AuthMiddleware(tokenManager, stores.Shop) return &Middlewares{ PersonalAuth: personalAuthMiddleware, @@ -42,7 +42,7 @@ func initMiddlewares(deps *Dependencies) *Middlewares { } } -func createAdminAuthMiddleware(tokenManager *pkgauth.TokenManager) fiber.Handler { +func createAdminAuthMiddleware(tokenManager *pkgauth.TokenManager, shopStore pkgmiddleware.AuthShopStoreInterface) fiber.Handler { return pkgmiddleware.Auth(pkgmiddleware.AuthConfig{ TokenValidator: func(token string) (*pkgmiddleware.UserContextInfo, error) { tokenInfo, err := tokenManager.ValidateAccessToken(context.Background(), token) @@ -65,10 +65,11 @@ func createAdminAuthMiddleware(tokenManager *pkgauth.TokenManager) fiber.Handler }, nil }, SkipPaths: []string{"/api/admin/login", "/api/admin/refresh-token"}, + ShopStore: shopStore, }) } -func createH5AuthMiddleware(tokenManager *pkgauth.TokenManager) fiber.Handler { +func createH5AuthMiddleware(tokenManager *pkgauth.TokenManager, shopStore pkgmiddleware.AuthShopStoreInterface) fiber.Handler { return pkgmiddleware.Auth(pkgmiddleware.AuthConfig{ TokenValidator: func(token string) (*pkgmiddleware.UserContextInfo, error) { tokenInfo, err := tokenManager.ValidateAccessToken(context.Background(), token) @@ -90,5 +91,6 @@ func createH5AuthMiddleware(tokenManager *pkgauth.TokenManager) fiber.Handler { }, nil }, SkipPaths: []string{"/api/h5/login", "/api/h5/refresh-token"}, + ShopStore: shopStore, }) } diff --git a/internal/bootstrap/services.go b/internal/bootstrap/services.go index 4694f5e..bf2ceed 100644 --- a/internal/bootstrap/services.go +++ b/internal/bootstrap/services.go @@ -147,6 +147,6 @@ func initServices(s *stores, deps *Dependencies) *services { PollingMonitoring: pollingSvc.NewMonitoringService(deps.Redis), PollingAlert: pollingSvc.NewAlertService(s.PollingAlertRule, s.PollingAlertHistory, deps.Redis, deps.Logger), PollingCleanup: pollingSvc.NewCleanupService(s.DataCleanupConfig, s.DataCleanupLog, deps.Logger), - PollingManualTrigger: pollingSvc.NewManualTriggerService(s.PollingManualTriggerLog, s.IotCard, s.Shop, deps.Redis, deps.Logger), + PollingManualTrigger: pollingSvc.NewManualTriggerService(s.PollingManualTriggerLog, s.IotCard, deps.Redis, deps.Logger), } } diff --git a/internal/service/account/service.go b/internal/service/account/service.go index c1b4d99..9de7c96 100644 --- a/internal/service/account/service.go +++ b/internal/service/account/service.go @@ -17,13 +17,18 @@ import ( "gorm.io/gorm" ) +// ShopStoreInterface 店铺存储接口(仅用于获取店铺信息) +type ShopStoreInterface interface { + GetByIDs(ctx context.Context, ids []uint) ([]*model.Shop, error) +} + // Service 账号业务服务 type Service struct { accountStore *postgres.AccountStore roleStore *postgres.RoleStore accountRoleStore *postgres.AccountRoleStore shopRoleStore *postgres.ShopRoleStore - shopStore middleware.ShopStoreInterface + shopStore ShopStoreInterface enterpriseStore middleware.EnterpriseStoreInterface auditService AuditServiceInterface } @@ -38,7 +43,7 @@ func New( roleStore *postgres.RoleStore, accountRoleStore *postgres.AccountRoleStore, shopRoleStore *postgres.ShopRoleStore, - shopStore middleware.ShopStoreInterface, + shopStore ShopStoreInterface, enterpriseStore middleware.EnterpriseStoreInterface, auditService AuditServiceInterface, ) *Service { @@ -79,13 +84,13 @@ func (s *Service) Create(ctx context.Context, req *dto.CreateAccountRequest) (*m } if req.UserType == constants.UserTypeAgent && req.ShopID != nil { - if err := middleware.CanManageShop(ctx, *req.ShopID, s.shopStore); err != nil { + if err := middleware.CanManageShop(ctx, *req.ShopID); err != nil { return nil, err } } if req.UserType == constants.UserTypeEnterprise && req.EnterpriseID != nil { - if err := middleware.CanManageEnterprise(ctx, *req.EnterpriseID, s.enterpriseStore, s.shopStore); err != nil { + if err := middleware.CanManageEnterprise(ctx, *req.EnterpriseID, s.enterpriseStore); err != nil { return nil, err } } @@ -190,7 +195,7 @@ func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdateAccountReq if account.ShopID == nil { return nil, errors.New(errors.CodeForbidden, "无权限操作该账号") } - if err := middleware.CanManageShop(ctx, *account.ShopID, s.shopStore); err != nil { + if err := middleware.CanManageShop(ctx, *account.ShopID); err != nil { return nil, errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在") } } @@ -291,7 +296,7 @@ func (s *Service) Delete(ctx context.Context, id uint) error { if account.ShopID == nil { return errors.New(errors.CodeForbidden, "无权限操作该账号") } - if err := middleware.CanManageShop(ctx, *account.ShopID, s.shopStore); err != nil { + if err := middleware.CanManageShop(ctx, *account.ShopID); err != nil { return errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在") } } @@ -407,7 +412,7 @@ func (s *Service) AssignRoles(ctx context.Context, accountID uint, roleIDs []uin if account.ShopID == nil { return nil, errors.New(errors.CodeForbidden, "无权限操作该账号") } - if err := middleware.CanManageShop(ctx, *account.ShopID, s.shopStore); err != nil { + if err := middleware.CanManageShop(ctx, *account.ShopID); err != nil { return nil, errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在") } } @@ -558,7 +563,7 @@ func (s *Service) RemoveRole(ctx context.Context, accountID, roleID uint) error if account.ShopID == nil { return errors.New(errors.CodeForbidden, "无权限操作该账号") } - if err := middleware.CanManageShop(ctx, *account.ShopID, s.shopStore); err != nil { + if err := middleware.CanManageShop(ctx, *account.ShopID); err != nil { return errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在") } } diff --git a/internal/service/auth/service.go b/internal/service/auth/service.go index fd0138d..3cca105 100644 --- a/internal/service/auth/service.go +++ b/internal/service/auth/service.go @@ -10,7 +10,6 @@ import ( "github.com/break/junhong_cmp_fiber/pkg/auth" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" - pkgGorm "github.com/break/junhong_cmp_fiber/pkg/gorm" "go.uber.org/zap" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" @@ -47,8 +46,6 @@ func New( } func (s *Service) Login(ctx context.Context, req *dto.LoginRequest, clientIP string) (*dto.LoginResponse, error) { - ctx = pkgGorm.SkipDataPermission(ctx) - account, err := s.accountStore.GetByUsernameOrPhone(ctx, req.Username) if err != nil { if err == gorm.ErrRecordNotFound { diff --git a/internal/service/enterprise_device/service.go b/internal/service/enterprise_device/service.go index d643abd..cb4619d 100644 --- a/internal/service/enterprise_device/service.go +++ b/internal/service/enterprise_device/service.go @@ -9,7 +9,6 @@ import ( "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" - pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm" "github.com/break/junhong_cmp_fiber/pkg/middleware" "go.uber.org/zap" "gorm.io/gorm" @@ -426,10 +425,8 @@ func (s *Service) ListDevicesForEnterprise(ctx context.Context, req *dto.Enterpr authMap[auth.DeviceID] = auth } - skipCtx := pkggorm.SkipDataPermission(ctx) - var devices []model.Device - query := s.db.WithContext(skipCtx).Where("id IN ?", deviceIDs) + query := s.db.WithContext(ctx).Where("id IN ?", deviceIDs) if req.DeviceNo != "" { query = query.Where("device_no LIKE ?", "%"+req.DeviceNo+"%") } @@ -438,7 +435,7 @@ func (s *Service) ListDevicesForEnterprise(ctx context.Context, req *dto.Enterpr } var bindings []model.DeviceSimBinding - if err := s.db.WithContext(skipCtx). + if err := s.db.WithContext(ctx). Where("device_id IN ? AND bind_status = 1", deviceIDs). Find(&bindings).Error; err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "查询设备绑定卡失败") @@ -480,15 +477,14 @@ func (s *Service) GetDeviceDetail(ctx context.Context, deviceID uint) (*dto.Ente return nil, errors.New(errors.CodeDeviceNotAuthorized, "设备未授权给此企业") } - skipCtx := pkggorm.SkipDataPermission(ctx) var device model.Device - if err := s.db.WithContext(skipCtx).Where("id = ?", deviceID).First(&device).Error; err != nil { + if err := s.db.WithContext(ctx).Where("id = ?", deviceID).First(&device).Error; err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "查询设备信息失败") } var bindings []model.DeviceSimBinding - if err := s.db.WithContext(skipCtx). + if err := s.db.WithContext(ctx). Where("device_id = ? AND bind_status = 1", deviceID). Find(&bindings).Error; err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "查询设备绑定卡失败") @@ -502,7 +498,7 @@ func (s *Service) GetDeviceDetail(ctx context.Context, deviceID uint) (*dto.Ente var cards []model.IotCard cardInfos := make([]dto.DeviceCardInfo, 0) if len(cardIDs) > 0 { - if err := s.db.WithContext(skipCtx).Where("id IN ?", cardIDs).Find(&cards).Error; err != nil { + if err := s.db.WithContext(ctx).Where("id IN ?", cardIDs).Find(&cards).Error; err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "查询卡信息失败") } @@ -514,7 +510,7 @@ func (s *Service) GetDeviceDetail(ctx context.Context, deviceID uint) (*dto.Ente var carriers []model.Carrier carrierMap := make(map[uint]string) if len(carrierIDs) > 0 { - if err := s.db.WithContext(skipCtx).Where("id IN ?", carrierIDs).Find(&carriers).Error; err == nil { + if err := s.db.WithContext(ctx).Where("id IN ?", carrierIDs).Find(&carriers).Error; err == nil { for _, carrier := range carriers { carrierMap[carrier.ID] = carrier.CarrierName } @@ -551,8 +547,7 @@ func (s *Service) SuspendCard(ctx context.Context, deviceID, cardID uint, req *d return nil, err } - skipCtx := pkggorm.SkipDataPermission(ctx) - if err := s.db.WithContext(skipCtx).Model(&model.IotCard{}). + if err := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("id = ?", cardID). Update("network_status", 0).Error; err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "停机操作失败") @@ -569,8 +564,7 @@ func (s *Service) ResumeCard(ctx context.Context, deviceID, cardID uint, req *dt return nil, err } - skipCtx := pkggorm.SkipDataPermission(ctx) - if err := s.db.WithContext(skipCtx).Model(&model.IotCard{}). + if err := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("id = ?", cardID). Update("network_status", 1).Error; err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "复机操作失败") @@ -593,17 +587,16 @@ func (s *Service) validateCardOperation(ctx context.Context, deviceID, cardID ui return errors.New(errors.CodeDeviceNotAuthorized, "设备未授权给此企业") } - skipCtx := pkggorm.SkipDataPermission(ctx) var binding model.DeviceSimBinding - if err := s.db.WithContext(skipCtx). + if err := s.db.WithContext(ctx). Where("device_id = ? AND iot_card_id = ? AND bind_status = 1", deviceID, cardID). First(&binding).Error; err != nil { return errors.New(errors.CodeForbidden, "卡不属于该设备") } var cardAuth model.EnterpriseCardAuthorization - if err := s.db.WithContext(skipCtx). + if err := s.db.WithContext(ctx). Where("enterprise_id = ? AND card_id = ? AND device_auth_id IS NOT NULL AND revoked_at IS NULL", enterpriseID, cardID). First(&cardAuth).Error; err != nil { return errors.New(errors.CodeForbidden, "无权操作此卡") diff --git a/internal/service/polling/manual_trigger_service.go b/internal/service/polling/manual_trigger_service.go index b0a50d2..93b4bfd 100644 --- a/internal/service/polling/manual_trigger_service.go +++ b/internal/service/polling/manual_trigger_service.go @@ -19,7 +19,6 @@ import ( type ManualTriggerService struct { logStore *postgres.PollingManualTriggerLogStore iotCardStore *postgres.IotCardStore - shopStore middleware.ShopStoreInterface redis *redis.Client logger *zap.Logger } @@ -28,14 +27,12 @@ type ManualTriggerService struct { func NewManualTriggerService( logStore *postgres.PollingManualTriggerLogStore, iotCardStore *postgres.IotCardStore, - shopStore middleware.ShopStoreInterface, redis *redis.Client, logger *zap.Logger, ) *ManualTriggerService { return &ManualTriggerService{ logStore: logStore, iotCardStore: iotCardStore, - shopStore: shopStore, redis: redis, logger: logger, } @@ -386,7 +383,7 @@ func (s *ManualTriggerService) canManageCard(ctx context.Context, cardID uint) e } // 检查代理是否有权管理该店铺 - return middleware.CanManageShop(ctx, *card.ShopID, s.shopStore) + return middleware.CanManageShop(ctx, *card.ShopID) } // canManageCards 检查用户是否有权管理多张卡 @@ -403,18 +400,13 @@ func (s *ManualTriggerService) canManageCards(ctx context.Context, cardIDs []uin return errors.New(errors.CodeForbidden, "企业账号无权限手动触发轮询") } - // 代理账号只能管理自己店铺及下级店铺的卡 - currentShopID := middleware.GetShopIDFromContext(ctx) - if currentShopID == 0 { + // 从 Context 获取预计算的下级店铺 ID 列表 + subordinateIDs := middleware.GetSubordinateShopIDs(ctx) + if subordinateIDs == nil { + // 平台用户/超管不受限制,但这里不应该进入(前面已经检查过用户类型) return errors.New(errors.CodeForbidden, "无权限操作") } - // 获取下级店铺ID列表 - subordinateIDs, err := s.shopStore.GetSubordinateShopIDs(ctx, currentShopID) - if err != nil { - return errors.Wrap(errors.CodeInternalError, err, "查询下级店铺失败") - } - // 构建可管理的店铺ID集合 allowedShopIDs := make(map[uint]bool) for _, id := range subordinateIDs { @@ -462,7 +454,7 @@ func (s *ManualTriggerService) applyShopPermissionFilter(ctx context.Context, fi // 如果用户指定了 ShopID,验证是否在可管理范围内 if filter.ShopID != nil { - if err := middleware.CanManageShop(ctx, *filter.ShopID, s.shopStore); err != nil { + if err := middleware.CanManageShop(ctx, *filter.ShopID); err != nil { return err } // 已指定有效的 ShopID,无需修改 diff --git a/internal/service/shop/shop_role.go b/internal/service/shop/shop_role.go index f83a6b1..80ab3e2 100644 --- a/internal/service/shop/shop_role.go +++ b/internal/service/shop/shop_role.go @@ -11,7 +11,7 @@ import ( ) func (s *Service) AssignRolesToShop(ctx context.Context, shopID uint, roleIDs []uint) ([]*model.ShopRole, error) { - if err := middleware.CanManageShop(ctx, shopID, s.shopStore); err != nil { + if err := middleware.CanManageShop(ctx, shopID); err != nil { return nil, err } @@ -70,7 +70,7 @@ func (s *Service) AssignRolesToShop(ctx context.Context, shopID uint, roleIDs [] } func (s *Service) GetShopRoles(ctx context.Context, shopID uint) (*dto.ShopRolesResponse, error) { - if err := middleware.CanManageShop(ctx, shopID, s.shopStore); err != nil { + if err := middleware.CanManageShop(ctx, shopID); err != nil { return nil, err } @@ -128,7 +128,7 @@ func (s *Service) GetShopRoles(ctx context.Context, shopID uint) (*dto.ShopRoles } func (s *Service) DeleteShopRole(ctx context.Context, shopID, roleID uint) error { - if err := middleware.CanManageShop(ctx, shopID, s.shopStore); err != nil { + if err := middleware.CanManageShop(ctx, shopID); err != nil { return err } diff --git a/internal/service/shop_series_allocation/service.go b/internal/service/shop_series_allocation/service.go index 9df41b8..2b94dfc 100644 --- a/internal/service/shop_series_allocation/service.go +++ b/internal/service/shop_series_allocation/service.go @@ -10,7 +10,6 @@ import ( "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" - pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm" "github.com/break/junhong_cmp_fiber/pkg/middleware" "gorm.io/gorm" ) @@ -71,9 +70,8 @@ func (s *Service) Create(ctx context.Context, req *dto.CreateShopSeriesAllocatio return nil, errors.Wrap(errors.CodeInternalError, err, "获取套餐系列失败") } - // 检查是否已存在分配(跳过数据权限过滤,避免误判) - skipCtx := pkggorm.SkipDataPermission(ctx) - exists, err := s.seriesAllocationStore.ExistsByShopAndSeries(skipCtx, req.ShopID, req.SeriesID) + // 检查是否已存在分配 + exists, err := s.seriesAllocationStore.ExistsByShopAndSeries(ctx, req.ShopID, req.SeriesID) if err != nil { return nil, errors.Wrap(errors.CodeInternalError, err, "检查分配记录失败") } @@ -84,7 +82,7 @@ func (s *Service) Create(ctx context.Context, req *dto.CreateShopSeriesAllocatio // 代理用户:检查自己是否有该系列的分配权限,且金额不能超过上级给的上限 // 平台用户:无上限限制,可自由设定金额 if userType == constants.UserTypeAgent { - allocatorAllocation, err := s.seriesAllocationStore.GetByShopAndSeries(skipCtx, allocatorShopID, req.SeriesID) + allocatorAllocation, err := s.seriesAllocationStore.GetByShopAndSeries(ctx, allocatorShopID, req.SeriesID) if err != nil { if err == gorm.ErrRecordNotFound { return nil, errors.New(errors.CodeForbidden, "您没有该套餐系列的分配权限") @@ -239,8 +237,7 @@ func (s *Service) Update(ctx context.Context, id uint, req *dto.UpdateShopSeries } func (s *Service) Delete(ctx context.Context, id uint) error { - skipCtx := pkggorm.SkipDataPermission(ctx) - _, err := s.seriesAllocationStore.GetByID(skipCtx, id) + _, err := s.seriesAllocationStore.GetByID(ctx, id) if err != nil { if err == gorm.ErrRecordNotFound { return errors.New(errors.CodeNotFound, "分配记录不存在") @@ -248,7 +245,7 @@ func (s *Service) Delete(ctx context.Context, id uint) error { return errors.Wrap(errors.CodeInternalError, err, "获取分配记录失败") } - count, err := s.packageAllocationStore.CountBySeriesAllocationID(skipCtx, id) + count, err := s.packageAllocationStore.CountBySeriesAllocationID(ctx, id) if err != nil { return errors.Wrap(errors.CodeInternalError, err, "检查关联套餐分配失败") } @@ -256,7 +253,7 @@ func (s *Service) Delete(ctx context.Context, id uint) error { return errors.New(errors.CodeInvalidParam, "存在关联的套餐分配,无法删除") } - if err := s.seriesAllocationStore.Delete(skipCtx, id); err != nil { + if err := s.seriesAllocationStore.Delete(ctx, id); err != nil { return errors.Wrap(errors.CodeInternalError, err, "删除分配失败") } diff --git a/internal/store/postgres/account_store.go b/internal/store/postgres/account_store.go index 33180a3..cc75ac7 100644 --- a/internal/store/postgres/account_store.go +++ b/internal/store/postgres/account_store.go @@ -3,9 +3,9 @@ package postgres import ( "context" - "github.com/break/junhong_cmp_fiber/internal/store" - "github.com/break/junhong_cmp_fiber/internal/model" + "github.com/break/junhong_cmp_fiber/internal/store" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -32,7 +32,12 @@ func (s *AccountStore) Create(ctx context.Context, account *model.Account) error // GetByID 根据 ID 获取账号 func (s *AccountStore) GetByID(ctx context.Context, id uint) (*model.Account, error) { var account model.Account - if err := s.db.WithContext(ctx).First(&account, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 根据当前用户类型应用数据权限过滤 + // 代理用户:过滤 shop_id;企业用户:过滤 enterprise_id + query = middleware.ApplyShopFilter(ctx, query) + query = middleware.ApplyEnterpriseFilter(ctx, query) + if err := query.First(&account).Error; err != nil { return nil, err } return &account, nil @@ -68,7 +73,10 @@ func (s *AccountStore) GetByUsernameOrPhone(ctx context.Context, identifier stri // GetByShopID 根据店铺 ID 查询账号列表 func (s *AccountStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.Account, error) { var accounts []*model.Account - if err := s.db.WithContext(ctx).Where("shop_id = ?", shopID).Find(&accounts).Error; err != nil { + query := s.db.WithContext(ctx).Where("shop_id = ?", shopID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&accounts).Error; err != nil { return nil, err } return accounts, nil @@ -77,7 +85,10 @@ func (s *AccountStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.A // GetByEnterpriseID 根据企业 ID 查询账号列表 func (s *AccountStore) GetByEnterpriseID(ctx context.Context, enterpriseID uint) ([]*model.Account, error) { var accounts []*model.Account - if err := s.db.WithContext(ctx).Where("enterprise_id = ?", enterpriseID).Find(&accounts).Error; err != nil { + query := s.db.WithContext(ctx).Where("enterprise_id = ?", enterpriseID) + // 应用企业数据权限过滤 + query = middleware.ApplyEnterpriseFilter(ctx, query) + if err := query.Find(&accounts).Error; err != nil { return nil, err } return accounts, nil @@ -99,6 +110,10 @@ func (s *AccountStore) List(ctx context.Context, opts *store.QueryOptions, filte var total int64 query := s.db.WithContext(ctx).Model(&model.Account{}) + // 根据当前用户类型应用数据权限过滤 + // 代理用户:过滤 shop_id;企业用户:过滤 enterprise_id + query = middleware.ApplyShopFilter(ctx, query) + query = middleware.ApplyEnterpriseFilter(ctx, query) // 应用过滤条件 if username, ok := filters["username"].(string); ok && username != "" { @@ -229,7 +244,11 @@ func (s *AccountStore) GetByIDs(ctx context.Context, ids []uint) ([]*model.Accou return []*model.Account{}, nil } var accounts []*model.Account - if err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&accounts).Error; err != nil { + query := s.db.WithContext(ctx).Where("id IN ?", ids) + // 根据当前用户类型应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + query = middleware.ApplyEnterpriseFilter(ctx, query) + if err := query.Find(&accounts).Error; err != nil { return nil, err } return accounts, nil @@ -240,9 +259,11 @@ func (s *AccountStore) GetPrimaryAccountsByShopIDs(ctx context.Context, shopIDs return []*model.Account{}, nil } var accounts []*model.Account - if err := s.db.WithContext(ctx). - Where("shop_id IN ? AND is_primary = ?", shopIDs, true). - Find(&accounts).Error; err != nil { + query := s.db.WithContext(ctx). + Where("shop_id IN ? AND is_primary = ?", shopIDs, true) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&accounts).Error; err != nil { return nil, err } return accounts, nil @@ -254,6 +275,8 @@ func (s *AccountStore) ListByShopID(ctx context.Context, shopID uint, opts *stor var total int64 query := s.db.WithContext(ctx).Model(&model.Account{}).Where("shop_id = ?", shopID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if username, ok := filters["username"].(string); ok && username != "" { query = query.Where("username LIKE ?", "%"+username+"%") diff --git a/internal/store/postgres/agent_wallet_store.go b/internal/store/postgres/agent_wallet_store.go index 613f1d1..fe50cfb 100644 --- a/internal/store/postgres/agent_wallet_store.go +++ b/internal/store/postgres/agent_wallet_store.go @@ -6,6 +6,7 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/pkg/constants" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -41,9 +42,11 @@ func (s *AgentWalletStore) GetByShopIDAndType(ctx context.Context, shopID uint, // 注意:这里简化处理,实际项目中可以缓存完整的钱包信息 var wallet model.AgentWallet - err := s.db.WithContext(ctx). - Where("shop_id = ? AND wallet_type = ?", shopID, walletType). - First(&wallet).Error + query := s.db.WithContext(ctx). + Where("shop_id = ? AND wallet_type = ?", shopID, walletType) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + err := query.First(&wallet).Error if err != nil { return nil, err } @@ -58,7 +61,10 @@ func (s *AgentWalletStore) GetByShopIDAndType(ctx context.Context, shopID uint, // GetByID 根据钱包 ID 查询 func (s *AgentWalletStore) GetByID(ctx context.Context, id uint) (*model.AgentWallet, error) { var wallet model.AgentWallet - if err := s.db.WithContext(ctx).First(&wallet, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&wallet).Error; err != nil { return nil, err } return &wallet, nil @@ -209,9 +215,11 @@ func (s *AgentWalletStore) GetShopCommissionSummaryBatch(ctx context.Context, sh } var wallets []model.AgentWallet - err := s.db.WithContext(ctx). - Where("shop_id IN ? AND wallet_type = ?", shopIDs, constants.AgentWalletTypeCommission). - Find(&wallets).Error + query := s.db.WithContext(ctx). + Where("shop_id IN ? AND wallet_type = ?", shopIDs, constants.AgentWalletTypeCommission) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + err := query.Find(&wallets).Error if err != nil { return nil, err } diff --git a/internal/store/postgres/agent_wallet_transaction_store.go b/internal/store/postgres/agent_wallet_transaction_store.go index 01945ec..5e11dfe 100644 --- a/internal/store/postgres/agent_wallet_transaction_store.go +++ b/internal/store/postgres/agent_wallet_transaction_store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/break/junhong_cmp_fiber/internal/model" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -30,9 +31,11 @@ func (s *AgentWalletTransactionStore) CreateWithTx(ctx context.Context, tx *gorm // ListByShopID 按店铺查询交易记录(支持分页) func (s *AgentWalletTransactionStore) ListByShopID(ctx context.Context, shopID uint, offset, limit int) ([]*model.AgentWalletTransaction, error) { var transactions []*model.AgentWalletTransaction - err := s.db.WithContext(ctx). - Where("shop_id = ?", shopID). - Order("created_at DESC"). + query := s.db.WithContext(ctx). + Where("shop_id = ?", shopID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + err := query.Order("created_at DESC"). Offset(offset). Limit(limit). Find(&transactions).Error @@ -45,19 +48,23 @@ func (s *AgentWalletTransactionStore) ListByShopID(ctx context.Context, shopID u // CountByShopID 统计店铺的交易记录数量 func (s *AgentWalletTransactionStore) CountByShopID(ctx context.Context, shopID uint) (int64, error) { var count int64 - err := s.db.WithContext(ctx). + query := s.db.WithContext(ctx). Model(&model.AgentWalletTransaction{}). - Where("shop_id = ?", shopID). - Count(&count).Error + Where("shop_id = ?", shopID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + err := query.Count(&count).Error return count, err } // ListByWalletID 按钱包查询交易记录(支持分页) func (s *AgentWalletTransactionStore) ListByWalletID(ctx context.Context, walletID uint, offset, limit int) ([]*model.AgentWalletTransaction, error) { var transactions []*model.AgentWalletTransaction - err := s.db.WithContext(ctx). - Where("agent_wallet_id = ?", walletID). - Order("created_at DESC"). + query := s.db.WithContext(ctx). + Where("agent_wallet_id = ?", walletID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + err := query.Order("created_at DESC"). Offset(offset). Limit(limit). Find(&transactions).Error @@ -70,9 +77,11 @@ func (s *AgentWalletTransactionStore) ListByWalletID(ctx context.Context, wallet // GetByReference 根据关联业务查询交易记录 func (s *AgentWalletTransactionStore) GetByReference(ctx context.Context, referenceType string, referenceID uint) (*model.AgentWalletTransaction, error) { var transaction model.AgentWalletTransaction - err := s.db.WithContext(ctx). - Where("reference_type = ? AND reference_id = ?", referenceType, referenceID). - First(&transaction).Error + query := s.db.WithContext(ctx). + Where("reference_type = ? AND reference_id = ?", referenceType, referenceID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + err := query.First(&transaction).Error if err != nil { return nil, err } diff --git a/internal/store/postgres/card_wallet_store.go b/internal/store/postgres/card_wallet_store.go index ee0147e..7ad0e89 100644 --- a/internal/store/postgres/card_wallet_store.go +++ b/internal/store/postgres/card_wallet_store.go @@ -6,6 +6,7 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/pkg/constants" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -27,9 +28,11 @@ func NewCardWalletStore(db *gorm.DB, redis *redis.Client) *CardWalletStore { // GetByResourceTypeAndID 根据资源类型和 ID 查询钱包 func (s *CardWalletStore) GetByResourceTypeAndID(ctx context.Context, resourceType string, resourceID uint) (*model.CardWallet, error) { var wallet model.CardWallet - err := s.db.WithContext(ctx). - Where("resource_type = ? AND resource_id = ?", resourceType, resourceID). - First(&wallet).Error + query := s.db.WithContext(ctx). + Where("resource_type = ? AND resource_id = ?", resourceType, resourceID) + // 应用数据权限过滤(使用 shop_id_tag 字段) + query = middleware.ApplyShopTagFilter(ctx, query) + err := query.First(&wallet).Error if err != nil { return nil, err } @@ -39,7 +42,10 @@ func (s *CardWalletStore) GetByResourceTypeAndID(ctx context.Context, resourceTy // GetByID 根据钱包 ID 查询 func (s *CardWalletStore) GetByID(ctx context.Context, id uint) (*model.CardWallet, error) { var wallet model.CardWallet - if err := s.db.WithContext(ctx).First(&wallet, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤(使用 shop_id_tag 字段) + query = middleware.ApplyShopTagFilter(ctx, query) + if err := query.First(&wallet).Error; err != nil { return nil, err } return &wallet, nil diff --git a/internal/store/postgres/card_wallet_transaction_store.go b/internal/store/postgres/card_wallet_transaction_store.go index 89c9a6e..439dfdc 100644 --- a/internal/store/postgres/card_wallet_transaction_store.go +++ b/internal/store/postgres/card_wallet_transaction_store.go @@ -4,6 +4,7 @@ import ( "context" "github.com/break/junhong_cmp_fiber/internal/model" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -30,9 +31,11 @@ func (s *CardWalletTransactionStore) CreateWithTx(ctx context.Context, tx *gorm. // ListByResourceID 按资源查询交易记录(支持分页) func (s *CardWalletTransactionStore) ListByResourceID(ctx context.Context, resourceType string, resourceID uint, offset, limit int) ([]*model.CardWalletTransaction, error) { var transactions []*model.CardWalletTransaction - err := s.db.WithContext(ctx). - Where("resource_type = ? AND resource_id = ?", resourceType, resourceID). - Order("created_at DESC"). + query := s.db.WithContext(ctx). + Where("resource_type = ? AND resource_id = ?", resourceType, resourceID) + // 应用数据权限过滤(使用 shop_id_tag 字段) + query = middleware.ApplyShopTagFilter(ctx, query) + err := query.Order("created_at DESC"). Offset(offset). Limit(limit). Find(&transactions).Error @@ -45,19 +48,23 @@ func (s *CardWalletTransactionStore) ListByResourceID(ctx context.Context, resou // CountByResourceID 统计资源的交易记录数量 func (s *CardWalletTransactionStore) CountByResourceID(ctx context.Context, resourceType string, resourceID uint) (int64, error) { var count int64 - err := s.db.WithContext(ctx). + query := s.db.WithContext(ctx). Model(&model.CardWalletTransaction{}). - Where("resource_type = ? AND resource_id = ?", resourceType, resourceID). - Count(&count).Error + Where("resource_type = ? AND resource_id = ?", resourceType, resourceID) + // 应用数据权限过滤(使用 shop_id_tag 字段) + query = middleware.ApplyShopTagFilter(ctx, query) + err := query.Count(&count).Error return count, err } // ListByWalletID 按钱包查询交易记录(支持分页) func (s *CardWalletTransactionStore) ListByWalletID(ctx context.Context, walletID uint, offset, limit int) ([]*model.CardWalletTransaction, error) { var transactions []*model.CardWalletTransaction - err := s.db.WithContext(ctx). - Where("card_wallet_id = ?", walletID). - Order("created_at DESC"). + query := s.db.WithContext(ctx). + Where("card_wallet_id = ?", walletID) + // 应用数据权限过滤(使用 shop_id_tag 字段) + query = middleware.ApplyShopTagFilter(ctx, query) + err := query.Order("created_at DESC"). Offset(offset). Limit(limit). Find(&transactions).Error @@ -70,9 +77,11 @@ func (s *CardWalletTransactionStore) ListByWalletID(ctx context.Context, walletI // GetByReference 根据关联业务查询交易记录 func (s *CardWalletTransactionStore) GetByReference(ctx context.Context, referenceType string, referenceID uint) (*model.CardWalletTransaction, error) { var transaction model.CardWalletTransaction - err := s.db.WithContext(ctx). - Where("reference_type = ? AND reference_id = ?", referenceType, referenceID). - First(&transaction).Error + query := s.db.WithContext(ctx). + Where("reference_type = ? AND reference_id = ?", referenceType, referenceID) + // 应用数据权限过滤(使用 shop_id_tag 字段) + query = middleware.ApplyShopTagFilter(ctx, query) + err := query.First(&transaction).Error if err != nil { return nil, err } diff --git a/internal/store/postgres/commission_record_store.go b/internal/store/postgres/commission_record_store.go index 848e9d8..be7626b 100644 --- a/internal/store/postgres/commission_record_store.go +++ b/internal/store/postgres/commission_record_store.go @@ -6,6 +6,7 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/pkg/constants" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -28,7 +29,10 @@ func (s *CommissionRecordStore) Create(ctx context.Context, record *model.Commis func (s *CommissionRecordStore) GetByID(ctx context.Context, id uint) (*model.CommissionRecord, error) { var record model.CommissionRecord - if err := s.db.WithContext(ctx).First(&record, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&record).Error; err != nil { return nil, err } return &record, nil @@ -50,6 +54,8 @@ func (s *CommissionRecordStore) ListByShopID(ctx context.Context, opts *store.Qu var total int64 query := s.db.WithContext(ctx).Model(&model.CommissionRecord{}) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if filters != nil { if filters.ShopID > 0 { @@ -107,6 +113,8 @@ type CommissionStats struct { func (s *CommissionRecordStore) GetStats(ctx context.Context, filters *CommissionRecordListFilters) (*CommissionStats, error) { query := s.db.WithContext(ctx).Model(&model.CommissionRecord{}). Where("status = ?", model.CommissionStatusReleased) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if filters != nil { if filters.ShopID > 0 { @@ -151,6 +159,8 @@ func (s *CommissionRecordStore) GetDailyStats(ctx context.Context, filters *Comm query := s.db.WithContext(ctx).Model(&model.CommissionRecord{}). Where("status = ?", model.CommissionStatusReleased) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if filters != nil { if filters.ShopID > 0 { diff --git a/internal/store/postgres/commission_withdrawal_request_store.go b/internal/store/postgres/commission_withdrawal_request_store.go index 0713c01..dd58380 100644 --- a/internal/store/postgres/commission_withdrawal_request_store.go +++ b/internal/store/postgres/commission_withdrawal_request_store.go @@ -7,6 +7,7 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/pkg/constants" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -29,7 +30,10 @@ func (s *CommissionWithdrawalRequestStore) Create(ctx context.Context, req *mode func (s *CommissionWithdrawalRequestStore) GetByID(ctx context.Context, id uint) (*model.CommissionWithdrawalRequest, error) { var req model.CommissionWithdrawalRequest - if err := s.db.WithContext(ctx).First(&req, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&req).Error; err != nil { return nil, err } return &req, nil @@ -52,6 +56,8 @@ func (s *CommissionWithdrawalRequestStore) ListByShopID(ctx context.Context, opt var total int64 query := s.db.WithContext(ctx).Model(&model.CommissionWithdrawalRequest{}) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if filters != nil { if filters.ShopID > 0 { @@ -146,6 +152,8 @@ func (s *CommissionWithdrawalRequestStore) List(ctx context.Context, opts *store var total int64 query := s.db.WithContext(ctx).Model(&model.CommissionWithdrawalRequest{}) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if filters != nil { if filters.WithdrawalNo != "" { diff --git a/internal/store/postgres/device_store.go b/internal/store/postgres/device_store.go index 178bc72..8556343 100644 --- a/internal/store/postgres/device_store.go +++ b/internal/store/postgres/device_store.go @@ -7,6 +7,7 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/pkg/constants" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -36,7 +37,10 @@ func (s *DeviceStore) CreateBatch(ctx context.Context, devices []*model.Device) func (s *DeviceStore) GetByID(ctx context.Context, id uint) (*model.Device, error) { var device model.Device - if err := s.db.WithContext(ctx).First(&device, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤(NULL shop_id 对代理用户不可见) + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&device).Error; err != nil { return nil, err } return &device, nil @@ -44,7 +48,10 @@ func (s *DeviceStore) GetByID(ctx context.Context, id uint) (*model.Device, erro func (s *DeviceStore) GetByDeviceNo(ctx context.Context, deviceNo string) (*model.Device, error) { var device model.Device - if err := s.db.WithContext(ctx).Where("device_no = ?", deviceNo).First(&device).Error; err != nil { + query := s.db.WithContext(ctx).Where("device_no = ?", deviceNo) + // 应用数据权限过滤(NULL shop_id 对代理用户不可见) + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&device).Error; err != nil { return nil, err } return &device, nil @@ -55,7 +62,10 @@ func (s *DeviceStore) GetByIDs(ctx context.Context, ids []uint) ([]*model.Device if len(ids) == 0 { return devices, nil } - if err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&devices).Error; err != nil { + query := s.db.WithContext(ctx).Where("id IN ?", ids) + // 应用数据权限过滤(NULL shop_id 对代理用户不可见) + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&devices).Error; err != nil { return nil, err } return devices, nil @@ -74,6 +84,8 @@ func (s *DeviceStore) List(ctx context.Context, opts *store.QueryOptions, filter var total int64 query := s.db.WithContext(ctx).Model(&model.Device{}) + // 应用数据权限过滤(NULL shop_id 对代理用户不可见) + query = middleware.ApplyShopFilter(ctx, query) if deviceNo, ok := filters["device_no"].(string); ok && deviceNo != "" { query = query.Where("device_no LIKE ?", "%"+deviceNo+"%") @@ -179,7 +191,10 @@ func (s *DeviceStore) GetByDeviceNos(ctx context.Context, deviceNos []string) ([ if len(deviceNos) == 0 { return devices, nil } - if err := s.db.WithContext(ctx).Where("device_no IN ?", deviceNos).Find(&devices).Error; err != nil { + query := s.db.WithContext(ctx).Where("device_no IN ?", deviceNos) + // 应用数据权限过滤(NULL shop_id 对代理用户不可见) + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&devices).Error; err != nil { return nil, err } return devices, nil @@ -198,7 +213,10 @@ func (s *DeviceStore) BatchUpdateSeriesID(ctx context.Context, deviceIDs []uint, // ListBySeriesID 根据套餐系列ID查询设备列表 func (s *DeviceStore) ListBySeriesID(ctx context.Context, seriesID uint) ([]*model.Device, error) { var devices []*model.Device - if err := s.db.WithContext(ctx).Where("series_id = ?", seriesID).Find(&devices).Error; err != nil { + query := s.db.WithContext(ctx).Where("series_id = ?", seriesID) + // 应用数据权限过滤(NULL shop_id 对代理用户不可见) + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&devices).Error; err != nil { return nil, err } return devices, nil diff --git a/internal/store/postgres/enterprise_card_authorization_store.go b/internal/store/postgres/enterprise_card_authorization_store.go index 3ed6d19..61bc5c9 100644 --- a/internal/store/postgres/enterprise_card_authorization_store.go +++ b/internal/store/postgres/enterprise_card_authorization_store.go @@ -6,7 +6,6 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/pkg/constants" - pkgGorm "github.com/break/junhong_cmp_fiber/pkg/gorm" "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" @@ -50,9 +49,11 @@ func (s *EnterpriseCardAuthorizationStore) RevokeAuthorizations(ctx context.Cont func (s *EnterpriseCardAuthorizationStore) GetByEnterpriseAndCard(ctx context.Context, enterpriseID, cardID uint) (*model.EnterpriseCardAuthorization, error) { var auth model.EnterpriseCardAuthorization - err := s.db.WithContext(ctx). - Where("enterprise_id = ? AND card_id = ?", enterpriseID, cardID). - First(&auth).Error + query := s.db.WithContext(ctx). + Where("enterprise_id = ? AND card_id = ?", enterpriseID, cardID) + // 应用数据权限过滤 + query = s.applyEnterpriseAuthFilter(ctx, query) + err := query.First(&auth).Error if err != nil { return nil, err } @@ -62,6 +63,8 @@ func (s *EnterpriseCardAuthorizationStore) GetByEnterpriseAndCard(ctx context.Co func (s *EnterpriseCardAuthorizationStore) ListByEnterprise(ctx context.Context, enterpriseID uint, includeRevoked bool) ([]*model.EnterpriseCardAuthorization, error) { var auths []*model.EnterpriseCardAuthorization query := s.db.WithContext(ctx).Where("enterprise_id = ?", enterpriseID) + // 应用数据权限过滤 + query = s.applyEnterpriseAuthFilter(ctx, query) if !includeRevoked { query = query.Where("revoked_at IS NULL") } @@ -77,6 +80,8 @@ func (s *EnterpriseCardAuthorizationStore) ListByCards(ctx context.Context, card } var auths []*model.EnterpriseCardAuthorization query := s.db.WithContext(ctx).Where("card_id IN ?", cardIDs) + // 应用数据权限过滤 + query = s.applyEnterpriseAuthFilter(ctx, query) if !includeRevoked { query = query.Where("revoked_at IS NULL") } @@ -88,17 +93,21 @@ func (s *EnterpriseCardAuthorizationStore) ListByCards(ctx context.Context, card func (s *EnterpriseCardAuthorizationStore) GetActiveAuthorizedCardIDs(ctx context.Context, enterpriseID uint) ([]uint, error) { var cardIDs []uint - err := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). - Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID). - Pluck("card_id", &cardIDs).Error + query := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). + Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID) + // 应用数据权限过滤 + query = s.applyEnterpriseAuthFilter(ctx, query) + err := query.Pluck("card_id", &cardIDs).Error return cardIDs, err } func (s *EnterpriseCardAuthorizationStore) CheckAuthorizationExists(ctx context.Context, enterpriseID, cardID uint) (bool, error) { var count int64 - err := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). - Where("enterprise_id = ? AND card_id = ? AND revoked_at IS NULL", enterpriseID, cardID). - Count(&count).Error + query := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). + Where("enterprise_id = ? AND card_id = ? AND revoked_at IS NULL", enterpriseID, cardID) + // 应用数据权限过滤 + query = s.applyEnterpriseAuthFilter(ctx, query) + err := query.Count(&count).Error return count > 0, err } @@ -115,6 +124,8 @@ type AuthorizationListOptions struct { func (s *EnterpriseCardAuthorizationStore) ListWithOptions(ctx context.Context, opts AuthorizationListOptions) ([]*model.EnterpriseCardAuthorization, int64, error) { var auths []*model.EnterpriseCardAuthorization query := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}) + // 应用数据权限过滤 + query = s.applyEnterpriseAuthFilter(ctx, query) if opts.EnterpriseID != nil { query = query.Where("enterprise_id = ?", *opts.EnterpriseID) @@ -154,9 +165,11 @@ func (s *EnterpriseCardAuthorizationStore) GetActiveAuthsByCardIDs(ctx context.C return make(map[uint]bool), nil } var authCardIDs []uint - err := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). - Where("enterprise_id = ? AND card_id IN ? AND revoked_at IS NULL", enterpriseID, cardIDs). - Pluck("card_id", &authCardIDs).Error + query := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). + Where("enterprise_id = ? AND card_id IN ? AND revoked_at IS NULL", enterpriseID, cardIDs) + // 应用数据权限过滤 + query = s.applyEnterpriseAuthFilter(ctx, query) + err := query.Pluck("card_id", &authCardIDs).Error if err != nil { return nil, err } @@ -186,9 +199,11 @@ func (s *EnterpriseCardAuthorizationStore) BatchUpdateStatus(ctx context.Context // ListCardIDsByEnterprise 获取企业的有效授权卡ID列表 func (s *EnterpriseCardAuthorizationStore) ListCardIDsByEnterprise(ctx context.Context, enterpriseID uint) ([]uint, error) { var cardIDs []uint - err := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). - Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID). - Pluck("card_id", &cardIDs).Error + query := s.db.WithContext(ctx).Model(&model.EnterpriseCardAuthorization{}). + Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID) + // 应用数据权限过滤 + query = s.applyEnterpriseAuthFilter(ctx, query) + err := query.Pluck("card_id", &cardIDs).Error return cardIDs, err } @@ -233,31 +248,28 @@ func (s *EnterpriseCardAuthorizationStore) ListWithJoin(ctx context.Context, opt args := []interface{}{} // 数据权限过滤(原生 SQL 需要手动处理) - // 检查是否跳过数据权限过滤 - if skip, ok := ctx.Value(pkgGorm.SkipDataPermissionKey).(bool); !ok || !skip { - userType := middleware.GetUserTypeFromContext(ctx) - // 超级管理员和平台用户跳过过滤 - if userType != constants.UserTypeSuperAdmin && userType != constants.UserTypePlatform { - if userType == constants.UserTypeAgent { - shopID := middleware.GetShopIDFromContext(ctx) - if shopID == 0 { - // 代理用户没有 shop_id,返回空结果 - return []AuthorizationWithJoin{}, 0, nil - } - // 只能看到自己店铺下企业的授权记录(不包含下级店铺) - baseQuery += " AND a.enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id = ? AND deleted_at IS NULL)" - args = append(args, shopID) - } else if userType == constants.UserTypeEnterprise { - enterpriseID := middleware.GetEnterpriseIDFromContext(ctx) - if enterpriseID == 0 { - return []AuthorizationWithJoin{}, 0, nil - } - baseQuery += " AND a.enterprise_id = ?" - args = append(args, enterpriseID) - } else { - // 其他用户类型(个人客户等)不应访问授权记录 + userType := middleware.GetUserTypeFromContext(ctx) + // 超级管理员和平台用户跳过过滤 + if userType != constants.UserTypeSuperAdmin && userType != constants.UserTypePlatform { + if userType == constants.UserTypeAgent { + // 代理用户:只能看到自己及下级店铺所拥有企业的授权记录 + shopIDs := middleware.GetSubordinateShopIDs(ctx) + if len(shopIDs) == 0 { + // 代理用户没有下级店铺信息,返回空结果 return []AuthorizationWithJoin{}, 0, nil } + baseQuery += " AND a.enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id IN (?) AND deleted_at IS NULL)" + args = append(args, shopIDs) + } else if userType == constants.UserTypeEnterprise { + enterpriseID := middleware.GetEnterpriseIDFromContext(ctx) + if enterpriseID == 0 { + return []AuthorizationWithJoin{}, 0, nil + } + baseQuery += " AND a.enterprise_id = ?" + args = append(args, enterpriseID) + } else { + // 其他用户类型(个人客户等)不应访问授权记录 + return []AuthorizationWithJoin{}, 0, nil } } @@ -338,26 +350,25 @@ func (s *EnterpriseCardAuthorizationStore) GetByIDWithJoin(ctx context.Context, args := []interface{}{id} // 数据权限过滤(原生 SQL 需要手动处理) - if skip, ok := ctx.Value(pkgGorm.SkipDataPermissionKey).(bool); !ok || !skip { - userType := middleware.GetUserTypeFromContext(ctx) - if userType != constants.UserTypeSuperAdmin && userType != constants.UserTypePlatform { - if userType == constants.UserTypeAgent { - shopID := middleware.GetShopIDFromContext(ctx) - if shopID == 0 { - return nil, gorm.ErrRecordNotFound - } - baseSQL += " AND a.enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id = ? AND deleted_at IS NULL)" - args = append(args, shopID) - } else if userType == constants.UserTypeEnterprise { - enterpriseID := middleware.GetEnterpriseIDFromContext(ctx) - if enterpriseID == 0 { - return nil, gorm.ErrRecordNotFound - } - baseSQL += " AND a.enterprise_id = ?" - args = append(args, enterpriseID) - } else { + userType := middleware.GetUserTypeFromContext(ctx) + if userType != constants.UserTypeSuperAdmin && userType != constants.UserTypePlatform { + if userType == constants.UserTypeAgent { + // 代理用户:只能看到自己及下级店铺所拥有企业的授权记录 + shopIDs := middleware.GetSubordinateShopIDs(ctx) + if len(shopIDs) == 0 { return nil, gorm.ErrRecordNotFound } + baseSQL += " AND a.enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id IN (?) AND deleted_at IS NULL)" + args = append(args, shopIDs) + } else if userType == constants.UserTypeEnterprise { + enterpriseID := middleware.GetEnterpriseIDFromContext(ctx) + if enterpriseID == 0 { + return nil, gorm.ErrRecordNotFound + } + baseSQL += " AND a.enterprise_id = ?" + args = append(args, enterpriseID) + } else { + return nil, gorm.ErrRecordNotFound } } @@ -401,7 +412,10 @@ func (s *EnterpriseCardAuthorizationStore) UpdateRemarkWithConstraint(ctx contex func (s *EnterpriseCardAuthorizationStore) GetByID(ctx context.Context, id uint) (*model.EnterpriseCardAuthorization, error) { var auth model.EnterpriseCardAuthorization - err := s.db.WithContext(ctx).Where("id = ?", id).First(&auth).Error + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤 + query = s.applyEnterpriseAuthFilter(ctx, query) + err := query.First(&auth).Error if err != nil { return nil, err } @@ -417,3 +431,23 @@ func (s *EnterpriseCardAuthorizationStore) RevokeByDeviceAuthID(ctx context.Cont "revoked_at": now, }).Error } + +// applyEnterpriseAuthFilter 应用企业卡授权表的数据权限过滤 +// 企业用户:只能看到自己企业的授权记录 +// 代理用户:只能看到自己及下级店铺所拥有企业的授权记录 +// 平台/超管:不过滤 +func (s *EnterpriseCardAuthorizationStore) applyEnterpriseAuthFilter(ctx context.Context, query *gorm.DB) *gorm.DB { + // 企业用户过滤 + query = middleware.ApplyEnterpriseFilter(ctx, query) + + // 代理用户:通过企业的 owner_shop_id 过滤 + userType := middleware.GetUserTypeFromContext(ctx) + if userType == constants.UserTypeAgent { + shopIDs := middleware.GetSubordinateShopIDs(ctx) + if shopIDs != nil { + query = query.Where("enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id IN ? AND deleted_at IS NULL)", shopIDs) + } + } + + return query +} diff --git a/internal/store/postgres/enterprise_device_authorization_store.go b/internal/store/postgres/enterprise_device_authorization_store.go index 0aa6971..4445598 100644 --- a/internal/store/postgres/enterprise_device_authorization_store.go +++ b/internal/store/postgres/enterprise_device_authorization_store.go @@ -5,6 +5,7 @@ import ( "time" "github.com/break/junhong_cmp_fiber/internal/model" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -45,7 +46,10 @@ func (s *EnterpriseDeviceAuthorizationStore) BatchCreate(ctx context.Context, au func (s *EnterpriseDeviceAuthorizationStore) GetByID(ctx context.Context, id uint) (*model.EnterpriseDeviceAuthorization, error) { var auth model.EnterpriseDeviceAuthorization - err := s.db.WithContext(ctx).Where("id = ?", id).First(&auth).Error + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用企业数据权限过滤 + query = middleware.ApplyEnterpriseFilter(ctx, query) + err := query.First(&auth).Error if err != nil { return nil, err } @@ -54,9 +58,11 @@ func (s *EnterpriseDeviceAuthorizationStore) GetByID(ctx context.Context, id uin func (s *EnterpriseDeviceAuthorizationStore) GetByDeviceID(ctx context.Context, deviceID uint) (*model.EnterpriseDeviceAuthorization, error) { var auth model.EnterpriseDeviceAuthorization - err := s.db.WithContext(ctx). - Where("device_id = ? AND revoked_at IS NULL", deviceID). - First(&auth).Error + query := s.db.WithContext(ctx). + Where("device_id = ? AND revoked_at IS NULL", deviceID) + // 应用企业数据权限过滤 + query = middleware.ApplyEnterpriseFilter(ctx, query) + err := query.First(&auth).Error if err != nil { return nil, err } @@ -66,6 +72,8 @@ func (s *EnterpriseDeviceAuthorizationStore) GetByDeviceID(ctx context.Context, func (s *EnterpriseDeviceAuthorizationStore) GetByEnterpriseID(ctx context.Context, enterpriseID uint, includeRevoked bool) ([]*model.EnterpriseDeviceAuthorization, error) { var auths []*model.EnterpriseDeviceAuthorization query := s.db.WithContext(ctx).Where("enterprise_id = ?", enterpriseID) + // 应用企业数据权限过滤 + query = middleware.ApplyEnterpriseFilter(ctx, query) if !includeRevoked { query = query.Where("revoked_at IS NULL") } @@ -87,6 +95,8 @@ func (s *EnterpriseDeviceAuthorizationStore) ListByEnterprise(ctx context.Contex var total int64 query := s.db.WithContext(ctx).Model(&model.EnterpriseDeviceAuthorization{}) + // 应用企业数据权限过滤 + query = middleware.ApplyEnterpriseFilter(ctx, query) if opts.EnterpriseID != nil { query = query.Where("enterprise_id = ?", *opts.EnterpriseID) @@ -134,10 +144,12 @@ func (s *EnterpriseDeviceAuthorizationStore) GetActiveAuthsByDeviceIDs(ctx conte } var auths []model.EnterpriseDeviceAuthorization - err := s.db.WithContext(ctx). + query := s.db.WithContext(ctx). Select("device_id"). - Where("enterprise_id = ? AND device_id IN ? AND revoked_at IS NULL", enterpriseID, deviceIDs). - Find(&auths).Error + Where("enterprise_id = ? AND device_id IN ? AND revoked_at IS NULL", enterpriseID, deviceIDs) + // 应用企业数据权限过滤 + query = middleware.ApplyEnterpriseFilter(ctx, query) + err := query.Find(&auths).Error if err != nil { return nil, err @@ -152,9 +164,11 @@ func (s *EnterpriseDeviceAuthorizationStore) GetActiveAuthsByDeviceIDs(ctx conte func (s *EnterpriseDeviceAuthorizationStore) ListDeviceIDsByEnterprise(ctx context.Context, enterpriseID uint) ([]uint, error) { var deviceIDs []uint - err := s.db.WithContext(ctx). + query := s.db.WithContext(ctx). Model(&model.EnterpriseDeviceAuthorization{}). - Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID). - Pluck("device_id", &deviceIDs).Error + Where("enterprise_id = ? AND revoked_at IS NULL", enterpriseID) + // 应用企业数据权限过滤 + query = middleware.ApplyEnterpriseFilter(ctx, query) + err := query.Pluck("device_id", &deviceIDs).Error return deviceIDs, err } diff --git a/internal/store/postgres/enterprise_store.go b/internal/store/postgres/enterprise_store.go index 5fba63b..6380454 100644 --- a/internal/store/postgres/enterprise_store.go +++ b/internal/store/postgres/enterprise_store.go @@ -6,6 +6,7 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/pkg/constants" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -32,7 +33,10 @@ func (s *EnterpriseStore) Create(ctx context.Context, enterprise *model.Enterpri // GetByID 根据 ID 获取企业 func (s *EnterpriseStore) GetByID(ctx context.Context, id uint) (*model.Enterprise, error) { var enterprise model.Enterprise - if err := s.db.WithContext(ctx).First(&enterprise, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用归属店铺数据权限过滤 + query = middleware.ApplyOwnerShopFilter(ctx, query) + if err := query.First(&enterprise).Error; err != nil { return nil, err } return &enterprise, nil @@ -41,7 +45,10 @@ func (s *EnterpriseStore) GetByID(ctx context.Context, id uint) (*model.Enterpri // GetByCode 根据企业编号获取企业 func (s *EnterpriseStore) GetByCode(ctx context.Context, code string) (*model.Enterprise, error) { var enterprise model.Enterprise - if err := s.db.WithContext(ctx).Where("enterprise_code = ?", code).First(&enterprise).Error; err != nil { + query := s.db.WithContext(ctx).Where("enterprise_code = ?", code) + // 应用归属店铺数据权限过滤 + query = middleware.ApplyOwnerShopFilter(ctx, query) + if err := query.First(&enterprise).Error; err != nil { return nil, err } return &enterprise, nil @@ -63,6 +70,8 @@ func (s *EnterpriseStore) List(ctx context.Context, opts *store.QueryOptions, fi var total int64 query := s.db.WithContext(ctx).Model(&model.Enterprise{}) + // 应用归属店铺数据权限过滤 + query = middleware.ApplyOwnerShopFilter(ctx, query) // 应用过滤条件 if enterpriseName, ok := filters["enterprise_name"].(string); ok && enterpriseName != "" { @@ -111,7 +120,10 @@ func (s *EnterpriseStore) List(ctx context.Context, opts *store.QueryOptions, fi // GetByOwnerShopID 根据归属店铺 ID 查询企业列表 func (s *EnterpriseStore) GetByOwnerShopID(ctx context.Context, ownerShopID uint) ([]*model.Enterprise, error) { var enterprises []*model.Enterprise - if err := s.db.WithContext(ctx).Where("owner_shop_id = ?", ownerShopID).Find(&enterprises).Error; err != nil { + query := s.db.WithContext(ctx).Where("owner_shop_id = ?", ownerShopID) + // 应用归属店铺数据权限过滤 + query = middleware.ApplyOwnerShopFilter(ctx, query) + if err := query.Find(&enterprises).Error; err != nil { return nil, err } return enterprises, nil @@ -120,7 +132,10 @@ func (s *EnterpriseStore) GetByOwnerShopID(ctx context.Context, ownerShopID uint // GetPlatformEnterprises 获取平台直属企业列表(owner_shop_id 为 NULL) func (s *EnterpriseStore) GetPlatformEnterprises(ctx context.Context) ([]*model.Enterprise, error) { var enterprises []*model.Enterprise - if err := s.db.WithContext(ctx).Where("owner_shop_id IS NULL").Find(&enterprises).Error; err != nil { + query := s.db.WithContext(ctx).Where("owner_shop_id IS NULL") + // 应用归属店铺数据权限过滤(代理用户无法看到平台直属企业) + query = middleware.ApplyOwnerShopFilter(ctx, query) + if err := query.Find(&enterprises).Error; err != nil { return nil, err } return enterprises, nil @@ -132,7 +147,10 @@ func (s *EnterpriseStore) GetByIDs(ctx context.Context, ids []uint) ([]*model.En return []*model.Enterprise{}, nil } var enterprises []*model.Enterprise - if err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&enterprises).Error; err != nil { + query := s.db.WithContext(ctx).Where("id IN ?", ids) + // 应用归属店铺数据权限过滤 + query = middleware.ApplyOwnerShopFilter(ctx, query) + if err := query.Find(&enterprises).Error; err != nil { return nil, err } return enterprises, nil diff --git a/internal/store/postgres/iot_card_import_task_store.go b/internal/store/postgres/iot_card_import_task_store.go index 5d3f1fc..3226c3a 100644 --- a/internal/store/postgres/iot_card_import_task_store.go +++ b/internal/store/postgres/iot_card_import_task_store.go @@ -8,6 +8,7 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/pkg/constants" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -30,7 +31,10 @@ func (s *IotCardImportTaskStore) Create(ctx context.Context, task *model.IotCard func (s *IotCardImportTaskStore) GetByID(ctx context.Context, id uint) (*model.IotCardImportTask, error) { var task model.IotCardImportTask - if err := s.db.WithContext(ctx).First(&task, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&task).Error; err != nil { return nil, err } return &task, nil @@ -38,7 +42,10 @@ func (s *IotCardImportTaskStore) GetByID(ctx context.Context, id uint) (*model.I func (s *IotCardImportTaskStore) GetByTaskNo(ctx context.Context, taskNo string) (*model.IotCardImportTask, error) { var task model.IotCardImportTask - if err := s.db.WithContext(ctx).Where("task_no = ?", taskNo).First(&task).Error; err != nil { + query := s.db.WithContext(ctx).Where("task_no = ?", taskNo) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&task).Error; err != nil { return nil, err } return &task, nil @@ -82,6 +89,8 @@ func (s *IotCardImportTaskStore) List(ctx context.Context, opts *store.QueryOpti var total int64 query := s.db.WithContext(ctx).Model(&model.IotCardImportTask{}) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if status, ok := filters["status"].(int); ok && status > 0 { query = query.Where("status = ?", status) diff --git a/internal/store/postgres/iot_card_store.go b/internal/store/postgres/iot_card_store.go index 7a8fd74..9746bdf 100644 --- a/internal/store/postgres/iot_card_store.go +++ b/internal/store/postgres/iot_card_store.go @@ -11,7 +11,6 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/pkg/constants" - pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm" "github.com/break/junhong_cmp_fiber/pkg/logger" "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" @@ -46,7 +45,10 @@ func (s *IotCardStore) CreateBatch(ctx context.Context, cards []*model.IotCard) func (s *IotCardStore) GetByID(ctx context.Context, id uint) (*model.IotCard, error) { var card model.IotCard - if err := s.db.WithContext(ctx).First(&card, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤(NULL shop_id 对代理用户不可见) + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&card).Error; err != nil { return nil, err } return &card, nil @@ -54,7 +56,10 @@ func (s *IotCardStore) GetByID(ctx context.Context, id uint) (*model.IotCard, er func (s *IotCardStore) GetByICCID(ctx context.Context, iccid string) (*model.IotCard, error) { var card model.IotCard - if err := s.db.WithContext(ctx).Where("iccid = ?", iccid).First(&card).Error; err != nil { + query := s.db.WithContext(ctx).Where("iccid = ?", iccid) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&card).Error; err != nil { return nil, err } return &card, nil @@ -65,7 +70,10 @@ func (s *IotCardStore) GetByIDs(ctx context.Context, ids []uint) ([]*model.IotCa return []*model.IotCard{}, nil } var cards []*model.IotCard - if err := s.db.WithContext(ctx).Where("id IN ?", ids).Find(&cards).Error; err != nil { + query := s.db.WithContext(ctx).Where("id IN ?", ids) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&cards).Error; err != nil { return nil, err } return cards, nil @@ -111,13 +119,15 @@ func (s *IotCardStore) List(ctx context.Context, opts *store.QueryOptions, filte var total int64 query := s.db.WithContext(ctx).Model(&model.IotCard{}) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) // 企业用户特殊处理:只能看到授权给自己的卡 // 子查询跳过数据权限过滤,权限已由外层查询的 GORM callback 保证 - skipCtx := pkggorm.SkipDataPermission(ctx) + // 子查询无需数据权限过滤(在不同表上执行) if enterpriseID, ok := filters["authorized_enterprise_id"].(uint); ok && enterpriseID > 0 { query = query.Where("id IN (?)", - s.db.WithContext(skipCtx).Table("tb_enterprise_card_authorization"). + s.db.WithContext(ctx).Table("tb_enterprise_card_authorization"). Select("card_id"). Where("enterprise_id = ? AND revoked_at IS NULL AND deleted_at IS NULL", enterpriseID)) } @@ -143,7 +153,7 @@ func (s *IotCardStore) List(ctx context.Context, opts *store.QueryOptions, filte } if packageID, ok := filters["package_id"].(uint); ok && packageID > 0 { query = query.Where("id IN (?)", - s.db.WithContext(skipCtx).Table("tb_package_usage"). + s.db.WithContext(ctx).Table("tb_package_usage"). Select("iot_card_id"). Where("package_id = ? AND deleted_at IS NULL", packageID)) } @@ -249,6 +259,8 @@ func (s *IotCardStore) listStandaloneTwoPhase(ctx context.Context, opts *store.Q query := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("is_standalone = true") + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) query = s.applyStandaloneFilters(ctx, query, filters) if cachedTotal, ok := s.getCachedCount(ctx, "standalone", filters); ok { @@ -309,6 +321,8 @@ func (s *IotCardStore) listStandaloneDefault(ctx context.Context, opts *store.Qu query := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("is_standalone = true") + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) query = s.applyStandaloneFilters(ctx, query, filters) if cachedTotal, ok := s.getCachedCount(ctx, "standalone", filters); ok { @@ -339,7 +353,7 @@ func (s *IotCardStore) listStandaloneDefault(ctx context.Context, opts *store.Qu // 将 shop_id IN (...) 拆分为 per-shop 独立查询,每个查询走 Index Scan // 然后在应用层归并排序,避免 PG 对多值 IN + ORDER BY 选择全表扫描 func (s *IotCardStore) listStandaloneParallel(ctx context.Context, opts *store.QueryOptions, filters map[string]any, shopIDs []uint) ([]*model.IotCard, int64, error) { - skipCtx := pkggorm.SkipDataPermission(ctx) + // 子查询无需数据权限过滤(在不同表上执行) fetchLimit := (opts.Page-1)*opts.PageSize + opts.PageSize @@ -366,9 +380,9 @@ func (s *IotCardStore) listStandaloneParallel(ctx context.Context, opts *store.Q go func(idx int, sid uint) { defer wg.Done() - q := s.db.WithContext(skipCtx).Model(&model.IotCard{}). + q := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid) - q = s.applyStandaloneFilters(skipCtx, q, filters) + q = s.applyStandaloneFilters(ctx, q, filters) var cards []*model.IotCard if err := q.Select(standaloneListColumns). @@ -381,9 +395,9 @@ func (s *IotCardStore) listStandaloneParallel(ctx context.Context, opts *store.Q var count int64 if !hasCachedTotal { - countQ := s.db.WithContext(skipCtx).Model(&model.IotCard{}). + countQ := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid) - countQ = s.applyStandaloneFilters(skipCtx, countQ, filters) + countQ = s.applyStandaloneFilters(ctx, countQ, filters) if err := countQ.Count(&count).Error; err != nil { results[idx] = shopResult{err: err} return @@ -455,7 +469,7 @@ type cardIDWithTime struct { // 归并排序后取目标页的 20 个 ID // Phase 2: SELECT 完整列 WHERE id IN (20 IDs)(PK 精确回表) func (s *IotCardStore) listStandaloneParallelTwoPhase(ctx context.Context, opts *store.QueryOptions, filters map[string]any, shopIDs []uint) ([]*model.IotCard, int64, error) { - skipCtx := pkggorm.SkipDataPermission(ctx) + // 子查询无需数据权限过滤(在不同表上执行) fetchLimit := (opts.Page-1)*opts.PageSize + opts.PageSize @@ -476,9 +490,9 @@ func (s *IotCardStore) listStandaloneParallelTwoPhase(ctx context.Context, opts go func(idx int, sid uint) { defer wg.Done() - q := s.db.WithContext(skipCtx).Model(&model.IotCard{}). + q := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid) - q = s.applyStandaloneFilters(skipCtx, q, filters) + q = s.applyStandaloneFilters(ctx, q, filters) var ids []cardIDWithTime if err := q.Select("id, created_at"). @@ -491,9 +505,9 @@ func (s *IotCardStore) listStandaloneParallelTwoPhase(ctx context.Context, opts var count int64 if !hasCachedTotal { - countQ := s.db.WithContext(skipCtx).Model(&model.IotCard{}). + countQ := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("is_standalone = true AND deleted_at IS NULL AND shop_id = ?", sid) - countQ = s.applyStandaloneFilters(skipCtx, countQ, filters) + countQ = s.applyStandaloneFilters(ctx, countQ, filters) if err := countQ.Count(&count).Error; err != nil { results[idx] = shopResult{err: err} return @@ -553,7 +567,7 @@ func (s *IotCardStore) listStandaloneParallelTwoPhase(ctx context.Context, opts // Phase 2: 用 ID 精确回表获取完整数据(PK Index Scan,仅 20 行) var cards []*model.IotCard - if err := s.db.WithContext(skipCtx).Model(&model.IotCard{}). + if err := s.db.WithContext(ctx).Model(&model.IotCard{}). Select(standaloneListColumns). Where("id IN ?", pageIDs). Find(&cards).Error; err != nil { @@ -584,7 +598,7 @@ func (s *IotCardStore) listStandaloneParallelTwoPhase(ctx context.Context, opts // 注意:不包含 is_standalone、shop_id、deleted_at 条件(由调用方控制) // 也不包含 subordinate_shop_ids(仅用于路由选择,不作为查询条件) func (s *IotCardStore) applyStandaloneFilters(ctx context.Context, query *gorm.DB, filters map[string]any) *gorm.DB { - skipCtx := pkggorm.SkipDataPermission(ctx) + // 子查询无需数据权限过滤(在不同表上执行) if status, ok := filters["status"].(int); ok && status > 0 { query = query.Where("status = ?", status) @@ -607,7 +621,7 @@ func (s *IotCardStore) applyStandaloneFilters(ctx context.Context, query *gorm.D } if packageID, ok := filters["package_id"].(uint); ok && packageID > 0 { query = query.Where("id IN (?)", - s.db.WithContext(skipCtx).Table("tb_package_usage"). + s.db.WithContext(ctx).Table("tb_package_usage"). Select("iot_card_id"). Where("package_id = ? AND deleted_at IS NULL", packageID)) } @@ -627,12 +641,12 @@ func (s *IotCardStore) applyStandaloneFilters(ctx context.Context, query *gorm.D if isReplaced, ok := filters["is_replaced"].(bool); ok { if isReplaced { query = query.Where("id IN (?)", - s.db.WithContext(skipCtx).Table("tb_card_replacement_record"). + s.db.WithContext(ctx).Table("tb_card_replacement_record"). Select("old_iot_card_id"). Where("deleted_at IS NULL")) } else { query = query.Where("id NOT IN (?)", - s.db.WithContext(skipCtx).Table("tb_card_replacement_record"). + s.db.WithContext(ctx).Table("tb_card_replacement_record"). Select("old_iot_card_id"). Where("deleted_at IS NULL")) } @@ -649,7 +663,10 @@ func (s *IotCardStore) GetByICCIDs(ctx context.Context, iccids []string) ([]*mod return nil, nil } var cards []*model.IotCard - if err := s.db.WithContext(ctx).Where("iccid IN ?", iccids).Find(&cards).Error; err != nil { + query := s.db.WithContext(ctx).Where("iccid IN ?", iccids) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&cards).Error; err != nil { return nil, err } return cards, nil @@ -659,6 +676,8 @@ func (s *IotCardStore) GetStandaloneByICCIDRange(ctx context.Context, iccidStart query := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("is_standalone = true"). Where("iccid >= ? AND iccid <= ?", iccidStart, iccidEnd) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if shopID == nil { query = query.Where("shop_id IS NULL") @@ -676,11 +695,13 @@ func (s *IotCardStore) GetStandaloneByICCIDRange(ctx context.Context, iccidStart // GetDistributedStandaloneByICCIDRange 根据号段范围查询已分配给店铺的单卡(用于回收) func (s *IotCardStore) GetDistributedStandaloneByICCIDRange(ctx context.Context, iccidStart, iccidEnd string) ([]*model.IotCard, error) { var cards []*model.IotCard - if err := s.db.WithContext(ctx).Model(&model.IotCard{}). + query := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("is_standalone = true"). Where("shop_id IS NOT NULL"). - Where("iccid >= ? AND iccid <= ?", iccidStart, iccidEnd). - Find(&cards).Error; err != nil { + Where("iccid >= ? AND iccid <= ?", iccidStart, iccidEnd) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&cards).Error; err != nil { return nil, err } return cards, nil @@ -689,6 +710,8 @@ func (s *IotCardStore) GetDistributedStandaloneByICCIDRange(ctx context.Context, func (s *IotCardStore) GetStandaloneByFilters(ctx context.Context, filters map[string]any, shopID *uint) ([]*model.IotCard, error) { query := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("is_standalone = true") + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if shopID == nil { query = query.Where("shop_id IS NULL") @@ -718,6 +741,8 @@ func (s *IotCardStore) GetDistributedStandaloneByFilters(ctx context.Context, fi query := s.db.WithContext(ctx).Model(&model.IotCard{}). Where("is_standalone = true"). Where("shop_id IS NOT NULL") + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if carrierID, ok := filters["carrier_id"].(uint); ok && carrierID > 0 { query = query.Where("carrier_id = ?", carrierID) @@ -764,10 +789,10 @@ func (s *IotCardStore) GetByIDsWithEnterpriseFilter(ctx context.Context, cardIDs query := s.db.WithContext(ctx).Model(&model.IotCard{}) if enterpriseID != nil && *enterpriseID > 0 { - skipCtx := pkggorm.SkipDataPermission(ctx) + // 子查询无需数据权限过滤(在不同表上执行) query = query.Where("id IN (?) AND id IN (?)", cardIDs, - s.db.WithContext(skipCtx).Table("tb_enterprise_card_authorization"). + s.db.WithContext(ctx).Table("tb_enterprise_card_authorization"). Select("card_id"). Where("enterprise_id = ? AND revoked_at IS NULL AND deleted_at IS NULL", *enterpriseID)) } else { @@ -796,7 +821,10 @@ func (s *IotCardStore) BatchUpdateSeriesID(ctx context.Context, cardIDs []uint, // 用于查询某个套餐系列下的所有卡 func (s *IotCardStore) ListBySeriesID(ctx context.Context, seriesID uint) ([]*model.IotCard, error) { var cards []*model.IotCard - if err := s.db.WithContext(ctx).Where("series_id = ?", seriesID).Find(&cards).Error; err != nil { + query := s.db.WithContext(ctx).Where("series_id = ?", seriesID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&cards).Error; err != nil { return nil, err } return cards, nil diff --git a/internal/store/postgres/order_store.go b/internal/store/postgres/order_store.go index 1df9796..d5c9fdd 100644 --- a/internal/store/postgres/order_store.go +++ b/internal/store/postgres/order_store.go @@ -8,6 +8,7 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -43,7 +44,10 @@ func (s *OrderStore) Create(ctx context.Context, order *model.Order, items []*mo func (s *OrderStore) GetByID(ctx context.Context, id uint) (*model.Order, error) { var order model.Order - if err := s.db.WithContext(ctx).First(&order, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤(使用 seller_shop_id 字段) + query = middleware.ApplySellerShopFilter(ctx, query) + if err := query.First(&order).Error; err != nil { return nil, err } return &order, nil @@ -51,7 +55,10 @@ func (s *OrderStore) GetByID(ctx context.Context, id uint) (*model.Order, error) func (s *OrderStore) GetByIDWithItems(ctx context.Context, id uint) (*model.Order, []*model.OrderItem, error) { var order model.Order - if err := s.db.WithContext(ctx).First(&order, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤(使用 seller_shop_id 字段) + query = middleware.ApplySellerShopFilter(ctx, query) + if err := query.First(&order).Error; err != nil { return nil, nil, err } @@ -65,7 +72,10 @@ func (s *OrderStore) GetByIDWithItems(ctx context.Context, id uint) (*model.Orde func (s *OrderStore) GetByOrderNo(ctx context.Context, orderNo string) (*model.Order, error) { var order model.Order - if err := s.db.WithContext(ctx).Where("order_no = ?", orderNo).First(&order).Error; err != nil { + query := s.db.WithContext(ctx).Where("order_no = ?", orderNo) + // 应用数据权限过滤(使用 seller_shop_id 字段) + query = middleware.ApplySellerShopFilter(ctx, query) + if err := query.First(&order).Error; err != nil { return nil, err } return &order, nil @@ -80,6 +90,8 @@ func (s *OrderStore) List(ctx context.Context, opts *store.QueryOptions, filters var total int64 query := s.db.WithContext(ctx).Model(&model.Order{}) + // 应用数据权限过滤(使用 seller_shop_id 字段) + query = middleware.ApplySellerShopFilter(ctx, query) if v, ok := filters["payment_status"]; ok { query = query.Where("payment_status = ?", v) diff --git a/internal/store/postgres/shop_package_allocation_store.go b/internal/store/postgres/shop_package_allocation_store.go index db5e8d2..682f541 100644 --- a/internal/store/postgres/shop_package_allocation_store.go +++ b/internal/store/postgres/shop_package_allocation_store.go @@ -5,6 +5,7 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "gorm.io/gorm" ) @@ -22,7 +23,10 @@ func (s *ShopPackageAllocationStore) Create(ctx context.Context, allocation *mod func (s *ShopPackageAllocationStore) GetByID(ctx context.Context, id uint) (*model.ShopPackageAllocation, error) { var allocation model.ShopPackageAllocation - if err := s.db.WithContext(ctx).First(&allocation, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&allocation).Error; err != nil { return nil, err } return &allocation, nil @@ -30,7 +34,10 @@ func (s *ShopPackageAllocationStore) GetByID(ctx context.Context, id uint) (*mod func (s *ShopPackageAllocationStore) GetByShopAndPackage(ctx context.Context, shopID, packageID uint) (*model.ShopPackageAllocation, error) { var allocation model.ShopPackageAllocation - if err := s.db.WithContext(ctx).Where("shop_id = ? AND package_id = ?", shopID, packageID).First(&allocation).Error; err != nil { + query := s.db.WithContext(ctx).Where("shop_id = ? AND package_id = ?", shopID, packageID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&allocation).Error; err != nil { return nil, err } return &allocation, nil @@ -49,6 +56,8 @@ func (s *ShopPackageAllocationStore) List(ctx context.Context, opts *store.Query var total int64 query := s.db.WithContext(ctx).Model(&model.ShopPackageAllocation{}) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if shopID, ok := filters["shop_id"].(uint); ok && shopID > 0 { query = query.Where("shop_id = ?", shopID) @@ -99,7 +108,10 @@ func (s *ShopPackageAllocationStore) UpdateStatus(ctx context.Context, id uint, func (s *ShopPackageAllocationStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.ShopPackageAllocation, error) { var allocations []*model.ShopPackageAllocation - if err := s.db.WithContext(ctx).Where("shop_id = ? AND status = 1", shopID).Find(&allocations).Error; err != nil { + query := s.db.WithContext(ctx).Where("shop_id = ? AND status = 1", shopID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&allocations).Error; err != nil { return nil, err } return allocations, nil @@ -107,9 +119,11 @@ func (s *ShopPackageAllocationStore) GetByShopID(ctx context.Context, shopID uin func (s *ShopPackageAllocationStore) GetByShopAndPackages(ctx context.Context, shopID uint, packageIDs []uint) ([]*model.ShopPackageAllocation, error) { var allocations []*model.ShopPackageAllocation - if err := s.db.WithContext(ctx). - Where("shop_id = ? AND package_id IN ? AND status = 1", shopID, packageIDs). - Find(&allocations).Error; err != nil { + query := s.db.WithContext(ctx). + Where("shop_id = ? AND package_id IN ? AND status = 1", shopID, packageIDs) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&allocations).Error; err != nil { return nil, err } return allocations, nil @@ -117,9 +131,11 @@ func (s *ShopPackageAllocationStore) GetByShopAndPackages(ctx context.Context, s func (s *ShopPackageAllocationStore) GetBySeriesAllocationID(ctx context.Context, seriesAllocationID uint) ([]*model.ShopPackageAllocation, error) { var allocations []*model.ShopPackageAllocation - if err := s.db.WithContext(ctx). - Where("series_allocation_id = ? AND status = 1", seriesAllocationID). - Find(&allocations).Error; err != nil { + query := s.db.WithContext(ctx). + Where("series_allocation_id = ? AND status = 1", seriesAllocationID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&allocations).Error; err != nil { return nil, err } return allocations, nil diff --git a/internal/store/postgres/shop_role_store.go b/internal/store/postgres/shop_role_store.go index 6f7be1b..2c69732 100644 --- a/internal/store/postgres/shop_role_store.go +++ b/internal/store/postgres/shop_role_store.go @@ -5,6 +5,7 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/pkg/constants" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -62,9 +63,10 @@ func (s *ShopRoleStore) DeleteByShopID(ctx context.Context, shopID uint) error { func (s *ShopRoleStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.ShopRole, error) { var srs []*model.ShopRole - if err := s.db.WithContext(ctx). - Where("shop_id = ?", shopID). - Find(&srs).Error; err != nil { + query := s.db.WithContext(ctx).Where("shop_id = ?", shopID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&srs).Error; err != nil { return nil, err } return srs, nil @@ -72,10 +74,12 @@ func (s *ShopRoleStore) GetByShopID(ctx context.Context, shopID uint) ([]*model. func (s *ShopRoleStore) GetRoleIDsByShopID(ctx context.Context, shopID uint) ([]uint, error) { var roleIDs []uint - if err := s.db.WithContext(ctx). + query := s.db.WithContext(ctx). Model(&model.ShopRole{}). - Where("shop_id = ?", shopID). - Pluck("role_id", &roleIDs).Error; err != nil { + Where("shop_id = ?", shopID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Pluck("role_id", &roleIDs).Error; err != nil { return nil, err } return roleIDs, nil diff --git a/internal/store/postgres/shop_series_allocation_store.go b/internal/store/postgres/shop_series_allocation_store.go index 47f72ab..491a255 100644 --- a/internal/store/postgres/shop_series_allocation_store.go +++ b/internal/store/postgres/shop_series_allocation_store.go @@ -5,6 +5,7 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store" + "github.com/break/junhong_cmp_fiber/pkg/middleware" "gorm.io/gorm" ) @@ -22,7 +23,10 @@ func (s *ShopSeriesAllocationStore) Create(ctx context.Context, allocation *mode func (s *ShopSeriesAllocationStore) GetByID(ctx context.Context, id uint) (*model.ShopSeriesAllocation, error) { var allocation model.ShopSeriesAllocation - if err := s.db.WithContext(ctx).First(&allocation, id).Error; err != nil { + query := s.db.WithContext(ctx).Where("id = ?", id) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&allocation).Error; err != nil { return nil, err } return &allocation, nil @@ -30,9 +34,11 @@ func (s *ShopSeriesAllocationStore) GetByID(ctx context.Context, id uint) (*mode func (s *ShopSeriesAllocationStore) GetByShopAndSeries(ctx context.Context, shopID, seriesID uint) (*model.ShopSeriesAllocation, error) { var allocation model.ShopSeriesAllocation - if err := s.db.WithContext(ctx). - Where("shop_id = ? AND series_id = ?", shopID, seriesID). - First(&allocation).Error; err != nil { + query := s.db.WithContext(ctx). + Where("shop_id = ? AND series_id = ?", shopID, seriesID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.First(&allocation).Error; err != nil { return nil, err } return &allocation, nil @@ -51,6 +57,8 @@ func (s *ShopSeriesAllocationStore) List(ctx context.Context, opts *store.QueryO var total int64 query := s.db.WithContext(ctx).Model(&model.ShopSeriesAllocation{}) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) if shopID, ok := filters["shop_id"].(uint); ok && shopID > 0 { query = query.Where("shop_id = ?", shopID) @@ -100,9 +108,11 @@ func (s *ShopSeriesAllocationStore) UpdateStatus(ctx context.Context, id uint, s func (s *ShopSeriesAllocationStore) GetByShopID(ctx context.Context, shopID uint) ([]*model.ShopSeriesAllocation, error) { var allocations []*model.ShopSeriesAllocation - if err := s.db.WithContext(ctx). - Where("shop_id = ? AND status = 1", shopID). - Find(&allocations).Error; err != nil { + query := s.db.WithContext(ctx). + Where("shop_id = ? AND status = 1", shopID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&allocations).Error; err != nil { return nil, err } return allocations, nil @@ -132,9 +142,11 @@ func (s *ShopSeriesAllocationStore) ExistsByShopAndSeries(ctx context.Context, s func (s *ShopSeriesAllocationStore) GetByAllocatorShopID(ctx context.Context, allocatorShopID uint) ([]*model.ShopSeriesAllocation, error) { var allocations []*model.ShopSeriesAllocation - if err := s.db.WithContext(ctx). - Where("allocator_shop_id = ? AND status = 1", allocatorShopID). - Find(&allocations).Error; err != nil { + query := s.db.WithContext(ctx). + Where("allocator_shop_id = ? AND status = 1", allocatorShopID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Find(&allocations).Error; err != nil { return nil, err } return allocations, nil @@ -145,10 +157,12 @@ func (s *ShopSeriesAllocationStore) GetIDsByShopIDsAndSeries(ctx context.Context return nil, nil } var ids []uint - if err := s.db.WithContext(ctx). + query := s.db.WithContext(ctx). Model(&model.ShopSeriesAllocation{}). - Where("shop_id IN ? AND series_id = ? AND status = 1", shopIDs, seriesID). - Pluck("id", &ids).Error; err != nil { + Where("shop_id IN ? AND series_id = ? AND status = 1", shopIDs, seriesID) + // 应用数据权限过滤 + query = middleware.ApplyShopFilter(ctx, query) + if err := query.Pluck("id", &ids).Error; err != nil { return nil, err } return ids, nil diff --git a/internal/task/commission_calculation.go b/internal/task/commission_calculation.go index 933c2af..53c9b83 100644 --- a/internal/task/commission_calculation.go +++ b/internal/task/commission_calculation.go @@ -9,7 +9,6 @@ import ( "gorm.io/gorm" "github.com/break/junhong_cmp_fiber/internal/service/commission_calculation" - pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm" ) const ( @@ -39,8 +38,6 @@ func NewCommissionCalculationHandler( } func (h *CommissionCalculationHandler) HandleCommissionCalculation(ctx context.Context, task *asynq.Task) error { - ctx = pkggorm.SkipDataPermission(ctx) - var payload CommissionCalculationPayload if err := sonic.Unmarshal(task.Payload(), &payload); err != nil { h.logger.Error("解析佣金计算任务载荷失败", diff --git a/internal/task/commission_stats_archive.go b/internal/task/commission_stats_archive.go index da84346..f0062b0 100644 --- a/internal/task/commission_stats_archive.go +++ b/internal/task/commission_stats_archive.go @@ -12,7 +12,6 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/constants" - pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm" ) type CommissionStatsArchiveHandler struct { @@ -37,8 +36,6 @@ func NewCommissionStatsArchiveHandler( } func (h *CommissionStatsArchiveHandler) HandleCommissionStatsArchive(ctx context.Context, task *asynq.Task) error { - ctx = pkggorm.SkipDataPermission(ctx) - now := time.Now() lastMonthStart := now.AddDate(0, -1, 0) lastMonthStart = time.Date(lastMonthStart.Year(), lastMonthStart.Month(), 1, 0, 0, 0, 0, time.UTC) diff --git a/internal/task/commission_stats_sync.go b/internal/task/commission_stats_sync.go index ceaac98..7c9030a 100644 --- a/internal/task/commission_stats_sync.go +++ b/internal/task/commission_stats_sync.go @@ -14,7 +14,6 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/constants" - pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm" ) type CommissionStatsSyncHandler struct { @@ -39,8 +38,6 @@ func NewCommissionStatsSyncHandler( } func (h *CommissionStatsSyncHandler) HandleCommissionStatsSync(ctx context.Context, task *asynq.Task) error { - ctx = pkggorm.SkipDataPermission(ctx) - lockKey := constants.RedisCommissionStatsLockKey() locked, err := h.redis.SetNX(ctx, lockKey, "1", 5*time.Minute).Result() if err != nil { diff --git a/internal/task/commission_stats_update.go b/internal/task/commission_stats_update.go index f61dd89..f485130 100644 --- a/internal/task/commission_stats_update.go +++ b/internal/task/commission_stats_update.go @@ -11,7 +11,6 @@ import ( "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/constants" - pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm" ) type CommissionStatsUpdatePayload struct { @@ -42,8 +41,6 @@ func NewCommissionStatsUpdateHandler( } func (h *CommissionStatsUpdateHandler) HandleCommissionStatsUpdate(ctx context.Context, task *asynq.Task) error { - ctx = pkggorm.SkipDataPermission(ctx) - var payload CommissionStatsUpdatePayload if err := sonic.Unmarshal(task.Payload(), &payload); err != nil { h.logger.Error("解析统计更新任务载荷失败", diff --git a/internal/task/device_import.go b/internal/task/device_import.go index c3c922f..4f340c7 100644 --- a/internal/task/device_import.go +++ b/internal/task/device_import.go @@ -17,7 +17,6 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/constants" - pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm" "github.com/break/junhong_cmp_fiber/pkg/storage" "github.com/break/junhong_cmp_fiber/pkg/utils" ) @@ -62,8 +61,6 @@ func NewDeviceImportHandler( } func (h *DeviceImportHandler) HandleDeviceImport(ctx context.Context, task *asynq.Task) error { - ctx = pkggorm.SkipDataPermission(ctx) - var payload DeviceImportPayload if err := sonic.Unmarshal(task.Payload(), &payload); err != nil { h.logger.Error("解析设备导入任务载荷失败", diff --git a/internal/task/iot_card_import.go b/internal/task/iot_card_import.go index 53f0aa7..92c3017 100644 --- a/internal/task/iot_card_import.go +++ b/internal/task/iot_card_import.go @@ -17,7 +17,6 @@ import ( "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/internal/store/postgres" "github.com/break/junhong_cmp_fiber/pkg/constants" - pkggorm "github.com/break/junhong_cmp_fiber/pkg/gorm" "github.com/break/junhong_cmp_fiber/pkg/storage" "github.com/break/junhong_cmp_fiber/pkg/utils" "github.com/break/junhong_cmp_fiber/pkg/validator" @@ -72,8 +71,6 @@ func NewIotCardImportHandler( } func (h *IotCardImportHandler) HandleIotCardImport(ctx context.Context, task *asynq.Task) error { - ctx = pkggorm.SkipDataPermission(ctx) - var payload IotCardImportPayload if err := sonic.Unmarshal(task.Payload(), &payload); err != nil { h.logger.Error("解析 IoT 卡导入任务载荷失败", diff --git a/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/.openspec.yaml b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/.openspec.yaml new file mode 100644 index 0000000..85ae75c --- /dev/null +++ b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-02-26 diff --git a/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/design.md b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/design.md new file mode 100644 index 0000000..40a7b99 --- /dev/null +++ b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/design.md @@ -0,0 +1,330 @@ +# Design: refactor-data-permission-filter + +## Context + +当前系统使用 GORM Callback 在查询前自动注入数据权限过滤条件。该机制存在以下问题: + +1. **隐式行为**:开发者不知道查询被加了什么条件,SQL 调试困难 +2. **跳过率高**:20+ 处使用 `SkipDataPermission`,说明自动过滤不适用于大量场景 +3. **特殊处理多**:`tb_shop`、`tb_tag`、`tb_enterprise_card_authorization` 等需要特殊逻辑 +4. **重复查询**:同一请求中 Service 层 `CanManageShop` 和 Callback 都调用 `GetSubordinateShopIDs` +5. **原生 SQL 失效**:Callback 无法处理原生 SQL,需要在 Store 里手动重复写过滤逻辑 + +**涉及的 Store 统计:** +- 需要改动:16 个 Store、85+ 个查询方法 +- 无需改动:32 个 Store(系统全局表、关联表等) + +## Goals / Non-Goals + +**Goals:** +- 数据权限过滤行为显式可控,便于调试 +- 单次请求内下级店铺 ID 只查询一次(中间件预计算) +- 保持现有的数据隔离行为不变(代理看自己+下级,企业看自己,平台看全部) +- NULL `shop_id` 记录对代理用户不可见(保持现有行为) + +**Non-Goals:** +- 不改变数据权限的业务规则 +- 不修改 Redis 缓存策略(仍保持 30 分钟过期) +- 不处理个人客户的数据权限(保持 `customer_id` / `creator` 字段的现有逻辑) + +## Decisions + +### Decision 1: 中间件预计算 vs 懒加载 + +**选择**:中间件预计算 + +**方案对比:** + +| 方案 | 优点 | 缺点 | +|------|------|------| +| 中间件预计算 | 单次请求只查一次;代码简单 | 所有请求都计算,即使不需要 | +| 懒加载(首次使用时计算) | 按需计算 | 需要加锁防止并发重复计算;代码复杂 | + +**理由**: +1. 绝大多数 API 都需要数据权限过滤,"不需要"是少数情况 +2. `GetSubordinateShopIDs` 有 Redis 缓存,命中率高,预计算开销小 +3. 代码更简单,不需要处理并发问题 + +### Decision 2: Helper 函数设计 + +**选择**:多个专用函数而非通用函数 + +```go +// 专用函数(选择) +func ApplyShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB +func ApplyEnterpriseFilter(ctx context.Context, query *gorm.DB) *gorm.DB +func ApplyOwnerShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB + +// 通用函数(否决) +func ApplyDataPermission(ctx context.Context, query *gorm.DB, field string) *gorm.DB +``` + +**理由**: +1. 字段名固定(`shop_id`、`enterprise_id`、`owner_shop_id`),无需参数化 +2. 专用函数调用更清晰,IDE 自动补全友好 +3. 不同字段的过滤逻辑可能有细微差异,专用函数更灵活 + +### Decision 3: UserContextInfo 扩展 vs 新增 DataScope 结构体 + +**选择**:扩展 UserContextInfo + +```go +// 扩展现有结构体(选择) +type UserContextInfo struct { + UserID uint + UserType int + ShopID uint + EnterpriseID uint + CustomerID uint + SubordinateShopIDs []uint // 新增 +} + +// 新增结构体(否决) +type DataScope struct { + SubordinateShopIDs []uint +} +``` + +**理由**: +1. 减少 Context 中的 key 数量 +2. 用户信息和权限范围本来就是强关联的 +3. 现有代码已经大量使用 `UserContextInfo`,扩展更自然 + +### Decision 4: AuthConfig 传入 ShopStore vs 全局注册 + +**选择**:AuthConfig 传入 + +```go +// AuthConfig 传入(选择) +type AuthConfig struct { + TokenExtractor func(c *fiber.Ctx) string + TokenValidator func(token string) (*UserContextInfo, error) + SkipPaths []string + ShopStore ShopStoreInterface // 新增 +} + +// 全局注册(否决) +var globalShopStore ShopStoreInterface +func RegisterShopStore(s ShopStoreInterface) { ... } +``` + +**理由**: +1. 显式依赖注入,便于测试 +2. 避免全局变量,符合 Go 最佳实践 +3. 与现有 `TokenValidator` 传入方式一致 + +### Decision 5: nil vs 空切片表示"不限制" + +**选择**:nil 表示不限制 + +```go +// nil 表示不限制(选择) +if shopIDs := GetSubordinateShopIDs(ctx); shopIDs != nil { + query = query.Where("shop_id IN ?", shopIDs) +} + +// 空切片表示不限制(否决) +if len(shopIDs) > 0 { + query = query.Where("shop_id IN ?", shopIDs) +} +``` + +**理由**: +1. 语义更清晰:nil = 未设置/不限制,`[]uint{}` = 设置了但列表为空 +2. 空切片 `WHERE shop_id IN ()` 在 SQL 中是无效语法 +3. 便于区分"平台用户不限制"和"代理用户但店铺列表为空(异常情况)" + +## Implementation + +### 文件结构 + +``` +pkg/middleware/ +├── auth.go # 扩展 UserContextInfo,修改 Auth 中间件 +├── data_scope.go # 新增:Helper 函数 +└── permission_helper.go # 修改:CanManageShop 等函数签名 + +pkg/gorm/ +└── callback.go # 移除 RegisterDataPermissionCallback +``` + +### 核心代码结构 + +**1. UserContextInfo 扩展(auth.go)** + +```go +type UserContextInfo struct { + UserID uint + UserType int + ShopID uint + EnterpriseID uint + CustomerID uint + SubordinateShopIDs []uint // 新增:代理用户的下级店铺ID列表,nil表示不限制 +} +``` + +**2. Auth 中间件改造(auth.go)** + +```go +func Auth(config AuthConfig) fiber.Handler { + return func(c *fiber.Ctx) error { + // ... 现有 token 验证逻辑 ... + + // 新增:预计算 SubordinateShopIDs + if config.ShopStore != nil && + userInfo.UserType == constants.UserTypeAgent && + userInfo.ShopID > 0 { + shopIDs, err := config.ShopStore.GetSubordinateShopIDs(c.UserContext(), userInfo.ShopID) + if err != nil { + // 降级处理 + shopIDs = []uint{userInfo.ShopID} + logger.Warn("获取下级店铺失败,降级为只包含自己", zap.Error(err)) + } + userInfo.SubordinateShopIDs = shopIDs + } + + SetUserToFiberContext(c, userInfo) + return c.Next() + } +} +``` + +**3. Helper 函数(data_scope.go)** + +```go +// GetSubordinateShopIDs 获取当前用户可管理的店铺ID列表 +// 返回 nil 表示不受限制(平台用户/超管) +func GetSubordinateShopIDs(ctx context.Context) []uint { + if ctx == nil { + return nil + } + if ids, ok := ctx.Value(constants.ContextKeySubordinateShopIDs).([]uint); ok { + return ids + } + return nil +} + +// ApplyShopFilter 应用店铺数据权限过滤 +// 平台用户/超管:不添加条件 +// 代理用户:WHERE shop_id IN (subordinateShopIDs) +func ApplyShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB { + shopIDs := GetSubordinateShopIDs(ctx) + if shopIDs == nil { + return query + } + return query.Where("shop_id IN ?", shopIDs) +} + +// ApplyEnterpriseFilter 应用企业数据权限过滤 +// 非企业用户:不添加条件 +// 企业用户:WHERE enterprise_id = ? +func ApplyEnterpriseFilter(ctx context.Context, query *gorm.DB) *gorm.DB { + userType := GetUserTypeFromContext(ctx) + if userType != constants.UserTypeEnterprise { + return query + } + enterpriseID := GetEnterpriseIDFromContext(ctx) + if enterpriseID == 0 { + return query.Where("1 = 0") // 企业用户但无企业ID,返回空 + } + return query.Where("enterprise_id = ?", enterpriseID) +} + +// ApplyOwnerShopFilter 应用归属店铺数据权限过滤 +// 用于 Enterprise 等使用 owner_shop_id 的表 +func ApplyOwnerShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB { + shopIDs := GetSubordinateShopIDs(ctx) + if shopIDs == nil { + return query + } + return query.Where("owner_shop_id IN ?", shopIDs) +} +``` + +**4. Store 层调用示例** + +```go +// 改造前 +func (s *DeviceStore) List(ctx context.Context, opts *QueryOptions) ([]*model.Device, error) { + query := s.db.WithContext(ctx).Model(&model.Device{}) + // ... GORM Callback 自动添加 WHERE shop_id IN (...) ... + return ... +} + +// 改造后 +func (s *DeviceStore) List(ctx context.Context, opts *QueryOptions) ([]*model.Device, error) { + query := s.db.WithContext(ctx).Model(&model.Device{}) + query = middleware.ApplyShopFilter(ctx, query) // 显式调用 + // ... + return ... +} +``` + +## Risks / Trade-offs + +### Risk 1: Store 层遗漏过滤调用 + +**风险**:改造过程中可能遗漏某些查询方法,导致数据泄露 + +**缓解措施**: +1. 按照 proposal 中的 Store 清单逐一检查 +2. 代码审查重点关注权限过滤 +3. 可考虑在开发环境添加检测中间件,对未过滤的敏感表查询打印告警 + +### Risk 2: 预计算开销 + +**风险**:每个请求都预计算 `SubordinateShopIDs`,即使某些请求不需要 + +**缓解措施**: +1. `GetSubordinateShopIDs` 有 Redis 缓存(30分钟),命中率高 +2. Redis 查询通常 < 1ms +3. 实际影响可忽略 + +### Risk 3: 改造期间的并行开发冲突 + +**风险**:改造涉及 16 个 Store 文件,可能与其他开发任务冲突 + +**缓解措施**: +1. 分阶段改造:先基础设施,再 Store 层 +2. 优先改造高复杂度 Store,降低后期风险 +3. 改造期间及时 rebase 和解决冲突 + +## Migration Plan + +### Phase 1: 基础设施(不影响现有功能) + +1. 扩展 `UserContextInfo`,添加 `SubordinateShopIDs` 字段 +2. 新增 `pkg/middleware/data_scope.go`,实现 Helper 函数 +3. 修改 Auth 中间件,预计算 `SubordinateShopIDs` +4. 此阶段 GORM Callback 仍然生效,两套机制并存 + +### Phase 2: 改造权限检查函数 + +1. 修改 `CanManageShop`,从 Context 获取数据,移除 `shopStore` 参数 +2. 修改 `CanManageEnterprise`,从 Context 获取数据 +3. 更新所有调用点 + +### Phase 3: Store 层改造(按复杂度分批) + +1. **低复杂度 Store(9 个)**:agent_wallet、commission_record 等 +2. **中复杂度 Store(4 个)**:device、order、shop_package_allocation 等 +3. **高复杂度 Store(3 个)**:iot_card、account、enterprise_card_authorization + +### Phase 4: 清理 + +1. 移除 `RegisterDataPermissionCallback` 及其调用 +2. 移除 `SkipDataPermission` 函数及所有调用点(20+ 处) +3. 移除 `pkg/gorm/callback.go` 中的相关代码 + +### Rollback Strategy + +如果发现问题,可以: +1. Phase 1-2 期间:直接回滚代码,GORM Callback 仍在工作 +2. Phase 3 期间:保留 Callback 代码,只回滚 Store 改动 +3. Phase 4 后:需要重新启用 Callback 代码 + +## Open Questions + +1. ~~NULL shop_id 的记录对代理用户是否可见?~~ **已确认:不可见(保持现有行为)** + +2. ~~是否需要为开发环境添加"未过滤敏感表查询"的告警机制?~~ **暂不需要,通过代码审查保证** diff --git a/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/proposal.md b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/proposal.md new file mode 100644 index 0000000..69f1699 --- /dev/null +++ b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/proposal.md @@ -0,0 +1,93 @@ +# Proposal: refactor-data-permission-filter + +## Why + +当前 GORM Callback 自动数据权限过滤机制存在以下问题: + +1. **跳过率高** - 20+ 处使用 `SkipDataPermission`(异步任务、登录、复杂查询等场景都需要跳过) +2. **特殊处理多** - `tb_shop`、`tb_tag`、`tb_enterprise_card_authorization` 等表需要特殊逻辑 +3. **重复查询** - 同一请求中 Service 层的 `CanManageShop` 和 Callback 都调用 `GetSubordinateShopIDs` +4. **隐式行为** - 开发者不知道 GORM 查询被加了什么条件,调试困难 +5. **原生 SQL 失效** - 原生 SQL 无法自动应用,需要手动在 Store 里重复写权限过滤逻辑 + +需要重构为显式调用模式,让数据权限过滤行为可预测、可控。 + +## What Changes + +### 新增 + +- **中间件预加载**:扩展 Auth 中间件,对代理用户预计算 `SubordinateShopIDs` 并放入 Context +- **UserContextInfo 扩展**:新增 `SubordinateShopIDs []uint` 字段 +- **Helper 函数**: + - `GetSubordinateShopIDs(ctx) []uint` - 获取下级店铺 ID 列表 + - `ApplyShopFilter(ctx, query) *gorm.DB` - 应用 `WHERE shop_id IN ?` 过滤 + - `ApplyEnterpriseFilter(ctx, query) *gorm.DB` - 应用 `WHERE enterprise_id = ?` 过滤 + - `ApplyOwnerShopFilter(ctx, query) *gorm.DB` - 应用 `WHERE owner_shop_id IN ?` 过滤 + +### 移除 + +- **BREAKING**:移除 `RegisterDataPermissionCallback` 函数 +- **BREAKING**:移除 `SkipDataPermission` 函数及所有调用点(20+ 处) +- **BREAKING**:`CanManageShop` / `CanManageEnterprise` 函数签名变更,不再需要 Store 参数 + +### 修改 + +- **Store 层显式过滤**:16 个 Store、85+ 个查询方法需要显式调用 Helper 函数添加权限过滤 + +## Capabilities + +### New Capabilities + +- `data-scope-middleware`: 数据权限范围中间件,负责预计算用户的数据访问范围并注入 Context + +### Modified Capabilities + +- `data-permission`: 数据权限过滤机制从 GORM Callback 自动过滤改为业务层显式调用 + +## Impact + +### 代码影响 + +| 类型 | 文件/模块 | 改动说明 | +|------|-----------|----------| +| 新增 | `pkg/middleware/data_scope.go` | Helper 函数和 Context 操作 | +| 修改 | `pkg/middleware/auth.go` | 扩展 `UserContextInfo`,Auth 中间件预计算 | +| 修改 | `pkg/middleware/permission_helper.go` | `CanManageShop` 等函数改为从 Context 取数据 | +| 移除 | `pkg/gorm/callback.go` | 移除 `RegisterDataPermissionCallback` | +| 修改 | 16 个 Store 文件 | 显式调用过滤 Helper 函数 | +| 移除 | 20+ 处 `SkipDataPermission` 调用 | 不再需要跳过机制 | + +### 需要改动的 Store(按复杂度分级) + +**高复杂度(3 个)**: +- `iot_card_store.go` - 9+ 方法,已有部分手动实现 +- `account_store.go` - 7+ 方法,双字段过滤(shop_id / enterprise_id) +- `enterprise_card_authorization_store.go` - 8+ 方法,已有参考实现 + +**中复杂度(4 个)**: +- `device_store.go` - 4+ 方法,NULL shop_id 表示平台库存 +- `order_store.go` - 3 方法,使用 seller_shop_id +- `shop_package_allocation_store.go` - 6 方法 +- `shop_series_allocation_store.go` - 6 方法 + +**低复杂度(9 个)**: +- `agent_wallet_store.go` - 5 方法 +- `agent_wallet_transaction_store.go` - 4 方法 +- `commission_record_store.go` - 4 方法 +- `enterprise_device_authorization_store.go` - 5 方法 +- `enterprise_store.go` - 3 方法 +- `shop_role_store.go` - 2 方法 +- `iot_card_import_task_store.go` - 2 方法 +- `commission_withdrawal_request_store.go` - 3 方法 +- `card_wallet_store.go` - 5+ 方法 + +### API 影响 + +- 无外部 API 变更 +- 内部函数签名变更:`CanManageShop(ctx, targetShopID, shopStore)` → `CanManageShop(ctx, targetShopID)` + +### 行为变更 + +- NULL `shop_id` 的记录对代理用户不可见(保持现有行为) +- 平台用户/超管不受数据权限限制(保持现有行为) +- 数据权限过滤从隐式变为显式,需要业务层主动调用 diff --git a/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/specs/data-permission/spec.md b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/specs/data-permission/spec.md new file mode 100644 index 0000000..fe9d3e7 --- /dev/null +++ b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/specs/data-permission/spec.md @@ -0,0 +1,90 @@ +# data-permission Delta Specification + +## Purpose + +数据权限过滤机制从 GORM Callback 自动过滤改为业务层显式调用。 + +## REMOVED Requirements + +### Requirement: GORM Callback Data Permission + +**Reason**: GORM Callback 自动过滤存在以下问题:跳过率高(20+ 处 SkipDataPermission)、特殊处理多、重复查询、隐式行为难调试、原生 SQL 失效。改为业务层显式调用模式。 + +**Migration**: +1. Store 层查询方法显式调用 `ApplyShopFilter`、`ApplyEnterpriseFilter` 等 Helper 函数 +2. 参考 `pkg/middleware/data_scope.go` 中的 Helper 函数用法 + +### Requirement: Skip Data Permission + +**Reason**: 移除 GORM Callback 后,不再需要跳过机制。数据权限过滤由业务层显式控制。 + +**Migration**: +1. 删除所有 `SkipDataPermission(ctx)` 调用 +2. 删除所有 `pkggorm.SkipDataPermission` 引用 +3. 业务逻辑直接控制是否调用过滤函数 + +### Requirement: Callback Registration + +**Reason**: 不再使用 GORM Callback 机制。数据权限范围在 Auth 中间件预计算,过滤在业务层显式调用。 + +**Migration**: +1. 移除 `RegisterDataPermissionCallback` 函数 +2. 移除应用启动时的 Callback 注册调用 +3. Auth 中间件配置 `ShopStore` 以支持预计算 + +## MODIFIED Requirements + +### Requirement: Subordinate IDs Caching + +系统 SHALL 缓存用户的下级店铺 ID 列表以提高查询性能。 + +#### Scenario: 缓存命中 +- **WHEN** 获取用户下级店铺 ID 列表 +- **AND** Redis 缓存存在 +- **THEN** 直接返回缓存数据 + +#### Scenario: 缓存未命中 +- **WHEN** 获取用户下级店铺 ID 列表 +- **AND** Redis 缓存不存在 +- **THEN** 执行递归查询获取下级店铺 ID +- **AND** 将结果缓存到 Redis(30 分钟过期) + +#### Scenario: 请求级别复用 +- **WHEN** 同一请求内多次需要下级店铺 ID 列表 +- **THEN** 从 Context 中获取预计算的值 +- **AND** 不重复查询 Redis 或数据库 + +## ADDED Requirements + +### Requirement: Store 层显式数据权限过滤 + +系统 SHALL 在 Store 层查询方法中显式调用数据权限过滤函数。 + +#### Scenario: 有 shop_id 字段的表 +- **WHEN** Store 执行列表查询 +- **AND** 表包含 `shop_id` 字段 +- **THEN** 显式调用 `ApplyShopFilter(ctx, query)` +- **AND** 代理用户只能查询 `shop_id IN (subordinateShopIDs)` 的数据 + +#### Scenario: 有 enterprise_id 字段的表 +- **WHEN** Store 执行列表查询 +- **AND** 表包含 `enterprise_id` 字段 +- **AND** 当前用户为企业用户 +- **THEN** 显式调用 `ApplyEnterpriseFilter(ctx, query)` +- **AND** 企业用户只能查询 `enterprise_id = ?` 的数据 + +#### Scenario: 有 owner_shop_id 字段的表 +- **WHEN** Store 执行列表查询 +- **AND** 表包含 `owner_shop_id` 字段(如 Enterprise 表) +- **THEN** 显式调用 `ApplyOwnerShopFilter(ctx, query)` +- **AND** 代理用户只能查询 `owner_shop_id IN (subordinateShopIDs)` 的数据 + +#### Scenario: NULL shop_id 不可见 +- **WHEN** 代理用户查询有 `shop_id` 字段的表 +- **AND** 记录的 `shop_id` 为 NULL(平台库存) +- **THEN** 该记录对代理用户不可见 + +#### Scenario: 平台用户/超管不过滤 +- **WHEN** 平台用户或超级管理员执行查询 +- **THEN** Helper 函数不添加任何过滤条件 +- **AND** 可查询所有数据 diff --git a/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/specs/data-scope-middleware/spec.md b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/specs/data-scope-middleware/spec.md new file mode 100644 index 0000000..a3d4fbd --- /dev/null +++ b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/specs/data-scope-middleware/spec.md @@ -0,0 +1,130 @@ +# data-scope-middleware Specification + +## Purpose + +数据权限范围中间件,负责在请求入口预计算用户的数据访问范围并注入 Context,供业务层显式使用。 + +## ADDED Requirements + +### Requirement: UserContextInfo 扩展 + +系统 SHALL 扩展 `UserContextInfo` 结构体以包含预计算的数据权限范围。 + +#### Scenario: 代理用户包含下级店铺 ID 列表 +- **WHEN** 代理用户登录成功 +- **AND** 用户有关联的店铺 ID +- **THEN** `UserContextInfo.SubordinateShopIDs` 包含自己店铺及所有下级店铺的 ID 列表 + +#### Scenario: 平台用户/超管不限制 +- **WHEN** 平台用户或超级管理员登录成功 +- **THEN** `UserContextInfo.SubordinateShopIDs` 为 nil +- **AND** nil 表示不受数据权限限制 + +#### Scenario: 企业用户使用 EnterpriseID +- **WHEN** 企业用户登录成功 +- **THEN** `UserContextInfo.EnterpriseID` 包含用户所属企业 ID +- **AND** `UserContextInfo.SubordinateShopIDs` 为 nil + +### Requirement: Auth 中间件预计算 + +系统 SHALL 在 Auth 中间件中预计算用户的数据访问范围。 + +#### Scenario: 代理用户预计算下级店铺 +- **WHEN** Auth 中间件验证 token 成功 +- **AND** 用户类型为代理用户 +- **AND** 用户有关联的店铺 ID +- **THEN** 调用 `GetSubordinateShopIDs` 获取下级店铺 ID 列表 +- **AND** 将结果设置到 `UserContextInfo.SubordinateShopIDs` + +#### Scenario: 获取下级店铺失败降级处理 +- **WHEN** 调用 `GetSubordinateShopIDs` 失败 +- **THEN** `SubordinateShopIDs` 降级为只包含用户自己的店铺 ID +- **AND** 记录 Error 日志 + +#### Scenario: 非代理用户跳过预计算 +- **WHEN** Auth 中间件验证 token 成功 +- **AND** 用户类型不是代理用户 +- **THEN** 不调用 `GetSubordinateShopIDs` +- **AND** `SubordinateShopIDs` 保持为 nil + +### Requirement: Context 数据获取函数 + +系统 SHALL 提供从 Context 获取数据权限范围的函数。 + +#### Scenario: 获取下级店铺 ID 列表 +- **WHEN** 调用 `GetSubordinateShopIDs(ctx)` +- **AND** Context 包含 `SubordinateShopIDs` +- **THEN** 返回下级店铺 ID 列表 + +#### Scenario: 获取空列表表示不限制 +- **WHEN** 调用 `GetSubordinateShopIDs(ctx)` +- **AND** Context 中 `SubordinateShopIDs` 为 nil +- **THEN** 返回 nil +- **AND** 调用方应理解 nil 表示不受数据权限限制 + +### Requirement: 查询过滤 Helper 函数 + +系统 SHALL 提供查询过滤 Helper 函数,供 Store 层显式调用。 + +#### Scenario: ApplyShopFilter 过滤店铺数据 +- **WHEN** 调用 `ApplyShopFilter(ctx, query)` +- **AND** `SubordinateShopIDs` 不为 nil +- **THEN** 返回添加了 `WHERE shop_id IN (?)` 条件的查询 +- **AND** 参数为 `SubordinateShopIDs` + +#### Scenario: ApplyShopFilter 不限制时不添加条件 +- **WHEN** 调用 `ApplyShopFilter(ctx, query)` +- **AND** `SubordinateShopIDs` 为 nil +- **THEN** 返回原查询,不添加任何条件 + +#### Scenario: ApplyEnterpriseFilter 过滤企业数据 +- **WHEN** 调用 `ApplyEnterpriseFilter(ctx, query)` +- **AND** 用户类型为企业用户 +- **AND** `EnterpriseID` 大于 0 +- **THEN** 返回添加了 `WHERE enterprise_id = ?` 条件的查询 + +#### Scenario: ApplyEnterpriseFilter 非企业用户不添加条件 +- **WHEN** 调用 `ApplyEnterpriseFilter(ctx, query)` +- **AND** 用户类型不是企业用户 +- **THEN** 返回原查询,不添加任何条件 + +#### Scenario: ApplyOwnerShopFilter 过滤归属店铺数据 +- **WHEN** 调用 `ApplyOwnerShopFilter(ctx, query)` +- **AND** `SubordinateShopIDs` 不为 nil +- **THEN** 返回添加了 `WHERE owner_shop_id IN (?)` 条件的查询 + +### Requirement: 权限检查函数改造 + +系统 SHALL 改造权限检查函数,从 Context 获取数据而非传入 Store。 + +#### Scenario: CanManageShop 从 Context 获取数据 +- **WHEN** 调用 `CanManageShop(ctx, targetShopID)` +- **AND** 用户类型为代理用户 +- **THEN** 从 Context 获取 `SubordinateShopIDs` +- **AND** 检查 `targetShopID` 是否在列表中 + +#### Scenario: CanManageShop 平台用户自动通过 +- **WHEN** 调用 `CanManageShop(ctx, targetShopID)` +- **AND** `SubordinateShopIDs` 为 nil +- **THEN** 返回成功(不受限制) + +#### Scenario: CanManageEnterprise 从 Context 获取数据 +- **WHEN** 调用 `CanManageEnterprise(ctx, targetEnterpriseID)` +- **AND** 用户类型为代理用户 +- **THEN** 从 Context 获取 `SubordinateShopIDs` +- **AND** 查询目标企业的 `owner_shop_id` +- **AND** 检查 `owner_shop_id` 是否在列表中 + +### Requirement: AuthConfig 扩展 + +系统 SHALL 扩展 `AuthConfig` 以支持传入 ShopStore。 + +#### Scenario: AuthConfig 包含 ShopStore +- **WHEN** 初始化 Auth 中间件 +- **THEN** `AuthConfig` 可选包含 `ShopStore ShopStoreInterface` +- **AND** 用于调用 `GetSubordinateShopIDs` + +#### Scenario: ShopStore 未配置时跳过预计算 +- **WHEN** `AuthConfig.ShopStore` 为 nil +- **THEN** 不预计算 `SubordinateShopIDs` +- **AND** 所有用户的 `SubordinateShopIDs` 为 nil diff --git a/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/tasks.md b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/tasks.md new file mode 100644 index 0000000..2cd961f --- /dev/null +++ b/openspec/changes/archive/2026-02-26-refactor-data-permission-filter/tasks.md @@ -0,0 +1,83 @@ +# Tasks: refactor-data-permission-filter + +## 1. 基础设施 - 数据结构和 Helper 函数 + +- [x] 1.1 扩展 `UserContextInfo` 结构体,添加 `SubordinateShopIDs []uint` 字段(`pkg/middleware/auth.go`) +- [x] 1.2 新增 Context key 常量 `ContextKeySubordinateShopIDs`(`pkg/constants/constants.go`) +- [x] 1.3 新增 `SetUserContext` 和 `SetUserToFiberContext` 对 `SubordinateShopIDs` 的处理 +- [x] 1.4 新增 `pkg/middleware/data_scope.go` 文件,实现 `GetSubordinateShopIDs` 函数 +- [x] 1.5 实现 `ApplyShopFilter` Helper 函数 +- [x] 1.6 实现 `ApplyEnterpriseFilter` Helper 函数 +- [x] 1.7 实现 `ApplyOwnerShopFilter` Helper 函数 +- [x] 1.8 验证:编译通过,`go build ./...` + +## 2. Auth 中间件改造 + +- [x] 2.1 扩展 `AuthConfig` 结构体,添加 `ShopStore ShopStoreInterface` 字段 +- [x] 2.2 定义 `AuthShopStoreInterface` 接口(包含 `GetSubordinateShopIDs` 方法) +- [x] 2.3 修改 `Auth` 中间件,在 token 验证成功后预计算 `SubordinateShopIDs` +- [x] 2.4 实现降级逻辑:获取下级店铺失败时降级为只包含自己的店铺 ID +- [x] 2.5 更新 Admin API 和 H5 API 的 Auth 中间件配置,传入 ShopStore +- [x] 2.6 验证:编译通过(运行时验证将在最终验证阶段进行) + +## 3. 权限检查函数改造 + +- [x] 3.1 修改 `CanManageShop` 函数签名,移除 `shopStore` 参数,改为从 Context 获取数据 +- [x] 3.2 修改 `CanManageEnterprise` 函数签名,移除 `shopStore` 参数(保留 enterpriseStore) +- [x] 3.3 更新所有 `CanManageShop` 调用点(Service 层) +- [x] 3.4 更新所有 `CanManageEnterprise` 调用点(Service 层) +- [x] 3.5 验证:编译通过 + +## 4. Store 层改造 - 低复杂度(9 个) + +- [x] 4.1 改造 `agent_wallet_store.go`:List、GetByShopID 等方法添加 `ApplyShopFilter` +- [x] 4.2 改造 `agent_wallet_transaction_store.go`:List 等方法添加 `ApplyShopFilter` +- [x] 4.3 改造 `commission_record_store.go`:List 等方法添加 `ApplyShopFilter` +- [x] 4.4 改造 `enterprise_device_authorization_store.go`:List 等方法添加 `ApplyEnterpriseFilter` +- [x] 4.5 改造 `enterprise_store.go`:List 等方法添加 `ApplyOwnerShopFilter` +- [x] 4.6 改造 `shop_role_store.go`:List 等方法添加 `ApplyShopFilter` +- [x] 4.7 改造 `iot_card_import_task_store.go`:List 等方法添加 `ApplyShopFilter` +- [x] 4.8 改造 `commission_withdrawal_request_store.go`:List 等方法添加 `ApplyShopFilter` +- [x] 4.9 改造 `card_wallet_store.go` 和 `card_wallet_transaction_store.go` +- [x] 4.10 验证:编译通过,低复杂度 Store 的列表接口数据过滤正常 + +## 5. Store 层改造 - 中复杂度(4 个) + +- [x] 5.1 改造 `device_store.go`:List、Count 等方法添加 `ApplyShopFilter`,注意 NULL shop_id 处理 +- [x] 5.2 改造 `order_store.go`:List 等方法添加店铺过滤(使用 seller_shop_id 字段) +- [x] 5.3 改造 `shop_package_allocation_store.go`:List 等方法添加 `ApplyShopFilter` +- [x] 5.4 改造 `shop_series_allocation_store.go`:List 等方法添加 `ApplyShopFilter` +- [x] 5.5 验证:编译通过,中复杂度 Store 的列表接口数据过滤正常 + +## 6. Store 层改造 - 高复杂度(3 个) + +- [x] 6.1 改造 `account_store.go`:List、GetByID 等方法,根据用户类型选择 `ApplyShopFilter` 或 `ApplyEnterpriseFilter` +- [x] 6.2 改造 `iot_card_store.go`:List、ListStandalone 等方法添加 `ApplyShopFilter`,移除已有的手动过滤逻辑 +- [x] 6.3 改造 `enterprise_card_authorization_store.go`:List、ListWithJoin 等方法,整合现有的手动权限过滤 +- [x] 6.4 验证:编译通过,高复杂度 Store 的列表接口数据过滤正常 + +## 7. 清理 - 移除 GORM Callback + +- [x] 7.1 移除 `pkg/gorm/callback.go` 中的 `RegisterDataPermissionCallback` 函数 +- [x] 7.2 移除 `SkipDataPermission` 函数和 `SkipDataPermissionKey` 常量 +- [x] 7.3 移除应用启动时的 `RegisterDataPermissionCallback` 调用 +- [x] 7.4 验证:编译通过 + +## 8. 清理 - 移除 SkipDataPermission 调用 + +- [x] 8.1 移除 `internal/task/*.go` 中的 `SkipDataPermission` 调用(6 处) +- [x] 8.2 移除 `internal/service/auth/service.go` 中的 `SkipDataPermission` 调用 +- [x] 8.3 移除 `internal/service/shop_series_allocation/service.go` 中的 `SkipDataPermission` 调用 +- [x] 8.4 移除 `internal/service/enterprise_device/service.go` 中的 `SkipDataPermission` 调用(5 处) +- [x] 8.5 移除 `internal/store/postgres/iot_card_store.go` 中的 `SkipDataPermission` 调用(5 处) +- [x] 8.6 移除 `internal/store/postgres/enterprise_card_authorization_store.go` 中的 `SkipDataPermission` 调用 +- [x] 8.7 移除 `internal/bootstrap/admin.go` 中的 `SkipDataPermission` 调用 +- [x] 8.8 验证:编译通过,全局搜索 `SkipDataPermission` 无结果 + +## 9. 最终验证 + +- [ ] 9.1 启动服务,验证平台管理员可以查看所有数据 +- [ ] 9.2 验证代理用户只能看到自己店铺及下级店铺的数据 +- [ ] 9.3 验证企业用户只能看到自己企业的数据 +- [ ] 9.4 验证 NULL shop_id 的记录对代理用户不可见 +- [x] 9.5 运行 `go build ./...` 和 `go vet ./...` 确认无编译警告 diff --git a/openspec/specs/data-permission/spec.md b/openspec/specs/data-permission/spec.md index af9a9da..4acceaa 100644 --- a/openspec/specs/data-permission/spec.md +++ b/openspec/specs/data-permission/spec.md @@ -1,77 +1,60 @@ # data-permission Specification ## Purpose -TBD - created by archiving change refactor-framework-cleanup. Update Purpose after archive. + +数据权限过滤机制,通过业务层显式调用实现数据隔离。 + ## Requirements -### Requirement: GORM Callback Data Permission - -系统 SHALL 使用 GORM Callback 机制自动为所有查询添加数据权限过滤。 - -#### Scenario: 自动应用权限过滤 -- **WHEN** 执行 GORM 查询 -- **AND** Context 包含用户信息 -- **AND** 表包含 owner_id 字段 -- **THEN** 自动添加 WHERE owner_id IN (subordinateIDs) 条件 - -#### Scenario: Root 用户跳过过滤 -- **WHEN** 当前用户是 Root 用户 -- **THEN** 不添加任何数据权限过滤条件 -- **AND** 可查询所有数据 - -#### Scenario: 无 owner_id 字段的表 -- **WHEN** 表不包含 owner_id 字段 -- **THEN** 不添加数据权限过滤条件 - -#### Scenario: 授权记录表特殊处理 -- **WHEN** 查询 `tb_enterprise_card_authorization` 表 -- **AND** 当前用户是代理用户 -- **THEN** 自动添加 WHERE enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id = 当前店铺ID) 条件 -- **AND** 不包含下级店铺的数据 - -#### Scenario: 平台用户查询授权记录 -- **WHEN** 查询 `tb_enterprise_card_authorization` 表 -- **AND** 当前用户是平台用户或超级管理员 -- **THEN** 不添加数据权限过滤条件 -- **AND** 可查询所有授权记录 - -### Requirement: Skip Data Permission - -系统 SHALL 支持通过 Context 绕过数据权限过滤。 - -#### Scenario: 显式跳过权限过滤 -- **WHEN** 调用 SkipDataPermission(ctx) 获取新 Context -- **AND** 使用该 Context 执行 GORM 查询 -- **THEN** 不添加任何数据权限过滤条件 - -#### Scenario: 内部操作跳过过滤 -- **WHEN** 执行内部同步、批量操作或管理员操作 -- **THEN** 应使用 SkipDataPermission 绕过过滤 ### Requirement: Subordinate IDs Caching -系统 SHALL 缓存用户的下级 ID 列表以提高查询性能。 +系统 SHALL 缓存用户的下级店铺 ID 列表以提高查询性能。 #### Scenario: 缓存命中 -- **WHEN** 获取用户下级 ID 列表 +- **WHEN** 获取用户下级店铺 ID 列表 - **AND** Redis 缓存存在 - **THEN** 直接返回缓存数据 #### Scenario: 缓存未命中 -- **WHEN** 获取用户下级 ID 列表 +- **WHEN** 获取用户下级店铺 ID 列表 - **AND** Redis 缓存不存在 -- **THEN** 执行递归 CTE 查询获取下级 ID +- **THEN** 执行递归查询获取下级店铺 ID - **AND** 将结果缓存到 Redis(30 分钟过期) -### Requirement: Callback Registration +#### Scenario: 请求级别复用 +- **WHEN** 同一请求内多次需要下级店铺 ID 列表 +- **THEN** 从 Context 中获取预计算的值 +- **AND** 不重复查询 Redis 或数据库 -系统 SHALL 在应用启动时注册 GORM 数据权限 Callback。 +### Requirement: Store 层显式数据权限过滤 -#### Scenario: 注册 Callback -- **WHEN** 调用 RegisterDataPermissionCallback(db, accountStore) -- **THEN** 注册 Query Before Callback -- **AND** Callback 名称为 "data_permission" +系统 SHALL 在 Store 层查询方法中显式调用数据权限过滤函数。 -#### Scenario: AccountStore 依赖 -- **WHEN** 注册 Callback 时 -- **THEN** 需要传入 AccountStore 实例用于获取下级 ID +#### Scenario: 有 shop_id 字段的表 +- **WHEN** Store 执行列表查询 +- **AND** 表包含 `shop_id` 字段 +- **THEN** 显式调用 `ApplyShopFilter(ctx, query)` +- **AND** 代理用户只能查询 `shop_id IN (subordinateShopIDs)` 的数据 +#### Scenario: 有 enterprise_id 字段的表 +- **WHEN** Store 执行列表查询 +- **AND** 表包含 `enterprise_id` 字段 +- **AND** 当前用户为企业用户 +- **THEN** 显式调用 `ApplyEnterpriseFilter(ctx, query)` +- **AND** 企业用户只能查询 `enterprise_id = ?` 的数据 + +#### Scenario: 有 owner_shop_id 字段的表 +- **WHEN** Store 执行列表查询 +- **AND** 表包含 `owner_shop_id` 字段(如 Enterprise 表) +- **THEN** 显式调用 `ApplyOwnerShopFilter(ctx, query)` +- **AND** 代理用户只能查询 `owner_shop_id IN (subordinateShopIDs)` 的数据 + +#### Scenario: NULL shop_id 不可见 +- **WHEN** 代理用户查询有 `shop_id` 字段的表 +- **AND** 记录的 `shop_id` 为 NULL(平台库存) +- **THEN** 该记录对代理用户不可见 + +#### Scenario: 平台用户/超管不过滤 +- **WHEN** 平台用户或超级管理员执行查询 +- **THEN** Helper 函数不添加任何过滤条件 +- **AND** 可查询所有数据 diff --git a/openspec/specs/data-scope-middleware/spec.md b/openspec/specs/data-scope-middleware/spec.md new file mode 100644 index 0000000..173d9e1 --- /dev/null +++ b/openspec/specs/data-scope-middleware/spec.md @@ -0,0 +1,130 @@ +# data-scope-middleware Specification + +## Purpose + +数据权限范围中间件,负责在请求入口预计算用户的数据访问范围并注入 Context,供业务层显式使用。 + +## Requirements + +### Requirement: UserContextInfo 扩展 + +系统 SHALL 扩展 `UserContextInfo` 结构体以包含预计算的数据权限范围。 + +#### Scenario: 代理用户包含下级店铺 ID 列表 +- **WHEN** 代理用户登录成功 +- **AND** 用户有关联的店铺 ID +- **THEN** `UserContextInfo.SubordinateShopIDs` 包含自己店铺及所有下级店铺的 ID 列表 + +#### Scenario: 平台用户/超管不限制 +- **WHEN** 平台用户或超级管理员登录成功 +- **THEN** `UserContextInfo.SubordinateShopIDs` 为 nil +- **AND** nil 表示不受数据权限限制 + +#### Scenario: 企业用户使用 EnterpriseID +- **WHEN** 企业用户登录成功 +- **THEN** `UserContextInfo.EnterpriseID` 包含用户所属企业 ID +- **AND** `UserContextInfo.SubordinateShopIDs` 为 nil + +### Requirement: Auth 中间件预计算 + +系统 SHALL 在 Auth 中间件中预计算用户的数据访问范围。 + +#### Scenario: 代理用户预计算下级店铺 +- **WHEN** Auth 中间件验证 token 成功 +- **AND** 用户类型为代理用户 +- **AND** 用户有关联的店铺 ID +- **THEN** 调用 `GetSubordinateShopIDs` 获取下级店铺 ID 列表 +- **AND** 将结果设置到 `UserContextInfo.SubordinateShopIDs` + +#### Scenario: 获取下级店铺失败降级处理 +- **WHEN** 调用 `GetSubordinateShopIDs` 失败 +- **THEN** `SubordinateShopIDs` 降级为只包含用户自己的店铺 ID +- **AND** 记录 Error 日志 + +#### Scenario: 非代理用户跳过预计算 +- **WHEN** Auth 中间件验证 token 成功 +- **AND** 用户类型不是代理用户 +- **THEN** 不调用 `GetSubordinateShopIDs` +- **AND** `SubordinateShopIDs` 保持为 nil + +### Requirement: Context 数据获取函数 + +系统 SHALL 提供从 Context 获取数据权限范围的函数。 + +#### Scenario: 获取下级店铺 ID 列表 +- **WHEN** 调用 `GetSubordinateShopIDs(ctx)` +- **AND** Context 包含 `SubordinateShopIDs` +- **THEN** 返回下级店铺 ID 列表 + +#### Scenario: 获取空列表表示不限制 +- **WHEN** 调用 `GetSubordinateShopIDs(ctx)` +- **AND** Context 中 `SubordinateShopIDs` 为 nil +- **THEN** 返回 nil +- **AND** 调用方应理解 nil 表示不受数据权限限制 + +### Requirement: 查询过滤 Helper 函数 + +系统 SHALL 提供查询过滤 Helper 函数,供 Store 层显式调用。 + +#### Scenario: ApplyShopFilter 过滤店铺数据 +- **WHEN** 调用 `ApplyShopFilter(ctx, query)` +- **AND** `SubordinateShopIDs` 不为 nil +- **THEN** 返回添加了 `WHERE shop_id IN (?)` 条件的查询 +- **AND** 参数为 `SubordinateShopIDs` + +#### Scenario: ApplyShopFilter 不限制时不添加条件 +- **WHEN** 调用 `ApplyShopFilter(ctx, query)` +- **AND** `SubordinateShopIDs` 为 nil +- **THEN** 返回原查询,不添加任何条件 + +#### Scenario: ApplyEnterpriseFilter 过滤企业数据 +- **WHEN** 调用 `ApplyEnterpriseFilter(ctx, query)` +- **AND** 用户类型为企业用户 +- **AND** `EnterpriseID` 大于 0 +- **THEN** 返回添加了 `WHERE enterprise_id = ?` 条件的查询 + +#### Scenario: ApplyEnterpriseFilter 非企业用户不添加条件 +- **WHEN** 调用 `ApplyEnterpriseFilter(ctx, query)` +- **AND** 用户类型不是企业用户 +- **THEN** 返回原查询,不添加任何条件 + +#### Scenario: ApplyOwnerShopFilter 过滤归属店铺数据 +- **WHEN** 调用 `ApplyOwnerShopFilter(ctx, query)` +- **AND** `SubordinateShopIDs` 不为 nil +- **THEN** 返回添加了 `WHERE owner_shop_id IN (?)` 条件的查询 + +### Requirement: 权限检查函数改造 + +系统 SHALL 改造权限检查函数,从 Context 获取数据而非传入 Store。 + +#### Scenario: CanManageShop 从 Context 获取数据 +- **WHEN** 调用 `CanManageShop(ctx, targetShopID)` +- **AND** 用户类型为代理用户 +- **THEN** 从 Context 获取 `SubordinateShopIDs` +- **AND** 检查 `targetShopID` 是否在列表中 + +#### Scenario: CanManageShop 平台用户自动通过 +- **WHEN** 调用 `CanManageShop(ctx, targetShopID)` +- **AND** `SubordinateShopIDs` 为 nil +- **THEN** 返回成功(不受限制) + +#### Scenario: CanManageEnterprise 从 Context 获取数据 +- **WHEN** 调用 `CanManageEnterprise(ctx, targetEnterpriseID)` +- **AND** 用户类型为代理用户 +- **THEN** 从 Context 获取 `SubordinateShopIDs` +- **AND** 查询目标企业的 `owner_shop_id` +- **AND** 检查 `owner_shop_id` 是否在列表中 + +### Requirement: AuthConfig 扩展 + +系统 SHALL 扩展 `AuthConfig` 以支持传入 ShopStore。 + +#### Scenario: AuthConfig 包含 ShopStore +- **WHEN** 初始化 Auth 中间件 +- **THEN** `AuthConfig` 可选包含 `ShopStore ShopStoreInterface` +- **AND** 用于调用 `GetSubordinateShopIDs` + +#### Scenario: ShopStore 未配置时跳过预计算 +- **WHEN** `AuthConfig.ShopStore` 为 nil +- **THEN** 不预计算 `SubordinateShopIDs` +- **AND** 所有用户的 `SubordinateShopIDs` 为 nil diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 7e6d4d4..660d7f9 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -4,16 +4,17 @@ import "time" // Fiber Locals 的上下文键 const ( - ContextKeyRequestID = "requestid" // 请求记录ID - ContextKeyStartTime = "start_time" // 请求开始时间 - ContextKeyUserID = "user_id" // 用户ID - ContextKeyUserType = "user_type" // 用户类型 - ContextKeyShopID = "shop_id" // 店铺ID - ContextKeyEnterpriseID = "enterprise_id" // 企业ID - ContextKeyCustomerID = "customer_id" // 个人客户ID - ContextKeyUserInfo = "user_info" // 完整的用户信息 - ContextKeyIP = "ip_address" // IP地址 - ContextKeyUserAgent = "user_agent" // User-Agent + ContextKeyRequestID = "requestid" // 请求记录ID + ContextKeyStartTime = "start_time" // 请求开始时间 + ContextKeyUserID = "user_id" // 用户ID + ContextKeyUserType = "user_type" // 用户类型 + ContextKeyShopID = "shop_id" // 店铺ID + ContextKeyEnterpriseID = "enterprise_id" // 企业ID + ContextKeyCustomerID = "customer_id" // 个人客户ID + ContextKeyUserInfo = "user_info" // 完整的用户信息 + ContextKeyIP = "ip_address" // IP地址 + ContextKeyUserAgent = "user_agent" // User-Agent + ContextKeySubordinateShopIDs = "subordinate_shop_ids" // 下级店铺ID列表(代理用户预计算) ) // 配置环境变量 diff --git a/pkg/gorm/callback.go b/pkg/gorm/callback.go index ed357b0..b022233 100644 --- a/pkg/gorm/callback.go +++ b/pkg/gorm/callback.go @@ -1,250 +1,12 @@ package gorm import ( - "context" "reflect" "github.com/break/junhong_cmp_fiber/pkg/constants" - "github.com/break/junhong_cmp_fiber/pkg/logger" - "github.com/break/junhong_cmp_fiber/pkg/middleware" - "go.uber.org/zap" "gorm.io/gorm" - "gorm.io/gorm/schema" ) -// contextKey 用于 context value 的 key 类型 -type contextKey string - -// SkipDataPermissionKey 跳过数据权限过滤的 context key -const SkipDataPermissionKey contextKey = "skip_data_permission" - -// SkipDataPermission 返回跳过数据权限过滤的 Context -// 用于需要查询所有数据的场景(如管理后台统计、系统任务等) -// -// 使用示例: -// -// ctx = gorm.SkipDataPermission(ctx) -// db.WithContext(ctx).Find(&accounts) -func SkipDataPermission(ctx context.Context) context.Context { - return context.WithValue(ctx, SkipDataPermissionKey, true) -} - -// ShopStoreInterface 店铺 Store 接口 -// 用于 Callback 获取下级店铺 ID,避免循环依赖 -type ShopStoreInterface interface { - GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) -} - -// RegisterDataPermissionCallback 注册 GORM 数据权限过滤 Callback -// -// 自动化数据权限过滤规则: -// 1. 超级管理员跳过过滤,可以查看所有数据 -// 2. 平台用户跳过过滤,可以查看所有数据 -// 3. 代理用户只能查看自己店铺及下级店铺的数据(基于 shop_id 字段) -// 4. 企业用户只能查看自己企业的数据(基于 enterprise_id 字段) -// 5. 个人客户只能查看自己的数据(基于 creator 字段或 customer_id 字段) -// 6. 通过 SkipDataPermission(ctx) 可以绕过权限过滤 -// -// 软删除过滤规则: -// 1. 所有查询自动排除 deleted_at IS NOT NULL 的记录 -// 2. 使用 db.Unscoped() 可以查询包含已删除的记录 -// -// 注意: -// - Callback 根据表的字段自动选择过滤策略 -// - 必须在初始化 Store 之前注册 -// -// 参数: -// - db: GORM DB 实例 -// - shopStore: 店铺 Store,用于查询下级店铺 ID -// -// 返回: -// - error: 注册错误 -func RegisterDataPermissionCallback(db *gorm.DB, shopStore ShopStoreInterface) error { - // 注册查询前的 Callback - err := db.Callback().Query().Before("gorm:query").Register("data_permission:query", func(tx *gorm.DB) { - ctx := tx.Statement.Context - if ctx == nil { - return - } - - // 1. 检查是否跳过数据权限过滤 - if skip, ok := ctx.Value(SkipDataPermissionKey).(bool); ok && skip { - return - } - - // 2. 获取用户类型 - userType := middleware.GetUserTypeFromContext(ctx) - - // 3. 超级管理员和平台用户跳过过滤,可以查看所有数据 - if userType == constants.UserTypeSuperAdmin || userType == constants.UserTypePlatform { - return - } - - // 4. 获取当前用户信息 - userID := middleware.GetUserIDFromContext(ctx) - if userID == 0 { - // 未登录用户返回空结果 - logger.GetAppLogger().Warn("数据权限过滤:未获取到用户 ID") - tx.Where("1 = 0") - return - } - - shopID := middleware.GetShopIDFromContext(ctx) - - // 5. 根据用户类型和表结构应用不同的过滤规则 - schema := tx.Statement.Schema - if schema == nil { - return - } - - // 5.1 代理用户:基于店铺层级过滤 - if userType == constants.UserTypeAgent { - tableName := schema.Table - - // 特殊处理:授权记录表(通过企业归属过滤,不含下级店铺) - if tableName == "tb_enterprise_card_authorization" { - if shopID == 0 { - // 代理用户没有 shop_id,返回空结果 - tx.Where("1 = 0") - return - } - // 只能看到自己店铺下企业的授权记录(不包含下级店铺) - tx.Where("enterprise_id IN (SELECT id FROM tb_enterprise WHERE owner_shop_id = ? AND deleted_at IS NULL)", shopID) - return - } - - // 特殊处理:标签表和资源标签表(包含全局标签) - if tableName == "tb_tag" || tableName == "tb_resource_tag" { - if shopID == 0 { - // 没有 shop_id,只能看全局标签 - tx.Where("enterprise_id IS NULL AND shop_id IS NULL") - return - } - - // 查询该店铺及下级店铺的 ID - subordinateShopIDs, err := shopStore.GetSubordinateShopIDs(ctx, shopID) - if err != nil { - logger.GetAppLogger().Error("数据权限过滤:获取下级店铺 ID 失败", - zap.Uint("shop_id", shopID), - zap.Error(err)) - subordinateShopIDs = []uint{shopID} - } - - // 过滤:店铺标签(自己店铺及下级店铺)或全局标签 - tx.Where("shop_id IN ? OR (enterprise_id IS NULL AND shop_id IS NULL)", subordinateShopIDs) - return - } - - if !hasShopIDField(schema) { - // 表没有 shop_id 字段,无法过滤 - return - } - - if shopID == 0 { - // 代理用户没有 shop_id,只能看自己创建的数据 - if hasCreatorField(schema) { - tx.Where("creator = ?", userID) - } else { - tx.Where("1 = 0") - } - return - } - - // 查询该店铺及下级店铺的 ID - subordinateShopIDs, err := shopStore.GetSubordinateShopIDs(ctx, shopID) - if err != nil { - logger.GetAppLogger().Error("数据权限过滤:获取下级店铺 ID 失败", - zap.Uint("shop_id", shopID), - zap.Error(err)) - // 降级为只能看自己店铺的数据 - subordinateShopIDs = []uint{shopID} - } - - // 过滤:shop_id IN (自己店铺及下级店铺) - tx.Where("shop_id IN ?", subordinateShopIDs) - return - } - - // 5.2 企业用户:基于 enterprise_id 过滤 - if userType == constants.UserTypeEnterprise { - enterpriseID := middleware.GetEnterpriseIDFromContext(ctx) - tableName := schema.Table - - // 特殊处理:标签表和资源标签表(包含全局标签) - if tableName == "tb_tag" || tableName == "tb_resource_tag" { - if enterpriseID != 0 { - // 过滤:企业标签或全局标签 - tx.Where("enterprise_id = ? OR (enterprise_id IS NULL AND shop_id IS NULL)", enterpriseID) - } else { - // 没有 enterprise_id,只能看全局标签 - tx.Where("enterprise_id IS NULL AND shop_id IS NULL") - } - return - } - - if hasEnterpriseIDField(schema) { - if enterpriseID != 0 { - tx.Where("enterprise_id = ?", enterpriseID) - } else { - // 企业用户没有 enterprise_id,返回空结果 - tx.Where("1 = 0") - } - return - } - - // 如果表没有 enterprise_id 字段,但有 creator 字段,则只能看自己创建的数据 - if hasCreatorField(schema) { - tx.Where("creator = ?", userID) - return - } - - // 无法过滤,返回空结果 - tx.Where("1 = 0") - return - } - - // 5.3 个人客户:只能看自己的数据 - if userType == constants.UserTypePersonalCustomer { - customerID := middleware.GetCustomerIDFromContext(ctx) - tableName := schema.Table - - // 特殊处理:标签表和资源标签表(只能看全局标签) - if tableName == "tb_tag" || tableName == "tb_resource_tag" { - tx.Where("enterprise_id IS NULL AND shop_id IS NULL") - return - } - - // 优先使用 customer_id 字段 - if hasCustomerIDField(schema) { - if customerID != 0 { - tx.Where("customer_id = ?", customerID) - } else { - // 个人客户没有 customer_id,返回空结果 - tx.Where("1 = 0") - } - return - } - - // 降级为使用 creator 字段 - if hasCreatorField(schema) { - tx.Where("creator = ?", userID) - return - } - - // 无法过滤,返回空结果 - tx.Where("1 = 0") - return - } - - // 6. 默认:未知用户类型,返回空结果 - logger.GetAppLogger().Warn("数据权限过滤:未知用户类型", - zap.Uint("user_id", userID), - zap.Int("user_type", userType)) - tx.Where("1 = 0") - }) - return err -} - // RegisterSetCreatorUpdaterCallback 注册 GORM 创建数据时创建人更新人 Callback func RegisterSetCreatorUpdaterCallback(db *gorm.DB) error { err := db.Callback().Create().Before("gorm:create").Register("set_creator_updater", func(tx *gorm.DB) { @@ -296,48 +58,3 @@ func RegisterSetCreatorUpdaterCallback(db *gorm.DB) error { }) return err } - -// hasCreatorField 检查 Schema 是否包含 creator 字段 -func hasCreatorField(s *schema.Schema) bool { - if s == nil { - return false - } - _, ok := s.FieldsByDBName["creator"] - return ok -} - -// hasShopIDField 检查 Schema 是否包含 shop_id 字段 -func hasShopIDField(s *schema.Schema) bool { - if s == nil { - return false - } - _, ok := s.FieldsByDBName["shop_id"] - return ok -} - -// hasEnterpriseIDField 检查 Schema 是否包含 enterprise_id 字段 -func hasEnterpriseIDField(s *schema.Schema) bool { - if s == nil { - return false - } - _, ok := s.FieldsByDBName["enterprise_id"] - return ok -} - -// hasCustomerIDField 检查 Schema 是否包含 customer_id 字段 -func hasCustomerIDField(s *schema.Schema) bool { - if s == nil { - return false - } - _, ok := s.FieldsByDBName["customer_id"] - return ok -} - -// hasDeletedAtField 检查 Schema 是否包含 deleted_at 字段 -func hasDeletedAtField(s *schema.Schema) bool { - if s == nil { - return false - } - _, ok := s.FieldsByDBName["deleted_at"] - return ok -} diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go index 73932f1..8becf75 100644 --- a/pkg/middleware/auth.go +++ b/pkg/middleware/auth.go @@ -5,16 +5,19 @@ import ( "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" + "github.com/break/junhong_cmp_fiber/pkg/logger" "github.com/gofiber/fiber/v2" + "go.uber.org/zap" ) // UserContextInfo 用户上下文信息 type UserContextInfo struct { - UserID uint - UserType int - ShopID uint - EnterpriseID uint - CustomerID uint + UserID uint + UserType int + ShopID uint + EnterpriseID uint + CustomerID uint + SubordinateShopIDs []uint // 代理用户的下级店铺ID列表,nil 表示不受数据权限限制 } // SetUserContext 将用户信息设置到 context 中 @@ -25,6 +28,10 @@ func SetUserContext(ctx context.Context, info *UserContextInfo) context.Context ctx = context.WithValue(ctx, constants.ContextKeyShopID, info.ShopID) ctx = context.WithValue(ctx, constants.ContextKeyEnterpriseID, info.EnterpriseID) ctx = context.WithValue(ctx, constants.ContextKeyCustomerID, info.CustomerID) + // SubordinateShopIDs: nil 表示不限制,空切片表示无权限 + if info.SubordinateShopIDs != nil { + ctx = context.WithValue(ctx, constants.ContextKeySubordinateShopIDs, info.SubordinateShopIDs) + } return ctx } @@ -134,12 +141,21 @@ func SetUserToFiberContext(c *fiber.Ctx, info *UserContextInfo) { c.Locals(constants.ContextKeyShopID, info.ShopID) c.Locals(constants.ContextKeyEnterpriseID, info.EnterpriseID) c.Locals(constants.ContextKeyCustomerID, info.CustomerID) + if info.SubordinateShopIDs != nil { + c.Locals(constants.ContextKeySubordinateShopIDs, info.SubordinateShopIDs) + } - // 设置到标准 context(用于 GORM 数据权限过滤) + // 设置到标准 context(用于数据权限过滤) ctx := SetUserContext(c.UserContext(), info) c.SetUserContext(ctx) } +// AuthShopStoreInterface 店铺存储接口 +// 用于 Auth 中间件获取下级店铺 ID,避免循环依赖 +type AuthShopStoreInterface interface { + GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) +} + // AuthConfig Auth 中间件配置 type AuthConfig struct { // TokenExtractor 自定义 token 提取函数 @@ -153,6 +169,10 @@ type AuthConfig struct { // SkipPaths 跳过认证的路径列表 SkipPaths []string + + // ShopStore 店铺存储,用于预计算代理用户的下级店铺 ID + // 可选,不传则不预计算 SubordinateShopIDs + ShopStore AuthShopStoreInterface } // Auth 认证中间件 @@ -196,6 +216,21 @@ func Auth(config AuthConfig) fiber.Handler { return errors.Wrap(errors.CodeInvalidToken, err, "认证令牌无效") } + // 预计算代理用户的下级店铺 ID + if config.ShopStore != nil && + userInfo.UserType == constants.UserTypeAgent && + userInfo.ShopID > 0 { + shopIDs, err := config.ShopStore.GetSubordinateShopIDs(c.UserContext(), userInfo.ShopID) + if err != nil { + // 降级处理:只包含自己的店铺 ID + shopIDs = []uint{userInfo.ShopID} + logger.GetAppLogger().Warn("预计算下级店铺失败,降级为只包含自己", + zap.Uint("shop_id", userInfo.ShopID), + zap.Error(err)) + } + userInfo.SubordinateShopIDs = shopIDs + } + // 将用户信息设置到 context SetUserToFiberContext(c, userInfo) diff --git a/pkg/middleware/data_scope.go b/pkg/middleware/data_scope.go new file mode 100644 index 0000000..aad3d6b --- /dev/null +++ b/pkg/middleware/data_scope.go @@ -0,0 +1,91 @@ +package middleware + +import ( + "context" + + "github.com/break/junhong_cmp_fiber/pkg/constants" + "gorm.io/gorm" +) + +// GetSubordinateShopIDs 获取当前用户可管理的店铺ID列表 +// 返回 nil 表示不受数据权限限制(平台用户/超管) +// 返回 []uint 表示限制在这些店铺范围内(代理用户) +func GetSubordinateShopIDs(ctx context.Context) []uint { + if ctx == nil { + return nil + } + if ids, ok := ctx.Value(constants.ContextKeySubordinateShopIDs).([]uint); ok { + return ids + } + return nil +} + +// ApplyShopFilter 应用店铺数据权限过滤 +// 平台用户/超管:不添加条件(SubordinateShopIDs 为 nil) +// 代理用户:WHERE shop_id IN (subordinateShopIDs) +// 注意:NULL shop_id 的记录对代理用户不可见 +func ApplyShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB { + shopIDs := GetSubordinateShopIDs(ctx) + if shopIDs == nil { + return query + } + return query.Where("shop_id IN ?", shopIDs) +} + +// ApplyEnterpriseFilter 应用企业数据权限过滤 +// 非企业用户:不添加条件 +// 企业用户:WHERE enterprise_id = ? +func ApplyEnterpriseFilter(ctx context.Context, query *gorm.DB) *gorm.DB { + userType := GetUserTypeFromContext(ctx) + if userType != constants.UserTypeEnterprise { + return query + } + enterpriseID := GetEnterpriseIDFromContext(ctx) + if enterpriseID == 0 { + // 企业用户但无企业ID,返回空结果 + return query.Where("1 = 0") + } + return query.Where("enterprise_id = ?", enterpriseID) +} + +// ApplyOwnerShopFilter 应用归属店铺数据权限过滤 +// 用于 Enterprise 等使用 owner_shop_id 字段的表 +// 平台用户/超管:不添加条件 +// 代理用户:WHERE owner_shop_id IN (subordinateShopIDs) +func ApplyOwnerShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB { + shopIDs := GetSubordinateShopIDs(ctx) + if shopIDs == nil { + return query + } + return query.Where("owner_shop_id IN ?", shopIDs) +} + +// IsUnrestricted 检查当前用户是否不受数据权限限制 +// 平台用户/超管返回 true,代理/企业用户返回 false +func IsUnrestricted(ctx context.Context) bool { + return GetSubordinateShopIDs(ctx) == nil +} + +// ApplySellerShopFilter 应用销售店铺数据权限过滤 +// 用于 Order 等使用 seller_shop_id 字段的表 +// 平台用户/超管:不添加条件 +// 代理用户:WHERE seller_shop_id IN (subordinateShopIDs) +func ApplySellerShopFilter(ctx context.Context, query *gorm.DB) *gorm.DB { + shopIDs := GetSubordinateShopIDs(ctx) + if shopIDs == nil { + return query + } + return query.Where("seller_shop_id IN ?", shopIDs) +} + +// ApplyShopTagFilter 应用店铺标签数据权限过滤 +// 用于 CardWallet 等使用 shop_id_tag 字段的表 +// 平台用户/超管:不添加条件 +// 代理用户:WHERE shop_id_tag IN (subordinateShopIDs) +func ApplyShopTagFilter(ctx context.Context, query *gorm.DB) *gorm.DB { + shopIDs := GetSubordinateShopIDs(ctx) + if shopIDs == nil { + return query + } + return query.Where("shop_id_tag IN ?", shopIDs) +} diff --git a/pkg/middleware/permission_helper.go b/pkg/middleware/permission_helper.go index 1337ea3..d6cf232 100644 --- a/pkg/middleware/permission_helper.go +++ b/pkg/middleware/permission_helper.go @@ -2,20 +2,13 @@ package middleware import ( "context" + "slices" "github.com/break/junhong_cmp_fiber/internal/model" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" ) -// ShopStoreInterface 店铺存储接口 -// 用于权限检查时查询店铺信息和下级店铺ID -type ShopStoreInterface interface { - GetByID(ctx context.Context, id uint) (*model.Shop, error) - GetByIDs(ctx context.Context, ids []uint) ([]*model.Shop, error) - GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) -} - // EnterpriseStoreInterface 企业存储接口 // 用于权限检查时查询企业信息 type EnterpriseStoreInterface interface { @@ -23,91 +16,80 @@ type EnterpriseStoreInterface interface { GetByIDs(ctx context.Context, ids []uint) ([]*model.Enterprise, error) } -// CanManageShop 检查当前用户是否有权管理目标店铺的账号 -// 超级管理员和平台用户自动通过 -// 代理账号只能管理自己店铺及下级店铺的账号 -// 企业账号禁止管理店铺账号 -func CanManageShop(ctx context.Context, targetShopID uint, shopStore ShopStoreInterface) error { +// CanManageShop 检查当前用户是否有权管理目标店铺 +// 超级管理员和平台用户自动通过(SubordinateShopIDs 为 nil) +// 代理账号只能管理自己店铺及下级店铺 +// 企业账号禁止管理店铺 +func CanManageShop(ctx context.Context, targetShopID uint) error { userType := GetUserTypeFromContext(ctx) - // 超级管理员和平台用户跳过权限检查 - if userType == constants.UserTypeSuperAdmin || userType == constants.UserTypePlatform { + // 企业账号禁止管理店铺 + if userType == constants.UserTypeEnterprise { + return errors.New(errors.CodeForbidden, "无权限管理店铺") + } + + // 从 Context 获取预计算的下级店铺 ID 列表 + subordinateIDs := GetSubordinateShopIDs(ctx) + + // nil 表示不受限制(超级管理员/平台用户) + if subordinateIDs == nil { return nil } - // 企业账号禁止管理店铺账号 - if userType != constants.UserTypeAgent { - return errors.New(errors.CodeForbidden, "无权限管理店铺账号") - } - - // 获取当前代理账号的店铺ID - currentShopID := GetShopIDFromContext(ctx) - if currentShopID == 0 { - return errors.New(errors.CodeForbidden, "无权限管理店铺账号") - } - - // 递归查询下级店铺ID(包含自己) - subordinateIDs, err := shopStore.GetSubordinateShopIDs(ctx, currentShopID) - if err != nil { - return errors.Wrap(errors.CodeInternalError, err, "查询下级店铺失败") - } - // 检查目标店铺是否在下级列表中 - for _, id := range subordinateIDs { - if id == targetShopID { - return nil - } - } - - return errors.New(errors.CodeForbidden, "无权限管理该店铺的账号") -} - -// CanManageEnterprise 检查当前用户是否有权管理目标企业的账号 -// 超级管理员和平台用户自动通过 -// 代理账号只能管理归属于自己店铺或下级店铺的企业账号 -// 企业账号禁止管理其他企业账号 -func CanManageEnterprise(ctx context.Context, targetEnterpriseID uint, enterpriseStore EnterpriseStoreInterface, shopStore ShopStoreInterface) error { - userType := GetUserTypeFromContext(ctx) - - // 超级管理员和平台用户跳过权限检查 - if userType == constants.UserTypeSuperAdmin || userType == constants.UserTypePlatform { + if slices.Contains(subordinateIDs, targetShopID) { return nil } - // 企业账号禁止管理其他企业账号 - if userType != constants.UserTypeAgent { - return errors.New(errors.CodeForbidden, "无权限管理企业账号") + return errors.New(errors.CodeForbidden, "无权限管理该店铺") +} + +// CanManageEnterprise 检查当前用户是否有权管理目标企业 +// 超级管理员和平台用户自动通过(SubordinateShopIDs 为 nil) +// 代理账号只能管理归属于自己店铺或下级店铺的企业 +// 企业账号禁止管理其他企业 +func CanManageEnterprise(ctx context.Context, targetEnterpriseID uint, enterpriseStore EnterpriseStoreInterface) error { + userType := GetUserTypeFromContext(ctx) + + // 企业账号禁止管理其他企业 + if userType == constants.UserTypeEnterprise { + return errors.New(errors.CodeForbidden, "无权限管理企业") + } + + // 从 Context 获取预计算的下级店铺 ID 列表 + subordinateIDs := GetSubordinateShopIDs(ctx) + + // nil 表示不受限制(超级管理员/平台用户) + if subordinateIDs == nil { + return nil } // 获取目标企业信息 enterprise, err := enterpriseStore.GetByID(ctx, targetEnterpriseID) if err != nil { - return errors.Wrap(errors.CodeForbidden, err, "无权限操作该资源或资源不存在") + return errors.New(errors.CodeForbidden, "无权限操作该资源或资源不存在") } - // 代理账号不能管理平台级企业(owner_shop_id为NULL) + // 代理账号不能管理平台级企业(owner_shop_id 为 NULL) if enterprise.OwnerShopID == nil { - return errors.New(errors.CodeForbidden, "无权限管理平台级企业账号") - } - - // 获取当前代理账号的店铺ID - currentShopID := GetShopIDFromContext(ctx) - if currentShopID == 0 { - return errors.New(errors.CodeForbidden, "无权限管理企业账号") - } - - // 递归查询下级店铺ID(包含自己) - subordinateIDs, err := shopStore.GetSubordinateShopIDs(ctx, currentShopID) - if err != nil { - return errors.Wrap(errors.CodeInternalError, err, "查询下级店铺失败") + return errors.New(errors.CodeForbidden, "无权限管理平台级企业") } // 检查企业归属的店铺是否在下级列表中 - for _, id := range subordinateIDs { - if id == *enterprise.OwnerShopID { - return nil - } + if slices.Contains(subordinateIDs, *enterprise.OwnerShopID) { + return nil } - return errors.New(errors.CodeForbidden, "无权限管理该企业的账号") + return errors.New(errors.CodeForbidden, "无权限管理该企业") +} + +// ContainsShopID 检查目标店铺 ID 是否在当前用户可管理的店铺列表中 +// 平台用户/超管返回 true(不受限制) +// 代理用户检查是否在 SubordinateShopIDs 中 +func ContainsShopID(ctx context.Context, targetShopID uint) bool { + subordinateIDs := GetSubordinateShopIDs(ctx) + if subordinateIDs == nil { + return true // 不受限制 + } + return slices.Contains(subordinateIDs, targetShopID) }