Files
junhong_cmp_fiber/pkg/gorm/callback_test.go
huang 743db126f7 重构数据权限模型并清理旧RBAC代码
核心变更:
- 数据权限过滤从基于账号层级改为基于用户类型的多策略过滤
- 移除 AccountStore 中的 GetSubordinateIDs 等旧方法
- 重构认证中间件,支持 enterprise_id 和 customer_id
- 更新 GORM Callback,根据用户类型自动选择过滤策略(代理/企业/个人客户)
- 更新所有集成测试以适配新的 API 签名
- 添加功能总结文档和 OpenSpec 归档

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-10 15:08:11 +08:00

408 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package gorm
import (
"context"
"testing"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/stretchr/testify/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// mockShopStore 模拟店铺 Store
type mockShopStore struct {
subordinateShopIDs []uint
err error
}
func (m *mockShopStore) GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) {
if m.err != nil {
return nil, m.err
}
return m.subordinateShopIDs, nil
}
// TestSkipDataPermission 测试跳过数据权限过滤
func TestSkipDataPermission(t *testing.T) {
ctx := context.Background()
// 设置跳过标记
ctx = SkipDataPermission(ctx)
// 验证标记已设置
skip, ok := ctx.Value(SkipDataPermissionKey).(bool)
assert.True(t, ok)
assert.True(t, skip)
}
// TestRegisterDataPermissionCallback 测试注册数据权限 Callback
func TestRegisterDataPermissionCallback(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{1, 2, 3},
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
}
// TestDataPermissionCallback_SkipForSuperAdmin 测试超级管理员跳过过滤
func TestDataPermissionCallback_SkipForSuperAdmin(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
ShopID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100}, // 只有店铺 100
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置超级管理员 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeSuperAdmin,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 超级管理员应该看到所有数据
assert.Equal(t, 2, len(results))
}
// TestDataPermissionCallback_SkipForPlatform 测试平台用户跳过过滤
func TestDataPermissionCallback_SkipForPlatform(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
ShopID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100}, // 只有店铺 100
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置平台用户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePlatform,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 平台用户应该看到所有数据
assert.Equal(t, 2, len(results))
}
// TestDataPermissionCallback_FilterForAgent 测试代理用户过滤
func TestDataPermissionCallback_FilterForAgent(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表(包含 shop_id 字段以触发店铺层级过滤)
type TestModel struct {
ID uint
ShopID uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Name: "test2"})
db.Create(&TestModel{ID: 3, ShopID: 300, Name: "test3"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100, 200}, // 只能看到店铺 100 和 200
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context (shop_id = 100)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 代理用户只能看到自己店铺和下级店铺的数据
assert.Equal(t, 2, len(results))
assert.Equal(t, uint(100), results[0].ShopID)
assert.Equal(t, uint(200), results[1].ShopID)
}
// TestDataPermissionCallback_SkipWithContext 测试通过 Context 跳过过滤
func TestDataPermissionCallback_SkipWithContext(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
ShopID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"})
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100}, // 只有店铺 100
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context 并跳过过滤
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
ctx = SkipDataPermission(ctx)
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 跳过过滤后应该看到所有数据
assert.Equal(t, 2, len(results))
}
// TestDataPermissionCallback_WithShopID 测试带 shop_id 的过滤
func TestDataPermissionCallback_WithShopID(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
Creator uint
ShopID uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, Creator: 1, ShopID: 100, Name: "test1"})
db.Create(&TestModel{ID: 2, Creator: 2, ShopID: 100, Name: "test2"})
db.Create(&TestModel{ID: 3, Creator: 2, ShopID: 200, Name: "test3"}) // 不同 shop_id
// 创建 mock ShopStore
mockStore := &mockShopStore{
subordinateShopIDs: []uint{100, 200}, // 可以看到店铺 100 和 200
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置代理用户 context (shop_id = 100)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeAgent,
ShopID: 100,
EnterpriseID: 0,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 应该看到 shop_id = 100 和 200 的所有数据(因为 mockStore 返回了这两个店铺 ID
assert.Equal(t, 3, len(results))
}
// TestDataPermissionCallback_FilterForEnterprise 测试企业用户过滤
func TestDataPermissionCallback_FilterForEnterprise(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表(包含 enterprise_id 字段)
type TestModel struct {
ID uint
EnterpriseID uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, EnterpriseID: 1001, Name: "test1"})
db.Create(&TestModel{ID: 2, EnterpriseID: 1001, Name: "test2"})
db.Create(&TestModel{ID: 3, EnterpriseID: 1002, Name: "test3"})
// 创建 mock ShopStore企业用户不需要但注册时需要
mockStore := &mockShopStore{
subordinateShopIDs: []uint{},
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置企业用户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypeEnterprise,
ShopID: 0,
EnterpriseID: 1001,
CustomerID: 0,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 企业用户只能看到自己企业的数据
assert.Equal(t, 2, len(results))
for _, r := range results {
assert.Equal(t, uint(1001), r.EnterpriseID)
}
}
// TestDataPermissionCallback_FilterForPersonalCustomer 测试个人客户过滤
func TestDataPermissionCallback_FilterForPersonalCustomer(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表(包含 creator 字段)
type TestModel struct {
ID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
db.Create(&TestModel{ID: 3, Creator: 1, Name: "test3"})
// 创建 mock ShopStore个人客户不需要但注册时需要
mockStore := &mockShopStore{
subordinateShopIDs: []uint{},
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置个人客户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{
UserID: 1,
UserType: constants.UserTypePersonalCustomer,
ShopID: 0,
EnterpriseID: 0,
CustomerID: 1,
})
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 个人客户只能看到自己创建的数据
assert.Equal(t, 2, len(results))
for _, r := range results {
assert.Equal(t, uint(1), r.Creator)
}
}