Files
junhong_cmp_fiber/pkg/middleware/auth.go
huang 80f560df33
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 6m17s
refactor(account): 统一账号管理API、完善权限检查和操作审计
- 合并 customer_account 和 shop_account 路由到统一的 account 接口
- 新增统一认证接口 (auth handler)
- 实现越权防护中间件和权限检查工具函数
- 新增操作审计日志模型和服务
- 更新数据库迁移 (版本 39: account_operation_log 表)
- 补充集成测试覆盖权限检查和审计日志场景
2026-02-02 17:23:20 +08:00

227 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/gofiber/fiber/v2"
)
// UserContextInfo 用户上下文信息
type UserContextInfo struct {
UserID uint
UserType int
ShopID uint
EnterpriseID uint
CustomerID uint
}
// SetUserContext 将用户信息设置到 context 中
// 在 Auth 中间件认证成功后调用
func SetUserContext(ctx context.Context, info *UserContextInfo) context.Context {
ctx = context.WithValue(ctx, constants.ContextKeyUserID, info.UserID)
ctx = context.WithValue(ctx, constants.ContextKeyUserType, info.UserType)
ctx = context.WithValue(ctx, constants.ContextKeyShopID, info.ShopID)
ctx = context.WithValue(ctx, constants.ContextKeyEnterpriseID, info.EnterpriseID)
ctx = context.WithValue(ctx, constants.ContextKeyCustomerID, info.CustomerID)
return ctx
}
// GetUserIDFromContext 从 context 中提取用户 ID
// 如果未设置,返回 0
func GetUserIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if userID, ok := ctx.Value(constants.ContextKeyUserID).(uint); ok {
return userID
}
return 0
}
// GetUserTypeFromContext 从 context 中提取用户类型
// 如果未设置,返回 0
func GetUserTypeFromContext(ctx context.Context) int {
if ctx == nil {
return 0
}
if userType, ok := ctx.Value(constants.ContextKeyUserType).(int); ok {
return userType
}
return 0
}
// GetShopIDFromContext 从 context 中提取店铺 ID
// 如果未设置,返回 0
func GetShopIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if shopID, ok := ctx.Value(constants.ContextKeyShopID).(uint); ok {
return shopID
}
return 0
}
// GetEnterpriseIDFromContext 从 context 中提取企业 ID
// 如果未设置,返回 0
func GetEnterpriseIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if enterpriseID, ok := ctx.Value(constants.ContextKeyEnterpriseID).(uint); ok {
return enterpriseID
}
return 0
}
// GetCustomerIDFromContext 从 context 中提取个人客户 ID
// 如果未设置,返回 0
func GetCustomerIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if customerID, ok := ctx.Value(constants.ContextKeyCustomerID).(uint); ok {
return customerID
}
return 0
}
// IsRootUser 检查当前用户是否为 root 用户
// root 用户跳过数据权限过滤
func IsRootUser(ctx context.Context) bool {
userType := GetUserTypeFromContext(ctx)
return userType == constants.UserTypeSuperAdmin
}
func GetRequestIDFromContext(ctx context.Context) *string {
if ctx == nil {
return nil
}
if requestID, ok := ctx.Value(constants.ContextKeyRequestID).(string); ok {
return &requestID
}
return nil
}
func GetIPFromContext(ctx context.Context) *string {
if ctx == nil {
return nil
}
if ip, ok := ctx.Value(constants.ContextKeyIP).(string); ok {
return &ip
}
return nil
}
func GetUserAgentFromContext(ctx context.Context) *string {
if ctx == nil {
return nil
}
if userAgent, ok := ctx.Value(constants.ContextKeyUserAgent).(string); ok {
return &userAgent
}
return nil
}
// SetUserToFiberContext 将用户信息设置到 Fiber context 的 Locals 中
// 同时也设置到标准 context 中,便于 GORM 查询使用
func SetUserToFiberContext(c *fiber.Ctx, info *UserContextInfo) {
// 设置到 Fiber Locals
c.Locals(constants.ContextKeyUserID, info.UserID)
c.Locals(constants.ContextKeyUserType, info.UserType)
c.Locals(constants.ContextKeyShopID, info.ShopID)
c.Locals(constants.ContextKeyEnterpriseID, info.EnterpriseID)
c.Locals(constants.ContextKeyCustomerID, info.CustomerID)
// 设置到标准 context用于 GORM 数据权限过滤)
ctx := SetUserContext(c.UserContext(), info)
c.SetUserContext(ctx)
}
// AuthConfig Auth 中间件配置
type AuthConfig struct {
// TokenExtractor 自定义 token 提取函数
// 如果为空,默认从 Authorization header 提取 Bearer token
TokenExtractor func(c *fiber.Ctx) string
// TokenValidator token 验证函数
// 验证成功返回用户上下文信息
// 验证失败返回 error
TokenValidator func(token string) (*UserContextInfo, error)
// SkipPaths 跳过认证的路径列表
SkipPaths []string
}
// Auth 认证中间件
// 从请求中提取 token验证后将用户信息设置到 context
// 所有错误统一返回 AppError由全局 ErrorHandler 处理
func Auth(config AuthConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// 检查是否跳过认证
path := c.Path()
for _, skipPath := range config.SkipPaths {
if path == skipPath {
return c.Next()
}
}
// 提取 token
var token string
if config.TokenExtractor != nil {
token = config.TokenExtractor(c)
} else {
// 默认从 Authorization header 提取 Bearer token
token = extractBearerToken(c)
}
if token == "" {
return errors.New(errors.CodeMissingToken, "未提供认证令牌")
}
// 验证 token
if config.TokenValidator == nil {
return errors.New(errors.CodeInternalError, "认证验证器未配置")
}
userInfo, err := config.TokenValidator(token)
if err != nil {
// 如果验证器返回的是 AppError直接返回
if appErr, ok := err.(*errors.AppError); ok {
return appErr
}
// 否则包装为 AppError
return errors.Wrap(errors.CodeInvalidToken, err, "认证令牌无效")
}
// 将用户信息设置到 context
SetUserToFiberContext(c, userInfo)
return c.Next()
}
}
// extractBearerToken 从 Authorization header 提取 Bearer token
func extractBearerToken(c *fiber.Ctx) string {
auth := c.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " {
return auth[7:]
}
return ""
}
// NewSimpleUserContext 创建简单的用户上下文信息(仅包含基本字段)
// 这是一个兼容性辅助函数,用于快速创建只包含 userID, userType, shopID 的上下文
// 适用于测试代码和不需要完整上下文信息的场景
func NewSimpleUserContext(userID uint, userType int, shopID uint) *UserContextInfo {
return &UserContextInfo{
UserID: userID,
UserType: userType,
ShopID: shopID,
EnterpriseID: 0,
CustomerID: 0,
}
}