refactor: align framework cleanup with new bootstrap flow
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
This commit is contained in:
@@ -24,16 +24,17 @@ type RedisConfig struct {
|
||||
// NewRedisClient 创建新的 Redis 客户端
|
||||
func NewRedisClient(cfg RedisConfig, logger *zap.Logger) (*redis.Client, error) {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Address,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
PoolSize: cfg.PoolSize,
|
||||
MinIdleConns: cfg.MinIdleConns,
|
||||
DialTimeout: cfg.DialTimeout,
|
||||
ReadTimeout: cfg.ReadTimeout,
|
||||
WriteTimeout: cfg.WriteTimeout,
|
||||
MaxRetries: 3,
|
||||
PoolTimeout: 4 * time.Second,
|
||||
Addr: cfg.Address,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
PoolSize: cfg.PoolSize,
|
||||
MinIdleConns: cfg.MinIdleConns,
|
||||
DialTimeout: cfg.DialTimeout,
|
||||
ReadTimeout: cfg.ReadTimeout,
|
||||
WriteTimeout: cfg.WriteTimeout,
|
||||
MaxRetries: 3,
|
||||
PoolTimeout: 4 * time.Second,
|
||||
DisableIndentity: true,
|
||||
})
|
||||
|
||||
// 测试连接
|
||||
@@ -41,7 +42,7 @@ func NewRedisClient(cfg RedisConfig, logger *zap.Logger) (*redis.Client, error)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||||
return nil, fmt.Errorf("redis连接错误: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Redis 连接成功",
|
||||
|
||||
@@ -43,10 +43,6 @@ const (
|
||||
CodeServiceUnavailable = 2004 // 服务不可用
|
||||
CodeTimeout = 2005 // 请求超时
|
||||
CodeTaskQueueError = 2006 // 任务队列错误
|
||||
|
||||
// 向后兼容的别名(供现有代码使用)
|
||||
CodeBadRequest = CodeInvalidParam // 别名:参数验证失败
|
||||
CodeAuthServiceUnavailable = CodeServiceUnavailable // 别名:认证服务不可用
|
||||
)
|
||||
|
||||
// errorMessages 错误消息映射表(中文)
|
||||
|
||||
@@ -15,10 +15,9 @@ var (
|
||||
|
||||
// AppError 表示带错误码的应用错误
|
||||
type AppError struct {
|
||||
Code int // 应用错误码
|
||||
Message string // 错误消息
|
||||
HTTPStatus int // HTTP 状态码(自动从 Code 映射,可通过 WithHTTPStatus 覆盖)
|
||||
Err error // 底层错误(可选)
|
||||
Code int // 应用错误码
|
||||
Message string // 错误消息
|
||||
Err error // 底层错误(可选)
|
||||
}
|
||||
|
||||
func (e *AppError) Error() string {
|
||||
@@ -39,9 +38,8 @@ func New(code int, message string) *AppError {
|
||||
message = GetMessage(code, "zh-CN")
|
||||
}
|
||||
return &AppError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
HTTPStatus: GetHTTPStatus(code), // 自动从错误码映射 HTTP 状态码
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,15 +50,8 @@ func Wrap(code int, message string, err error) *AppError {
|
||||
message = GetMessage(code, "zh-CN")
|
||||
}
|
||||
return &AppError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
HTTPStatus: GetHTTPStatus(code), // 自动从错误码映射 HTTP 状态码
|
||||
Err: err,
|
||||
Code: code,
|
||||
Message: message,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPStatus 设置自定义 HTTP 状态码(用于特殊场景)
|
||||
func (e *AppError) WithHTTPStatus(status int) *AppError {
|
||||
e.HTTPStatus = status
|
||||
return e
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ func handleError(c *fiber.Ctx, err error, logger *zap.Logger) error {
|
||||
// 应用自定义错误
|
||||
code = e.Code
|
||||
message = e.Message
|
||||
httpStatus = e.HTTPStatus
|
||||
httpStatus = GetHTTPStatus(e.Code)
|
||||
|
||||
// 记录错误日志(包含完整上下文)
|
||||
logFields := append(errCtx.ToLogFields(),
|
||||
|
||||
@@ -87,32 +87,22 @@ func TestSafeErrorHandler(t *testing.T) {
|
||||
// TestAppErrorMethods 测试 AppError 的方法
|
||||
func TestAppErrorMethods(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *AppError
|
||||
expectedError string
|
||||
expectedHTTPStatus int
|
||||
expectedCode int
|
||||
name string
|
||||
err *AppError
|
||||
expectedError string
|
||||
expectedCode int
|
||||
}{
|
||||
{
|
||||
name: "基本 AppError",
|
||||
err: New(CodeInvalidParam, "参数错误"),
|
||||
expectedError: "参数错误",
|
||||
expectedHTTPStatus: 400,
|
||||
expectedCode: CodeInvalidParam,
|
||||
name: "基本 AppError",
|
||||
err: New(CodeInvalidParam, "参数错误"),
|
||||
expectedError: "参数错误",
|
||||
expectedCode: CodeInvalidParam,
|
||||
},
|
||||
{
|
||||
name: "带自定义 HTTP 状态码",
|
||||
err: New(CodeNotFound, "用户不存在").WithHTTPStatus(404),
|
||||
expectedError: "用户不存在",
|
||||
expectedHTTPStatus: 404,
|
||||
expectedCode: CodeNotFound,
|
||||
},
|
||||
{
|
||||
name: "空消息使用默认",
|
||||
err: New(CodeDatabaseError, ""),
|
||||
expectedError: "数据库错误",
|
||||
expectedHTTPStatus: 500,
|
||||
expectedCode: CodeDatabaseError,
|
||||
name: "空消息使用默认",
|
||||
err: New(CodeDatabaseError, ""),
|
||||
expectedError: "数据库错误",
|
||||
expectedCode: CodeDatabaseError,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -127,11 +117,6 @@ func TestAppErrorMethods(t *testing.T) {
|
||||
if tt.err.Code != tt.expectedCode {
|
||||
t.Errorf("Code = %d, expected %d", tt.err.Code, tt.expectedCode)
|
||||
}
|
||||
|
||||
// 测试 HTTPStatus 字段
|
||||
if tt.err.HTTPStatus != tt.expectedHTTPStatus {
|
||||
t.Errorf("HTTPStatus = %d, expected %d", tt.err.HTTPStatus, tt.expectedHTTPStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
133
pkg/gorm/callback.go
Normal file
133
pkg/gorm/callback.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/logger"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/middleware"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// contextKey 用于 context value 的 key 类型
|
||||
type contextKey string
|
||||
|
||||
// SkipDataPermissionKey 跳过数据权限过滤的 context key
|
||||
const SkipDataPermissionKey contextKey = "skip_data_permission"
|
||||
|
||||
// SkipDataPermission 返回跳过数据权限过滤的 Context
|
||||
// 用于需要查询所有数据的场景(如管理后台统计、系统任务等)
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// ctx = gorm.SkipDataPermission(ctx)
|
||||
// db.WithContext(ctx).Find(&accounts)
|
||||
func SkipDataPermission(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, SkipDataPermissionKey, true)
|
||||
}
|
||||
|
||||
// AccountStoreInterface 账号 Store 接口
|
||||
// 用于 Callback 获取下级 ID,避免循环依赖
|
||||
type AccountStoreInterface interface {
|
||||
GetSubordinateIDs(ctx context.Context, accountID uint) ([]uint, error)
|
||||
}
|
||||
|
||||
// RegisterDataPermissionCallback 注册 GORM 数据权限过滤 Callback
|
||||
//
|
||||
// 自动化数据权限过滤规则:
|
||||
// 1. root 用户跳过过滤,可以查看所有数据
|
||||
// 2. 普通用户只能查看自己和下级的数据(通过递归查询下级 ID)
|
||||
// 3. 同时限制 shop_id 相同(如果配置了 shop_id)
|
||||
// 4. 通过 SkipDataPermission(ctx) 可以绕过权限过滤
|
||||
//
|
||||
// 注意:
|
||||
// - Callback 只对包含 creator 字段的表生效
|
||||
// - 必须在初始化 Store 之前注册
|
||||
//
|
||||
// 参数:
|
||||
// - db: GORM DB 实例
|
||||
// - accountStore: 账号 Store,用于查询下级 ID
|
||||
//
|
||||
// 返回:
|
||||
// - error: 注册错误
|
||||
func RegisterDataPermissionCallback(db *gorm.DB, accountStore AccountStoreInterface) error {
|
||||
// 注册查询前的 Callback
|
||||
err := db.Callback().Query().Before("gorm:query").Register("data_permission:query", func(tx *gorm.DB) {
|
||||
ctx := tx.Statement.Context
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 检查是否跳过数据权限过滤
|
||||
if skip, ok := ctx.Value(SkipDataPermissionKey).(bool); ok && skip {
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 检查是否为 root 用户,root 用户跳过过滤
|
||||
if middleware.IsRootUser(ctx) {
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 检查表是否有 creator 字段(只对有 creator 字段的表生效)
|
||||
if !hasCreatorField(tx.Statement.Schema) {
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 获取当前用户 ID
|
||||
userID := middleware.GetUserIDFromContext(ctx)
|
||||
if userID == 0 {
|
||||
// 未登录用户返回空结果
|
||||
logger.GetAppLogger().Warn("数据权限过滤:未获取到用户 ID")
|
||||
tx.Where("1 = 0")
|
||||
return
|
||||
}
|
||||
|
||||
// 5. 获取当前用户及所有下级的 ID
|
||||
subordinateIDs, err := accountStore.GetSubordinateIDs(ctx, userID)
|
||||
if err != nil {
|
||||
// 查询失败时,降级为只能看自己的数据
|
||||
logger.GetAppLogger().Error("数据权限过滤:获取下级 ID 失败",
|
||||
zap.Uint("user_id", userID),
|
||||
zap.Error(err))
|
||||
subordinateIDs = []uint{userID}
|
||||
}
|
||||
|
||||
if len(subordinateIDs) == 0 {
|
||||
subordinateIDs = []uint{userID}
|
||||
}
|
||||
|
||||
// 6. 获取当前用户的 shop_id
|
||||
shopID := middleware.GetShopIDFromContext(ctx)
|
||||
|
||||
// 7. 应用数据权限过滤条件
|
||||
// creator IN (用户自己及所有下级) AND shop_id = 当前用户 shop_id
|
||||
if shopID != 0 && hasShopIDField(tx.Statement.Schema) {
|
||||
// 同时过滤 creator 和 shop_id
|
||||
tx.Where("creator IN ? AND shop_id = ?", subordinateIDs, shopID)
|
||||
} else {
|
||||
// 只根据 creator 过滤
|
||||
tx.Where("creator IN ?", subordinateIDs)
|
||||
}
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// hasCreatorField 检查 Schema 是否包含 creator 字段
|
||||
func hasCreatorField(s *schema.Schema) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := s.FieldsByDBName["creator"]
|
||||
return ok
|
||||
}
|
||||
|
||||
// hasShopIDField 检查 Schema 是否包含 shop_id 字段
|
||||
func hasShopIDField(s *schema.Schema) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := s.FieldsByDBName["shop_id"]
|
||||
return ok
|
||||
}
|
||||
312
pkg/gorm/callback_test.go
Normal file
312
pkg/gorm/callback_test.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// mockAccountStore 模拟账号 Store
|
||||
type mockAccountStore struct {
|
||||
subordinateIDs []uint
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockAccountStore) GetSubordinateIDs(ctx context.Context, accountID uint) ([]uint, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.subordinateIDs, nil
|
||||
}
|
||||
|
||||
// TestSkipDataPermission 测试跳过数据权限过滤
|
||||
func TestSkipDataPermission(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// 设置跳过标记
|
||||
ctx = SkipDataPermission(ctx)
|
||||
|
||||
// 验证标记已设置
|
||||
skip, ok := ctx.Value(SkipDataPermissionKey).(bool)
|
||||
assert.True(t, ok)
|
||||
assert.True(t, skip)
|
||||
}
|
||||
|
||||
// TestHasCreatorField 测试检查 creator 字段
|
||||
func TestHasCreatorField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema *schema.Schema
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil schema",
|
||||
schema: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "schema with creator field",
|
||||
schema: &schema.Schema{
|
||||
FieldsByDBName: map[string]*schema.Field{
|
||||
"creator": {},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "schema without creator field",
|
||||
schema: &schema.Schema{
|
||||
FieldsByDBName: map[string]*schema.Field{
|
||||
"id": {},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasCreatorField(tt.schema)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasShopIDField 测试检查 shop_id 字段
|
||||
func TestHasShopIDField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema *schema.Schema
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil schema",
|
||||
schema: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "schema with shop_id field",
|
||||
schema: &schema.Schema{
|
||||
FieldsByDBName: map[string]*schema.Field{
|
||||
"shop_id": {},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "schema without shop_id field",
|
||||
schema: &schema.Schema{
|
||||
FieldsByDBName: map[string]*schema.Field{
|
||||
"id": {},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasShopIDField(tt.schema)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegisterDataPermissionCallback 测试注册数据权限 Callback
|
||||
func TestRegisterDataPermissionCallback(t *testing.T) {
|
||||
// 创建内存数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建 mock AccountStore
|
||||
mockStore := &mockAccountStore{
|
||||
subordinateIDs: []uint{1, 2, 3},
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestDataPermissionCallback_SkipForRootUser 测试 root 用户跳过过滤
|
||||
func TestDataPermissionCallback_SkipForRootUser(t *testing.T) {
|
||||
// 创建内存数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建测试表
|
||||
type TestModel struct {
|
||||
ID uint
|
||||
Creator uint
|
||||
Name string
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&TestModel{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 插入测试数据
|
||||
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
|
||||
|
||||
// 创建 mock AccountStore
|
||||
mockStore := &mockAccountStore{
|
||||
subordinateIDs: []uint{1}, // 只有 ID 1
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 设置 root 用户 context
|
||||
ctx := context.Background()
|
||||
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeRoot, 0)
|
||||
|
||||
// 查询数据
|
||||
var results []TestModel
|
||||
err = db.WithContext(ctx).Find(&results).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// root 用户应该看到所有数据
|
||||
assert.Equal(t, 2, len(results))
|
||||
}
|
||||
|
||||
// TestDataPermissionCallback_FilterForNormalUser 测试普通用户过滤
|
||||
func TestDataPermissionCallback_FilterForNormalUser(t *testing.T) {
|
||||
// 创建内存数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建测试表
|
||||
type TestModel struct {
|
||||
ID uint
|
||||
Creator uint
|
||||
Name string
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&TestModel{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 插入测试数据
|
||||
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
|
||||
db.Create(&TestModel{ID: 3, Creator: 3, Name: "test3"})
|
||||
|
||||
// 创建 mock AccountStore
|
||||
mockStore := &mockAccountStore{
|
||||
subordinateIDs: []uint{1, 2}, // 只能看到 1 和 2
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 设置普通用户 context (非 root)
|
||||
ctx := context.Background()
|
||||
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 0)
|
||||
|
||||
// 查询数据
|
||||
var results []TestModel
|
||||
err = db.WithContext(ctx).Find(&results).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 普通用户只能看到自己和下级的数据
|
||||
assert.Equal(t, 2, len(results))
|
||||
assert.Equal(t, uint(1), results[0].Creator)
|
||||
assert.Equal(t, uint(2), results[1].Creator)
|
||||
}
|
||||
|
||||
// TestDataPermissionCallback_SkipWithContext 测试通过 Context 跳过过滤
|
||||
func TestDataPermissionCallback_SkipWithContext(t *testing.T) {
|
||||
// 创建内存数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建测试表
|
||||
type TestModel struct {
|
||||
ID uint
|
||||
Creator uint
|
||||
Name string
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&TestModel{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 插入测试数据
|
||||
db.Create(&TestModel{ID: 1, Creator: 1, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, Creator: 2, Name: "test2"})
|
||||
|
||||
// 创建 mock AccountStore
|
||||
mockStore := &mockAccountStore{
|
||||
subordinateIDs: []uint{1}, // 只有 ID 1
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 设置普通用户 context 并跳过过滤
|
||||
ctx := context.Background()
|
||||
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 0)
|
||||
ctx = SkipDataPermission(ctx)
|
||||
|
||||
// 查询数据
|
||||
var results []TestModel
|
||||
err = db.WithContext(ctx).Find(&results).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 跳过过滤后应该看到所有数据
|
||||
assert.Equal(t, 2, len(results))
|
||||
}
|
||||
|
||||
// TestDataPermissionCallback_WithShopID 测试带 shop_id 的过滤
|
||||
func TestDataPermissionCallback_WithShopID(t *testing.T) {
|
||||
// 创建内存数据库
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建测试表
|
||||
type TestModel struct {
|
||||
ID uint
|
||||
Creator uint
|
||||
ShopID uint
|
||||
Name string
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&TestModel{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 插入测试数据
|
||||
db.Create(&TestModel{ID: 1, Creator: 1, ShopID: 100, Name: "test1"})
|
||||
db.Create(&TestModel{ID: 2, Creator: 2, ShopID: 100, Name: "test2"})
|
||||
db.Create(&TestModel{ID: 3, Creator: 2, ShopID: 200, Name: "test3"}) // 不同 shop_id
|
||||
|
||||
// 创建 mock AccountStore
|
||||
mockStore := &mockAccountStore{
|
||||
subordinateIDs: []uint{1, 2}, // 可以看到 1 和 2
|
||||
}
|
||||
|
||||
// 注册 Callback
|
||||
err = RegisterDataPermissionCallback(db, mockStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 设置普通用户 context (shop_id = 100)
|
||||
ctx := context.Background()
|
||||
ctx = middleware.SetUserContext(ctx, 1, constants.UserTypeAgent, 100)
|
||||
|
||||
// 查询数据
|
||||
var results []TestModel
|
||||
err = db.WithContext(ctx).Find(&results).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 只能看到 shop_id = 100 的数据
|
||||
assert.Equal(t, 2, len(results))
|
||||
for _, r := range results {
|
||||
assert.Equal(t, uint(100), r.ShopID)
|
||||
}
|
||||
}
|
||||
@@ -35,7 +35,7 @@ func GetUserTypeFromContext(ctx context.Context) int {
|
||||
if ctx == nil {
|
||||
return 0
|
||||
}
|
||||
if userType, ok := ctx.Value(constants.ContextKeyUserID).(int); ok {
|
||||
if userType, ok := ctx.Value(constants.ContextKeyUserType).(int); ok {
|
||||
return userType
|
||||
}
|
||||
return 0
|
||||
@@ -84,17 +84,18 @@ type AuthConfig struct {
|
||||
// 验证失败返回 error
|
||||
TokenValidator func(token string) (userID uint, userType int, shopID uint, err error)
|
||||
|
||||
// Skip 跳过认证的路径
|
||||
Skip []string
|
||||
// SkipPaths 跳过认证的路径列表
|
||||
SkipPaths []string
|
||||
}
|
||||
|
||||
// Auth 认证中间件
|
||||
// 从请求中提取 token,验证后将用户信息设置到 context
|
||||
// 所有错误统一返回 AppError,由全局 ErrorHandler 处理
|
||||
func Auth(config AuthConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// 检查是否跳过认证
|
||||
path := c.Path()
|
||||
for _, skipPath := range config.Skip {
|
||||
for _, skipPath := range config.SkipPaths {
|
||||
if path == skipPath {
|
||||
return c.Next()
|
||||
}
|
||||
@@ -110,26 +111,22 @@ func Auth(config AuthConfig) fiber.Handler {
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"code": errors.CodeUnauthorized,
|
||||
"message": "未提供认证令牌",
|
||||
})
|
||||
return errors.New(errors.CodeMissingToken, "未提供认证令牌")
|
||||
}
|
||||
|
||||
// 验证 token
|
||||
if config.TokenValidator == nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
"code": errors.CodeInternalError,
|
||||
"message": "认证验证器未配置",
|
||||
})
|
||||
return errors.New(errors.CodeInternalError, "认证验证器未配置")
|
||||
}
|
||||
|
||||
userID, userType, shopID, err := config.TokenValidator(token)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"code": errors.CodeUnauthorized,
|
||||
"message": "认证令牌无效",
|
||||
})
|
||||
// 如果验证器返回的是 AppError,直接返回
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
return appErr
|
||||
}
|
||||
// 否则包装为 AppError
|
||||
return errors.Wrap(errors.CodeInvalidToken, "认证令牌无效", err)
|
||||
}
|
||||
|
||||
// 将用户信息设置到 context
|
||||
|
||||
@@ -25,16 +25,6 @@ func Success(c *fiber.Ctx, data any) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Error 返回错误响应
|
||||
func Error(c *fiber.Ctx, httpStatus int, code int, message string) error {
|
||||
return c.Status(httpStatus).JSON(Response{
|
||||
Code: code,
|
||||
Data: nil,
|
||||
Message: message,
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
// SuccessWithMessage 返回带自定义消息的成功响应
|
||||
func SuccessWithMessage(c *fiber.Ctx, data any, message string) error {
|
||||
return c.JSON(Response{
|
||||
|
||||
@@ -36,17 +36,8 @@ func BenchmarkSuccess(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkError 测试错误响应性能
|
||||
func BenchmarkError(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
_ = Error(ctx, 400, 1001, "无效的请求")
|
||||
app.ReleaseCtx(ctx)
|
||||
}
|
||||
}
|
||||
// BenchmarkError 基准测试已被删除 - Error() 函数已在重构中移除
|
||||
// 错误响应现在由全局 ErrorHandler 统一处理
|
||||
|
||||
// BenchmarkSuccessWithMessage 测试带自定义消息的成功响应性能
|
||||
func BenchmarkSuccessWithMessage(b *testing.B) {
|
||||
|
||||
@@ -111,107 +111,9 @@ func TestSuccess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestError 测试错误响应(T035)
|
||||
func TestError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
httpStatus int
|
||||
code int
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "internal server error",
|
||||
httpStatus: 500,
|
||||
code: errors.CodeInternalError,
|
||||
message: "Internal server error occurred",
|
||||
},
|
||||
{
|
||||
name: "missing token error",
|
||||
httpStatus: 401,
|
||||
code: errors.CodeMissingToken,
|
||||
message: "Authentication token is missing",
|
||||
},
|
||||
{
|
||||
name: "invalid token error",
|
||||
httpStatus: 401,
|
||||
code: errors.CodeInvalidToken,
|
||||
message: "Token is invalid or expired",
|
||||
},
|
||||
{
|
||||
name: "rate limit error",
|
||||
httpStatus: 429,
|
||||
code: errors.CodeTooManyRequests,
|
||||
message: "Too many requests, please try again later",
|
||||
},
|
||||
{
|
||||
name: "service unavailable error",
|
||||
httpStatus: 503,
|
||||
code: errors.CodeAuthServiceUnavailable,
|
||||
message: "Authentication service is currently unavailable",
|
||||
},
|
||||
{
|
||||
name: "bad request error",
|
||||
httpStatus: 400,
|
||||
code: 2000,
|
||||
message: "Invalid request parameters",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return Error(c, tt.httpStatus, tt.code, tt.message)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 验证 HTTP 状态码
|
||||
if resp.StatusCode != tt.httpStatus {
|
||||
t.Errorf("Expected status code %d, got %d", tt.httpStatus, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 验证响应头
|
||||
if resp.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
// 解析响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := sonic.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// 验证响应结构
|
||||
if response.Code != tt.code {
|
||||
t.Errorf("Expected code %d, got %d", tt.code, response.Code)
|
||||
}
|
||||
|
||||
if response.Message != tt.message {
|
||||
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message)
|
||||
}
|
||||
|
||||
if response.Data != nil {
|
||||
t.Errorf("Expected data to be nil in error response, got %v", response.Data)
|
||||
}
|
||||
|
||||
// 验证时间戳格式 RFC3339
|
||||
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
|
||||
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
// TestError 测试已被删除 - Error() 函数已在重构中移除
|
||||
// 错误响应现在由全局 ErrorHandler 统一处理
|
||||
// 相关测试已迁移到 pkg/errors/handler_test.go
|
||||
|
||||
// TestSuccessWithMessage 测试带自定义消息的成功响应(T034)
|
||||
func TestSuccessWithMessage(t *testing.T) {
|
||||
@@ -413,10 +315,8 @@ func TestMultipleResponses(t *testing.T) {
|
||||
callCount := 0
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
callCount++
|
||||
if callCount%2 == 0 {
|
||||
return Success(c, map[string]int{"count": callCount})
|
||||
}
|
||||
return Error(c, 500, errors.CodeInternalError, "error occurred")
|
||||
// 只返回成功响应,因为 Error() 函数已被删除
|
||||
return Success(c, map[string]int{"count": callCount})
|
||||
})
|
||||
|
||||
// 发送多个请求
|
||||
|
||||
Reference in New Issue
Block a user