Files
junhong_cmp_fiber/pkg/middleware/auth.go
huang 03a0960c4d
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 7m2s
refactor: 数据权限过滤从 GORM Callback 改为 Store 层显式调用
- 移除 RegisterDataPermissionCallback 和 SkipDataPermission 机制
- 在 Auth 中间件预计算 SubordinateShopIDs 并注入 Context
- 新增 ApplyShopFilter/ApplyEnterpriseFilter/ApplyOwnerShopFilter 等 Helper 函数
- 所有 Store 层查询方法显式调用数据权限过滤函数
- 权限检查函数 CanManageShop/CanManageEnterprise 改为从 Context 获取数据

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-26 16:38:52 +08:00

262 lines
7.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/gofiber/fiber/v2"
"go.uber.org/zap"
)
// UserContextInfo 用户上下文信息
type UserContextInfo struct {
UserID uint
UserType int
ShopID uint
EnterpriseID uint
CustomerID uint
SubordinateShopIDs []uint // 代理用户的下级店铺ID列表nil 表示不受数据权限限制
}
// SetUserContext 将用户信息设置到 context 中
// 在 Auth 中间件认证成功后调用
func SetUserContext(ctx context.Context, info *UserContextInfo) context.Context {
ctx = context.WithValue(ctx, constants.ContextKeyUserID, info.UserID)
ctx = context.WithValue(ctx, constants.ContextKeyUserType, info.UserType)
ctx = context.WithValue(ctx, constants.ContextKeyShopID, info.ShopID)
ctx = context.WithValue(ctx, constants.ContextKeyEnterpriseID, info.EnterpriseID)
ctx = context.WithValue(ctx, constants.ContextKeyCustomerID, info.CustomerID)
// SubordinateShopIDs: nil 表示不限制,空切片表示无权限
if info.SubordinateShopIDs != nil {
ctx = context.WithValue(ctx, constants.ContextKeySubordinateShopIDs, info.SubordinateShopIDs)
}
return ctx
}
// GetUserIDFromContext 从 context 中提取用户 ID
// 如果未设置,返回 0
func GetUserIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if userID, ok := ctx.Value(constants.ContextKeyUserID).(uint); ok {
return userID
}
return 0
}
// GetUserTypeFromContext 从 context 中提取用户类型
// 如果未设置,返回 0
func GetUserTypeFromContext(ctx context.Context) int {
if ctx == nil {
return 0
}
if userType, ok := ctx.Value(constants.ContextKeyUserType).(int); ok {
return userType
}
return 0
}
// GetShopIDFromContext 从 context 中提取店铺 ID
// 如果未设置,返回 0
func GetShopIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if shopID, ok := ctx.Value(constants.ContextKeyShopID).(uint); ok {
return shopID
}
return 0
}
// GetEnterpriseIDFromContext 从 context 中提取企业 ID
// 如果未设置,返回 0
func GetEnterpriseIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if enterpriseID, ok := ctx.Value(constants.ContextKeyEnterpriseID).(uint); ok {
return enterpriseID
}
return 0
}
// GetCustomerIDFromContext 从 context 中提取个人客户 ID
// 如果未设置,返回 0
func GetCustomerIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if customerID, ok := ctx.Value(constants.ContextKeyCustomerID).(uint); ok {
return customerID
}
return 0
}
// IsRootUser 检查当前用户是否为 root 用户
// root 用户跳过数据权限过滤
func IsRootUser(ctx context.Context) bool {
userType := GetUserTypeFromContext(ctx)
return userType == constants.UserTypeSuperAdmin
}
func GetRequestIDFromContext(ctx context.Context) *string {
if ctx == nil {
return nil
}
if requestID, ok := ctx.Value(constants.ContextKeyRequestID).(string); ok {
return &requestID
}
return nil
}
func GetIPFromContext(ctx context.Context) *string {
if ctx == nil {
return nil
}
if ip, ok := ctx.Value(constants.ContextKeyIP).(string); ok {
return &ip
}
return nil
}
func GetUserAgentFromContext(ctx context.Context) *string {
if ctx == nil {
return nil
}
if userAgent, ok := ctx.Value(constants.ContextKeyUserAgent).(string); ok {
return &userAgent
}
return nil
}
// SetUserToFiberContext 将用户信息设置到 Fiber context 的 Locals 中
// 同时也设置到标准 context 中,便于 GORM 查询使用
func SetUserToFiberContext(c *fiber.Ctx, info *UserContextInfo) {
// 设置到 Fiber Locals
c.Locals(constants.ContextKeyUserID, info.UserID)
c.Locals(constants.ContextKeyUserType, info.UserType)
c.Locals(constants.ContextKeyShopID, info.ShopID)
c.Locals(constants.ContextKeyEnterpriseID, info.EnterpriseID)
c.Locals(constants.ContextKeyCustomerID, info.CustomerID)
if info.SubordinateShopIDs != nil {
c.Locals(constants.ContextKeySubordinateShopIDs, info.SubordinateShopIDs)
}
// 设置到标准 context用于数据权限过滤
ctx := SetUserContext(c.UserContext(), info)
c.SetUserContext(ctx)
}
// AuthShopStoreInterface 店铺存储接口
// 用于 Auth 中间件获取下级店铺 ID避免循环依赖
type AuthShopStoreInterface interface {
GetSubordinateShopIDs(ctx context.Context, shopID uint) ([]uint, error)
}
// AuthConfig Auth 中间件配置
type AuthConfig struct {
// TokenExtractor 自定义 token 提取函数
// 如果为空,默认从 Authorization header 提取 Bearer token
TokenExtractor func(c *fiber.Ctx) string
// TokenValidator token 验证函数
// 验证成功返回用户上下文信息
// 验证失败返回 error
TokenValidator func(token string) (*UserContextInfo, error)
// SkipPaths 跳过认证的路径列表
SkipPaths []string
// ShopStore 店铺存储,用于预计算代理用户的下级店铺 ID
// 可选,不传则不预计算 SubordinateShopIDs
ShopStore AuthShopStoreInterface
}
// Auth 认证中间件
// 从请求中提取 token验证后将用户信息设置到 context
// 所有错误统一返回 AppError由全局 ErrorHandler 处理
func Auth(config AuthConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// 检查是否跳过认证
path := c.Path()
for _, skipPath := range config.SkipPaths {
if path == skipPath {
return c.Next()
}
}
// 提取 token
var token string
if config.TokenExtractor != nil {
token = config.TokenExtractor(c)
} else {
// 默认从 Authorization header 提取 Bearer token
token = extractBearerToken(c)
}
if token == "" {
return errors.New(errors.CodeMissingToken, "未提供认证令牌")
}
// 验证 token
if config.TokenValidator == nil {
return errors.New(errors.CodeInternalError, "认证验证器未配置")
}
userInfo, err := config.TokenValidator(token)
if err != nil {
// 如果验证器返回的是 AppError直接返回
if appErr, ok := err.(*errors.AppError); ok {
return appErr
}
// 否则包装为 AppError
return errors.Wrap(errors.CodeInvalidToken, err, "认证令牌无效")
}
// 预计算代理用户的下级店铺 ID
if config.ShopStore != nil &&
userInfo.UserType == constants.UserTypeAgent &&
userInfo.ShopID > 0 {
shopIDs, err := config.ShopStore.GetSubordinateShopIDs(c.UserContext(), userInfo.ShopID)
if err != nil {
// 降级处理:只包含自己的店铺 ID
shopIDs = []uint{userInfo.ShopID}
logger.GetAppLogger().Warn("预计算下级店铺失败,降级为只包含自己",
zap.Uint("shop_id", userInfo.ShopID),
zap.Error(err))
}
userInfo.SubordinateShopIDs = shopIDs
}
// 将用户信息设置到 context
SetUserToFiberContext(c, userInfo)
return c.Next()
}
}
// extractBearerToken 从 Authorization header 提取 Bearer token
func extractBearerToken(c *fiber.Ctx) string {
auth := c.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " {
return auth[7:]
}
return ""
}
// NewSimpleUserContext 创建简单的用户上下文信息(仅包含基本字段)
// 这是一个兼容性辅助函数,用于快速创建只包含 userID, userType, shopID 的上下文
// 适用于测试代码和不需要完整上下文信息的场景
func NewSimpleUserContext(userID uint, userType int, shopID uint) *UserContextInfo {
return &UserContextInfo{
UserID: userID,
UserType: userType,
ShopID: shopID,
EnterpriseID: 0,
CustomerID: 0,
}
}