package middleware import ( "context" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/gofiber/fiber/v2" ) // UserContextInfo 用户上下文信息 type UserContextInfo struct { UserID uint UserType int ShopID uint EnterpriseID uint CustomerID uint } // SetUserContext 将用户信息设置到 context 中 // 在 Auth 中间件认证成功后调用 func SetUserContext(ctx context.Context, info *UserContextInfo) context.Context { ctx = context.WithValue(ctx, constants.ContextKeyUserID, info.UserID) ctx = context.WithValue(ctx, constants.ContextKeyUserType, info.UserType) ctx = context.WithValue(ctx, constants.ContextKeyShopID, info.ShopID) ctx = context.WithValue(ctx, constants.ContextKeyEnterpriseID, info.EnterpriseID) ctx = context.WithValue(ctx, constants.ContextKeyCustomerID, info.CustomerID) return ctx } // GetUserIDFromContext 从 context 中提取用户 ID // 如果未设置,返回 0 func GetUserIDFromContext(ctx context.Context) uint { if ctx == nil { return 0 } if userID, ok := ctx.Value(constants.ContextKeyUserID).(uint); ok { return userID } return 0 } // GetUserTypeFromContext 从 context 中提取用户类型 // 如果未设置,返回 0 func GetUserTypeFromContext(ctx context.Context) int { if ctx == nil { return 0 } if userType, ok := ctx.Value(constants.ContextKeyUserType).(int); ok { return userType } return 0 } // GetShopIDFromContext 从 context 中提取店铺 ID // 如果未设置,返回 0 func GetShopIDFromContext(ctx context.Context) uint { if ctx == nil { return 0 } if shopID, ok := ctx.Value(constants.ContextKeyShopID).(uint); ok { return shopID } return 0 } // GetEnterpriseIDFromContext 从 context 中提取企业 ID // 如果未设置,返回 0 func GetEnterpriseIDFromContext(ctx context.Context) uint { if ctx == nil { return 0 } if enterpriseID, ok := ctx.Value(constants.ContextKeyEnterpriseID).(uint); ok { return enterpriseID } return 0 } // GetCustomerIDFromContext 从 context 中提取个人客户 ID // 如果未设置,返回 0 func GetCustomerIDFromContext(ctx context.Context) uint { if ctx == nil { return 0 } if customerID, ok := ctx.Value(constants.ContextKeyCustomerID).(uint); ok { return customerID } return 0 } // IsRootUser 检查当前用户是否为 root 用户 // root 用户跳过数据权限过滤 func IsRootUser(ctx context.Context) bool { userType := GetUserTypeFromContext(ctx) return userType == constants.UserTypeSuperAdmin } // SetUserToFiberContext 将用户信息设置到 Fiber context 的 Locals 中 // 同时也设置到标准 context 中,便于 GORM 查询使用 func SetUserToFiberContext(c *fiber.Ctx, info *UserContextInfo) { // 设置到 Fiber Locals c.Locals(constants.ContextKeyUserID, info.UserID) c.Locals(constants.ContextKeyUserType, info.UserType) c.Locals(constants.ContextKeyShopID, info.ShopID) c.Locals(constants.ContextKeyEnterpriseID, info.EnterpriseID) c.Locals(constants.ContextKeyCustomerID, info.CustomerID) // 设置到标准 context(用于 GORM 数据权限过滤) ctx := SetUserContext(c.UserContext(), info) c.SetUserContext(ctx) } // AuthConfig Auth 中间件配置 type AuthConfig struct { // TokenExtractor 自定义 token 提取函数 // 如果为空,默认从 Authorization header 提取 Bearer token TokenExtractor func(c *fiber.Ctx) string // TokenValidator token 验证函数 // 验证成功返回用户上下文信息 // 验证失败返回 error TokenValidator func(token string) (*UserContextInfo, error) // SkipPaths 跳过认证的路径列表 SkipPaths []string } // Auth 认证中间件 // 从请求中提取 token,验证后将用户信息设置到 context // 所有错误统一返回 AppError,由全局 ErrorHandler 处理 func Auth(config AuthConfig) fiber.Handler { return func(c *fiber.Ctx) error { // 检查是否跳过认证 path := c.Path() for _, skipPath := range config.SkipPaths { if path == skipPath { return c.Next() } } // 提取 token var token string if config.TokenExtractor != nil { token = config.TokenExtractor(c) } else { // 默认从 Authorization header 提取 Bearer token token = extractBearerToken(c) } if token == "" { return errors.New(errors.CodeMissingToken, "未提供认证令牌") } // 验证 token if config.TokenValidator == nil { return errors.New(errors.CodeInternalError, "认证验证器未配置") } userInfo, err := config.TokenValidator(token) if err != nil { // 如果验证器返回的是 AppError,直接返回 if appErr, ok := err.(*errors.AppError); ok { return appErr } // 否则包装为 AppError return errors.Wrap(errors.CodeInvalidToken, "认证令牌无效", err) } // 将用户信息设置到 context SetUserToFiberContext(c, userInfo) return c.Next() } } // extractBearerToken 从 Authorization header 提取 Bearer token func extractBearerToken(c *fiber.Ctx) string { auth := c.Get("Authorization") if len(auth) > 7 && auth[:7] == "Bearer " { return auth[7:] } return "" } // NewSimpleUserContext 创建简单的用户上下文信息(仅包含基本字段) // 这是一个兼容性辅助函数,用于快速创建只包含 userID, userType, shopID 的上下文 // 适用于测试代码和不需要完整上下文信息的场景 func NewSimpleUserContext(userID uint, userType int, shopID uint) *UserContextInfo { return &UserContextInfo{ UserID: userID, UserType: userType, ShopID: shopID, EnterpriseID: 0, CustomerID: 0, } }