package gorm import ( "context" "testing" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/middleware" "github.com/stretchr/testify/assert" "gorm.io/driver/sqlite" "gorm.io/gorm" ) // mockShopStore 模拟店铺 Store type mockShopStore struct { subordinateShopIDs []uint err error } func (m *mockShopStore) GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error) { if m.err != nil { return nil, m.err } return m.subordinateShopIDs, nil } // TestSkipDataPermission 测试跳过数据权限过滤 func TestSkipDataPermission(t *testing.T) { ctx := context.Background() // 设置跳过标记 ctx = SkipDataPermission(ctx) // 验证标记已设置 skip, ok := ctx.Value(SkipDataPermissionKey).(bool) assert.True(t, ok) assert.True(t, skip) } // TestRegisterDataPermissionCallback 测试注册数据权限 Callback func TestRegisterDataPermissionCallback(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{1, 2, 3}, } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) } // TestDataPermissionCallback_SkipForSuperAdmin 测试超级管理员跳过过滤 func TestDataPermissionCallback_SkipForSuperAdmin(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint ShopID uint Creator uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"}) db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"}) // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{100}, // 只有店铺 100 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置超级管理员 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeSuperAdmin, ShopID: 0, EnterpriseID: 0, CustomerID: 0, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 超级管理员应该看到所有数据 assert.Equal(t, 2, len(results)) } // TestDataPermissionCallback_SkipForPlatform 测试平台用户跳过过滤 func TestDataPermissionCallback_SkipForPlatform(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint ShopID uint Creator uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"}) db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"}) // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{100}, // 只有店铺 100 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置平台用户 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePlatform, ShopID: 0, EnterpriseID: 0, CustomerID: 0, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 平台用户应该看到所有数据 assert.Equal(t, 2, len(results)) } // TestDataPermissionCallback_FilterForAgent 测试代理用户过滤 func TestDataPermissionCallback_FilterForAgent(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表(包含 shop_id 字段以触发店铺层级过滤) type TestModel struct { ID uint ShopID uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, ShopID: 100, Name: "test1"}) db.Create(&TestModel{ID: 2, ShopID: 200, Name: "test2"}) db.Create(&TestModel{ID: 3, ShopID: 300, Name: "test3"}) // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{100, 200}, // 只能看到店铺 100 和 200 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置代理用户 context (shop_id = 100) ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeAgent, ShopID: 100, EnterpriseID: 0, CustomerID: 0, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 代理用户只能看到自己店铺和下级店铺的数据 assert.Equal(t, 2, len(results)) assert.Equal(t, uint(100), results[0].ShopID) assert.Equal(t, uint(200), results[1].ShopID) } // TestDataPermissionCallback_SkipWithContext 测试通过 Context 跳过过滤 func TestDataPermissionCallback_SkipWithContext(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint ShopID uint Creator uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, ShopID: 100, Creator: 1, Name: "test1"}) db.Create(&TestModel{ID: 2, ShopID: 200, Creator: 2, Name: "test2"}) // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{100}, // 只有店铺 100 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置代理用户 context 并跳过过滤 ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeAgent, ShopID: 100, EnterpriseID: 0, CustomerID: 0, }) ctx = SkipDataPermission(ctx) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 跳过过滤后应该看到所有数据 assert.Equal(t, 2, len(results)) } // TestDataPermissionCallback_WithShopID 测试带 shop_id 的过滤 func TestDataPermissionCallback_WithShopID(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表 type TestModel struct { ID uint Creator uint ShopID uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, Creator: 1, ShopID: 100, Name: "test1"}) db.Create(&TestModel{ID: 2, Creator: 2, ShopID: 100, Name: "test2"}) db.Create(&TestModel{ID: 3, Creator: 2, ShopID: 200, Name: "test3"}) // 不同 shop_id // 创建 mock ShopStore mockStore := &mockShopStore{ subordinateShopIDs: []uint{100, 200}, // 可以看到店铺 100 和 200 } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置代理用户 context (shop_id = 100) ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeAgent, ShopID: 100, EnterpriseID: 0, CustomerID: 0, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 应该看到 shop_id = 100 和 200 的所有数据(因为 mockStore 返回了这两个店铺 ID) assert.Equal(t, 3, len(results)) } // TestDataPermissionCallback_FilterForEnterprise 测试企业用户过滤 func TestDataPermissionCallback_FilterForEnterprise(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表(包含 enterprise_id 字段) type TestModel struct { ID uint EnterpriseID uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, EnterpriseID: 1001, Name: "test1"}) db.Create(&TestModel{ID: 2, EnterpriseID: 1001, Name: "test2"}) db.Create(&TestModel{ID: 3, EnterpriseID: 1002, Name: "test3"}) // 创建 mock ShopStore(企业用户不需要,但注册时需要) mockStore := &mockShopStore{ subordinateShopIDs: []uint{}, } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置企业用户 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypeEnterprise, ShopID: 0, EnterpriseID: 1001, CustomerID: 0, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 企业用户只能看到自己企业的数据 assert.Equal(t, 2, len(results)) for _, r := range results { assert.Equal(t, uint(1001), r.EnterpriseID) } } // TestDataPermissionCallback_FilterForPersonalCustomer 测试个人客户过滤 func TestDataPermissionCallback_FilterForPersonalCustomer(t *testing.T) { // 创建内存数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) assert.NoError(t, err) // 创建测试表(包含 creator 字段) type TestModel struct { ID uint Creator uint Name string } err = db.AutoMigrate(&TestModel{}) assert.NoError(t, err) // 插入测试数据 db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"}) db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"}) db.Create(&TestModel{ID: 3, Creator: 1, Name: "test3"}) // 创建 mock ShopStore(个人客户不需要,但注册时需要) mockStore := &mockShopStore{ subordinateShopIDs: []uint{}, } // 注册 Callback err = RegisterDataPermissionCallback(db, mockStore) assert.NoError(t, err) // 设置个人客户 context ctx := context.Background() ctx = middleware.SetUserContext(ctx, &middleware.UserContextInfo{ UserID: 1, UserType: constants.UserTypePersonalCustomer, ShopID: 0, EnterpriseID: 0, CustomerID: 1, }) // 查询数据 var results []TestModel err = db.WithContext(ctx).Find(&results).Error assert.NoError(t, err) // 个人客户只能看到自己创建的数据 assert.Equal(t, 2, len(results)) for _, r := range results { assert.Equal(t, uint(1), r.Creator) } }