Files
junhong_cmp_fiber/pkg/middleware/auth.go
huang eaa70ac255 feat: 实现 RBAC 权限系统和数据权限控制 (004-rbac-data-permission)
主要功能:
- 实现完整的 RBAC 权限系统(账号、角色、权限的多对多关联)
- 基于 owner_id + shop_id 的自动数据权限过滤
- 使用 PostgreSQL WITH RECURSIVE 查询下级账号
- Redis 缓存优化下级账号查询性能(30分钟过期)
- 支持多租户数据隔离和层级权限管理

技术实现:
- 新增 Account、Role、Permission 模型及关联关系表
- 实现 GORM Scopes 自动应用数据权限过滤
- 添加数据库迁移脚本(000002_rbac_data_permission、000003_add_owner_id_shop_id)
- 完善错误码定义(1010-1027 为 RBAC 相关错误)
- 重构 main.go 采用函数拆分提高可读性

测试覆盖:
- 添加 Account、Role、Permission 的集成测试
- 添加数据权限过滤的单元测试和集成测试
- 添加下级账号查询和缓存的单元测试
- 添加 API 回归测试确保向后兼容

文档更新:
- 更新 README.md 添加 RBAC 功能说明
- 更新 CLAUDE.md 添加技术栈和开发原则
- 添加 docs/004-rbac-data-permission/ 功能总结和使用指南

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 16:44:06 +08:00

150 lines
4.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/gofiber/fiber/v2"
)
// SetUserContext 将用户信息设置到 context 中
// 在 Auth 中间件认证成功后调用
func SetUserContext(ctx context.Context, userID uint, userType int, shopID uint) context.Context {
ctx = context.WithValue(ctx, constants.ContextKeyUserID, userID)
ctx = context.WithValue(ctx, constants.ContextKeyUserType, userType)
ctx = context.WithValue(ctx, constants.ContextKeyShopID, shopID)
return ctx
}
// GetUserIDFromContext 从 context 中提取用户 ID
// 如果未设置,返回 0
func GetUserIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if userID, ok := ctx.Value(constants.ContextKeyUserID).(uint); ok {
return userID
}
return 0
}
// GetUserTypeFromContext 从 context 中提取用户类型
// 如果未设置,返回 0
func GetUserTypeFromContext(ctx context.Context) int {
if ctx == nil {
return 0
}
if userType, ok := ctx.Value(constants.ContextKeyUserID).(int); ok {
return userType
}
return 0
}
// GetShopIDFromContext 从 context 中提取店铺 ID
// 如果未设置,返回 0
func GetShopIDFromContext(ctx context.Context) uint {
if ctx == nil {
return 0
}
if shopID, ok := ctx.Value(constants.ContextKeyShopID).(uint); ok {
return shopID
}
return 0
}
// IsRootUser 检查当前用户是否为 root 用户
// root 用户跳过数据权限过滤
func IsRootUser(ctx context.Context) bool {
userType := GetUserTypeFromContext(ctx)
return userType == constants.UserTypeRoot
}
// SetUserToFiberContext 将用户信息设置到 Fiber context 的 Locals 中
// 同时也设置到标准 context 中,便于 GORM 查询使用
func SetUserToFiberContext(c *fiber.Ctx, userID uint, userType int, shopID uint) {
// 设置到 Fiber Locals
c.Locals(constants.ContextKeyUserID, userID)
c.Locals(constants.ContextKeyUserType, userType)
c.Locals(constants.ContextKeyShopID, shopID)
// 设置到标准 context用于 GORM 数据权限过滤)
ctx := SetUserContext(c.UserContext(), userID, userType, shopID)
c.SetUserContext(ctx)
}
// AuthConfig Auth 中间件配置
type AuthConfig struct {
// TokenExtractor 自定义 token 提取函数
// 如果为空,默认从 Authorization header 提取 Bearer token
TokenExtractor func(c *fiber.Ctx) string
// TokenValidator token 验证函数
// 验证成功返回用户 ID、用户类型、店铺 ID
// 验证失败返回 error
TokenValidator func(token string) (userID uint, userType int, shopID uint, err error)
// Skip 跳过认证的路径
Skip []string
}
// Auth 认证中间件
// 从请求中提取 token验证后将用户信息设置到 context
func Auth(config AuthConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// 检查是否跳过认证
path := c.Path()
for _, skipPath := range config.Skip {
if path == skipPath {
return c.Next()
}
}
// 提取 token
var token string
if config.TokenExtractor != nil {
token = config.TokenExtractor(c)
} else {
// 默认从 Authorization header 提取 Bearer token
token = extractBearerToken(c)
}
if token == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"code": errors.CodeUnauthorized,
"message": "未提供认证令牌",
})
}
// 验证 token
if config.TokenValidator == nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"code": errors.CodeInternalError,
"message": "认证验证器未配置",
})
}
userID, userType, shopID, err := config.TokenValidator(token)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"code": errors.CodeUnauthorized,
"message": "认证令牌无效",
})
}
// 将用户信息设置到 context
SetUserToFiberContext(c, userID, userType, shopID)
return c.Next()
}
}
// extractBearerToken 从 Authorization header 提取 Bearer token
func extractBearerToken(c *fiber.Ctx) string {
auth := c.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " {
return auth[7:]
}
return ""
}