package gorm import ( "context" "testing" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/stretchr/testify/assert" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/schema" ) // mockAccountStore 模拟账号 Store type mockAccountStore struct { subordinateIDs []uint err error } func (m *mockAccountStore) GetSubordinateIDs(ctx context.Context, accountID uint) ([]uint, error) { if m.err != nil { return nil, m.err } return m.subordinateIDs, nil } // TestSkipDataPermission 测试跳过数据权限过滤 func TestSkipDataPermission(t *testing.T) { ctx := context.Background() // 设置跳过标记 ctx = SkipDataPermission(ctx) // 验证标记已设置 skip, ok := ctx.Value(SkipDataPermissionKey).(bool) assert.True(t, ok) assert.True(t, skip) } // TestHasCreatorField 测试检查 creator 字段 func TestHasCreatorField(t *testing.T) { tests := []struct { name string schema *schema.Schema expected bool }{ { name: "nil schema", schema: nil, expected: false, }, { name: "schema with creator field", schema: &schema.Schema{ FieldsByDBName: map[string]*schema.Field{ "creator": {}, }, }, expected: true, }, { name: "schema without creator field", schema: &schema.Schema{ FieldsByDBName: map[string]*schema.Field{ "id": {}, }, }, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := hasCreatorField(tt.schema) assert.Equal(t, tt.expected, result) }) } } // TestHasShopIDField 测试检查 shop_id 字段 func TestHasShopIDField(t *testing.T) { tests := []struct { name string schema *schema.Schema expected bool }{ { name: "nil schema", schema: nil, expected: false, }, { name: "schema with shop_id field", schema: &schema.Schema{ FieldsByDBName: map[string]*schema.Field{ "shop_id": {}, }, }, expected: true, }, { name: "schema without shop_id field", schema: &schema.Schema{ FieldsByDBName: map[string]*schema.Field{ "id": {}, }, }, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := hasShopIDField(tt.schema) assert.Equal(t, tt.expected, result) }) } } // TestRegisterDataPermissionCallback 测试注册数据权限 Callback func TestRegisterDataPermissionCallback(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建 mock AccountStore mockStore := &mockAccountStore{ subordinateIDs: []uint{1, 2, 3}, } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) } // TestDataPermissionCallback_SkipForRootUser 测试 root 用户跳过过滤 func TestDataPermissionCallback_SkipForRootUser(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint Creator uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"}) db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"}) // 创建 mock AccountStore mockStore := &mockAccountStore{ subordinateIDs: []uint{1}, // 只有 ID 1 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置 root 用户 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeSuperAdmin, 0) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // root 用户应该看到所有数据 assert.Equal(t, 2, len(results)) } // TestDataPermissionCallback_FilterForNormalUser 测试普通用户过滤 func TestDataPermissionCallback_FilterForNormalUser(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint Creator uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"}) db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"}) db.Create(&TestModel{ID: 3, Creator: 3, Name: "test3"}) // 创建 mock AccountStore mockStore := &mockAccountStore{ subordinateIDs: []uint{1, 2}, // 只能看到 1 和 2 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置普通用户 context (非 root) ctx := context.Background() ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 0) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 普通用户只能看到自己和下级的数据 assert.Equal(t, 2, len(results)) assert.Equal(t, uint(1), results[0].Creator) assert.Equal(t, uint(2), results[1].Creator) } // TestDataPermissionCallback_SkipWithContext 测试通过 Context 跳过过滤 func TestDataPermissionCallback_SkipWithContext(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint Creator uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"}) db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"}) // 创建 mock AccountStore mockStore := &mockAccountStore{ subordinateIDs: []uint{1}, // 只有 ID 1 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置普通用户 context 并跳过过滤 ctx := context.Background() ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 0) ctx = SkipDataPermission(ctx) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 跳过过滤后应该看到所有数据 assert.Equal(t, 2, len(results)) } // TestDataPermissionCallback_WithShopID 测试带 shop_id 的过滤 func TestDataPermissionCallback_WithShopID(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint Creator uint ShopID uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, Creator: 1, ShopID: 100, Name: "test1"}) db.Create(&TestModel{ID: 2, Creator: 2, ShopID: 100, Name: "test2"}) db.Create(&TestModel{ID: 3, Creator: 2, ShopID: 200, Name: "test3"}) // 不同 shop_id // 创建 mock AccountStore mockStore := &mockAccountStore{ subordinateIDs: []uint{1, 2}, // 可以看到 1 和 2 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置普通用户 context (shop_id = 100) ctx := context.Background() ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 100) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 只能看到 shop_id = 100 的数据 assert.Equal(t, 2, len(results)) for _, r := range results { assert.Equal(t, uint(100), r.ShopID) } }