refactor: align framework cleanup with new bootstrap flow

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
This commit is contained in:
2025-11-19 12:47:25 +08:00
parent 39d14ec093
commit d66323487b
67 changed files with 3020 additions and 3992 deletions

312
pkg/gorm/callback_test.go Normal file
View File

@@ -0,0 +1,312 @@
package gorm
import (
"context"
"testing"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/stretchr/testify/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)
// mockAccountStore 模拟账号 Store
type mockAccountStore struct {
subordinateIDs []uint
err error
}
func (m *mockAccountStore) GetSubordinateIDs(ctx context.Context, accountID uint) ([]uint, error) {
if m.err != nil {
return nil, m.err
}
return m.subordinateIDs, nil
}
// TestSkipDataPermission 测试跳过数据权限过滤
func TestSkipDataPermission(t *testing.T) {
ctx := context.Background()
// 设置跳过标记
ctx = SkipDataPermission(ctx)
// 验证标记已设置
skip, ok := ctx.Value(SkipDataPermissionKey).(bool)
assert.True(t, ok)
assert.True(t, skip)
}
// TestHasCreatorField 测试检查 creator 字段
func TestHasCreatorField(t *testing.T) {
tests := []struct {
name string
schema *schema.Schema
expected bool
}{
{
name: "nil schema",
schema: nil,
expected: false,
},
{
name: "schema with creator field",
schema: &schema.Schema{
FieldsByDBName: map[string]*schema.Field{
"creator": {},
},
},
expected: true,
},
{
name: "schema without creator field",
schema: &schema.Schema{
FieldsByDBName: map[string]*schema.Field{
"id": {},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := hasCreatorField(tt.schema)
assert.Equal(t, tt.expected, result)
})
}
}
// TestHasShopIDField 测试检查 shop_id 字段
func TestHasShopIDField(t *testing.T) {
tests := []struct {
name string
schema *schema.Schema
expected bool
}{
{
name: "nil schema",
schema: nil,
expected: false,
},
{
name: "schema with shop_id field",
schema: &schema.Schema{
FieldsByDBName: map[string]*schema.Field{
"shop_id": {},
},
},
expected: true,
},
{
name: "schema without shop_id field",
schema: &schema.Schema{
FieldsByDBName: map[string]*schema.Field{
"id": {},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := hasShopIDField(tt.schema)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRegisterDataPermissionCallback 测试注册数据权限 Callback
func TestRegisterDataPermissionCallback(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建 mock AccountStore
mockStore := &mockAccountStore{
subordinateIDs: []uint{1, 2, 3},
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
}
// TestDataPermissionCallback_SkipForRootUser 测试 root 用户跳过过滤
func TestDataPermissionCallback_SkipForRootUser(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
// 创建 mock AccountStore
mockStore := &mockAccountStore{
subordinateIDs: []uint{1}, // 只有 ID 1
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置 root 用户 context
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeRoot, 0)
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// root 用户应该看到所有数据
assert.Equal(t, 2, len(results))
}
// TestDataPermissionCallback_FilterForNormalUser 测试普通用户过滤
func TestDataPermissionCallback_FilterForNormalUser(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
db.Create(&TestModel{ID: 3, Creator: 3, Name: "test3"})
// 创建 mock AccountStore
mockStore := &mockAccountStore{
subordinateIDs: []uint{1, 2}, // 只能看到 1 和 2
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置普通用户 context (非 root)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 0)
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 普通用户只能看到自己和下级的数据
assert.Equal(t, 2, len(results))
assert.Equal(t, uint(1), results[0].Creator)
assert.Equal(t, uint(2), results[1].Creator)
}
// TestDataPermissionCallback_SkipWithContext 测试通过 Context 跳过过滤
func TestDataPermissionCallback_SkipWithContext(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
Creator uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
// 创建 mock AccountStore
mockStore := &mockAccountStore{
subordinateIDs: []uint{1}, // 只有 ID 1
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置普通用户 context 并跳过过滤
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 0)
ctx = SkipDataPermission(ctx)
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 跳过过滤后应该看到所有数据
assert.Equal(t, 2, len(results))
}
// TestDataPermissionCallback_WithShopID 测试带 shop_id 的过滤
func TestDataPermissionCallback_WithShopID(t *testing.T) {
// 创建内存数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
assert.NoError(t, err)
// 创建测试表
type TestModel struct {
ID uint
Creator uint
ShopID uint
Name string
}
err = db.AutoMigrate(&TestModel{})
assert.NoError(t, err)
// 插入测试数据
db.Create(&TestModel{ID: 1, Creator: 1, ShopID: 100, Name: "test1"})
db.Create(&TestModel{ID: 2, Creator: 2, ShopID: 100, Name: "test2"})
db.Create(&TestModel{ID: 3, Creator: 2, ShopID: 200, Name: "test3"}) // 不同 shop_id
// 创建 mock AccountStore
mockStore := &mockAccountStore{
subordinateIDs: []uint{1, 2}, // 可以看到 1 和 2
}
// 注册 Callback
err = RegisterDataPermissionCallback(db, mockStore)
assert.NoError(t, err)
// 设置普通用户 context (shop_id = 100)
ctx := context.Background()
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 100)
// 查询数据
var results []TestModel
err = db.WithContext(ctx).Find(&results).Error
assert.NoError(t, err)
// 只能看到 shop_id = 100 的数据
assert.Equal(t, 2, len(results))
for _, r := range results {
assert.Equal(t, uint(100), r.ShopID)
}
}