package middleware import ( "context" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/gofiber/fiber/v2" ) // SetUserContext 将用户信息设置到 context 中 // 在 Auth 中间件认证成功后调用 func SetUserContext(ctx context.Context, userID uint, userType int, shopID uint) context.Context { ctx = context.WithValue(ctx, constants.ContextKeyUserID, userID) ctx = context.WithValue(ctx, constants.ContextKeyUserType, userType) ctx = context.WithValue(ctx, constants.ContextKeyShopID, shopID) return ctx } // GetUserIDFromContext 从 context 中提取用户 ID // 如果未设置,返回 0 func GetUserIDFromContext(ctx context.Context) uint { if ctx == nil { return 0 } if userID, ok := ctx.Value(constants.ContextKeyUserID).(uint); ok { return userID } return 0 } // GetUserTypeFromContext 从 context 中提取用户类型 // 如果未设置,返回 0 func GetUserTypeFromContext(ctx context.Context) int { if ctx == nil { return 0 } if userType, ok := ctx.Value(constants.ContextKeyUserID).(int); ok { return userType } return 0 } // GetShopIDFromContext 从 context 中提取店铺 ID // 如果未设置,返回 0 func GetShopIDFromContext(ctx context.Context) uint { if ctx == nil { return 0 } if shopID, ok := ctx.Value(constants.ContextKeyShopID).(uint); ok { return shopID } return 0 } // IsRootUser 检查当前用户是否为 root 用户 // root 用户跳过数据权限过滤 func IsRootUser(ctx context.Context) bool { userType := GetUserTypeFromContext(ctx) return userType == constants.UserTypeRoot } // SetUserToFiberContext 将用户信息设置到 Fiber context 的 Locals 中 // 同时也设置到标准 context 中,便于 GORM 查询使用 func SetUserToFiberContext(c *fiber.Ctx, userID uint, userType int, shopID uint) { // 设置到 Fiber Locals c.Locals(constants.ContextKeyUserID, userID) c.Locals(constants.ContextKeyUserType, userType) c.Locals(constants.ContextKeyShopID, shopID) // 设置到标准 context(用于 GORM 数据权限过滤) ctx := SetUserContext(c.UserContext(), userID, userType, shopID) c.SetUserContext(ctx) } // AuthConfig Auth 中间件配置 type AuthConfig struct { // TokenExtractor 自定义 token 提取函数 // 如果为空,默认从 Authorization header 提取 Bearer token TokenExtractor func(c *fiber.Ctx) string // TokenValidator token 验证函数 // 验证成功返回用户 ID、用户类型、店铺 ID // 验证失败返回 error TokenValidator func(token string) (userID uint, userType int, shopID uint, err error) // Skip 跳过认证的路径 Skip []string } // Auth 认证中间件 // 从请求中提取 token,验证后将用户信息设置到 context func Auth(config AuthConfig) fiber.Handler { return func(c *fiber.Ctx) error { // 检查是否跳过认证 path := c.Path() for _, skipPath := range config.Skip { if path == skipPath { return c.Next() } } // 提取 token var token string if config.TokenExtractor != nil { token = config.TokenExtractor(c) } else { // 默认从 Authorization header 提取 Bearer token token = extractBearerToken(c) } if token == "" { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ "code": errors.CodeUnauthorized, "message": "未提供认证令牌", }) } // 验证 token if config.TokenValidator == nil { return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ "code": errors.CodeInternalError, "message": "认证验证器未配置", }) } userID, userType, shopID, err := config.TokenValidator(token) if err != nil { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ "code": errors.CodeUnauthorized, "message": "认证令牌无效", }) } // 将用户信息设置到 context SetUserToFiberContext(c, userID, userType, shopID) return c.Next() } } // extractBearerToken 从 Authorization header 提取 Bearer token func extractBearerToken(c *fiber.Ctx) string { auth := c.Get("Authorization") if len(auth) > 7 && auth[:7] == "Bearer " { return auth[7:] } return "" }