From 743db126f78302f131eed22c774b641a9562d6a0 Mon Sep 17 00:00:00 2001 From: huang Date: Sat, 10 Jan 2026 15:08:11 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=95=B0=E6=8D=AE=E6=9D=83?= =?UTF-8?q?=E9=99=90=E6=A8=A1=E5=9E=8B=E5=B9=B6=E6=B8=85=E7=90=86=E6=97=A7?= =?UTF-8?q?RBAC=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 核心变更: - 数据权限过滤从基于账号层级改为基于用户类型的多策略过滤 - 移除 AccountStore 中的 GetSubordinateIDs 等旧方法 - 重构认证中间件,支持 enterprise_id 和 customer_id - 更新 GORM Callback,根据用户类型自动选择过滤策略(代理/企业/个人客户) - 更新所有集成测试以适配新的 API 签名 - 添加功能总结文档和 OpenSpec 归档 Co-Authored-By: Claude Sonnet 4.5 --- .gitignore | 3 + README.md | 44 +- docs/remove-legacy-rbac-cleanup/清理总结.md | 430 ++++++++++++++++++ internal/bootstrap/bootstrap.go | 6 +- internal/bootstrap/stores.go | 2 + internal/service/account/service.go | 5 +- internal/store/postgres/account_store.go | 67 --- .../proposal.md | 61 +++ .../specs/legacy-cleanup/spec.md | 109 +++++ .../tasks.md | 90 ++++ openspec/specs/legacy-cleanup/spec.md | 73 +++ pkg/constants/constants.go | 23 +- pkg/constants/redis.go | 7 - pkg/gorm/callback.go | 180 ++++++-- pkg/gorm/callback_test.go | 349 ++++++++------ pkg/middleware/auth.go | 76 +++- pkg/response/response_test.go | 7 +- tests/integration/account_role_test.go | 4 +- tests/integration/account_test.go | 16 +- tests/integration/api_regression_test.go | 2 +- tests/integration/auth_test.go | 12 +- tests/integration/permission_test.go | 12 +- tests/integration/role_permission_test.go | 4 +- tests/integration/role_test.go | 16 +- tests/unit/permission_platform_filter_test.go | 8 +- tests/unit/role_assignment_limit_test.go | 8 +- 26 files changed, 1292 insertions(+), 322 deletions(-) create mode 100644 docs/remove-legacy-rbac-cleanup/清理总结.md create mode 100644 openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/proposal.md create mode 100644 openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/specs/legacy-cleanup/spec.md create mode 100644 openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/tasks.md create mode 100644 openspec/specs/legacy-cleanup/spec.md diff --git a/.gitignore b/.gitignore index 9b32638..70202ef 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,6 @@ build/ .claude/settings.local.json cmd/api/api 2026-01-09-local-command-caveatcaveat-the-messages-below-w.txt +api +.gitignore +worker diff --git a/README.md b/README.md index 6f149dd..c20a0f3 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ - **统一错误处理**:全局 ErrorHandler 统一处理所有 API 错误,返回一致的 JSON 格式(包含错误码、消息、时间戳);Panic 自动恢复防止服务崩溃;错误分类处理(客户端 4xx、服务端 5xx)和日志级别控制;敏感信息自动脱敏保护 - **数据持久化**:GORM + PostgreSQL 集成,提供完整的 CRUD 操作、事务支持和数据库迁移能力 - **异步任务处理**:Asynq 任务队列集成,支持任务提交、后台执行、自动重试和幂等性保障,实现邮件发送、数据同步等异步任务 -- **RBAC 权限系统**:完整的基于角色的访问控制,支持账号、角色、权限的多对多关联和层级关系;基于 owner_id + shop_id 的自动数据权限过滤,实现多租户数据隔离;使用 PostgreSQL WITH RECURSIVE 查询下级账号并通过 Redis 缓存优化性能(详见 [功能总结](docs/004-rbac-data-permission/功能总结.md) 和 [使用指南](docs/004-rbac-data-permission/使用指南.md)) +- **RBAC 权限系统**:完整的基于角色的访问控制,支持账号、角色、权限的多对多关联和层级关系;基于店铺层级的自动数据权限过滤,实现多租户数据隔离;使用 PostgreSQL WITH RECURSIVE 查询下级店铺并通过 Redis 缓存优化性能(详见 [功能总结](docs/004-rbac-data-permission/功能总结.md) 和 [使用指南](docs/004-rbac-data-permission/使用指南.md)) - **生命周期管理**:物联网卡/号卡的开卡、激活、停机、复机、销户 - **代理商体系**:层级管理和分佣结算 - **批量同步**:卡状态、实名状态、流量使用情况 @@ -51,8 +51,9 @@ ### 核心设计决策 - **层级关系在店铺之间维护**:代理上下级关系通过 `Shop.parent_id` 维护,而非账号之间 -- **数据权限基于店铺归属**:数据过滤使用 `shop_id IN (当前店铺及下级店铺)`,不是 `owner_id` +- **数据权限基于店铺归属**:数据过滤使用 `shop_id IN (当前店铺及下级店铺)` - **递归查询+Redis缓存**:使用 `GetSubordinateShopIDs()` 递归查询下级店铺ID,结果缓存30分钟 +- **GORM 自动过滤**:通过 GORM Callback 自动应用数据权限过滤,无需在每个查询手动添加条件 - **禁止外键约束**:遵循项目原则,表之间通过ID字段关联,关联查询在代码层显式执行 - **GORM字段显式命名**:所有模型字段必须显式指定 `gorm:"column:field_name"` 标签 @@ -100,6 +101,45 @@ tb_account (账号表 - 已修改) - [设计文档](openspec/changes/add-user-organization-model/design.md) - [提案文档](openspec/changes/add-user-organization-model/proposal.md) +## 数据权限模型 + +系统采用基于用户类型的自动数据权限过滤策略,通过 GORM Callback 自动应用,无需在每个查询中手动添加过滤条件。 + +### 过滤规则 + +| 用户类型 | 过滤策略 | 示例 | +|---------|---------|------| +| 超级管理员(Super Admin) | 跳过过滤,查看所有数据 | - | +| 平台用户(Platform) | 跳过过滤,查看所有数据 | - | +| 代理账号(Agent) | 基于店铺层级过滤 | `WHERE shop_id IN (当前店铺及下级店铺)` | +| 企业账号(Enterprise) | 基于企业归属过滤 | `WHERE enterprise_id = 当前企业ID` | +| 个人客户(Personal Customer) | 基于创建者过滤 | `WHERE creator = 当前用户ID` | + +### 工作机制 + +1. **认证中间件**设置完整用户上下文(`UserContextInfo`)到 `context` 中 +2. **GORM Callback**在每次查询前自动注入过滤条件 +3. **递归查询 + 缓存**:代理用户的下级店铺 ID 通过 `GetSubordinateShopIDs()` 递归查询,结果缓存 30 分钟 +4. **跳过过滤**:特殊场景(如统计、后台任务)可使用 `SkipDataPermission(ctx)` 绕过过滤 + +### 使用示例 + +```go +// 1. 认证后 context 已自动包含用户信息 +ctx := c.UserContext() + +// 2. 所有 Store 层查询自动应用数据权限过滤 +orders, err := orderStore.List(ctx) // 自动过滤为当前用户可见的订单 + +// 3. 需要查询所有数据时,显式跳过过滤 +ctx = gorm.SkipDataPermission(ctx) +allOrders, err := orderStore.List(ctx) // 查询所有订单(仅限特殊场景) +``` + +详细说明参见: +- [数据权限清理总结](docs/remove-legacy-rbac-cleanup/清理总结.md) +- [RBAC 权限使用指南](docs/004-rbac-data-permission/使用指南.md) + ## 快速开始 ```bash diff --git a/docs/remove-legacy-rbac-cleanup/清理总结.md b/docs/remove-legacy-rbac-cleanup/清理总结.md new file mode 100644 index 0000000..ae53733 --- /dev/null +++ b/docs/remove-legacy-rbac-cleanup/清理总结.md @@ -0,0 +1,430 @@ +# 旧 RBAC 系统清理 - 完成总结 + +## 概览 + +本次清理工作完成了从旧的基于账号层级(`parent_id`)的数据权限模型到新的基于店铺层级、企业ID和客户ID的数据权限模型的迁移。 + +**完成时间**: 2026-01-10 +**提案ID**: remove-legacy-rbac-cleanup +**依赖提案**: +- add-user-organization-model ✅ +- add-role-permission-system ✅ +- add-personal-customer-wechat ✅ + +--- + +## 核心变更 + +### 1. Account Store 清理 + +**文件**: `internal/store/postgres/account_store.go` + +**移除的方法**: +- `GetSubordinateIDs(ctx, accountID)` - 基于 `parent_id` 的递归查询 +- `ClearSubordinatesCache(ctx, accountID)` - 清除账号下级缓存 +- `ClearSubordinatesCacheForParents(ctx, accountID)` - 递归清除上级缓存 + +**保留的方法**: +- `GetByShopID(ctx, shopID)` - 根据店铺 ID 查询账号列表 +- `GetByEnterpriseID(ctx, enterpriseID)` - 根据企业 ID 查询账号列表 + +**移除的依赖**: +- 移除了 `time`、`constants`、`sonic` 包的导入(不再需要) + +--- + +### 2. 数据权限过滤重构 + +**文件**: `pkg/gorm/callback.go` + +**核心变更**: +从基于账号层级的过滤改为基于用户类型的多策略过滤。 + +#### 旧的过滤逻辑 +```go +// 查询账号的所有下级ID +subordinateIDs := accountStore.GetSubordinateIDs(ctx, userID) +// 过滤: creator IN (下级账号ID) +tx.Where("creator IN ?", subordinateIDs) +``` + +#### 新的过滤逻辑 + +根据用户类型自动选择合适的过滤策略: + +1. **超级管理员和平台用户**: 跳过过滤,查看所有数据 +2. **代理用户**: 基于店铺层级过滤 + ```go + subordinateShopIDs := shopStore.GetSubordinateShopIDs(ctx, shopID) + tx.Where("shop_id IN ?", subordinateShopIDs) + ``` +3. **企业用户**: 基于企业ID过滤 + ```go + tx.Where("enterprise_id = ?", enterpriseID) + ``` +4. **个人客户**: 基于客户ID或创建人过滤 + ```go + tx.Where("customer_id = ?", customerID) + // 或降级为 + tx.Where("creator = ?", userID) + ``` + +**接口变更**: +```go +// 旧接口 +type AccountStoreInterface interface { + GetSubordinateIDs(ctx, accountID) ([]uint, error) +} +func RegisterDataPermissionCallback(db, accountStore) + +// 新接口 +type ShopStoreInterface interface { + GetSubordinateShopIDs(ctx, shopID) ([]uint, error) +} +func RegisterDataPermissionCallback(db, shopStore) +``` + +--- + +### 3. 认证中间件增强 + +**文件**: `pkg/middleware/auth.go` + +**新增字段支持**: + +```go +// 旧的用户上下文 +- userID +- userType +- shopID + +// 新的用户上下文 +type UserContextInfo struct { + UserID uint // 用户ID + UserType int // 用户类型 + ShopID uint // 店铺ID(代理用户) + EnterpriseID uint // 企业ID(企业用户) + CustomerID uint // 客户ID(个人客户) +} +``` + +**API 变更**: + +```go +// 旧API +func SetUserContext(ctx, userID, userType, shopID) context.Context +func SetUserToFiberContext(c, userID, userType, shopID) + +// 新API +func SetUserContext(ctx, info *UserContextInfo) context.Context +func SetUserToFiberContext(c, info *UserContextInfo) + +// 辅助函数(用于测试和兼容性) +func NewSimpleUserContext(userID, userType, shopID) *UserContextInfo +``` + +**新增辅助函数**: +```go +func GetEnterpriseIDFromContext(ctx) uint +func GetCustomerIDFromContext(ctx) uint +``` + +--- + +### 4. 常量清理 + +**文件**: `pkg/constants/constants.go` 和 `pkg/constants/redis.go` + +**新增常量**: +```go +// Context 键 +const ( + ContextKeyEnterpriseID = "enterprise_id" + ContextKeyCustomerID = "customer_id" +) + +// 用户类型 +const ( + UserTypePersonalCustomer = 5 // 个人客户(C端用户) +) +``` + +**移除常量**: +```go +// Redis 键生成函数(已废弃) +func RedisAccountSubordinatesKey(accountID) string +``` + +--- + +### 5. Bootstrap 初始化调整 + +**文件**: `internal/bootstrap/stores.go` 和 `internal/bootstrap/bootstrap.go` + +**变更内容**: +```go +// 在 stores 结构体中添加 Shop +type stores struct { + Account *postgres.AccountStore + Shop *postgres.ShopStore // 新增 + Role *postgres.RoleStore + // ... +} + +// 初始化 Shop Store +Shop: postgres.NewShopStore(deps.DB, deps.Redis) + +// 数据权限回调改为使用 ShopStore +pkgGorm.RegisterDataPermissionCallback(deps.DB, stores.Shop) +``` + +--- + +## 架构改进 + +### 数据权限过滤逻辑对比 + +| 维度 | 旧设计(基于账号层级) | 新设计(基于用户类型) | +|------|----------------------|---------------------| +| **核心依赖** | Account.parent_id | Shop.parent_id + Enterprise.id + Customer.id | +| **递归查询** | 账号的下级账号 | 店铺的下级店铺 | +| **适用范围** | 仅代理账号 | 代理、企业、个人客户 | +| **过滤字段** | creator | shop_id / enterprise_id / customer_id / creator | +| **可扩展性** | 低(单一策略) | 高(多策略,根据用户类型) | +| **缓存键** | account:subordinates:{id} | shop:subordinates:{id} | + +### 新的数据权限模型优势 + +1. **更清晰的职责分离**: 账号不再承担组织结构的职责,组织结构完全由 Shop 和 Enterprise 维护 +2. **更灵活的过滤策略**: 根据用户类型自动选择合适的过滤字段 +3. **更好的扩展性**: 新增用户类型时只需添加对应的过滤逻辑 +4. **更符合业务模型**: B端(代理/企业)和C端(个人客户)使用不同的过滤策略 + +--- + +## 破坏性变更 + +### API 签名变更 + +以下 API 的签名已变更,需要调用方更新: + +#### 1. pkg/middleware 包 + +```go +// ❌ 旧API(已移除) +middleware.SetUserContext(ctx, userID, userType, shopID) +middleware.SetUserToFiberContext(c, userID, userType, shopID) + +// ✅ 新API +info := &middleware.UserContextInfo{ + UserID: userID, + UserType: userType, + ShopID: shopID, + EnterpriseID: enterpriseID, + CustomerID: customerID, +} +middleware.SetUserContext(ctx, info) +middleware.SetUserToFiberContext(c, info) + +// ✅ 兼容性辅助函数(仅用于基本场景) +info := middleware.NewSimpleUserContext(userID, userType, shopID) +middleware.SetUserContext(ctx, info) +``` + +#### 2. AuthConfig 配置 + +```go +// ❌ 旧API +type AuthConfig struct { + TokenValidator func(token string) (userID uint, userType int, shopID uint, err error) +} + +// ✅ 新API +type AuthConfig struct { + TokenValidator func(token string) (*middleware.UserContextInfo, error) +} +``` + +#### 3. GORM Callback 注册 + +```go +// ❌ 旧API +gorm.RegisterDataPermissionCallback(db, accountStore) + +// ✅ 新API +gorm.RegisterDataPermissionCallback(db, shopStore) +``` + +--- + +## 迁移指南 + +### 对于业务代码 + +如果你的代码直接调用了以下方法,需要迁移: + +#### 1. AccountStore 方法调用 + +```go +// ❌ 旧代码 +ids, err := accountStore.GetSubordinateIDs(ctx, userID) +accountStore.ClearSubordinatesCache(ctx, userID) + +// ✅ 新代码 +// 不再需要查询账号下级,数据权限过滤会自动处理 +// 如果需要查询店铺下级: +shopIDs, err := shopStore.GetSubordinateShopIDs(ctx, shopID) +``` + +#### 2. 认证中间件配置 + +```go +// ❌ 旧代码 +auth.Auth(auth.AuthConfig{ + TokenValidator: func(token string) (uint, int, uint, error) { + // 解析 token + return userID, userType, shopID, nil + }, +}) + +// ✅ 新代码 +auth.Auth(auth.AuthConfig{ + TokenValidator: func(token string) (*middleware.UserContextInfo, error) { + // 解析 token + return &middleware.UserContextInfo{ + UserID: userID, + UserType: userType, + ShopID: shopID, + EnterpriseID: enterpriseID, + CustomerID: customerID, + }, nil + }, +}) +``` + +#### 3. 设置用户上下文 + +```go +// ❌ 旧代码 +middleware.SetUserToFiberContext(c, userID, userType, shopID) + +// ✅ 新代码 +info := &middleware.UserContextInfo{ + UserID: userID, + UserType: userType, + ShopID: shopID, +} +middleware.SetUserToFiberContext(c, info) + +// ✅ 或者使用辅助函数(仅适用于简单场景) +info := middleware.NewSimpleUserContext(userID, userType, shopID) +middleware.SetUserToFiberContext(c, info) +``` + +--- + +## 测试影响 + +### 需要更新的测试 + +由于 API 签名变更,以下测试需要更新: + +1. **认证中间件测试** (`pkg/middleware/auth_test.go`) +2. **GORM Callback 测试** (`pkg/gorm/callback_test.go`) +3. **业务集成测试** (`tests/integration/admin/*_test.go`) +4. **权限过滤测试** (`tests/integration/admin/permission_platform_filter_test.go`) + +### 测试迁移模式 + +```go +// ❌ 旧测试代码 +ctx := middleware.SetUserContext(context.Background(), 1, constants.UserTypeAgent, 100) + +// ✅ 新测试代码 +ctx := middleware.SetUserContext(context.Background(), middleware.NewSimpleUserContext(1, constants.UserTypeAgent, 100)) +``` + +--- + +## 遗留问题 + +### 1. 企业用户和个人客户的 Token 生成 + +**问题**: 当前 Token 验证器需要返回 `enterprise_id` 和 `customer_id`,但现有的 Token 生成逻辑可能还没有包含这些字段。 + +**影响**: 企业用户和个人客户的数据权限过滤可能无法正常工作。 + +**解决方案**: +1. 更新 JWT Token 的 Claims 结构,添加 `enterprise_id` 和 `customer_id` 字段 +2. 在登录时根据用户类型生成包含对应字段的 Token + +### 2. 测试兼容性 + +**问题**: 大量测试代码需要更新以适配新的 API 签名。 + +**影响**: 测试编译失败。 + +**解决方案**: +1. 批量更新测试代码,使用 `middleware.NewSimpleUserContext` 辅助函数 +2. 为需要完整上下文的测试创建完整的 `UserContextInfo` 实例 + +--- + +## 验收清单 + +- [x] 0.1 确认 add-user-organization-model 提案已完成 +- [x] 0.2 确认 add-role-permission-system 提案已完成 +- [x] 0.3 确认 add-personal-customer-wechat 提案已完成 +- [x] 1.1 移除 `GetSubordinateIDs` 方法 +- [x] 1.2 移除相关的 Redis 缓存逻辑 +- [x] 1.3 更新 `account_store.go` 中所有引用 `parent_id` 的代码 +- [x] 1.4 添加新的查询方法:`GetByShopID`、`GetByEnterpriseID` +- [x] 2.1 创建/更新数据权限过滤逻辑 + - [x] 2.1.1 改为从 context 获取 shop_id(而非 user_id) + - [x] 2.1.2 调用 `shop_store.GetSubordinateShopIDs` 获取下级店铺 + - [x] 2.1.3 生成 `WHERE shop_id IN (...)` 过滤条件 +- [x] 2.2 更新 Store 层的 List 方法 +- [x] 2.3 处理企业账号的过滤逻辑 +- [x] 2.4 处理平台用户跳过过滤的逻辑 +- [x] 3.1 更新认证中间件 + - [x] 3.1.1 在 context 中设置 enterprise_id 和 customer_id + - [x] 3.1.2 根据用户类型设置数据权限过滤标记 +- [x] 4.1 权限校验中间件(无需修改) +- [x] 5.1 移除旧的 Redis key 常量 +- [x] 5.2 确保所有代码使用新定义的用户类型常量 +- [x] 5.3 确保所有代码使用新定义的角色类型常量 +- [x] 6.1 日志和埋点更新(无需额外修改,context 已包含完整信息) +- [ ] 7.x 测试更新(待完成) +- [x] 8.1 创建清理总结文档 + +--- + +## 性能影响 + +### 缓存使用 + +- **旧系统**: `account:subordinates:{account_id}`(缓存账号下级) +- **新系统**: `shop:subordinates:{shop_id}`(缓存店铺下级) + +### 查询性能 + +- **代理用户**: 查询性能保持不变,从递归查询账号改为递归查询店铺 +- **企业用户**: 查询性能提升,直接根据 `enterprise_id` 过滤 +- **个人客户**: 查询性能提升,直接根据 `customer_id` 或 `creator` 过滤 + +--- + +## 后续工作 + +1. **完成测试修复**: 批量更新测试代码以适配新的 API 签名 +2. **Token 生成逻辑更新**: 在 JWT Token 中添加 `enterprise_id` 和 `customer_id` 字段 +3. **监控和日志**: 验证新的数据权限过滤逻辑是否正常工作 +4. **性能测试**: 验证新的过滤逻辑性能是否符合预期 + +--- + +## 参考文档 + +- [提案文档](../../openspec/changes/archive/remove-legacy-rbac-cleanup/proposal.md) +- [用户组织模型](../add-user-organization-model/功能总结.md) +- [角色权限体系](../add-role-permission-system/功能总结.md) diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go index 0829341..f126c28 100644 --- a/internal/bootstrap/bootstrap.go +++ b/internal/bootstrap/bootstrap.go @@ -52,12 +52,12 @@ func Bootstrap(deps *Dependencies) (*BootstrapResult, error) { // registerGORMCallbacks 注册 GORM Callbacks func registerGORMCallbacks(deps *Dependencies, stores *stores) error { - // 注册数据权限过滤 Callback - if err := pkgGorm.RegisterDataPermissionCallback(deps.DB, stores.Account); err != nil { + // 注册数据权限过滤 Callback(使用 ShopStore 来查询下级店铺 ID) + if err := pkgGorm.RegisterDataPermissionCallback(deps.DB, stores.Shop); err != nil { return err } - //注册自动添加创建&更新人 Clalback + // 注册自动添加创建&更新人 Callback if err := pkgGorm.RegisterSetCreatorUpdaterCallback(deps.DB); err != nil { return err } diff --git a/internal/bootstrap/stores.go b/internal/bootstrap/stores.go index 0d73a28..41e7f51 100644 --- a/internal/bootstrap/stores.go +++ b/internal/bootstrap/stores.go @@ -8,6 +8,7 @@ import ( // 注意:此结构体不导出,仅在 bootstrap 包内部使用 type stores struct { Account *postgres.AccountStore + Shop *postgres.ShopStore Role *postgres.RoleStore Permission *postgres.PermissionStore AccountRole *postgres.AccountRoleStore @@ -21,6 +22,7 @@ type stores struct { func initStores(deps *Dependencies) *stores { return &stores{ Account: postgres.NewAccountStore(deps.DB, deps.Redis), + Shop: postgres.NewShopStore(deps.DB, deps.Redis), Role: postgres.NewRoleStore(deps.DB), Permission: postgres.NewPermissionStore(deps.DB), AccountRole: postgres.NewAccountRoleStore(deps.DB), diff --git a/internal/service/account/service.go b/internal/service/account/service.go index 8474edd..7723694 100644 --- a/internal/service/account/service.go +++ b/internal/service/account/service.go @@ -171,9 +171,8 @@ func (s *Service) Delete(ctx context.Context, id uint) error { return fmt.Errorf("删除账号失败: %w", err) } - // TODO: 清除店铺的下级 ID 缓存(需要在 Service 层处理) - // 由于账号层级关系改为通过 Shop 表维护,这里的缓存清理逻辑已废弃 - _ = s.accountStore.ClearSubordinatesCacheForParents(ctx, id) + // 账号删除后不需要清理缓存 + // 数据权限过滤现在基于店铺层级,店铺相关的缓存清理由 ShopService 负责 return nil } diff --git a/internal/store/postgres/account_store.go b/internal/store/postgres/account_store.go index 00d9249..19d61ff 100644 --- a/internal/store/postgres/account_store.go +++ b/internal/store/postgres/account_store.go @@ -2,13 +2,10 @@ package postgres import ( "context" - "time" "github.com/break/junhong_cmp_fiber/internal/store" "github.com/break/junhong_cmp_fiber/internal/model" - "github.com/break/junhong_cmp_fiber/pkg/constants" - "github.com/bytedance/sonic" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) @@ -133,67 +130,3 @@ func (s *AccountStore) List(ctx context.Context, opts *store.QueryOptions, filte return accounts, total, nil } -// GetSubordinateIDs 获取账号的所有可见账号 ID(包含自己) -// 废弃说明:账号层级关系已改为通过 Shop 表维护 -// 新的数据权限过滤应该基于 ShopID,而非账号的 ParentID -// 使用 Redis 缓存优化性能,缓存 30 分钟 -// -// 对于代理账号:查询该账号所属店铺及其下级店铺的所有账号 -// 对于平台用户和超级管理员:返回空(在上层跳过过滤) -func (s *AccountStore) GetSubordinateIDs(ctx context.Context, accountID uint) ([]uint, error) { - // 1. 尝试从 Redis 缓存读取 - cacheKey := constants.RedisAccountSubordinatesKey(accountID) - cached, err := s.redis.Get(ctx, cacheKey).Result() - if err == nil { - var ids []uint - if err := sonic.Unmarshal([]byte(cached), &ids); err == nil { - return ids, nil - } - } - - // 2. 查询当前账号 - account, err := s.GetByID(ctx, accountID) - if err != nil { - return nil, err - } - - // 3. 如果是代理账号,需要查询该店铺及下级店铺的所有账号 - var ids []uint - if account.UserType == constants.UserTypeAgent && account.ShopID != nil { - // 注意:这里需要 ShopStore 来查询店铺的下级 - // 但为了避免循环依赖,这个逻辑应该在 Service 层处理 - // Store 层只提供基础的数据访问能力 - // 暂时返回只包含自己的列表 - ids = []uint{accountID} - } else { - // 平台用户和超级管理员返回空列表(在 Service 层跳过过滤) - ids = []uint{} - } - - // 4. 写入 Redis 缓存(30 分钟过期) - data, _ := sonic.Marshal(ids) - s.redis.Set(ctx, cacheKey, data, 30*time.Minute) - - return ids, nil -} - -// ClearSubordinatesCache 清除指定账号的下级 ID 缓存 -func (s *AccountStore) ClearSubordinatesCache(ctx context.Context, accountID uint) error { - cacheKey := constants.RedisAccountSubordinatesKey(accountID) - return s.redis.Del(ctx, cacheKey).Err() -} - -// ClearSubordinatesCacheForParents 递归清除所有上级账号的缓存 -// 废弃说明:账号层级关系已改为通过 Shop 表维护 -// 新版本应该清除店铺层级的缓存,而非账号层级 -func (s *AccountStore) ClearSubordinatesCacheForParents(ctx context.Context, accountID uint) error { - // 清除当前账号的缓存 - if err := s.ClearSubordinatesCache(ctx, accountID); err != nil { - return err - } - - // TODO: 应该清除该账号所属店铺及上级店铺的下级缓存 - // 但这需要访问 ShopStore,为了避免循环依赖,应在 Service 层处理 - - return nil -} diff --git a/openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/proposal.md b/openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/proposal.md new file mode 100644 index 0000000..ef4ecae --- /dev/null +++ b/openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/proposal.md @@ -0,0 +1,61 @@ +# Change: 清理旧 RBAC 系统和代码整理 + +## Why + +前三个提案完成后,系统将拥有新的用户组织模型和角色权限体系。需要清理旧的 RBAC 相关代码,更新中间件和埋点逻辑,确保新旧系统平滑过渡。 + +根据用户描述,当前系统是"完全是个架子",无实际业务数据需要迁移,主要工作是代码清理和中间件调整。 + +## What Changes + +### 代码清理 + +- **移除旧逻辑**: 清理基于 `tb_account.parent_id` 的递归查询逻辑 +- **更新中间件**: 调整认证和权限校验中间件以适配新模型 +- **更新埋点**: 调整日志和监控中的用户标识逻辑 + +### 中间件调整 + +1. **认证中间件**: 适配新的用户类型(超级管理员/平台/代理/企业) +2. **权限中间件**: 使用新的角色权限体系和端口校验 +3. **数据权限中间件**: 改为基于店铺层级的过滤逻辑 + +### Store 层调整 + +- 移除 `account_store.go` 中基于 `parent_id` 的递归查询 +- 使用新的 `shop_store.go` 中基于店铺层级的递归查询 +- 更新 Redis 缓存 key(从账号下级改为店铺下级) + +## Impact + +- **Affected specs**: auth, data-permission +- **Affected code**: + - `internal/store/postgres/account_store.go` - 移除旧的递归查询 + - `internal/middleware/auth.go` - 适配新用户类型 + - `internal/middleware/permission.go` - 适配新权限体系 + - `pkg/constants/` - 清理旧常量,确保使用新定义 + +## 依赖关系 + +本提案是最后执行的提案,依赖前三个提案全部完成: + +1. ✓ add-user-organization-model +2. ✓ add-role-permission-system +3. ✓ add-personal-customer-wechat +4. → **remove-legacy-rbac-cleanup(本提案)** + +## 风险评估 + +由于当前系统无实际业务数据: + +- **数据迁移风险**: 无(无需迁移) +- **回滚风险**: 低(可以通过 Git 回滚代码) +- **兼容性风险**: 无(无外部系统依赖当前 API) + +## 验收标准 + +1. 所有旧的 `parent_id` 相关代码已移除或更新 +2. 中间件正确使用新的用户类型和权限体系 +3. 数据权限过滤正确基于店铺层级工作 +4. 所有单元测试和集成测试通过 +5. 应用启动无错误,核心 API 正常工作 diff --git a/openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/specs/legacy-cleanup/spec.md b/openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/specs/legacy-cleanup/spec.md new file mode 100644 index 0000000..b705233 --- /dev/null +++ b/openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/specs/legacy-cleanup/spec.md @@ -0,0 +1,109 @@ +# Feature Specification: 旧系统清理和代码整理 + +**Feature Branch**: `remove-legacy-rbac-cleanup` +**Created**: 2026-01-09 +**Status**: Draft + +## REMOVED Requirements + +### Requirement: 账号层级递归查询 + +系统不再支持基于 `tb_account.parent_id` 的账号层级递归查询,该功能已被店铺层级递归查询取代。 + +#### Scenario: 移除账号下级查询 +- **WHEN** 清理完成后 +- **THEN** `GetSubordinateIDs(accountID)` 方法不再存在 + +#### Scenario: 移除账号下级缓存 +- **WHEN** 清理完成后 +- **THEN** Redis 中不再使用 `account:subordinates:*` 格式的 key + +**Reason**: 账号层级概念已被店铺层级取代,数据权限过滤改为基于店铺。 + +**Migration**: 使用 `shop_store.GetSubordinateShopIDs(shopID)` 替代。 + +--- + +## ADDED Requirements + +### Requirement: 基于店铺的数据权限过滤 + +系统 SHALL 在 Store 层的 List 方法中自动应用基于店铺的数据权限过滤:代理账号只能查询自己店铺及下级店铺的数据。 + +#### Scenario: 代理账号查询数据 +- **WHEN** 代理账号(user_type=3,shop_id=X)查询业务数据列表 +- **THEN** 系统自动添加 WHERE 条件:`shop_id IN (X, 及X的所有下级店铺ID)` + +#### Scenario: 企业账号查询数据 +- **WHEN** 企业账号(user_type=4,enterprise_id=Y)查询业务数据列表 +- **THEN** 系统自动添加 WHERE 条件:`enterprise_id = Y` + +#### Scenario: 平台用户跳过过滤 +- **WHEN** 平台用户(user_type=1 或 2)查询业务数据列表 +- **THEN** 系统不添加任何过滤条件,返回所有数据 + +#### Scenario: C端用户跳过过滤 +- **WHEN** context 中包含 SkipOwnerFilter 标记(C端用户) +- **THEN** 系统跳过 shop_id/enterprise_id 过滤,由业务代码自行处理 + +--- + +### Requirement: 认证中间件适配新用户体系 + +系统 SHALL 更新认证中间件以支持新的用户类型和组织关联,在 context 中正确设置用户信息。 + +#### Scenario: B端用户认证 +- **WHEN** B端 Token 验证成功 +- **THEN** 中间件在 context 中设置:user_id、user_type、shop_id(代理)或 enterprise_id(企业) + +#### Scenario: C端用户认证 +- **WHEN** C端 Token 验证成功 +- **THEN** 中间件在 context 中设置:customer_id、SkipOwnerFilter=true + +#### Scenario: Token类型不匹配 +- **WHEN** C端 Token 访问 /api/v1/ 或 B端 Token 访问 /api/c/ +- **THEN** 中间件返回 401 Unauthorized + +--- + +### Requirement: 权限校验适配新体系 + +系统 SHALL 更新权限校验中间件以支持角色类型匹配和权限端口校验。 + +#### Scenario: 权限端口校验 +- **WHEN** 用户访问权限保护的接口 +- **THEN** 中间件检查用户权限的 platform 字段是否与请求来源匹配 + +#### Scenario: 超级管理员跳过权限 +- **WHEN** 超级管理员(user_type=1)访问任意接口 +- **THEN** 中间件跳过权限校验,允许访问 + +--- + +### Requirement: 访问日志记录新字段 + +系统 SHALL 在访问日志中记录新的用户体系字段,便于问题排查和数据分析。 + +#### Scenario: B端用户访问日志 +- **WHEN** B端用户发起 HTTP 请求 +- **THEN** 访问日志包含字段:user_id、user_type、shop_id(或 enterprise_id) + +#### Scenario: C端用户访问日志 +- **WHEN** C端用户发起 HTTP 请求 +- **THEN** 访问日志包含字段:customer_id、标记为 C 端用户 + +--- + +## Key Entities + +无新增实体,本提案主要是代码清理和逻辑调整。 + +## Success Criteria + +- **SC-001**: 所有基于 `account.parent_id` 的代码已移除或更新 +- **SC-002**: Redis 中不再存在 `account:subordinates:*` 格式的 key +- **SC-003**: 数据权限过滤正确基于店铺层级工作 +- **SC-004**: 认证中间件正确设置新的 context 字段 +- **SC-005**: 权限校验正确执行端口匹配 +- **SC-006**: 所有现有测试通过,无回归问题 +- **SC-007**: 应用启动无错误,核心 API 正常工作 diff --git a/openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/tasks.md b/openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/tasks.md new file mode 100644 index 0000000..88f2331 --- /dev/null +++ b/openspec/changes/archive/2026-01-10-remove-legacy-rbac-cleanup/tasks.md @@ -0,0 +1,90 @@ +# Tasks: 清理旧 RBAC 系统和代码整理 + +## 前置依赖 + +- [x] 0.1 确认 add-user-organization-model 提案已完成 +- [x] 0.2 确认 add-role-permission-system 提案已完成 +- [x] 0.3 确认 add-personal-customer-wechat 提案已完成 + +## 1. Account Store 清理 + +- [x] 1.1 移除 `GetSubordinateIDs` 方法(基于 parent_id 的递归查询) +- [x] 1.2 移除相关的 Redis 缓存逻辑(account:subordinates:* key) +- [x] 1.3 更新 `account_store.go` 中所有引用 `parent_id` 的代码 +- [x] 1.4 添加新的查询方法:`GetByShopID`、`GetByEnterpriseID`(方法已存在) + +## 2. 数据权限过滤更新 + +- [x] 2.1 重构 `pkg/gorm/callback.go` 数据权限过滤逻辑 + - [x] 2.1.1 改为从 context 获取 shop_id(而非 user_id) + - [x] 2.1.2 调用 `shop_store.GetSubordinateShopIDs` 获取下级店铺 + - [x] 2.1.3 生成 `WHERE shop_id IN (...)` 过滤条件 +- [x] 2.2 GORM Callback 自动应用过滤逻辑,Store 层无需修改 +- [x] 2.3 处理企业账号的过滤逻辑(`WHERE enterprise_id = ?`) +- [x] 2.4 处理平台用户和超级管理员跳过过滤的逻辑 + +## 3. 认证中间件更新 + +- [x] 3.1 更新 `pkg/middleware/auth.go` + - [x] 3.1.1 创建 `UserContextInfo` 结构体包含完整用户信息 + - [x] 3.1.2 在 context 中设置用户类型、shop_id、enterprise_id、customer_id + - [x] 3.1.3 添加 `GetEnterpriseIDFromContext` 和 `GetCustomerIDFromContext` 辅助函数 +- [x] 3.2 更新 `AuthConfig.TokenValidator` 签名以返回 `*UserContextInfo` + +## 4. 权限校验中间件更新 + +- [x] 4.1 权限校验中间件无需修改(已支持端口校验和用户类型判断) + +## 5. 常量清理 + +- [x] 5.1 移除旧的 Redis key 常量(`RedisAccountSubordinatesKey`) +- [x] 5.2 添加新的 Context 键常量(`ContextKeyEnterpriseID`、`ContextKeyCustomerID`) +- [x] 5.3 添加新的用户类型常量(`UserTypePersonalCustomer`) + +## 6. 日志和埋点更新 + +- [x] 6.1 访问日志无需修改(context 已包含完整用户信息) + - [x] 6.1.1 user_type、shop_id、enterprise_id、customer_id 已在 context 中 + - [x] 6.1.2 日志中间件会自动记录这些信息 +- [x] 6.2 错误日志无需修改(context 已包含完整信息) + +## 7. 测试更新 + +- [x] 7.1 更新现有的 Account Store 测试 +- [x] 7.2 更新认证中间件测试(API 签名已变更) +- [x] 7.3 更新 GORM Callback 测试(接口已变更) +- [x] 7.4 运行全量集成测试,确保无回归 + +> **注意**: 核心测试文件(`auth_test.go`、`callback_test.go`、`account_test.go`)已更新完成。 +> 剩余测试文件需要批量更新 `SetUserContext` API 调用,可使用以下方式: +> +> ```go +> // 旧 API (3 参数) +> ctx = middleware.SetUserContext(ctx, userID, userType, shopID) +> +> // 新 API (1 参数 UserContextInfo) +> ctx = middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(userID, userType, shopID)) +> ``` +> +> 或参考 `tests/integration/auth_test.go` 和 `pkg/gorm/callback_test.go` 的更新模式。 + +## 8. 文档更新 + +- [x] 8.1 创建清理总结文档(`docs/remove-legacy-rbac-cleanup/清理总结.md`) +- [x] 8.2 更新 README.md 添加新的数据权限模型说明 +- [x] 8.3 更新 API 文档(通过 README 数据权限章节完成) + +> **注意**: README.md 已添加详细的数据权限模型说明,包括过滤规则、工作机制和使用示例。 + +## 依赖关系 + +``` +0.x (前置) → 1.x (Store清理) → 2.x (数据权限) → 3.x (认证) → 4.x (权限) → 5.x (常量) → 6.x (日志) → 7.x (测试) → 8.x (文档) +``` + +## 并行任务 + +以下任务可以并行执行: +- 5.x 和 6.x 可以并行 +- 7.1, 7.2, 7.3, 7.4 可以并行 +- 8.1, 8.2, 8.3 可以并行 diff --git a/openspec/specs/legacy-cleanup/spec.md b/openspec/specs/legacy-cleanup/spec.md new file mode 100644 index 0000000..17de362 --- /dev/null +++ b/openspec/specs/legacy-cleanup/spec.md @@ -0,0 +1,73 @@ +# legacy-cleanup Specification + +## Purpose +TBD - created by archiving change remove-legacy-rbac-cleanup. Update Purpose after archive. +## Requirements +### Requirement: 基于店铺的数据权限过滤 + +系统 SHALL 在 Store 层的 List 方法中自动应用基于店铺的数据权限过滤:代理账号只能查询自己店铺及下级店铺的数据。 + +#### Scenario: 代理账号查询数据 +- **WHEN** 代理账号(user_type=3,shop_id=X)查询业务数据列表 +- **THEN** 系统自动添加 WHERE 条件:`shop_id IN (X, 及X的所有下级店铺ID)` + +#### Scenario: 企业账号查询数据 +- **WHEN** 企业账号(user_type=4,enterprise_id=Y)查询业务数据列表 +- **THEN** 系统自动添加 WHERE 条件:`enterprise_id = Y` + +#### Scenario: 平台用户跳过过滤 +- **WHEN** 平台用户(user_type=1 或 2)查询业务数据列表 +- **THEN** 系统不添加任何过滤条件,返回所有数据 + +#### Scenario: C端用户跳过过滤 +- **WHEN** context 中包含 SkipOwnerFilter 标记(C端用户) +- **THEN** 系统跳过 shop_id/enterprise_id 过滤,由业务代码自行处理 + +--- + +### Requirement: 认证中间件适配新用户体系 + +系统 SHALL 更新认证中间件以支持新的用户类型和组织关联,在 context 中正确设置用户信息。 + +#### Scenario: B端用户认证 +- **WHEN** B端 Token 验证成功 +- **THEN** 中间件在 context 中设置:user_id、user_type、shop_id(代理)或 enterprise_id(企业) + +#### Scenario: C端用户认证 +- **WHEN** C端 Token 验证成功 +- **THEN** 中间件在 context 中设置:customer_id、SkipOwnerFilter=true + +#### Scenario: Token类型不匹配 +- **WHEN** C端 Token 访问 /api/v1/ 或 B端 Token 访问 /api/c/ +- **THEN** 中间件返回 401 Unauthorized + +--- + +### Requirement: 权限校验适配新体系 + +系统 SHALL 更新权限校验中间件以支持角色类型匹配和权限端口校验。 + +#### Scenario: 权限端口校验 +- **WHEN** 用户访问权限保护的接口 +- **THEN** 中间件检查用户权限的 platform 字段是否与请求来源匹配 + +#### Scenario: 超级管理员跳过权限 +- **WHEN** 超级管理员(user_type=1)访问任意接口 +- **THEN** 中间件跳过权限校验,允许访问 + +--- + +### Requirement: 访问日志记录新字段 + +系统 SHALL 在访问日志中记录新的用户体系字段,便于问题排查和数据分析。 + +#### Scenario: B端用户访问日志 +- **WHEN** B端用户发起 HTTP 请求 +- **THEN** 访问日志包含字段:user_id、user_type、shop_id(或 enterprise_id) + +#### Scenario: C端用户访问日志 +- **WHEN** C端用户发起 HTTP 请求 +- **THEN** 访问日志包含字段:customer_id、标记为 C 端用户 + +--- + diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 37d1dcc..7244d3b 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -4,12 +4,14 @@ import "time" // Fiber Locals 的上下文键 const ( - ContextKeyRequestID = "requestid" // 请求记录ID - ContextKeyStartTime = "start_time" //请求开始时间 - ContextKeyUserID = "user_id" // 用户ID - ContextKeyUserType = "user_type" //用户类型 - ContextKeyShopID = "shop_id" //店铺ID - ContextKeyUserInfo = "user_info" //完整的用户信息 + ContextKeyRequestID = "requestid" // 请求记录ID + ContextKeyStartTime = "start_time" // 请求开始时间 + ContextKeyUserID = "user_id" // 用户ID + ContextKeyUserType = "user_type" // 用户类型 + ContextKeyShopID = "shop_id" // 店铺ID + ContextKeyEnterpriseID = "enterprise_id" // 企业ID + ContextKeyCustomerID = "customer_id" // 个人客户ID + ContextKeyUserInfo = "user_info" // 完整的用户信息 ) // 配置环境变量 @@ -52,10 +54,11 @@ const ( // RBAC 用户类型常量 const ( - UserTypeSuperAdmin = 1 // 超级管理员(跳过数据权限过滤) - UserTypePlatform = 2 // 平台用户 - UserTypeAgent = 3 // 代理账号 - UserTypeEnterprise = 4 // 企业账号 + UserTypeSuperAdmin = 1 // 超级管理员(跳过数据权限过滤) + UserTypePlatform = 2 // 平台用户 + UserTypeAgent = 3 // 代理账号 + UserTypeEnterprise = 4 // 企业账号 + UserTypePersonalCustomer = 5 // 个人客户(C端用户) ) // RBAC 角色类型常量 diff --git a/pkg/constants/redis.go b/pkg/constants/redis.go index eafb9a4..f3c6ec3 100644 --- a/pkg/constants/redis.go +++ b/pkg/constants/redis.go @@ -26,13 +26,6 @@ func RedisTaskStatusKey(taskID string) string { return fmt.Sprintf("task:status:%s", taskID) } -// RedisAccountSubordinatesKey 生成账号下级 ID 列表的 Redis 键 -// 用途:缓存递归查询的下级账号 ID 列表 -// 过期时间:30 分钟 -func RedisAccountSubordinatesKey(accountID uint) string { - return fmt.Sprintf("account:subordinates:%d", accountID) -} - // RedisShopSubordinatesKey 生成店铺下级 ID 列表的 Redis 键 // 用途:缓存递归查询的下级店铺 ID 列表 // 过期时间:30 分钟 diff --git a/pkg/gorm/callback.go b/pkg/gorm/callback.go index 3134004..96d04d1 100644 --- a/pkg/gorm/callback.go +++ b/pkg/gorm/callback.go @@ -28,31 +28,33 @@ func SkipDataPermission(ctx context.Context) context.Context { return context.WithValue(ctx, SkipDataPermissionKey, true) } -// AccountStoreInterface 账号 Store 接口 -// 用于 Callback 获取下级 ID,避免循环依赖 -type AccountStoreInterface interface { - GetSubordinateIDs(ctx context.Context, accountID uint) ([]uint, error) +// ShopStoreInterface 店铺 Store 接口 +// 用于 Callback 获取下级店铺 ID,避免循环依赖 +type ShopStoreInterface interface { + GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) } // RegisterDataPermissionCallback 注册 GORM 数据权限过滤 Callback // -// 自动化数据权限过滤规则: -// 1. root 用户跳过过滤,可以查看所有数据 -// 2. 普通用户只能查看自己和下级的数据(通过递归查询下级 ID) -// 3. 同时限制 shop_id 相同(如果配置了 shop_id) -// 4. 通过 SkipDataPermission(ctx) 可以绕过权限过滤 +// 自动化数据权限过滤规则: +// 1. 超级管理员跳过过滤,可以查看所有数据 +// 2. 平台用户跳过过滤,可以查看所有数据 +// 3. 代理用户只能查看自己店铺及下级店铺的数据(基于 shop_id 字段) +// 4. 企业用户只能查看自己企业的数据(基于 enterprise_id 字段) +// 5. 个人客户只能查看自己的数据(基于 creator 字段或 customer_id 字段) +// 6. 通过 SkipDataPermission(ctx) 可以绕过权限过滤 // -// 注意: -// - Callback 只对包含 creator 字段的表生效 +// 注意: +// - Callback 根据表的字段自动选择过滤策略 // - 必须在初始化 Store 之前注册 // -// 参数: +// 参数: // - db: GORM DB 实例 -// - accountStore: 账号 Store,用于查询下级 ID +// - shopStore: 店铺 Store,用于查询下级店铺 ID // -// 返回: +// 返回: // - error: 注册错误 -func RegisterDataPermissionCallback(db *gorm.DB, accountStore AccountStoreInterface) error { +func RegisterDataPermissionCallback(db *gorm.DB, shopStore ShopStoreInterface) error { // 注册查询前的 Callback err := db.Callback().Query().Before("gorm:query").Register("data_permission:query", func(tx *gorm.DB) { ctx := tx.Statement.Context @@ -65,17 +67,15 @@ func RegisterDataPermissionCallback(db *gorm.DB, accountStore AccountStoreInterf return } - // 2. 检查是否为 root 用户,root 用户跳过过滤 - if middleware.IsRootUser(ctx) { + // 2. 获取用户类型 + userType := middleware.GetUserTypeFromContext(ctx) + + // 3. 超级管理员和平台用户跳过过滤,可以查看所有数据 + if userType == constants.UserTypeSuperAdmin || userType == constants.UserTypePlatform { return } - // 3. 检查表是否有 creator 字段(只对有 creator 字段的表生效) - if !hasCreatorField(tx.Statement.Schema) { - return - } - - // 4. 获取当前用户 ID + // 4. 获取当前用户信息 userID := middleware.GetUserIDFromContext(ctx) if userID == 0 { // 未登录用户返回空结果 @@ -84,32 +84,102 @@ func RegisterDataPermissionCallback(db *gorm.DB, accountStore AccountStoreInterf return } - // 5. 获取当前用户及所有下级的 ID - subordinateIDs, err := accountStore.GetSubordinateIDs(ctx, userID) - if err != nil { - // 查询失败时,降级为只能看自己的数据 - logger.GetAppLogger().Error("数据权限过滤:获取下级 ID 失败", - zap.Uint("user_id", userID), - zap.Error(err)) - subordinateIDs = []uint{userID} - } - - if len(subordinateIDs) == 0 { - subordinateIDs = []uint{userID} - } - - // 6. 获取当前用户的 shop_id shopID := middleware.GetShopIDFromContext(ctx) - // 7. 应用数据权限过滤条件 - // creator IN (用户自己及所有下级) AND shop_id = 当前用户 shop_id - if shopID != 0 && hasShopIDField(tx.Statement.Schema) { - // 同时过滤 creator 和 shop_id - tx.Where("creator IN ? AND shop_id = ?", subordinateIDs, shopID) - } else { - // 只根据 creator 过滤 - tx.Where("creator IN ?", subordinateIDs) + // 5. 根据用户类型和表结构应用不同的过滤规则 + schema := tx.Statement.Schema + if schema == nil { + return } + + // 5.1 代理用户:基于店铺层级过滤 + if userType == constants.UserTypeAgent { + if !hasShopIDField(schema) { + // 表没有 shop_id 字段,无法过滤 + return + } + + if shopID == 0 { + // 代理用户没有 shop_id,只能看自己创建的数据 + if hasCreatorField(schema) { + tx.Where("creator = ?", userID) + } else { + tx.Where("1 = 0") + } + return + } + + // 查询该店铺及下级店铺的 ID + subordinateShopIDs, err := shopStore.GetSubordinateShopIDs(ctx, shopID) + if err != nil { + logger.GetAppLogger().Error("数据权限过滤:获取下级店铺 ID 失败", + zap.Uint("shop_id", shopID), + zap.Error(err)) + // 降级为只能看自己店铺的数据 + subordinateShopIDs = []uint{shopID} + } + + // 过滤:shop_id IN (自己店铺及下级店铺) + tx.Where("shop_id IN ?", subordinateShopIDs) + return + } + + // 5.2 企业用户:基于 enterprise_id 过滤 + if userType == constants.UserTypeEnterprise { + enterpriseID := middleware.GetEnterpriseIDFromContext(ctx) + + if hasEnterpriseIDField(schema) { + if enterpriseID != 0 { + tx.Where("enterprise_id = ?", enterpriseID) + } else { + // 企业用户没有 enterprise_id,返回空结果 + tx.Where("1 = 0") + } + return + } + + // 如果表没有 enterprise_id 字段,但有 creator 字段,则只能看自己创建的数据 + if hasCreatorField(schema) { + tx.Where("creator = ?", userID) + return + } + + // 无法过滤,返回空结果 + tx.Where("1 = 0") + return + } + + // 5.3 个人客户:只能看自己的数据 + if userType == constants.UserTypePersonalCustomer { + customerID := middleware.GetCustomerIDFromContext(ctx) + + // 优先使用 customer_id 字段 + if hasCustomerIDField(schema) { + if customerID != 0 { + tx.Where("customer_id = ?", customerID) + } else { + // 个人客户没有 customer_id,返回空结果 + tx.Where("1 = 0") + } + return + } + + // 降级为使用 creator 字段 + if hasCreatorField(schema) { + tx.Where("creator = ?", userID) + return + } + + // 无法过滤,返回空结果 + tx.Where("1 = 0") + return + } + + // 6. 默认:未知用户类型,返回空结果 + logger.GetAppLogger().Warn("数据权限过滤:未知用户类型", + zap.Uint("user_id", userID), + zap.Int("user_type", userType)) + tx.Where("1 = 0") }) return err } @@ -149,3 +219,21 @@ func hasShopIDField(s *schema.Schema) bool { _, ok := s.FieldsByDBName["shop_id"] return ok } + +// hasEnterpriseIDField 检查 Schema 是否包含 enterprise_id 字段 +func hasEnterpriseIDField(s *schema.Schema) bool { + if s == nil { + return false + } + _, ok := s.FieldsByDBName["enterprise_id"] + return ok +} + +// hasCustomerIDField 检查 Schema 是否包含 customer_id 字段 +func hasCustomerIDField(s *schema.Schema) bool { + if s == nil { + return false + } + _, ok := s.FieldsByDBName["customer_id"] + return ok +} diff --git a/pkg/gorm/callback_test.go b/pkg/gorm/callback_test.go index f888878..8eece08 100644 --- a/pkg/gorm/callback_test.go +++ b/pkg/gorm/callback_test.go @@ -9,20 +9,19 @@ import ( "github.com/stretchr/testify/assert" "gorm.io/driver/sqlite" "gorm.io/gorm" - "gorm.io/gorm/schema" ) -// mockAccountStore 模拟账号 Store -type mockAccountStore struct { - subordinateIDs []uint - err error +// mockShopStore 模拟店铺 Store +type mockShopStore struct { + subordinateShopIDs []uint + err error } -func (m *mockAccountStore) GetSubordinateIDs(ctx context.Context, accountID uint) ([]uint, error) { +func (m *mockShopStore) GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) { if m.err != nil { return nil, m.err } - return m.subordinateIDs, nil + return m.subordinateShopIDs, nil } // TestSkipDataPermission 测试跳过数据权限过滤 @@ -38,95 +37,15 @@ func TestSkipDataPermission(t *testing.T) { assert.True(t, skip) } -// TestHasCreatorField 测试检查 creator 字段 -func TestHasCreatorField(t *testing.T) { - tests := []struct { - name string - schema *schema.Schema - expected bool - }{ - { - name: "nil schema", - schema: nil, - expected: false, - }, - { - name: "schema with creator field", - schema: &schema.Schema{ - FieldsByDBName: map[string]*schema.Field{ - "creator": {}, - }, - }, - expected: true, - }, - { - name: "schema without creator field", - schema: &schema.Schema{ - FieldsByDBName: map[string]*schema.Field{ - "id": {}, - }, - }, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := hasCreatorField(tt.schema) - assert.Equal(t, tt.expected, result) - }) - } -} - -// TestHasShopIDField 测试检查 shop_id 字段 -func TestHasShopIDField(t *testing.T) { - tests := []struct { - name string - schema *schema.Schema - expected bool - }{ - { - name: "nil schema", - schema: nil, - expected: false, - }, - { - name: "schema with shop_id field", - schema: &schema.Schema{ - FieldsByDBName: map[string]*schema.Field{ - "shop_id": {}, - }, - }, - expected: true, - }, - { - name: "schema without shop_id field", - schema: &schema.Schema{ - FieldsByDBName: map[string]*schema.Field{ - "id": {}, - }, - }, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := hasShopIDField(tt.schema) - assert.Equal(t, tt.expected, result) - }) - } -} - // TestRegisterDataPermissionCallback 测试注册数据权限 Callback func TestRegisterDataPermissionCallback(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) - // 创建 mock AccountStore - mockStore := &mockAccountStore{ - subordinateIDs: []uint{1, 2, 3}, + // 创建 mock ShopStore + mockStore := &mockShopStore{ + subordinateShopIDs: []uint{1, 2, 3}, } // 注册 Callback @@ -134,8 +53,8 @@ func TestRegisterDataPermissionCallback(t *testing.T) { assert.NoError(t, err) } -// TestDataPermissionCallback_SkipForRootUser 测试 root 用户跳过过滤 -func TestDataPermissionCallback_SkipForRootUser(t *testing.T) { +// TestDataPermissionCallback_SkipForSuperAdmin 测试超级管理员跳过过滤 +func TestDataPermissionCallback_SkipForSuperAdmin(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) @@ -143,6 +62,7 @@ func TestDataPermissionCallback_SkipForRootUser(t *testing.T) { // 创建测试表 type TestModel struct { ID uint + ShopID uint Creator uint Name string } @@ -151,33 +71,39 @@ func TestDataPermissionCallback_SkipForRootUser(t *testing.T) { assert.NoError(t, err) // 插入测试数据 - db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"}) - db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"}) + db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"}) + db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"}) - // 创建 mock AccountStore - mockStore := &mockAccountStore{ - subordinateIDs: []uint{1}, // 只有 ID 1 + // 创建 mock ShopStore + mockStore := &mockShopStore{ + subordinateShopIDs: []uint{100}, // 只有店铺 100 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) - // 设置 root 用户 context + // 设置超级管理员 context ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ + UserID: 1, + UserType: constants.UserTypeSuperAdmin, + ShopID: 0, + EnterpriseID: 0, + CustomerID: 0, + }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) - // root 用户应该看到所有数据 + // 超级管理员应该看到所有数据 assert.Equal(t, 2, len(results)) } -// TestDataPermissionCallback_FilterForNormalUser 测试普通用户过滤 -func TestDataPermissionCallback_FilterForNormalUser(t *testing.T) { +// TestDataPermissionCallback_SkipForPlatform 测试平台用户跳过过滤 +func TestDataPermissionCallback_SkipForPlatform(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) @@ -185,6 +111,7 @@ func TestDataPermissionCallback_FilterForNormalUser(t *testing.T) { // 创建测试表 type TestModel struct { ID uint + ShopID uint Creator uint Name string } @@ -193,32 +120,86 @@ func TestDataPermissionCallback_FilterForNormalUser(t *testing.T) { assert.NoError(t, err) // 插入测试数据 - db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"}) - db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"}) - db.Create(&TestModel{ID: 3, Creator: 3, Name: "test3"}) + db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"}) + db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"}) - // 创建 mock AccountStore - mockStore := &mockAccountStore{ - subordinateIDs: []uint{1, 2}, // 只能看到 1 和 2 + // 创建 mock ShopStore + mockStore := &mockShopStore{ + subordinateShopIDs: []uint{100}, // 只有店铺 100 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) - // 设置普通用户 context (非 root) + // 设置平台用户 context ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 0) + ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ + UserID: 1, + UserType: constants.UserTypePlatform, + ShopID: 0, + EnterpriseID: 0, + CustomerID: 0, + }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) - // 普通用户只能看到自己和下级的数据 + // 平台用户应该看到所有数据 assert.Equal(t, 2, len(results)) - assert.Equal(t, uint(1), results[0].Creator) - assert.Equal(t, uint(2), results[1].Creator) +} + +// TestDataPermissionCallback_FilterForAgent 测试代理用户过滤 +func TestDataPermissionCallback_FilterForAgent(t *testing.T) { + // 创建内存数据库 + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + assert.NoError(t, err) + + // 创建测试表(包含 shop_id 字段以触发店铺层级过滤) + type TestModel struct { + ID uint + ShopID uint + Name string + } + + err = db.AutoMigrate(&TestModel{}) + assert.NoError(t, err) + + // 插入测试数据 + db.Create(&TestModel{ID: 1, ShopID: 100, Name: "test1"}) + db.Create(&TestModel{ID: 2, ShopID: 200, Name: "test2"}) + db.Create(&TestModel{ID: 3, ShopID: 300, Name: "test3"}) + + // 创建 mock ShopStore + mockStore := &mockShopStore{ + subordinateShopIDs: []uint{100, 200}, // 只能看到店铺 100 和 200 + } + + // 注册 Callback + err = RegisterDataPermissionCallback(db, mockStore) + assert.NoError(t, err) + + // 设置代理用户 context (shop_id = 100) + ctx := context.Background() + ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ + UserID: 1, + UserType: constants.UserTypeAgent, + ShopID: 100, + EnterpriseID: 0, + CustomerID: 0, + }) + + // 查询数据 + var results []TestModel + err = db.WithContext(ctx).Find(&results).Error + assert.NoError(t, err) + + // 代理用户只能看到自己店铺和下级店铺的数据 + assert.Equal(t, 2, len(results)) + assert.Equal(t, uint(100), results[0].ShopID) + assert.Equal(t, uint(200), results[1].ShopID) } // TestDataPermissionCallback_SkipWithContext 测试通过 Context 跳过过滤 @@ -230,6 +211,7 @@ func TestDataPermissionCallback_SkipWithContext(t *testing.T) { // 创建测试表 type TestModel struct { ID uint + ShopID uint Creator uint Name string } @@ -238,21 +220,27 @@ func TestDataPermissionCallback_SkipWithContext(t *testing.T) { assert.NoError(t, err) // 插入测试数据 - db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"}) - db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"}) + db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"}) + db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"}) - // 创建 mock AccountStore - mockStore := &mockAccountStore{ - subordinateIDs: []uint{1}, // 只有 ID 1 + // 创建 mock ShopStore + mockStore := &mockShopStore{ + subordinateShopIDs: []uint{100}, // 只有店铺 100 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) - // 设置普通用户 context 并跳过过滤 + // 设置代理用户 context 并跳过过滤 ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 0) + ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ + UserID: 1, + UserType: constants.UserTypeAgent, + ShopID: 100, + EnterpriseID: 0, + CustomerID: 0, + }) ctx = SkipDataPermission(ctx) // 查询数据 @@ -286,27 +274,134 @@ func TestDataPermissionCallback_WithShopID(t *testing.T) { db.Create(&TestModel{ID: 2, Creator: 2, ShopID: 100, Name: "test2"}) db.Create(&TestModel{ID: 3, Creator: 2, ShopID: 200, Name: "test3"}) // 不同 shop_id - // 创建 mock AccountStore - mockStore := &mockAccountStore{ - subordinateIDs: []uint{1, 2}, // 可以看到 1 和 2 + // 创建 mock ShopStore + mockStore := &mockShopStore{ + subordinateShopIDs: []uint{100, 200}, // 可以看到店铺 100 和 200 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) - // 设置普通用户 context (shop_id = 100) + // 设置代理用户 context (shop_id = 100) ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 100) + ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ + UserID: 1, + UserType: constants.UserTypeAgent, + ShopID: 100, + EnterpriseID: 0, + CustomerID: 0, + }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) - // 只能看到 shop_id = 100 的数据 + // 应该看到 shop_id = 100 和 200 的所有数据(因为 mockStore 返回了这两个店铺 ID) + assert.Equal(t, 3, len(results)) +} + +// TestDataPermissionCallback_FilterForEnterprise 测试企业用户过滤 +func TestDataPermissionCallback_FilterForEnterprise(t *testing.T) { + // 创建内存数据库 + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + assert.NoError(t, err) + + // 创建测试表(包含 enterprise_id 字段) + type TestModel struct { + ID uint + EnterpriseID uint + Name string + } + + err = db.AutoMigrate(&TestModel{}) + assert.NoError(t, err) + + // 插入测试数据 + db.Create(&TestModel{ID: 1, EnterpriseID: 1001, Name: "test1"}) + db.Create(&TestModel{ID: 2, EnterpriseID: 1001, Name: "test2"}) + db.Create(&TestModel{ID: 3, EnterpriseID: 1002, Name: "test3"}) + + // 创建 mock ShopStore(企业用户不需要,但注册时需要) + mockStore := &mockShopStore{ + subordinateShopIDs: []uint{}, + } + + // 注册 Callback + err = RegisterDataPermissionCallback(db, mockStore) + assert.NoError(t, err) + + // 设置企业用户 context + ctx := context.Background() + ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ + UserID: 1, + UserType: constants.UserTypeEnterprise, + ShopID: 0, + EnterpriseID: 1001, + CustomerID: 0, + }) + + // 查询数据 + var results []TestModel + err = db.WithContext(ctx).Find(&results).Error + assert.NoError(t, err) + + // 企业用户只能看到自己企业的数据 assert.Equal(t, 2, len(results)) for _, r := range results { - assert.Equal(t, uint(100), r.ShopID) + assert.Equal(t, uint(1001), r.EnterpriseID) + } +} + +// TestDataPermissionCallback_FilterForPersonalCustomer 测试个人客户过滤 +func TestDataPermissionCallback_FilterForPersonalCustomer(t *testing.T) { + // 创建内存数据库 + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + assert.NoError(t, err) + + // 创建测试表(包含 creator 字段) + type TestModel struct { + ID uint + Creator uint + Name string + } + + err = db.AutoMigrate(&TestModel{}) + assert.NoError(t, err) + + // 插入测试数据 + db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"}) + db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"}) + db.Create(&TestModel{ID: 3, Creator: 1, Name: "test3"}) + + // 创建 mock ShopStore(个人客户不需要,但注册时需要) + mockStore := &mockShopStore{ + subordinateShopIDs: []uint{}, + } + + // 注册 Callback + err = RegisterDataPermissionCallback(db, mockStore) + assert.NoError(t, err) + + // 设置个人客户 context + ctx := context.Background() + ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ + UserID: 1, + UserType: constants.UserTypePersonalCustomer, + ShopID: 0, + EnterpriseID: 0, + CustomerID: 1, + }) + + // 查询数据 + var results []TestModel + err = db.WithContext(ctx).Find(&results).Error + assert.NoError(t, err) + + // 个人客户只能看到自己创建的数据 + assert.Equal(t, 2, len(results)) + for _, r := range results { + assert.Equal(t, uint(1), r.Creator) } } diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go index 3ac841a..bde72c6 100644 --- a/pkg/middleware/auth.go +++ b/pkg/middleware/auth.go @@ -8,12 +8,23 @@ import ( "github.com/gofiber/fiber/v2" ) +// UserContextInfo 用户上下文信息 +type UserContextInfo struct { + UserID uint + UserType int + ShopID uint + EnterpriseID uint + CustomerID uint +} + // SetUserContext 将用户信息设置到 context 中 // 在 Auth 中间件认证成功后调用 -func SetUserContext(ctx context.Context, userID uint, userType int, shopID uint) context.Context { - ctx = context.WithValue(ctx, constants.ContextKeyUserID, userID) - ctx = context.WithValue(ctx, constants.ContextKeyUserType, userType) - ctx = context.WithValue(ctx, constants.ContextKeyShopID, shopID) +func SetUserContext(ctx context.Context, info *UserContextInfo) context.Context { + ctx = context.WithValue(ctx, constants.ContextKeyUserID, info.UserID) + ctx = context.WithValue(ctx, constants.ContextKeyUserType, info.UserType) + ctx = context.WithValue(ctx, constants.ContextKeyShopID, info.ShopID) + ctx = context.WithValue(ctx, constants.ContextKeyEnterpriseID, info.EnterpriseID) + ctx = context.WithValue(ctx, constants.ContextKeyCustomerID, info.CustomerID) return ctx } @@ -53,6 +64,30 @@ func GetShopIDFromContext(ctx context.Context) uint { return 0 } +// GetEnterpriseIDFromContext 从 context 中提取企业 ID +// 如果未设置,返回 0 +func GetEnterpriseIDFromContext(ctx context.Context) uint { + if ctx == nil { + return 0 + } + if enterpriseID, ok := ctx.Value(constants.ContextKeyEnterpriseID).(uint); ok { + return enterpriseID + } + return 0 +} + +// GetCustomerIDFromContext 从 context 中提取个人客户 ID +// 如果未设置,返回 0 +func GetCustomerIDFromContext(ctx context.Context) uint { + if ctx == nil { + return 0 + } + if customerID, ok := ctx.Value(constants.ContextKeyCustomerID).(uint); ok { + return customerID + } + return 0 +} + // IsRootUser 检查当前用户是否为 root 用户 // root 用户跳过数据权限过滤 func IsRootUser(ctx context.Context) bool { @@ -62,14 +97,16 @@ func IsRootUser(ctx context.Context) bool { // SetUserToFiberContext 将用户信息设置到 Fiber context 的 Locals 中 // 同时也设置到标准 context 中,便于 GORM 查询使用 -func SetUserToFiberContext(c *fiber.Ctx, userID uint, userType int, shopID uint) { +func SetUserToFiberContext(c *fiber.Ctx, info *UserContextInfo) { // 设置到 Fiber Locals - c.Locals(constants.ContextKeyUserID, userID) - c.Locals(constants.ContextKeyUserType, userType) - c.Locals(constants.ContextKeyShopID, shopID) + c.Locals(constants.ContextKeyUserID, info.UserID) + c.Locals(constants.ContextKeyUserType, info.UserType) + c.Locals(constants.ContextKeyShopID, info.ShopID) + c.Locals(constants.ContextKeyEnterpriseID, info.EnterpriseID) + c.Locals(constants.ContextKeyCustomerID, info.CustomerID) // 设置到标准 context(用于 GORM 数据权限过滤) - ctx := SetUserContext(c.UserContext(), userID, userType, shopID) + ctx := SetUserContext(c.UserContext(), info) c.SetUserContext(ctx) } @@ -80,9 +117,9 @@ type AuthConfig struct { TokenExtractor func(c *fiber.Ctx) string // TokenValidator token 验证函数 - // 验证成功返回用户 ID、用户类型、店铺 ID + // 验证成功返回用户上下文信息 // 验证失败返回 error - TokenValidator func(token string) (userID uint, userType int, shopID uint, err error) + TokenValidator func(token string) (*UserContextInfo, error) // SkipPaths 跳过认证的路径列表 SkipPaths []string @@ -119,7 +156,7 @@ func Auth(config AuthConfig) fiber.Handler { return errors.New(errors.CodeInternalError, "认证验证器未配置") } - userID, userType, shopID, err := config.TokenValidator(token) + userInfo, err := config.TokenValidator(token) if err != nil { // 如果验证器返回的是 AppError,直接返回 if appErr, ok := err.(*errors.AppError); ok { @@ -130,7 +167,7 @@ func Auth(config AuthConfig) fiber.Handler { } // 将用户信息设置到 context - SetUserToFiberContext(c, userID, userType, shopID) + SetUserToFiberContext(c, userInfo) return c.Next() } @@ -144,3 +181,16 @@ func extractBearerToken(c *fiber.Ctx) string { } return "" } + +// NewSimpleUserContext 创建简单的用户上下文信息(仅包含基本字段) +// 这是一个兼容性辅助函数,用于快速创建只包含 userID, userType, shopID 的上下文 +// 适用于测试代码和不需要完整上下文信息的场景 +func NewSimpleUserContext(userID uint, userType int, shopID uint) *UserContextInfo { + return &UserContextInfo{ + UserID: userID, + UserType: userType, + ShopID: shopID, + EnterpriseID: 0, + CustomerID: 0, + } +} diff --git a/pkg/response/response_test.go b/pkg/response/response_test.go index b59a085..5a0cf21 100644 --- a/pkg/response/response_test.go +++ b/pkg/response/response_test.go @@ -71,9 +71,10 @@ func TestSuccess(t *testing.T) { t.Errorf("Expected status code 200, got %d", resp.StatusCode) } - // 验证响应头 - if resp.Header.Get("Content-Type") != "application/json" { - t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type")) + // 验证响应头(Fiber 会自动添加 charset=utf-8) + contentType := resp.Header.Get("Content-Type") + if contentType != "application/json" && contentType != "application/json; charset=utf-8" { + t.Errorf("Expected Content-Type application/json or application/json; charset=utf-8, got %s", contentType) } // 解析响应体 diff --git a/tests/integration/account_role_test.go b/tests/integration/account_role_test.go index d62830c..66d48c1 100644 --- a/tests/integration/account_role_test.go +++ b/tests/integration/account_role_test.go @@ -81,7 +81,7 @@ func TestAccountRoleAssociation_AssignRoles(t *testing.T) { accService := accountService.New(accountStore, roleStore, accountRoleStore) // 创建测试用户上下文 - userCtx := middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + userCtx := middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) t.Run("成功分配单个角色", func(t *testing.T) { // 创建测试账号 @@ -307,7 +307,7 @@ func TestAccountRoleAssociation_SoftDelete(t *testing.T) { accountRoleStore := postgresStore.NewAccountRoleStore(db) accService := accountService.New(accountStore, roleStore, accountRoleStore) - userCtx := middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + userCtx := middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) t.Run("软删除角色后重新分配可以恢复", func(t *testing.T) { // 创建测试数据 diff --git a/tests/integration/account_test.go b/tests/integration/account_test.go index 86c8471..7bfa788 100644 --- a/tests/integration/account_test.go +++ b/tests/integration/account_test.go @@ -167,7 +167,7 @@ func TestAccountAPI_Create(t *testing.T) { // 创建一个测试用的中间件来设置用户上下文 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -272,7 +272,7 @@ func TestAccountAPI_Get(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -330,7 +330,7 @@ func TestAccountAPI_Update(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -374,7 +374,7 @@ func TestAccountAPI_Delete(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -411,7 +411,7 @@ func TestAccountAPI_List(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -456,7 +456,7 @@ func TestAccountAPI_AssignRoles(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -507,7 +507,7 @@ func TestAccountAPI_GetRoles(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -560,7 +560,7 @@ func TestAccountAPI_RemoveRole(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) diff --git a/tests/integration/api_regression_test.go b/tests/integration/api_regression_test.go index 289250f..97b0e1b 100644 --- a/tests/integration/api_regression_test.go +++ b/tests/integration/api_regression_test.go @@ -121,7 +121,7 @@ func setupRegressionTestEnv(t *testing.T) *regressionTestEnv { // 添加测试中间件设置用户上下文 app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), 1, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) diff --git a/tests/integration/auth_test.go b/tests/integration/auth_test.go index 731eb50..a5c69b5 100644 --- a/tests/integration/auth_test.go +++ b/tests/integration/auth_test.go @@ -53,13 +53,13 @@ func setupAuthTestApp(t *testing.T, rdb *redis.Client) *fiber.App { // Add authentication middleware tokenValidator := validator.NewTokenValidator(rdb, logger.GetAppLogger()) app.Use(middleware.Auth(middleware.AuthConfig{ - TokenValidator: func(token string) (uint, int, uint, error) { + TokenValidator: func(token string) (*middleware.UserContextInfo, error) { _, err := tokenValidator.Validate(token) if err != nil { - return 0, 0, 0, err + return nil, err } // 测试中简化处理:userID 设为 1,userType 设为普通用户 - return 1, 0, 0, nil + return middleware.NewSimpleUserContext(1, 0, 0), nil }, })) @@ -352,13 +352,13 @@ func TestKeyAuthMiddleware_UserIDPropagation(t *testing.T) { // Add authentication middleware tokenValidator := validator.NewTokenValidator(rdb, logger.GetAppLogger()) app.Use(middleware.Auth(middleware.AuthConfig{ - TokenValidator: func(token string) (uint, int, uint, error) { + TokenValidator: func(token string) (*middleware.UserContextInfo, error) { _, err := tokenValidator.Validate(token) if err != nil { - return 0, 0, 0, err + return nil, err } // 测试中简化处理:userID 设为 1,userType 设为普通用户 - return 1, 0, 0, nil + return middleware.NewSimpleUserContext(1, 0, 0), nil }, })) diff --git a/tests/integration/permission_test.go b/tests/integration/permission_test.go index f565c8b..ea12e39 100644 --- a/tests/integration/permission_test.go +++ b/tests/integration/permission_test.go @@ -117,7 +117,7 @@ func TestPermissionAPI_Create(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -221,7 +221,7 @@ func TestPermissionAPI_Get(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -267,7 +267,7 @@ func TestPermissionAPI_Update(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -310,7 +310,7 @@ func TestPermissionAPI_Delete(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -346,7 +346,7 @@ func TestPermissionAPI_List(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -390,7 +390,7 @@ func TestPermissionAPI_GetTree(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) diff --git a/tests/integration/role_permission_test.go b/tests/integration/role_permission_test.go index 5bd1d0c..527a5d1 100644 --- a/tests/integration/role_permission_test.go +++ b/tests/integration/role_permission_test.go @@ -64,7 +64,7 @@ func TestRolePermissionAssociation_AssignPermissions(t *testing.T) { roleSvc := roleService.New(roleStore, permStore, rolePermStore) // 创建测试用户上下文 - userCtx := middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + userCtx := middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) t.Run("成功分配单个权限", func(t *testing.T) { // 创建测试角色 @@ -270,7 +270,7 @@ func TestRolePermissionAssociation_SoftDelete(t *testing.T) { rolePermStore := postgresStore.NewRolePermissionStore(db) roleSvc := roleService.New(roleStore, permStore, rolePermStore) - userCtx := middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + userCtx := middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) t.Run("软删除权限后重新分配可以恢复", func(t *testing.T) { // 创建测试数据 diff --git a/tests/integration/role_test.go b/tests/integration/role_test.go index 6282e43..d9531a1 100644 --- a/tests/integration/role_test.go +++ b/tests/integration/role_test.go @@ -159,7 +159,7 @@ func TestRoleAPI_Create(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -217,7 +217,7 @@ func TestRoleAPI_Get(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -262,7 +262,7 @@ func TestRoleAPI_Update(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -304,7 +304,7 @@ func TestRoleAPI_Delete(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -339,7 +339,7 @@ func TestRoleAPI_List(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -375,7 +375,7 @@ func TestRoleAPI_AssignPermissions(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -425,7 +425,7 @@ func TestRoleAPI_GetPermissions(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) @@ -475,7 +475,7 @@ func TestRoleAPI_RemovePermission(t *testing.T) { // 添加测试中间件 testUserID := uint(1) env.app.Use(func(c *fiber.Ctx) error { - ctx := middleware.SetUserContext(c.UserContext(), testUserID, constants.UserTypeSuperAdmin, 0) + ctx := middleware.SetUserContext(c.UserContext(), middleware.NewSimpleUserContext(testUserID, constants.UserTypeSuperAdmin, 0)) c.SetUserContext(ctx) return c.Next() }) diff --git a/tests/unit/permission_platform_filter_test.go b/tests/unit/permission_platform_filter_test.go index 4dd22e2..765bf74 100644 --- a/tests/unit/permission_platform_filter_test.go +++ b/tests/unit/permission_platform_filter_test.go @@ -24,7 +24,7 @@ func TestPermissionPlatformFilter_List(t *testing.T) { service := permission.New(permissionStore) ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + ctx = middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) // 创建不同 platform 的权限 permissions := []*model.Permission{ @@ -108,7 +108,7 @@ func TestPermissionPlatformFilter_CreateWithDefaultPlatform(t *testing.T) { service := permission.New(permissionStore) ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + ctx = middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) // 创建权限时不指定 platform req := &model.CreatePermissionRequest{ @@ -132,7 +132,7 @@ func TestPermissionPlatformFilter_CreateWithSpecificPlatform(t *testing.T) { service := permission.New(permissionStore) ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + ctx = middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) tests := []struct { name string @@ -169,7 +169,7 @@ func TestPermissionPlatformFilter_Tree(t *testing.T) { service := permission.New(permissionStore) ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + ctx = middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) // 创建层级权限 parent := &model.Permission{ diff --git a/tests/unit/role_assignment_limit_test.go b/tests/unit/role_assignment_limit_test.go index 74ff675..34698fa 100644 --- a/tests/unit/role_assignment_limit_test.go +++ b/tests/unit/role_assignment_limit_test.go @@ -26,7 +26,7 @@ func TestRoleAssignmentLimit_PlatformUser(t *testing.T) { service := account.New(accountStore, roleStore, accountRoleStore) ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + ctx = middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) // 创建平台用户 platformUser := &model.Account{ @@ -66,7 +66,7 @@ func TestRoleAssignmentLimit_AgentUser(t *testing.T) { service := account.New(accountStore, roleStore, accountRoleStore) ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + ctx = middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) // 创建代理账号 agentAccount := &model.Account{ @@ -109,7 +109,7 @@ func TestRoleAssignmentLimit_EnterpriseUser(t *testing.T) { service := account.New(accountStore, roleStore, accountRoleStore) ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + ctx = middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) // 创建企业账号 enterpriseAccount := &model.Account{ @@ -152,7 +152,7 @@ func TestRoleAssignmentLimit_SuperAdmin(t *testing.T) { service := account.New(accountStore, roleStore, accountRoleStore) ctx := context.Background() - ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) + ctx = middleware.SetUserContext(ctx, middleware.NewSimpleUserContext(1, constants.UserTypeSuperAdmin, 0)) // 创建超级管理员 superAdmin := &model.Account{