Files
junhong_cmp_fiber/tests/integration/recover_test.go
huang eaa70ac255 feat: 实现 RBAC 权限系统和数据权限控制 (004-rbac-data-permission)
主要功能:
- 实现完整的 RBAC 权限系统(账号、角色、权限的多对多关联)
- 基于 owner_id + shop_id 的自动数据权限过滤
- 使用 PostgreSQL WITH RECURSIVE 查询下级账号
- Redis 缓存优化下级账号查询性能(30分钟过期)
- 支持多租户数据隔离和层级权限管理

技术实现:
- 新增 Account、Role、Permission 模型及关联关系表
- 实现 GORM Scopes 自动应用数据权限过滤
- 添加数据库迁移脚本(000002_rbac_data_permission、000003_add_owner_id_shop_id)
- 完善错误码定义(1010-1027 为 RBAC 相关错误)
- 重构 main.go 采用函数拆分提高可读性

测试覆盖:
- 添加 Account、Role、Permission 的集成测试
- 添加数据权限过滤的单元测试和集成测试
- 添加下级账号查询和缓存的单元测试
- 添加 API 回归测试确保向后兼容

文档更新:
- 更新 README.md 添加 RBAC 功能说明
- 更新 CLAUDE.md 添加技术栈和开发原则
- 添加 docs/004-rbac-data-permission/ 功能总结和使用指南

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 16:44:06 +08:00

619 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package integration
import (
"io"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/internal/middleware"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/break/junhong_cmp_fiber/pkg/response"
"github.com/bytedance/sonic"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/google/uuid"
)
// TestPanicRecovery 测试 panic 恢复功能T052
func TestPanicRecovery(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app-panic.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access-panic.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer func() { _ = logger.Sync() }()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 注册中间件recover 必须第一个)
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
// 创建会 panic 的 handler
app.Get("/panic", func(c *fiber.Ctx) error {
panic("intentional panic for testing")
})
// 创建正常的 handler
app.Get("/ok", func(c *fiber.Ctx) error {
return c.SendString("ok")
})
tests := []struct {
name string
path string
shouldPanic bool
expectedStatus int
expectedCode int
}{
{
name: "panic endpoint returns 500",
path: "/panic",
shouldPanic: true,
expectedStatus: 500,
expectedCode: errors.CodeInternalError,
},
{
name: "normal endpoint works after panic",
path: "/ok",
shouldPanic: false,
expectedStatus: 200,
expectedCode: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", tt.path, nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证 HTTP 状态码
if resp.StatusCode != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
}
// 解析响应
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
if tt.shouldPanic {
// panic 应该返回统一错误响应
var response response.Response
if err := sonic.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.Code != tt.expectedCode {
t.Errorf("Expected code %d, got %d", tt.expectedCode, response.Code)
}
if response.Data != nil {
t.Error("Error response data should be nil")
}
}
})
}
}
// TestPanicLogging 测试 panic 日志记录和堆栈跟踪T053
func TestPanicLogging(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
appLogFile := filepath.Join(tempDir, "app-panic-log.log")
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: appLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access-panic-log.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer func() { _ = logger.Sync() }()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
// 创建不同类型的 panic
app.Get("/panic-string", func(c *fiber.Ctx) error {
panic("string panic message")
})
app.Get("/panic-error", func(c *fiber.Ctx) error {
panic(fiber.NewError(500, "error panic message"))
})
app.Get("/panic-struct", func(c *fiber.Ctx) error {
panic(struct{ Message string }{"struct panic message"})
})
tests := []struct {
name string
path string
expectedInLog []string
unexpectedInLog []string
}{
{
name: "string panic logs correctly",
path: "/panic-string",
expectedInLog: []string{
"Panic 已恢复",
"string panic message",
"stack",
"request_id",
"method",
"path",
},
},
{
name: "error panic logs correctly",
path: "/panic-error",
expectedInLog: []string{
"Panic 已恢复",
"error panic message",
"stack",
},
},
{
name: "struct panic logs correctly",
path: "/panic-struct",
expectedInLog: []string{
"Panic 已恢复",
"stack",
"Message",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 执行会 panic 的请求
req := httptest.NewRequest("GET", tt.path, nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
// 刷新日志缓冲区
logger.Sync()
time.Sleep(100 * time.Millisecond)
// 读取日志内容
logContent, err := os.ReadFile(appLogFile)
if err != nil {
t.Fatalf("Failed to read app log: %v", err)
}
content := string(logContent)
// 验证日志包含预期内容
for _, expected := range tt.expectedInLog {
if !strings.Contains(content, expected) {
t.Errorf("Log should contain '%s'", expected)
}
}
// 验证日志不包含意外内容
for _, unexpected := range tt.unexpectedInLog {
if strings.Contains(content, unexpected) {
t.Errorf("Log should NOT contain '%s'", unexpected)
}
}
// 验证堆栈跟踪包含文件和行号
if !strings.Contains(content, "recover_test.go") {
t.Error("Stack trace should contain source file name")
}
t.Logf("Panic log contains stack trace: %v", strings.Contains(content, "stack"))
})
}
}
// TestSubsequentRequestsAfterPanic 测试 panic 后后续请求正常处理T054
func TestSubsequentRequestsAfterPanic(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer func() { _ = logger.Sync() }()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
callCount := 0
app.Get("/test", func(c *fiber.Ctx) error {
callCount++
// 第 1、3、5 次调用会 panic
if callCount%2 == 1 {
panic("test panic")
}
// 第 2、4、6 次调用正常返回
return c.JSON(fiber.Map{
"call_count": callCount,
"status": "ok",
})
})
// 执行多次请求,验证 panic 不影响后续请求
for i := 1; i <= 6; i++ {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Request %d failed: %v", i, err)
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
if i%2 == 1 {
// 奇数次应该返回 500
if resp.StatusCode != 500 {
t.Errorf("Request %d: expected status 500, got %d", i, resp.StatusCode)
}
} else {
// 偶数次应该返回 200
if resp.StatusCode != 200 {
t.Errorf("Request %d: expected status 200, got %d", i, resp.StatusCode)
}
// 验证响应内容
var response map[string]any
if err := sonic.Unmarshal(body, &response); err != nil {
t.Fatalf("Request %d: failed to unmarshal response: %v", i, err)
}
if status, ok := response["status"].(string); !ok || status != "ok" {
t.Errorf("Request %d: expected status 'ok', got %v", i, response["status"])
}
}
t.Logf("Request %d completed: status=%d", i, resp.StatusCode)
}
// 验证所有 6 次调用都执行了
if callCount != 6 {
t.Errorf("Expected 6 calls, got %d", callCount)
}
}
// TestPanicWithRequestID 测试 panic 日志包含 Request IDT053
func TestPanicWithRequestID(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
appLogFile := filepath.Join(tempDir, "app-panic-reqid.log")
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: appLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access-panic-reqid.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer func() { _ = logger.Sync() }()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 注册中间件(顺序重要)
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/panic", func(c *fiber.Ctx) error {
panic("test panic with request id")
})
// 执行请求
req := httptest.NewRequest("GET", "/panic", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
// 获取 Request ID
requestID := resp.Header.Get("X-Request-ID")
if requestID == "" {
t.Error("X-Request-ID header should be set even after panic")
}
// 刷新日志缓冲区
logger.Sync()
time.Sleep(100 * time.Millisecond)
// 读取日志内容
logContent, err := os.ReadFile(appLogFile)
if err != nil {
t.Fatalf("Failed to read app log: %v", err)
}
content := string(logContent)
// 验证日志包含 Request ID
if !strings.Contains(content, requestID) {
t.Errorf("Panic log should contain request ID '%s'", requestID)
}
// 验证日志包含关键字段
requiredFields := []string{
"request_id",
"method",
"path",
"panic",
"stack",
}
for _, field := range requiredFields {
if !strings.Contains(content, field) {
t.Errorf("Panic log should contain field '%s'", field)
}
}
t.Logf("Panic log successfully includes Request ID: %s", requestID)
}
// TestConcurrentPanics 测试并发 panic 处理T054
func TestConcurrentPanics(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer func() { _ = logger.Sync() }()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/panic", func(c *fiber.Ctx) error {
panic("concurrent panic test")
})
// 并发发送多个会 panic 的请求
const numRequests = 20
errors := make(chan error, numRequests)
statuses := make(chan int, numRequests)
for i := 0; i < numRequests; i++ {
go func() {
req := httptest.NewRequest("GET", "/panic", nil)
resp, err := app.Test(req)
if err != nil {
errors <- err
statuses <- 0
return
}
defer resp.Body.Close()
statuses <- resp.StatusCode
errors <- nil
}()
}
// 收集所有结果
for i := 0; i < numRequests; i++ {
if err := <-errors; err != nil {
t.Fatalf("Request failed: %v", err)
}
status := <-statuses
if status != 500 {
t.Errorf("Expected status 500, got %d", status)
}
}
t.Logf("Successfully handled %d concurrent panics", numRequests)
}
// TestRecoverMiddlewareOrder 测试 Recover 中间件必须在第一个T052
func TestRecoverMiddlewareOrder(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer func() { _ = logger.Sync() }()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 正确的顺序Recover → RequestID → Logger
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Use(logger.Middleware())
app.Get("/panic", func(c *fiber.Ctx) error {
panic("test panic")
})
// 执行请求
req := httptest.NewRequest("GET", "/panic", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证请求被正确处理(返回 500 而不是崩溃)
if resp.StatusCode != 500 {
t.Errorf("Expected status 500, got %d", resp.StatusCode)
}
// 验证仍然有 Request ID说明 RequestID 中间件在 Recover 之后执行)
requestID := resp.Header.Get("X-Request-ID")
if requestID == "" {
t.Error("X-Request-ID should be set even after panic")
}
// 解析响应,验证返回了统一错误格式
body, _ := io.ReadAll(resp.Body)
var response response.Response
if err := sonic.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.Code != errors.CodeInternalError {
t.Errorf("Expected code %d, got %d", errors.CodeInternalError, response.Code)
}
t.Logf("Recover middleware correctly placed first, handled panic gracefully")
}