Files
junhong_cmp_fiber/pkg/gorm/callback_test.go
huang fdcff33058
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 5m9s
feat: 实现企业卡授权和授权记录管理功能
主要功能:
- 添加企业卡授权/回收接口 (POST /enterprises/:id/allocate-cards, recall-cards)
- 添加授权记录管理接口 (GET/PUT /authorizations)
- 实现代理用户数据权限过滤(只能查看自己店铺下企业的授权记录)
- 添加 GORM callback 支持授权记录表的数据权限过滤

技术改进:
- 原生 SQL 查询手动添加数据权限过滤(ListWithJoin, GetByIDWithJoin)
- 移除卡授权预检接口(allocate-cards/preview),保留内部方法
- 完善单元测试和集成测试覆盖
2026-01-26 15:07:03 +08:00

1175 lines
37 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package gorm
import (
"context"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/stretchr/testify/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// mockShopStore 模拟店铺 Store
type mockShopStore struct {
subordinateShopIDs []uint
err error
}
func (m *mockShopStore) GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) {
if m.err != nil {
return nil, m.err
}
return m.subordinateShopIDs, nil
}
// TestSkipDataPermission 测试跳过数据权限过滤
func TestSkipDataPermission(t *testing.T) {
ctx := context.Background()
// 设置跳过标记
ctx = SkipDataPermission(ctx)
// 验证标记已设置
skip, ok := ctx.Value(SkipDataPermissionKey).(bool)
assert.True(t, ok)
assert.True(t, skip)
}
// TestRegisterDataPermissionCallback 测试注册数据权限 Callback
func TestRegisterDataPermissionCallback(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{1, 2, 3},
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
}
// TestDataPermissionCallback_SkipForSuperAdmin 测试超级管理员跳过过滤
func TestDataPermissionCallback_SkipForSuperAdmin(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
ShopID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100}, // 只有店铺 100
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置超级管理员 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeSuperAdmin,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 超级管理员应该看到所有数据
assert.Equal(t, 2, len(results))
}
// TestDataPermissionCallback_SkipForPlatform 测试平台用户跳过过滤
func TestDataPermissionCallback_SkipForPlatform(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
ShopID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100}, // 只有店铺 100
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置平台用户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 平台用户应该看到所有数据
assert.Equal(t, 2, len(results))
}
// TestDataPermissionCallback_FilterForAgent 测试代理用户过滤
func TestDataPermissionCallback_FilterForAgent(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表(包含 shop_id 字段以触发店铺层级过滤)
type TestModel struct {
ID uint
ShopID uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Name: "test2"})
db.Create(&TestModel{ID: 3, ShopID: 300, Name: "test3"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100, 200}, // 只能看到店铺 100 和 200
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context (shop_id = 100)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 代理用户只能看到自己店铺和下级店铺的数据
assert.Equal(t, 2, len(results))
assert.Equal(t, uint(100), results[0].ShopID)
assert.Equal(t, uint(200), results[1].ShopID)
}
// TestDataPermissionCallback_SkipWithContext 测试通过 Context 跳过过滤
func TestDataPermissionCallback_SkipWithContext(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
ShopID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100}, // 只有店铺 100
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context 并跳过过滤
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
ctx = SkipDataPermission(ctx)
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 跳过过滤后应该看到所有数据
assert.Equal(t, 2, len(results))
}
// TestDataPermissionCallback_WithShopID 测试带 shop_id 的过滤
func TestDataPermissionCallback_WithShopID(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
Creator uint
ShopID uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, Creator: 1, ShopID: 100, Name: "test1"})
db.Create(&TestModel{ID: 2, Creator: 2, ShopID: 100, Name: "test2"})
db.Create(&TestModel{ID: 3, Creator: 2, ShopID: 200, Name: "test3"}) // 不同 shop_id
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100, 200}, // 可以看到店铺 100 和 200
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context (shop_id = 100)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 应该看到 shop_id = 100 和 200 的所有数据(因为 mockStore 返回了这两个店铺 ID
assert.Equal(t, 3, len(results))
}
// TestDataPermissionCallback_FilterForEnterprise 测试企业用户过滤
func TestDataPermissionCallback_FilterForEnterprise(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表(包含 enterprise_id 字段)
type TestModel struct {
ID uint
EnterpriseID uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, EnterpriseID: 1001, Name: "test1"})
db.Create(&TestModel{ID: 2, EnterpriseID: 1001, Name: "test2"})
db.Create(&TestModel{ID: 3, EnterpriseID: 1002, Name: "test3"})
// 创建 mock ShopStore企业用户不需要但注册时需要
mockStore := &mockShopStore{
subordinateShopIDs: []uint{},
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置企业用户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeEnterprise,
ShopID: 0,
EnterpriseID: 1001,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 企业用户只能看到自己企业的数据
assert.Equal(t, 2, len(results))
for _, r := range results {
assert.Equal(t, uint(1001), r.EnterpriseID)
}
}
// TestDataPermissionCallback_FilterForPersonalCustomer 测试个人客户过滤
func TestDataPermissionCallback_FilterForPersonalCustomer(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表(包含 creator 字段)
type TestModel struct {
ID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
db.Create(&TestModel{ID: 3, Creator: 1, Name: "test3"})
// 创建 mock ShopStore个人客户不需要但注册时需要
mockStore := &mockShopStore{
subordinateShopIDs: []uint{},
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置个人客户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePersonalCustomer,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 1,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 个人客户只能看到自己创建的数据
assert.Equal(t, 2, len(results))
for _, r := range results {
assert.Equal(t, uint(1), r.Creator)
}
}
// ============================================================
// 标签表数据权限过滤测试tb_tag / tb_resource_tag 表)
// ============================================================
// TagModel 模拟标签表tb_tag结构
// 注意:必须指定 TableName 为 "tb_tag" 才能触发特殊过滤逻辑
type TagModel struct {
ID uint `gorm:"primaryKey"`
EnterpriseID *uint `gorm:"column:enterprise_id"`
ShopID *uint `gorm:"column:shop_id"`
Name string
}
func (TagModel) TableName() string {
return "tb_tag"
}
// ResourceTagModel 模拟资源标签表tb_resource_tag结构
type ResourceTagModel struct {
ID uint `gorm:"primaryKey"`
EnterpriseID *uint `gorm:"column:enterprise_id"`
ShopID *uint `gorm:"column:shop_id"`
ResourceType string
ResourceID uint
TagID uint
}
func (ResourceTagModel) TableName() string {
return "tb_resource_tag"
}
// uintPtr 辅助函数,将 uint 转换为 *uint
func uintPtr(v uint) *uint {
return &v
}
// setupTagTestDB 创建标签测试数据库和数据
// 返回db 实例和 mock ShopStore
func setupTagTestDB(t *testing.T) (*gorm.DB, *mockShopStore) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
err = db.AutoMigrate(&TagModel{}, &ResourceTagModel{})
assert.NoError(t, err)
// 插入测试数据
// 1. 全局标签enterprise_id = NULL, shop_id = NULL
db.Create(&TagModel{ID: 1, EnterpriseID: nil, ShopID: nil, Name: "全局标签-VIP"})
db.Create(&TagModel{ID: 2, EnterpriseID: nil, ShopID: nil, Name: "全局标签-重要客户"})
// 2. 企业标签enterprise_id = 1001, shop_id = NULL
db.Create(&TagModel{ID: 3, EnterpriseID: uintPtr(1001), ShopID: nil, Name: "企业A-测试标签"})
db.Create(&TagModel{ID: 4, EnterpriseID: uintPtr(1001), ShopID: nil, Name: "企业A-内部标签"})
// 3. 另一个企业的标签enterprise_id = 1002, shop_id = NULL
db.Create(&TagModel{ID: 5, EnterpriseID: uintPtr(1002), ShopID: nil, Name: "企业B-测试标签"})
// 4. 店铺标签enterprise_id = NULL, shop_id = 100
db.Create(&TagModel{ID: 6, EnterpriseID: nil, ShopID: uintPtr(100), Name: "店铺100-华东区"})
db.Create(&TagModel{ID: 7, EnterpriseID: nil, ShopID: uintPtr(100), Name: "店铺100-大客户"})
// 5. 下级店铺标签enterprise_id = NULL, shop_id = 200
db.Create(&TagModel{ID: 8, EnterpriseID: nil, ShopID: uintPtr(200), Name: "店铺200-华南区"})
// 6. 其他店铺标签enterprise_id = NULL, shop_id = 300
db.Create(&TagModel{ID: 9, EnterpriseID: nil, ShopID: uintPtr(300), Name: "店铺300-华北区"})
// 创建 mock ShopStore
// 假设店铺 100 的下级店铺包括 100 和 200
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100, 200},
}
return db, mockStore
}
// TestTagPermission_SuperAdmin 测试超级管理员查询标签(应看到所有标签)
func TestTagPermission_SuperAdmin(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置超级管理员 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeSuperAdmin,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 超级管理员应该看到所有 9 个标签
assert.Equal(t, 9, len(tags), "超级管理员应该看到所有标签")
}
// TestTagPermission_Platform 测试平台用户查询标签(应看到所有标签)
func TestTagPermission_Platform(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置平台用户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 平台用户应该看到所有 9 个标签
assert.Equal(t, 9, len(tags), "平台用户应该看到所有标签")
}
// TestTagPermission_Agent 测试代理用户查询标签
// 预期:看到自己店铺标签 + 下级店铺标签 + 全局标签
func TestTagPermission_Agent(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context店铺 ID = 100
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 代理用户应该看到:
// - 2 个全局标签ID: 1, 2
// - 2 个店铺 100 的标签ID: 6, 7
// - 1 个店铺 200下级的标签ID: 8
// 总共 5 个标签
assert.Equal(t, 5, len(tags), "代理用户应该看到自己店铺、下级店铺和全局标签")
// 验证标签 ID
expectedIDs := map[uint]bool{1: true, 2: true, 6: true, 7: true, 8: true}
for _, tag := range tags {
assert.True(t, expectedIDs[tag.ID], "标签 ID %d 不应该被代理用户看到", tag.ID)
}
// 验证看不到的标签
// - 企业标签ID: 3, 4, 5
// - 其他店铺标签ID: 9
for _, tag := range tags {
assert.NotEqual(t, uint(3), tag.ID, "代理用户不应该看到企业标签")
assert.NotEqual(t, uint(4), tag.ID, "代理用户不应该看到企业标签")
assert.NotEqual(t, uint(5), tag.ID, "代理用户不应该看到企业标签")
assert.NotEqual(t, uint(9), tag.ID, "代理用户不应该看到其他店铺标签")
}
}
// TestTagPermission_Agent_NoShopID 测试没有 ShopID 的代理用户
// 预期:只能看到全局标签
func TestTagPermission_Agent_NoShopID(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context没有店铺 ID
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 0, // 没有店铺
EnterpriseID: 0,
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 没有店铺的代理用户只能看到全局标签
assert.Equal(t, 2, len(tags), "没有店铺的代理用户只能看到全局标签")
// 验证都是全局标签
for _, tag := range tags {
assert.Nil(t, tag.EnterpriseID, "应该是全局标签enterprise_id 为 NULL")
assert.Nil(t, tag.ShopID, "应该是全局标签shop_id 为 NULL")
}
}
// TestTagPermission_Enterprise 测试企业用户查询标签
// 预期:看到自己企业标签 + 全局标签
func TestTagPermission_Enterprise(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置企业用户 context企业 ID = 1001
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeEnterprise,
ShopID: 0,
EnterpriseID: 1001,
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 企业用户应该看到:
// - 2 个全局标签ID: 1, 2
// - 2 个企业 1001 的标签ID: 3, 4
// 总共 4 个标签
assert.Equal(t, 4, len(tags), "企业用户应该看到自己企业和全局标签")
// 验证标签 ID
expectedIDs := map[uint]bool{1: true, 2: true, 3: true, 4: true}
for _, tag := range tags {
assert.True(t, expectedIDs[tag.ID], "标签 ID %d 不应该被企业用户看到", tag.ID)
}
// 验证看不到其他企业的标签
for _, tag := range tags {
assert.NotEqual(t, uint(5), tag.ID, "企业用户不应该看到其他企业的标签")
}
// 验证看不到店铺标签
for _, tag := range tags {
assert.NotEqual(t, uint(6), tag.ID, "企业用户不应该看到店铺标签")
assert.NotEqual(t, uint(7), tag.ID, "企业用户不应该看到店铺标签")
assert.NotEqual(t, uint(8), tag.ID, "企业用户不应该看到店铺标签")
assert.NotEqual(t, uint(9), tag.ID, "企业用户不应该看到店铺标签")
}
}
// TestTagPermission_Enterprise_NoEnterpriseID 测试没有 EnterpriseID 的企业用户
// 预期:只能看到全局标签
func TestTagPermission_Enterprise_NoEnterpriseID(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置企业用户 context没有企业 ID
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeEnterprise,
ShopID: 0,
EnterpriseID: 0, // 没有企业
CustomerID: 0,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 没有企业的企业用户只能看到全局标签
assert.Equal(t, 2, len(tags), "没有企业的企业用户只能看到全局标签")
// 验证都是全局标签
for _, tag := range tags {
assert.Nil(t, tag.EnterpriseID, "应该是全局标签enterprise_id 为 NULL")
assert.Nil(t, tag.ShopID, "应该是全局标签shop_id 为 NULL")
}
}
// TestTagPermission_PersonalCustomer 测试个人客户查询标签
// 预期:只能看到全局标签
func TestTagPermission_PersonalCustomer(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置个人客户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePersonalCustomer,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 1,
})
// 查询标签
var tags []TagModel
err = db.WithContext(ctx).Find(&tags).Error
assert.NoError(t, err)
// 个人客户只能看到 2 个全局标签
assert.Equal(t, 2, len(tags), "个人客户只能看到全局标签")
// 验证都是全局标签
for _, tag := range tags {
assert.Nil(t, tag.EnterpriseID, "个人客户只能看到全局标签enterprise_id 应为 NULL")
assert.Nil(t, tag.ShopID, "个人客户只能看到全局标签shop_id 应为 NULL")
}
}
// TestTagPermission_ResourceTag_Agent 测试代理用户查询资源标签表
// 预期:与 tb_tag 表相同的过滤规则
func TestTagPermission_ResourceTag_Agent(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 创建资源标签测试数据
// 1. 全局资源标签
db.Create(&ResourceTagModel{ID: 1, EnterpriseID: nil, ShopID: nil, ResourceType: "iot_card", ResourceID: 101, TagID: 1})
// 2. 店铺 100 的资源标签
db.Create(&ResourceTagModel{ID: 2, EnterpriseID: nil, ShopID: uintPtr(100), ResourceType: "iot_card", ResourceID: 102, TagID: 6})
// 3. 店铺 200下级的资源标签
db.Create(&ResourceTagModel{ID: 3, EnterpriseID: nil, ShopID: uintPtr(200), ResourceType: "device", ResourceID: 201, TagID: 8})
// 4. 店铺 300其他的资源标签
db.Create(&ResourceTagModel{ID: 4, EnterpriseID: nil, ShopID: uintPtr(300), ResourceType: "device", ResourceID: 301, TagID: 9})
// 5. 企业的资源标签
db.Create(&ResourceTagModel{ID: 5, EnterpriseID: uintPtr(1001), ShopID: nil, ResourceType: "iot_card", ResourceID: 103, TagID: 3})
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context店铺 ID = 100
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询资源标签
var resourceTags []ResourceTagModel
err = db.WithContext(ctx).Find(&resourceTags).Error
assert.NoError(t, err)
// 代理用户应该看到:
// - 1 个全局资源标签ID: 1
// - 1 个店铺 100 的资源标签ID: 2
// - 1 个店铺 200下级的资源标签ID: 3
// 总共 3 个
assert.Equal(t, 3, len(resourceTags), "代理用户应该看到自己店铺、下级店铺和全局的资源标签")
// 验证看不到的资源标签
for _, rt := range resourceTags {
assert.NotEqual(t, uint(4), rt.ID, "代理用户不应该看到其他店铺的资源标签")
assert.NotEqual(t, uint(5), rt.ID, "代理用户不应该看到企业的资源标签")
}
}
// TestTagPermission_CrossIsolation 测试跨租户隔离
// 验证企业 A 看不到企业 B 的标签
func TestTagPermission_CrossIsolation(t *testing.T) {
db, mockStore := setupTagTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 企业 A 用户enterprise_id = 1001
ctxA := context.Background()
ctxA = middleware.SetUserContext(ctxA, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeEnterprise,
ShopID: 0,
EnterpriseID: 1001,
CustomerID: 0,
})
// 企业 B 用户enterprise_id = 1002
ctxB := context.Background()
ctxB = middleware.SetUserContext(ctxB, &middleware.UserContextInfo{
UserID: 2,
UserType: constants.UserTypeEnterprise,
ShopID: 0,
EnterpriseID: 1002,
CustomerID: 0,
})
// 企业 A 查询标签
var tagsA []TagModel
err = db.WithContext(ctxA).Find(&tagsA).Error
assert.NoError(t, err)
// 企业 B 查询标签
var tagsB []TagModel
err = db.WithContext(ctxB).Find(&tagsB).Error
assert.NoError(t, err)
// 企业 A 应该看到 4 个标签2 全局 + 2 企业 A
assert.Equal(t, 4, len(tagsA), "企业 A 应该看到 4 个标签")
// 企业 B 应该看到 3 个标签2 全局 + 1 企业 B
assert.Equal(t, 3, len(tagsB), "企业 B 应该看到 3 个标签")
// 验证企业 A 看不到企业 B 的标签
for _, tag := range tagsA {
if tag.EnterpriseID != nil {
assert.Equal(t, uint(1001), *tag.EnterpriseID, "企业 A 不应该看到企业 B 的标签")
}
}
// 验证企业 B 看不到企业 A 的标签
for _, tag := range tagsB {
if tag.EnterpriseID != nil {
assert.Equal(t, uint(1002), *tag.EnterpriseID, "企业 B 不应该看到企业 A 的标签")
}
}
}
// ============================================================
// 企业卡授权表数据权限过滤测试tb_enterprise_card_authorization 表)
// ============================================================
// EnterpriseModel 模拟企业表,用于授权表过滤测试
type EnterpriseModel struct {
ID uint `gorm:"primaryKey"`
OwnerShopID *uint `gorm:"column:owner_shop_id"`
DeletedAt *time.Time `gorm:"column:deleted_at"`
Name string
}
func (EnterpriseModel) TableName() string {
return "tb_enterprise"
}
// AuthorizationModel 模拟企业卡授权表结构
type AuthorizationModel struct {
ID uint `gorm:"primaryKey"`
EnterpriseID uint `gorm:"column:enterprise_id"`
CardID uint `gorm:"column:card_id"`
AuthorizedBy uint `gorm:"column:authorized_by"`
AuthorizedAt time.Time `gorm:"column:authorized_at"`
AuthorizerType int `gorm:"column:authorizer_type"`
RevokedBy *uint `gorm:"column:revoked_by"`
RevokedAt *time.Time `gorm:"column:revoked_at"`
Remark string `gorm:"column:remark"`
}
func (AuthorizationModel) TableName() string {
return "tb_enterprise_card_authorization"
}
// setupAuthorizationTestDB 创建授权表测试数据库和数据
func setupAuthorizationTestDB(t *testing.T) (*gorm.DB, *mockShopStore) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
err = db.AutoMigrate(&EnterpriseModel{}, &AuthorizationModel{})
assert.NoError(t, err)
// 插入企业测试数据
// 1. 店铺 100 下的企业
db.Create(&EnterpriseModel{ID: 1, OwnerShopID: uintPtr(100), Name: "企业A-店铺100"})
db.Create(&EnterpriseModel{ID: 2, OwnerShopID: uintPtr(100), Name: "企业B-店铺100"})
// 2. 店铺 200店铺100的下级下的企业
db.Create(&EnterpriseModel{ID: 3, OwnerShopID: uintPtr(200), Name: "企业C-店铺200"})
// 3. 店铺 300其他店铺下的企业
db.Create(&EnterpriseModel{ID: 4, OwnerShopID: uintPtr(300), Name: "企业D-店铺300"})
// 4. 平台直属企业(无店铺归属)
db.Create(&EnterpriseModel{ID: 5, OwnerShopID: nil, Name: "企业E-平台直属"})
now := time.Now()
// 插入授权记录测试数据
// 1. 企业1的授权记录店铺100
db.Create(&AuthorizationModel{ID: 1, EnterpriseID: 1, CardID: 101, AuthorizedBy: 1, AuthorizedAt: now, AuthorizerType: 3})
db.Create(&AuthorizationModel{ID: 2, EnterpriseID: 1, CardID: 102, AuthorizedBy: 1, AuthorizedAt: now, AuthorizerType: 3})
// 2. 企业2的授权记录店铺100
db.Create(&AuthorizationModel{ID: 3, EnterpriseID: 2, CardID: 201, AuthorizedBy: 1, AuthorizedAt: now, AuthorizerType: 3})
// 3. 企业3的授权记录店铺200 - 下级店铺)
db.Create(&AuthorizationModel{ID: 4, EnterpriseID: 3, CardID: 301, AuthorizedBy: 2, AuthorizedAt: now, AuthorizerType: 3})
// 4. 企业4的授权记录店铺300 - 其他店铺)
db.Create(&AuthorizationModel{ID: 5, EnterpriseID: 4, CardID: 401, AuthorizedBy: 3, AuthorizedAt: now, AuthorizerType: 3})
db.Create(&AuthorizationModel{ID: 6, EnterpriseID: 4, CardID: 402, AuthorizedBy: 3, AuthorizedAt: now, AuthorizerType: 3})
// 5. 企业5的授权记录平台直属
db.Create(&AuthorizationModel{ID: 7, EnterpriseID: 5, CardID: 501, AuthorizedBy: 1, AuthorizedAt: now, AuthorizerType: 2})
// 创建 mock ShopStore
// 店铺 100 的下级店铺包括 100 和 200不含 300
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100, 200},
}
return db, mockStore
}
// TestAuthorizationPermission_SuperAdmin 测试超级管理员查询授权记录(应看到所有记录)
func TestAuthorizationPermission_SuperAdmin(t *testing.T) {
db, mockStore := setupAuthorizationTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置超级管理员 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeSuperAdmin,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询授权记录
var auths []AuthorizationModel
err = db.WithContext(ctx).Find(&auths).Error
assert.NoError(t, err)
// 超级管理员应该看到所有 7 条记录
assert.Equal(t, 7, len(auths), "超级管理员应该看到所有授权记录")
}
// TestAuthorizationPermission_Platform 测试平台用户查询授权记录(应看到所有记录)
func TestAuthorizationPermission_Platform(t *testing.T) {
db, mockStore := setupAuthorizationTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置平台用户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询授权记录
var auths []AuthorizationModel
err = db.WithContext(ctx).Find(&auths).Error
assert.NoError(t, err)
// 平台用户应该看到所有 7 条记录
assert.Equal(t, 7, len(auths), "平台用户应该看到所有授权记录")
}
// TestAuthorizationPermission_Agent_OwnShopOnly 测试代理用户查询授权记录
// 关键业务规则:代理只能看到自己店铺下企业的授权记录,不含下级店铺
func TestAuthorizationPermission_Agent_OwnShopOnly(t *testing.T) {
db, mockStore := setupAuthorizationTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context店铺 ID = 100
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询授权记录
var auths []AuthorizationModel
err = db.WithContext(ctx).Find(&auths).Error
assert.NoError(t, err)
// 代理用户店铺100应该只看到
// - 企业1的2条授权记录ID: 1, 2
// - 企业2的1条授权记录ID: 3
// 总共 3 条记录
// 注意不含下级店铺200的记录ID: 4这是关键业务规则
assert.Equal(t, 3, len(auths), "代理用户应该只看到自己店铺下企业的授权记录(不含下级店铺)")
// 验证授权记录 ID
expectedIDs := map[uint]bool{1: true, 2: true, 3: true}
for _, auth := range auths {
assert.True(t, expectedIDs[auth.ID], "授权记录 ID %d 不应该被代理用户看到", auth.ID)
}
// 验证看不到下级店铺的记录
for _, auth := range auths {
assert.NotEqual(t, uint(4), auth.ID, "代理用户不应该看到下级店铺的授权记录")
}
// 验证看不到其他店铺的记录
for _, auth := range auths {
assert.NotEqual(t, uint(5), auth.ID, "代理用户不应该看到其他店铺的授权记录")
assert.NotEqual(t, uint(6), auth.ID, "代理用户不应该看到其他店铺的授权记录")
}
// 验证看不到平台直属企业的记录
for _, auth := range auths {
assert.NotEqual(t, uint(7), auth.ID, "代理用户不应该看到平台直属企业的授权记录")
}
}
// TestAuthorizationPermission_Agent_SubordinateShop 测试下级店铺代理查询授权记录
// 验证下级店铺代理只能看到自己店铺下企业的授权记录
func TestAuthorizationPermission_Agent_SubordinateShop(t *testing.T) {
db, _ := setupAuthorizationTestDB(t)
// 创建 mock ShopStore店铺 200 只能看到自己
mockStore := &mockShopStore{
subordinateShopIDs: []uint{200},
}
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context店铺 ID = 200是店铺100的下级
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 2,
UserType: constants.UserTypeAgent,
ShopID: 200,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询授权记录
var auths []AuthorizationModel
err = db.WithContext(ctx).Find(&auths).Error
assert.NoError(t, err)
// 店铺200的代理用户应该只看到
// - 企业3的1条授权记录ID: 4
// 总共 1 条记录
assert.Equal(t, 1, len(auths), "下级店铺代理应该只看到自己店铺下企业的授权记录")
// 验证授权记录 ID
assert.Equal(t, uint(4), auths[0].ID, "应该是企业3的授权记录")
}
// TestAuthorizationPermission_Agent_NoShopID 测试没有 ShopID 的代理用户
// 预期:返回空结果
func TestAuthorizationPermission_Agent_NoShopID(t *testing.T) {
db, mockStore := setupAuthorizationTestDB(t)
// 注册 Callback
err := RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context没有店铺 ID
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 0, // 没有店铺
EnterpriseID: 0,
CustomerID: 0,
})
// 查询授权记录
var auths []AuthorizationModel
err = db.WithContext(ctx).Find(&auths).Error
assert.NoError(t, err)
// 没有店铺的代理用户应该看不到任何记录
assert.Equal(t, 0, len(auths), "没有店铺的代理用户应该看不到任何授权记录")
}
// TestAuthorizationPermission_Agent_CrossShopIsolation 测试跨店铺隔离
// 验证店铺 A 看不到店铺 B 的授权记录
func TestAuthorizationPermission_Agent_CrossShopIsolation(t *testing.T) {
db, _ := setupAuthorizationTestDB(t)
// 店铺 100 的 mock
mockStore100 := &mockShopStore{
subordinateShopIDs: []uint{100},
}
// 店铺 300 的 mock
mockStore300 := &mockShopStore{
subordinateShopIDs: []uint{300},
}
// 注册 Callback使用店铺100的mock
err := RegisterDataPermissionCallback(db, mockStore100)
assert.NoError(t, err)
// 店铺 100 代理用户
ctx100 := context.Background()
ctx100 = middleware.SetUserContext(ctx100, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询店铺100的授权记录
var auths100 []AuthorizationModel
err = db.WithContext(ctx100).Find(&auths100).Error
assert.NoError(t, err)
// 店铺100应该看到3条记录企业1和企业2的
assert.Equal(t, 3, len(auths100), "店铺100应该看到自己店铺下企业的授权记录")
// 重新创建数据库并注册店铺300的 Callback
db2, _ := setupAuthorizationTestDB(t)
err = RegisterDataPermissionCallback(db2, mockStore300)
assert.NoError(t, err)
// 店铺 300 代理用户
ctx300 := context.Background()
ctx300 = middleware.SetUserContext(ctx300, &middleware.UserContextInfo{
UserID: 3,
UserType: constants.UserTypeAgent,
ShopID: 300,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询店铺300的授权记录
var auths300 []AuthorizationModel
err = db2.WithContext(ctx300).Find(&auths300).Error
assert.NoError(t, err)
// 店铺300应该看到2条记录企业4的
assert.Equal(t, 2, len(auths300), "店铺300应该看到自己店铺下企业的授权记录")
// 验证店铺100看不到店铺300的记录
for _, auth := range auths100 {
assert.NotEqual(t, uint(5), auth.ID, "店铺100不应该看到店铺300的授权记录")
assert.NotEqual(t, uint(6), auth.ID, "店铺100不应该看到店铺300的授权记录")
}
// 验证店铺300看不到店铺100的记录
for _, auth := range auths300 {
assert.NotEqual(t, uint(1), auth.ID, "店铺300不应该看到店铺100的授权记录")
assert.NotEqual(t, uint(2), auth.ID, "店铺300不应该看到店铺100的授权记录")
assert.NotEqual(t, uint(3), auth.ID, "店铺300不应该看到店铺100的授权记录")
}
}