package gorm import ( "context" "testing" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/stretchr/testify/assert" "gorm.io/driver/sqlite" "gorm.io/gorm" ) // mockShopStore 模拟店铺 Store type mockShopStore struct { subordinateShopIDs []uint err error } func (m *mockShopStore) GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) { if m.err != nil { return nil, m.err } return m.subordinateShopIDs, nil } // TestSkipDataPermission 测试跳过数据权限过滤 func TestSkipDataPermission(t *testing.T) { ctx := context.Background() // 设置跳过标记 ctx = SkipDataPermission(ctx) // 验证标记已设置 skip, ok := ctx.Value(SkipDataPermissionKey).(bool) assert.True(t, ok) assert.True(t, skip) } // TestRegisterDataPermissionCallback 测试注册数据权限 Callback func TestRegisterDataPermissionCallback(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{1, 2, 3}, } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) } // TestDataPermissionCallback_SkipForSuperAdmin 测试超级管理员跳过过滤 func TestDataPermissionCallback_SkipForSuperAdmin(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint ShopID uint Creator uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"}) db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"}) // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{100}, // 只有店铺 100 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置超级管理员 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeSuperAdmin, ShopID: 0, EnterpriseID: 0, CustomerID: 0, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 超级管理员应该看到所有数据 assert.Equal(t, 2, len(results)) } // TestDataPermissionCallback_SkipForPlatform 测试平台用户跳过过滤 func TestDataPermissionCallback_SkipForPlatform(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint ShopID uint Creator uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"}) db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"}) // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{100}, // 只有店铺 100 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置平台用户 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePlatform, ShopID: 0, EnterpriseID: 0, CustomerID: 0, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 平台用户应该看到所有数据 assert.Equal(t, 2, len(results)) } // TestDataPermissionCallback_FilterForAgent 测试代理用户过滤 func TestDataPermissionCallback_FilterForAgent(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表(包含 shop_id 字段以触发店铺层级过滤) type TestModel struct { ID uint ShopID uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, ShopID: 100, Name: "test1"}) db.Create(&TestModel{ID: 2, ShopID: 200, Name: "test2"}) db.Create(&TestModel{ID: 3, ShopID: 300, Name: "test3"}) // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{100, 200}, // 只能看到店铺 100 和 200 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置代理用户 context (shop_id = 100) ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeAgent, ShopID: 100, EnterpriseID: 0, CustomerID: 0, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 代理用户只能看到自己店铺和下级店铺的数据 assert.Equal(t, 2, len(results)) assert.Equal(t, uint(100), results[0].ShopID) assert.Equal(t, uint(200), results[1].ShopID) } // TestDataPermissionCallback_SkipWithContext 测试通过 Context 跳过过滤 func TestDataPermissionCallback_SkipWithContext(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint ShopID uint Creator uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"}) db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"}) // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{100}, // 只有店铺 100 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置代理用户 context 并跳过过滤 ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeAgent, ShopID: 100, EnterpriseID: 0, CustomerID: 0, }) ctx = SkipDataPermission(ctx) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 跳过过滤后应该看到所有数据 assert.Equal(t, 2, len(results)) } // TestDataPermissionCallback_WithShopID 测试带 shop_id 的过滤 func TestDataPermissionCallback_WithShopID(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint Creator uint ShopID uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, Creator: 1, ShopID: 100, Name: "test1"}) db.Create(&TestModel{ID: 2, Creator: 2, ShopID: 100, Name: "test2"}) db.Create(&TestModel{ID: 3, Creator: 2, ShopID: 200, Name: "test3"}) // 不同 shop_id // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{100, 200}, // 可以看到店铺 100 和 200 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置代理用户 context (shop_id = 100) ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeAgent, ShopID: 100, EnterpriseID: 0, CustomerID: 0, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 应该看到 shop_id = 100 和 200 的所有数据(因为 mockStore 返回了这两个店铺 ID) assert.Equal(t, 3, len(results)) } // TestDataPermissionCallback_FilterForEnterprise 测试企业用户过滤 func TestDataPermissionCallback_FilterForEnterprise(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表(包含 enterprise_id 字段) type TestModel struct { ID uint EnterpriseID uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, EnterpriseID: 1001, Name: "test1"}) db.Create(&TestModel{ID: 2, EnterpriseID: 1001, Name: "test2"}) db.Create(&TestModel{ID: 3, EnterpriseID: 1002, Name: "test3"}) // 创建 mock ShopStore(企业用户不需要,但注册时需要) mockStore := &mockShopStore{ subordinateShopIDs: []uint{}, } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置企业用户 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeEnterprise, ShopID: 0, EnterpriseID: 1001, CustomerID: 0, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 企业用户只能看到自己企业的数据 assert.Equal(t, 2, len(results)) for _, r := range results { assert.Equal(t, uint(1001), r.EnterpriseID) } } // TestDataPermissionCallback_FilterForPersonalCustomer 测试个人客户过滤 func TestDataPermissionCallback_FilterForPersonalCustomer(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表(包含 creator 字段) type TestModel struct { ID uint Creator uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"}) db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"}) db.Create(&TestModel{ID: 3, Creator: 1, Name: "test3"}) // 创建 mock ShopStore(个人客户不需要,但注册时需要) mockStore := &mockShopStore{ subordinateShopIDs: []uint{}, } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置个人客户 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePersonalCustomer, ShopID: 0, EnterpriseID: 0, CustomerID: 1, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 个人客户只能看到自己创建的数据 assert.Equal(t, 2, len(results)) for _, r := range results { assert.Equal(t, uint(1), r.Creator) } } // ============================================================ // 标签表数据权限过滤测试(tb_tag / tb_resource_tag 表) // ============================================================ // TagModel 模拟标签表(tb_tag)结构 // 注意:必须指定 TableName 为 "tb_tag" 才能触发特殊过滤逻辑 type TagModel struct { ID uint `gorm:"primaryKey"` EnterpriseID *uint `gorm:"column:enterprise_id"` ShopID *uint `gorm:"column:shop_id"` Name string } func (TagModel) TableName() string { return "tb_tag" } // ResourceTagModel 模拟资源标签表(tb_resource_tag)结构 type ResourceTagModel struct { ID uint `gorm:"primaryKey"` EnterpriseID *uint `gorm:"column:enterprise_id"` ShopID *uint `gorm:"column:shop_id"` ResourceType string ResourceID uint TagID uint } func (ResourceTagModel) TableName() string { return "tb_resource_tag" } // uintPtr 辅助函数,将 uint 转换为 *uint func uintPtr(v uint) *uint { return &v } // setupTagTestDB 创建标签测试数据库和数据 // 返回:db 实例和 mock ShopStore func setupTagTestDB(t *testing.T) (*gorm.DB, *mockShopStore) { db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 err = db.AutoMigrate(&TagModel{}, &ResourceTagModel{}) assert.NoError(t, err) // 插入测试数据 // 1. 全局标签(enterprise_id = NULL, shop_id = NULL) db.Create(&TagModel{ID: 1, EnterpriseID: nil, ShopID: nil, Name: "全局标签-VIP"}) db.Create(&TagModel{ID: 2, EnterpriseID: nil, ShopID: nil, Name: "全局标签-重要客户"}) // 2. 企业标签(enterprise_id = 1001, shop_id = NULL) db.Create(&TagModel{ID: 3, EnterpriseID: uintPtr(1001), ShopID: nil, Name: "企业A-测试标签"}) db.Create(&TagModel{ID: 4, EnterpriseID: uintPtr(1001), ShopID: nil, Name: "企业A-内部标签"}) // 3. 另一个企业的标签(enterprise_id = 1002, shop_id = NULL) db.Create(&TagModel{ID: 5, EnterpriseID: uintPtr(1002), ShopID: nil, Name: "企业B-测试标签"}) // 4. 店铺标签(enterprise_id = NULL, shop_id = 100) db.Create(&TagModel{ID: 6, EnterpriseID: nil, ShopID: uintPtr(100), Name: "店铺100-华东区"}) db.Create(&TagModel{ID: 7, EnterpriseID: nil, ShopID: uintPtr(100), Name: "店铺100-大客户"}) // 5. 下级店铺标签(enterprise_id = NULL, shop_id = 200) db.Create(&TagModel{ID: 8, EnterpriseID: nil, ShopID: uintPtr(200), Name: "店铺200-华南区"}) // 6. 其他店铺标签(enterprise_id = NULL, shop_id = 300) db.Create(&TagModel{ID: 9, EnterpriseID: nil, ShopID: uintPtr(300), Name: "店铺300-华北区"}) // 创建 mock ShopStore // 假设店铺 100 的下级店铺包括 100 和 200 mockStore := &mockShopStore{ subordinateShopIDs: []uint{100, 200}, } return db, mockStore } // TestTagPermission_SuperAdmin 测试超级管理员查询标签(应看到所有标签) func TestTagPermission_SuperAdmin(t *testing.T) { db, mockStore := setupTagTestDB(t) // 注册 Callback err := RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置超级管理员 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeSuperAdmin, ShopID: 0, EnterpriseID: 0, CustomerID: 0, }) // 查询标签 var tags []TagModel err = db.WithContext(ctx).Find(&tags).Error assert.NoError(t, err) // 超级管理员应该看到所有 9 个标签 assert.Equal(t, 9, len(tags), "超级管理员应该看到所有标签") } // TestTagPermission_Platform 测试平台用户查询标签(应看到所有标签) func TestTagPermission_Platform(t *testing.T) { db, mockStore := setupTagTestDB(t) // 注册 Callback err := RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置平台用户 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePlatform, ShopID: 0, EnterpriseID: 0, CustomerID: 0, }) // 查询标签 var tags []TagModel err = db.WithContext(ctx).Find(&tags).Error assert.NoError(t, err) // 平台用户应该看到所有 9 个标签 assert.Equal(t, 9, len(tags), "平台用户应该看到所有标签") } // TestTagPermission_Agent 测试代理用户查询标签 // 预期:看到自己店铺标签 + 下级店铺标签 + 全局标签 func TestTagPermission_Agent(t *testing.T) { db, mockStore := setupTagTestDB(t) // 注册 Callback err := RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置代理用户 context(店铺 ID = 100) ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeAgent, ShopID: 100, EnterpriseID: 0, CustomerID: 0, }) // 查询标签 var tags []TagModel err = db.WithContext(ctx).Find(&tags).Error assert.NoError(t, err) // 代理用户应该看到: // - 2 个全局标签(ID: 1, 2) // - 2 个店铺 100 的标签(ID: 6, 7) // - 1 个店铺 200(下级)的标签(ID: 8) // 总共 5 个标签 assert.Equal(t, 5, len(tags), "代理用户应该看到自己店铺、下级店铺和全局标签") // 验证标签 ID expectedIDs := map[uint]bool{1: true, 2: true, 6: true, 7: true, 8: true} for _, tag := range tags { assert.True(t, expectedIDs[tag.ID], "标签 ID %d 不应该被代理用户看到", tag.ID) } // 验证看不到的标签 // - 企业标签(ID: 3, 4, 5) // - 其他店铺标签(ID: 9) for _, tag := range tags { assert.NotEqual(t, uint(3), tag.ID, "代理用户不应该看到企业标签") assert.NotEqual(t, uint(4), tag.ID, "代理用户不应该看到企业标签") assert.NotEqual(t, uint(5), tag.ID, "代理用户不应该看到企业标签") assert.NotEqual(t, uint(9), tag.ID, "代理用户不应该看到其他店铺标签") } } // TestTagPermission_Agent_NoShopID 测试没有 ShopID 的代理用户 // 预期:只能看到全局标签 func TestTagPermission_Agent_NoShopID(t *testing.T) { db, mockStore := setupTagTestDB(t) // 注册 Callback err := RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置代理用户 context(没有店铺 ID) ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeAgent, ShopID: 0, // 没有店铺 EnterpriseID: 0, CustomerID: 0, }) // 查询标签 var tags []TagModel err = db.WithContext(ctx).Find(&tags).Error assert.NoError(t, err) // 没有店铺的代理用户只能看到全局标签 assert.Equal(t, 2, len(tags), "没有店铺的代理用户只能看到全局标签") // 验证都是全局标签 for _, tag := range tags { assert.Nil(t, tag.EnterpriseID, "应该是全局标签,enterprise_id 为 NULL") assert.Nil(t, tag.ShopID, "应该是全局标签,shop_id 为 NULL") } } // TestTagPermission_Enterprise 测试企业用户查询标签 // 预期:看到自己企业标签 + 全局标签 func TestTagPermission_Enterprise(t *testing.T) { db, mockStore := setupTagTestDB(t) // 注册 Callback err := RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置企业用户 context(企业 ID = 1001) ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeEnterprise, ShopID: 0, EnterpriseID: 1001, CustomerID: 0, }) // 查询标签 var tags []TagModel err = db.WithContext(ctx).Find(&tags).Error assert.NoError(t, err) // 企业用户应该看到: // - 2 个全局标签(ID: 1, 2) // - 2 个企业 1001 的标签(ID: 3, 4) // 总共 4 个标签 assert.Equal(t, 4, len(tags), "企业用户应该看到自己企业和全局标签") // 验证标签 ID expectedIDs := map[uint]bool{1: true, 2: true, 3: true, 4: true} for _, tag := range tags { assert.True(t, expectedIDs[tag.ID], "标签 ID %d 不应该被企业用户看到", tag.ID) } // 验证看不到其他企业的标签 for _, tag := range tags { assert.NotEqual(t, uint(5), tag.ID, "企业用户不应该看到其他企业的标签") } // 验证看不到店铺标签 for _, tag := range tags { assert.NotEqual(t, uint(6), tag.ID, "企业用户不应该看到店铺标签") assert.NotEqual(t, uint(7), tag.ID, "企业用户不应该看到店铺标签") assert.NotEqual(t, uint(8), tag.ID, "企业用户不应该看到店铺标签") assert.NotEqual(t, uint(9), tag.ID, "企业用户不应该看到店铺标签") } } // TestTagPermission_Enterprise_NoEnterpriseID 测试没有 EnterpriseID 的企业用户 // 预期:只能看到全局标签 func TestTagPermission_Enterprise_NoEnterpriseID(t *testing.T) { db, mockStore := setupTagTestDB(t) // 注册 Callback err := RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置企业用户 context(没有企业 ID) ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeEnterprise, ShopID: 0, EnterpriseID: 0, // 没有企业 CustomerID: 0, }) // 查询标签 var tags []TagModel err = db.WithContext(ctx).Find(&tags).Error assert.NoError(t, err) // 没有企业的企业用户只能看到全局标签 assert.Equal(t, 2, len(tags), "没有企业的企业用户只能看到全局标签") // 验证都是全局标签 for _, tag := range tags { assert.Nil(t, tag.EnterpriseID, "应该是全局标签,enterprise_id 为 NULL") assert.Nil(t, tag.ShopID, "应该是全局标签,shop_id 为 NULL") } } // TestTagPermission_PersonalCustomer 测试个人客户查询标签 // 预期:只能看到全局标签 func TestTagPermission_PersonalCustomer(t *testing.T) { db, mockStore := setupTagTestDB(t) // 注册 Callback err := RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置个人客户 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePersonalCustomer, ShopID: 0, EnterpriseID: 0, CustomerID: 1, }) // 查询标签 var tags []TagModel err = db.WithContext(ctx).Find(&tags).Error assert.NoError(t, err) // 个人客户只能看到 2 个全局标签 assert.Equal(t, 2, len(tags), "个人客户只能看到全局标签") // 验证都是全局标签 for _, tag := range tags { assert.Nil(t, tag.EnterpriseID, "个人客户只能看到全局标签,enterprise_id 应为 NULL") assert.Nil(t, tag.ShopID, "个人客户只能看到全局标签,shop_id 应为 NULL") } } // TestTagPermission_ResourceTag_Agent 测试代理用户查询资源标签表 // 预期:与 tb_tag 表相同的过滤规则 func TestTagPermission_ResourceTag_Agent(t *testing.T) { db, mockStore := setupTagTestDB(t) // 创建资源标签测试数据 // 1. 全局资源标签 db.Create(&ResourceTagModel{ID: 1, EnterpriseID: nil, ShopID: nil, ResourceType: "iot_card", ResourceID: 101, TagID: 1}) // 2. 店铺 100 的资源标签 db.Create(&ResourceTagModel{ID: 2, EnterpriseID: nil, ShopID: uintPtr(100), ResourceType: "iot_card", ResourceID: 102, TagID: 6}) // 3. 店铺 200(下级)的资源标签 db.Create(&ResourceTagModel{ID: 3, EnterpriseID: nil, ShopID: uintPtr(200), ResourceType: "device", ResourceID: 201, TagID: 8}) // 4. 店铺 300(其他)的资源标签 db.Create(&ResourceTagModel{ID: 4, EnterpriseID: nil, ShopID: uintPtr(300), ResourceType: "device", ResourceID: 301, TagID: 9}) // 5. 企业的资源标签 db.Create(&ResourceTagModel{ID: 5, EnterpriseID: uintPtr(1001), ShopID: nil, ResourceType: "iot_card", ResourceID: 103, TagID: 3}) // 注册 Callback err := RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置代理用户 context(店铺 ID = 100) ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeAgent, ShopID: 100, EnterpriseID: 0, CustomerID: 0, }) // 查询资源标签 var resourceTags []ResourceTagModel err = db.WithContext(ctx).Find(&resourceTags).Error assert.NoError(t, err) // 代理用户应该看到: // - 1 个全局资源标签(ID: 1) // - 1 个店铺 100 的资源标签(ID: 2) // - 1 个店铺 200(下级)的资源标签(ID: 3) // 总共 3 个 assert.Equal(t, 3, len(resourceTags), "代理用户应该看到自己店铺、下级店铺和全局的资源标签") // 验证看不到的资源标签 for _, rt := range resourceTags { assert.NotEqual(t, uint(4), rt.ID, "代理用户不应该看到其他店铺的资源标签") assert.NotEqual(t, uint(5), rt.ID, "代理用户不应该看到企业的资源标签") } } // TestTagPermission_CrossIsolation 测试跨租户隔离 // 验证企业 A 看不到企业 B 的标签 func TestTagPermission_CrossIsolation(t *testing.T) { db, mockStore := setupTagTestDB(t) // 注册 Callback err := RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 企业 A 用户(enterprise_id = 1001) ctxA := context.Background() ctxA = middleware.SetUserContext(ctxA, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeEnterprise, ShopID: 0, EnterpriseID: 1001, CustomerID: 0, }) // 企业 B 用户(enterprise_id = 1002) ctxB := context.Background() ctxB = middleware.SetUserContext(ctxB, &middleware.UserContextInfo{ UserID: 2, UserType: constants.UserTypeEnterprise, ShopID: 0, EnterpriseID: 1002, CustomerID: 0, }) // 企业 A 查询标签 var tagsA []TagModel err = db.WithContext(ctxA).Find(&tagsA).Error assert.NoError(t, err) // 企业 B 查询标签 var tagsB []TagModel err = db.WithContext(ctxB).Find(&tagsB).Error assert.NoError(t, err) // 企业 A 应该看到 4 个标签(2 全局 + 2 企业 A) assert.Equal(t, 4, len(tagsA), "企业 A 应该看到 4 个标签") // 企业 B 应该看到 3 个标签(2 全局 + 1 企业 B) assert.Equal(t, 3, len(tagsB), "企业 B 应该看到 3 个标签") // 验证企业 A 看不到企业 B 的标签 for _, tag := range tagsA { if tag.EnterpriseID != nil { assert.Equal(t, uint(1001), *tag.EnterpriseID, "企业 A 不应该看到企业 B 的标签") } } // 验证企业 B 看不到企业 A 的标签 for _, tag := range tagsB { if tag.EnterpriseID != nil { assert.Equal(t, uint(1002), *tag.EnterpriseID, "企业 B 不应该看到企业 A 的标签") } } }