Files
junhong_cmp_fiber/pkg/gorm/callback_test.go
huang 2570269c8d feat(wallet,tag): 钱包和标签系统多租户改造
核心变更:
- 钱包表:删除 user_id,添加 resource_type/resource_id(绑定资源而非用户)
- 标签表:添加 enterprise_id/shop_id(实现三级隔离:全局/企业/店铺)
- GORM Callback:自动数据权限过滤
- 迁移脚本:可重复执行,已验证回滚功能

钱包归属重构原因:
- 旧设计:钱包绑定用户账号,个人客户卡/设备转手后新用户无法使用余额
- 新设计:钱包绑定资源(卡/设备/店铺),余额随资源流转

标签三级隔离:
- 平台全局标签:所有用户可见
- 企业标签:仅该企业可见(企业内唯一)
- 店铺标签:该店铺及下级可见(店铺内唯一)

测试覆盖:
- 9 个单元测试验证标签多租户过滤(全部通过)
- 迁移和回滚功能测试通过(测试环境)
- OpenSpec 验证通过

变更 ID: fix-wallet-tag-multi-tenant
迁移版本: 000008
参考: openspec/changes/archive/2026-01-13-fix-wallet-tag-multi-tenant/
2026-01-13 16:52:37 +08:00

849 lines
25 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package gorm
import (
"context"
"testing"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/stretchr/testify/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// mockShopStore 模拟店铺 Store
type mockShopStore struct {
subordinateShopIDs []uint
err error
}
func (m *mockShopStore) GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) {
if m.err != nil {
return nil, m.err
}
return m.subordinateShopIDs, nil
}
// TestSkipDataPermission 测试跳过数据权限过滤
func TestSkipDataPermission(t *testing.T) {
ctx := context.Background()
// 设置跳过标记
ctx = SkipDataPermission(ctx)
// 验证标记已设置
skip, ok := ctx.Value(SkipDataPermissionKey).(bool)
assert.True(t, ok)
assert.True(t, skip)
}
// TestRegisterDataPermissionCallback 测试注册数据权限 Callback
func TestRegisterDataPermissionCallback(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{1, 2, 3},
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
}
// TestDataPermissionCallback_SkipForSuperAdmin 测试超级管理员跳过过滤
func TestDataPermissionCallback_SkipForSuperAdmin(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
ShopID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100}, // 只有店铺 100
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置超级管理员 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeSuperAdmin,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 超级管理员应该看到所有数据
assert.Equal(t, 2, len(results))
}
// TestDataPermissionCallback_SkipForPlatform 测试平台用户跳过过滤
func TestDataPermissionCallback_SkipForPlatform(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
ShopID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100}, // 只有店铺 100
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置平台用户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 平台用户应该看到所有数据
assert.Equal(t, 2, len(results))
}
// TestDataPermissionCallback_FilterForAgent 测试代理用户过滤
func TestDataPermissionCallback_FilterForAgent(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表(包含 shop_id 字段以触发店铺层级过滤)
type TestModel struct {
ID uint
ShopID uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Name: "test2"})
db.Create(&TestModel{ID: 3, ShopID: 300, Name: "test3"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100, 200}, // 只能看到店铺 100 和 200
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context (shop_id = 100)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 代理用户只能看到自己店铺和下级店铺的数据
assert.Equal(t, 2, len(results))
assert.Equal(t, uint(100), results[0].ShopID)
assert.Equal(t, uint(200), results[1].ShopID)
}
// TestDataPermissionCallback_SkipWithContext 测试通过 Context 跳过过滤
func TestDataPermissionCallback_SkipWithContext(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
ShopID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100}, // 只有店铺 100
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context 并跳过过滤
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
ctx = SkipDataPermission(ctx)
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 跳过过滤后应该看到所有数据
assert.Equal(t, 2, len(results))
}
// TestDataPermissionCallback_WithShopID 测试带 shop_id 的过滤
func TestDataPermissionCallback_WithShopID(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
Creator uint
ShopID uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, Creator: 1, ShopID: 100, Name: "test1"})
db.Create(&TestModel{ID: 2, Creator: 2, ShopID: 100, Name: "test2"})
db.Create(&TestModel{ID: 3, Creator: 2, ShopID: 200, Name: "test3"}) // 不同 shop_id
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100, 200}, // 可以看到店铺 100 和 200
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context (shop_id = 100)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 应该看到 shop_id = 100 和 200 的所有数据(因为 mockStore 返回了这两个店铺 ID
assert.Equal(t, 3, len(results))
}
// TestDataPermissionCallback_FilterForEnterprise 测试企业用户过滤
func TestDataPermissionCallback_FilterForEnterprise(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表(包含 enterprise_id 字段)
type TestModel struct {
ID uint
EnterpriseID uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, EnterpriseID: 1001, Name: "test1"})
db.Create(&TestModel{ID: 2, EnterpriseID: 1001, Name: "test2"})
db.Create(&TestModel{ID: 3, EnterpriseID: 1002, Name: "test3"})
// 创建 mock ShopStore企业用户不需要但注册时需要
mockStore := &mockShopStore{
subordinateShopIDs: []uint{},
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置企业用户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeEnterprise,
ShopID: 0,
EnterpriseID: 1001,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 企业用户只能看到自己企业的数据
assert.Equal(t, 2, len(results))
for _, r := range results {
assert.Equal(t, uint(1001), r.EnterpriseID)
}
}
// TestDataPermissionCallback_FilterForPersonalCustomer 测试个人客户过滤
func TestDataPermissionCallback_FilterForPersonalCustomer(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表(包含 creator 字段)
type TestModel struct {
ID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
db.Create(&TestModel{ID: 3, Creator: 1, Name: "test3"})
// 创建 mock ShopStore个人客户不需要但注册时需要
mockStore := &mockShopStore{
subordinateShopIDs: []uint{},
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置个人客户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePersonalCustomer,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 1,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 个人客户只能看到自己创建的数据
assert.Equal(t, 2, len(results))
for _, r := range results {
assert.Equal(t, uint(1), r.Creator)
}
}
// ============================================================
// 标签表数据权限过滤测试tb_tag / tb_resource_tag 表)
// ============================================================
// TagModel 模拟标签表tb_tag结构
// 注意:必须指定 TableName 为 "tb_tag" 才能触发特殊过滤逻辑
type TagModel struct {
ID uint `gorm:"primaryKey"`
EnterpriseID *uint `gorm:"column:enterprise_id"`
ShopID *uint `gorm:"column:shop_id"`
Name string
}
func (TagModel) TableName() string {
return "tb_tag"
}
// ResourceTagModel 模拟资源标签表tb_resource_tag结构
type ResourceTagModel struct {
ID uint `gorm:"primaryKey"`
EnterpriseID *uint `gorm:"column:enterprise_id"`
ShopID *uint `gorm:"column:shop_id"`
ResourceType string
ResourceID uint
TagID uint
}
func (ResourceTagModel) TableName() string {
return "tb_resource_tag"
}
// uintPtr 辅助函数,将 uint 转换为 *uint
func uintPtr(v uint) *uint {
return &v
}
// setupTagTestDB 创建标签测试数据库和数据
// 返回db 实例和 mock ShopStore
func setupTagTestDB(t *testing.T) (*gorm.DB, *mockShopStore) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
err = db.AutoMigrate(&TagModel{}, &ResourceTagModel{})
assert.NoError(t, err)
// 插入测试数据
// 1. 全局标签enterprise_id = NULL, shop_id = NULL
db.Create(&TagModel{ID: 1, EnterpriseID: nil, ShopID: nil, Name: "全局标签-VIP"})
db.Create(&TagModel{ID: 2, EnterpriseID: nil, ShopID: nil, Name: "全局标签-重要客户"})
// 2. 企业标签enterprise_id = 1001, shop_id = NULL
db.Create(&TagModel{ID: 3, EnterpriseID: uintPtr(1001), ShopID: nil, Name: "企业A-测试标签"})
db.Create(&TagModel{ID: 4, EnterpriseID: uintPtr(1001), ShopID: nil, Name: "企业A-内部标签"})
// 3. 另一个企业的标签enterprise_id = 1002, shop_id = NULL
db.Create(&TagModel{ID: 5, EnterpriseID: uintPtr(1002), ShopID: nil, Name: "企业B-测试标签"})
// 4. 店铺标签enterprise_id = NULL, shop_id = 100
db.Create(&TagModel{ID: 6, EnterpriseID: nil, ShopID: uintPtr(100), Name: "店铺100-华东区"})
db.Create(&TagModel{ID: 7, EnterpriseID: nil, ShopID: uintPtr(100), Name: "店铺100-大客户"})
// 5. 下级店铺标签enterprise_id = NULL, shop_id = 200
db.Create(&TagModel{ID: 8, EnterpriseID: nil, ShopID: uintPtr(200), Name: "店铺200-华南区"})
// 6. 其他店铺标签enterprise_id = NULL, shop_id = 300
db.Create(&TagModel{ID: 9, EnterpriseID: nil, ShopID: uintPtr(300), Name: "店铺300-华北区"})
// 创建 mock ShopStore
// 假设店铺 100 的下级店铺包括 100 和 200
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100, 200},
}
return db, mockStore
}
// TestTagPermission_SuperAdmin 测试超级管理员查询标签(应看到所有标签)
func TestTagPermission_SuperAdmin(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置超级管理员 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeSuperAdmin,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 超级管理员应该看到所有 9 个标签
assert.Equal(t, 9, len(tags), "超级管理员应该看到所有标签")
}
// TestTagPermission_Platform 测试平台用户查询标签(应看到所有标签)
func TestTagPermission_Platform(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置平台用户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 平台用户应该看到所有 9 个标签
assert.Equal(t, 9, len(tags), "平台用户应该看到所有标签")
}
// TestTagPermission_Agent 测试代理用户查询标签
// 预期:看到自己店铺标签 + 下级店铺标签 + 全局标签
func TestTagPermission_Agent(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context店铺 ID = 100
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 代理用户应该看到:
// - 2 个全局标签ID: 1, 2
// - 2 个店铺 100 的标签ID: 6, 7
// - 1 个店铺 200下级的标签ID: 8
// 总共 5 个标签
assert.Equal(t, 5, len(tags), "代理用户应该看到自己店铺、下级店铺和全局标签")
// 验证标签 ID
expectedIDs := map[uint]bool{1: true, 2: true, 6: true, 7: true, 8: true}
for _, tag := range tags {
assert.True(t, expectedIDs[tag.ID], "标签 ID %d 不应该被代理用户看到", tag.ID)
}
// 验证看不到的标签
// - 企业标签ID: 3, 4, 5
// - 其他店铺标签ID: 9
for _, tag := range tags {
assert.NotEqual(t, uint(3), tag.ID, "代理用户不应该看到企业标签")
assert.NotEqual(t, uint(4), tag.ID, "代理用户不应该看到企业标签")
assert.NotEqual(t, uint(5), tag.ID, "代理用户不应该看到企业标签")
assert.NotEqual(t, uint(9), tag.ID, "代理用户不应该看到其他店铺标签")
}
}
// TestTagPermission_Agent_NoShopID 测试没有 ShopID 的代理用户
// 预期:只能看到全局标签
func TestTagPermission_Agent_NoShopID(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context没有店铺 ID
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 0, // 没有店铺
EnterpriseID: 0,
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 没有店铺的代理用户只能看到全局标签
assert.Equal(t, 2, len(tags), "没有店铺的代理用户只能看到全局标签")
// 验证都是全局标签
for _, tag := range tags {
assert.Nil(t, tag.EnterpriseID, "应该是全局标签enterprise_id 为 NULL")
assert.Nil(t, tag.ShopID, "应该是全局标签shop_id 为 NULL")
}
}
// TestTagPermission_Enterprise 测试企业用户查询标签
// 预期:看到自己企业标签 + 全局标签
func TestTagPermission_Enterprise(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置企业用户 context企业 ID = 1001
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeEnterprise,
ShopID: 0,
EnterpriseID: 1001,
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 企业用户应该看到:
// - 2 个全局标签ID: 1, 2
// - 2 个企业 1001 的标签ID: 3, 4
// 总共 4 个标签
assert.Equal(t, 4, len(tags), "企业用户应该看到自己企业和全局标签")
// 验证标签 ID
expectedIDs := map[uint]bool{1: true, 2: true, 3: true, 4: true}
for _, tag := range tags {
assert.True(t, expectedIDs[tag.ID], "标签 ID %d 不应该被企业用户看到", tag.ID)
}
// 验证看不到其他企业的标签
for _, tag := range tags {
assert.NotEqual(t, uint(5), tag.ID, "企业用户不应该看到其他企业的标签")
}
// 验证看不到店铺标签
for _, tag := range tags {
assert.NotEqual(t, uint(6), tag.ID, "企业用户不应该看到店铺标签")
assert.NotEqual(t, uint(7), tag.ID, "企业用户不应该看到店铺标签")
assert.NotEqual(t, uint(8), tag.ID, "企业用户不应该看到店铺标签")
assert.NotEqual(t, uint(9), tag.ID, "企业用户不应该看到店铺标签")
}
}
// TestTagPermission_Enterprise_NoEnterpriseID 测试没有 EnterpriseID 的企业用户
// 预期:只能看到全局标签
func TestTagPermission_Enterprise_NoEnterpriseID(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置企业用户 context没有企业 ID
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeEnterprise,
ShopID: 0,
EnterpriseID: 0, // 没有企业
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 没有企业的企业用户只能看到全局标签
assert.Equal(t, 2, len(tags), "没有企业的企业用户只能看到全局标签")
// 验证都是全局标签
for _, tag := range tags {
assert.Nil(t, tag.EnterpriseID, "应该是全局标签enterprise_id 为 NULL")
assert.Nil(t, tag.ShopID, "应该是全局标签shop_id 为 NULL")
}
}
// TestTagPermission_PersonalCustomer 测试个人客户查询标签
// 预期:只能看到全局标签
func TestTagPermission_PersonalCustomer(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置个人客户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePersonalCustomer,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 1,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 个人客户只能看到 2 个全局标签
assert.Equal(t, 2, len(tags), "个人客户只能看到全局标签")
// 验证都是全局标签
for _, tag := range tags {
assert.Nil(t, tag.EnterpriseID, "个人客户只能看到全局标签enterprise_id 应为 NULL")
assert.Nil(t, tag.ShopID, "个人客户只能看到全局标签shop_id 应为 NULL")
}
}
// TestTagPermission_ResourceTag_Agent 测试代理用户查询资源标签表
// 预期:与 tb_tag 表相同的过滤规则
func TestTagPermission_ResourceTag_Agent(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 创建资源标签测试数据
// 1. 全局资源标签
db.Create(&ResourceTagModel{ID: 1, EnterpriseID: nil, ShopID: nil, ResourceType: "iot_card", ResourceID: 101, TagID: 1})
// 2. 店铺 100 的资源标签
db.Create(&ResourceTagModel{ID: 2, EnterpriseID: nil, ShopID: uintPtr(100), ResourceType: "iot_card", ResourceID: 102, TagID: 6})
// 3. 店铺 200下级的资源标签
db.Create(&ResourceTagModel{ID: 3, EnterpriseID: nil, ShopID: uintPtr(200), ResourceType: "device", ResourceID: 201, TagID: 8})
// 4. 店铺 300其他的资源标签
db.Create(&ResourceTagModel{ID: 4, EnterpriseID: nil, ShopID: uintPtr(300), ResourceType: "device", ResourceID: 301, TagID: 9})
// 5. 企业的资源标签
db.Create(&ResourceTagModel{ID: 5, EnterpriseID: uintPtr(1001), ShopID: nil, ResourceType: "iot_card", ResourceID: 103, TagID: 3})
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context店铺 ID = 100
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询资源标签
var resourceTags []ResourceTagModel
err = db.WithContext(ctx).Find(&resourceTags).Error
assert.NoError(t, err)
// 代理用户应该看到:
// - 1 个全局资源标签ID: 1
// - 1 个店铺 100 的资源标签ID: 2
// - 1 个店铺 200下级的资源标签ID: 3
// 总共 3 个
assert.Equal(t, 3, len(resourceTags), "代理用户应该看到自己店铺、下级店铺和全局的资源标签")
// 验证看不到的资源标签
for _, rt := range resourceTags {
assert.NotEqual(t, uint(4), rt.ID, "代理用户不应该看到其他店铺的资源标签")
assert.NotEqual(t, uint(5), rt.ID, "代理用户不应该看到企业的资源标签")
}
}
// TestTagPermission_CrossIsolation 测试跨租户隔离
// 验证企业 A 看不到企业 B 的标签
func TestTagPermission_CrossIsolation(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 企业 A 用户enterprise_id = 1001
ctxA := context.Background()
ctxA = middleware.SetUserContext(ctxA, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeEnterprise,
ShopID: 0,
EnterpriseID: 1001,
CustomerID: 0,
})
// 企业 B 用户enterprise_id = 1002
ctxB := context.Background()
ctxB = middleware.SetUserContext(ctxB, &middleware.UserContextInfo{
UserID: 2,
UserType: constants.UserTypeEnterprise,
ShopID: 0,
EnterpriseID: 1002,
CustomerID: 0,
})
// 企业 A 查询标签
var tagsA []TagModel
err = db.WithContext(ctxA).Find(&tagsA).Error
assert.NoError(t, err)
// 企业 B 查询标签
var tagsB []TagModel
err = db.WithContext(ctxB).Find(&tagsB).Error
assert.NoError(t, err)
// 企业 A 应该看到 4 个标签2 全局 + 2 企业 A
assert.Equal(t, 4, len(tagsA), "企业 A 应该看到 4 个标签")
// 企业 B 应该看到 3 个标签2 全局 + 1 企业 B
assert.Equal(t, 3, len(tagsB), "企业 B 应该看到 3 个标签")
// 验证企业 A 看不到企业 B 的标签
for _, tag := range tagsA {
if tag.EnterpriseID != nil {
assert.Equal(t, uint(1001), *tag.EnterpriseID, "企业 A 不应该看到企业 B 的标签")
}
}
// 验证企业 B 看不到企业 A 的标签
for _, tag := range tagsB {
if tag.EnterpriseID != nil {
assert.Equal(t, uint(1002), *tag.EnterpriseID, "企业 B 不应该看到企业 A 的标签")
}
}
}