package middleware import ( "context" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/gofiber/fiber/v2" ) // SetUserContext 将用户信息设置到 context 中 // 在 Auth 中间件认证成功后调用 func SetUserContext(ctx context.Context, userID uint, userType int, shopID uint) context.Context { ctx = context.WithValue(ctx, constants.ContextKeyUserID, userID) ctx = context.WithValue(ctx, constants.ContextKeyUserType, userType) ctx = context.WithValue(ctx, constants.ContextKeyShopID, shopID) return ctx } // GetUserIDFromContext 从 context 中提取用户 ID // 如果未设置,返回 0 func GetUserIDFromContext(ctx context.Context) uint { if ctx == nil { return 0 } if userID, ok := ctx.Value(constants.ContextKeyUserID).(uint); ok { return userID } return 0 } // GetUserTypeFromContext 从 context 中提取用户类型 // 如果未设置,返回 0 func GetUserTypeFromContext(ctx context.Context) int { if ctx == nil { return 0 } if userType, ok := ctx.Value(constants.ContextKeyUserType).(int); ok { return userType } return 0 } // GetShopIDFromContext 从 context 中提取店铺 ID // 如果未设置,返回 0 func GetShopIDFromContext(ctx context.Context) uint { if ctx == nil { return 0 } if shopID, ok := ctx.Value(constants.ContextKeyShopID).(uint); ok { return shopID } return 0 } // IsRootUser 检查当前用户是否为 root 用户 // root 用户跳过数据权限过滤 func IsRootUser(ctx context.Context) bool { userType := GetUserTypeFromContext(ctx) return userType == constants.UserTypeSuperAdmin } // SetUserToFiberContext 将用户信息设置到 Fiber context 的 Locals 中 // 同时也设置到标准 context 中,便于 GORM 查询使用 func SetUserToFiberContext(c *fiber.Ctx, userID uint, userType int, shopID uint) { // 设置到 Fiber Locals c.Locals(constants.ContextKeyUserID, userID) c.Locals(constants.ContextKeyUserType, userType) c.Locals(constants.ContextKeyShopID, shopID) // 设置到标准 context(用于 GORM 数据权限过滤) ctx := SetUserContext(c.UserContext(), userID, userType, shopID) c.SetUserContext(ctx) } // AuthConfig Auth 中间件配置 type AuthConfig struct { // TokenExtractor 自定义 token 提取函数 // 如果为空,默认从 Authorization header 提取 Bearer token TokenExtractor func(c *fiber.Ctx) string // TokenValidator token 验证函数 // 验证成功返回用户 ID、用户类型、店铺 ID // 验证失败返回 error TokenValidator func(token string) (userID uint, userType int, shopID uint, err error) // SkipPaths 跳过认证的路径列表 SkipPaths []string } // Auth 认证中间件 // 从请求中提取 token,验证后将用户信息设置到 context // 所有错误统一返回 AppError,由全局 ErrorHandler 处理 func Auth(config AuthConfig) fiber.Handler { return func(c *fiber.Ctx) error { // 检查是否跳过认证 path := c.Path() for _, skipPath := range config.SkipPaths { if path == skipPath { return c.Next() } } // 提取 token var token string if config.TokenExtractor != nil { token = config.TokenExtractor(c) } else { // 默认从 Authorization header 提取 Bearer token token = extractBearerToken(c) } if token == "" { return errors.New(errors.CodeMissingToken, "未提供认证令牌") } // 验证 token if config.TokenValidator == nil { return errors.New(errors.CodeInternalError, "认证验证器未配置") } userID, userType, shopID, err := config.TokenValidator(token) if err != nil { // 如果验证器返回的是 AppError,直接返回 if appErr, ok := err.(*errors.AppError); ok { return appErr } // 否则包装为 AppError return errors.Wrap(errors.CodeInvalidToken, "认证令牌无效", err) } // 将用户信息设置到 context SetUserToFiberContext(c, userID, userType, shopID) return c.Next() } } // extractBearerToken 从 Authorization header 提取 Bearer token func extractBearerToken(c *fiber.Ctx) string { auth := c.Get("Authorization") if len(auth) > 7 && auth[:7] == "Bearer " { return auth[7:] } return "" }