Files
junhong_cmp_fiber/pkg/middleware/auth.go
huang a36e4a79c0 实现用户和组织模型(店铺、企业、个人客户)
核心功能:
- 实现 7 级店铺层级体系(Shop 模型 + 层级校验)
- 实现企业管理模型(Enterprise 模型)
- 实现个人客户管理模型(PersonalCustomer 模型)
- 重构 Account 模型关联关系(基于 EnterpriseID 而非 ParentID)
- 完整的 Store 层和 Service 层实现
- 递归查询下级店铺功能(含 Redis 缓存)
- 全面的单元测试覆盖(Shop/Enterprise/PersonalCustomer Store + Shop Service)

技术要点:
- 显式指定所有 GORM 模型的数据库字段名(column: 标签)
- 统一的字段命名规范(数据库用 snake_case,Go 用 PascalCase)
- 完整的中文字段注释和业务逻辑说明
- 100% 测试覆盖(20+ 测试用例全部通过)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-09 18:02:46 +08:00

147 lines
4.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/gofiber/fiber/v2"
)
// SetUserContext 将用户信息设置到 context 中
// 在 Auth 中间件认证成功后调用
func SetUserContext(ctx context.Context, userID uint, userType int, shopID uint) context.Context {
ctx = context.WithValue(ctx, constants.ContextKeyUserID, userID)
ctx = context.WithValue(ctx, constants.ContextKeyUserType, userType)
ctx = context.WithValue(ctx, constants.ContextKeyShopID, shopID)
return ctx
}
// GetUserIDFromContext 从 context 中提取用户 ID
// 如果未设置,返回 0
func GetUserIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if userID, ok := ctx.Value(constants.ContextKeyUserID).(uint); ok {
return userID
}
return 0
}
// GetUserTypeFromContext 从 context 中提取用户类型
// 如果未设置,返回 0
func GetUserTypeFromContext(ctx context.Context) int {
if ctx == nil {
return 0
}
if userType, ok := ctx.Value(constants.ContextKeyUserType).(int); ok {
return userType
}
return 0
}
// GetShopIDFromContext 从 context 中提取店铺 ID
// 如果未设置,返回 0
func GetShopIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if shopID, ok := ctx.Value(constants.ContextKeyShopID).(uint); ok {
return shopID
}
return 0
}
// IsRootUser 检查当前用户是否为 root 用户
// root 用户跳过数据权限过滤
func IsRootUser(ctx context.Context) bool {
userType := GetUserTypeFromContext(ctx)
return userType == constants.UserTypeSuperAdmin
}
// SetUserToFiberContext 将用户信息设置到 Fiber context 的 Locals 中
// 同时也设置到标准 context 中,便于 GORM 查询使用
func SetUserToFiberContext(c *fiber.Ctx, userID uint, userType int, shopID uint) {
// 设置到 Fiber Locals
c.Locals(constants.ContextKeyUserID, userID)
c.Locals(constants.ContextKeyUserType, userType)
c.Locals(constants.ContextKeyShopID, shopID)
// 设置到标准 context用于 GORM 数据权限过滤)
ctx := SetUserContext(c.UserContext(), userID, userType, shopID)
c.SetUserContext(ctx)
}
// AuthConfig Auth 中间件配置
type AuthConfig struct {
// TokenExtractor 自定义 token 提取函数
// 如果为空,默认从 Authorization header 提取 Bearer token
TokenExtractor func(c *fiber.Ctx) string
// TokenValidator token 验证函数
// 验证成功返回用户 ID、用户类型、店铺 ID
// 验证失败返回 error
TokenValidator func(token string) (userID uint, userType int, shopID uint, err error)
// SkipPaths 跳过认证的路径列表
SkipPaths []string
}
// Auth 认证中间件
// 从请求中提取 token验证后将用户信息设置到 context
// 所有错误统一返回 AppError由全局 ErrorHandler 处理
func Auth(config AuthConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// 检查是否跳过认证
path := c.Path()
for _, skipPath := range config.SkipPaths {
if path == skipPath {
return c.Next()
}
}
// 提取 token
var token string
if config.TokenExtractor != nil {
token = config.TokenExtractor(c)
} else {
// 默认从 Authorization header 提取 Bearer token
token = extractBearerToken(c)
}
if token == "" {
return errors.New(errors.CodeMissingToken, "未提供认证令牌")
}
// 验证 token
if config.TokenValidator == nil {
return errors.New(errors.CodeInternalError, "认证验证器未配置")
}
userID, userType, shopID, err := config.TokenValidator(token)
if err != nil {
// 如果验证器返回的是 AppError直接返回
if appErr, ok := err.(*errors.AppError); ok {
return appErr
}
// 否则包装为 AppError
return errors.Wrap(errors.CodeInvalidToken, "认证令牌无效", err)
}
// 将用户信息设置到 context
SetUserToFiberContext(c, userID, userType, shopID)
return c.Next()
}
}
// extractBearerToken 从 Authorization header 提取 Bearer token
func extractBearerToken(c *fiber.Ctx) string {
auth := c.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " {
return auth[7:]
}
return ""
}