package integration import ( "io" "net/http/httptest" "os" "path/filepath" "strings" "testing" "time" "github.com/break/junhong_cmp_fiber/internal/middleware" "github.com/break/junhong_cmp_fiber/pkg/errors" "github.com/break/junhong_cmp_fiber/pkg/logger" "github.com/break/junhong_cmp_fiber/pkg/response" "github.com/bytedance/sonic" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/requestid" "github.com/google/uuid" ) // TestPanicRecovery 测试 panic 恢复功能(T052) func TestPanicRecovery(t *testing.T) { // 创建临时目录用于日志 tempDir := t.TempDir() // 初始化日志系统 err := logger.InitLoggers("info", false, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "app-panic.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "access-panic.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, ) if err != nil { t.Fatalf("Failed to initialize loggers: %v", err) } defer func() { _ = logger.Sync() }() appLogger := logger.GetAppLogger() // 创建应用 app := fiber.New() // 注册中间件(recover 必须第一个) app.Use(middleware.Recover(appLogger)) app.Use(requestid.New(requestid.Config{ Generator: func() string { return uuid.NewString() }, })) // 创建会 panic 的 handler app.Get("/panic", func(c *fiber.Ctx) error { panic("intentional panic for testing") }) // 创建正常的 handler app.Get("/ok", func(c *fiber.Ctx) error { return c.SendString("ok") }) tests := []struct { name string path string shouldPanic bool expectedStatus int expectedCode int }{ { name: "panic endpoint returns 500", path: "/panic", shouldPanic: true, expectedStatus: 500, expectedCode: errors.CodeInternalError, }, { name: "normal endpoint works after panic", path: "/ok", shouldPanic: false, expectedStatus: 200, expectedCode: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", tt.path, nil) resp, err := app.Test(req) if err != nil { t.Fatalf("Failed to execute request: %v", err) } defer resp.Body.Close() // 验证 HTTP 状态码 if resp.StatusCode != tt.expectedStatus { t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode) } // 解析响应 body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Failed to read response body: %v", err) } if tt.shouldPanic { // panic 应该返回统一错误响应 var response response.Response if err := sonic.Unmarshal(body, &response); err != nil { t.Fatalf("Failed to unmarshal response: %v", err) } if response.Code != tt.expectedCode { t.Errorf("Expected code %d, got %d", tt.expectedCode, response.Code) } if response.Data != nil { t.Error("Error response data should be nil") } } }) } } // TestPanicLogging 测试 panic 日志记录和堆栈跟踪(T053) func TestPanicLogging(t *testing.T) { // 创建临时目录用于日志 tempDir := t.TempDir() appLogFile := filepath.Join(tempDir, "app-panic-log.log") // 初始化日志系统 err := logger.InitLoggers("info", false, logger.LogRotationConfig{ Filename: appLogFile, MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "access-panic-log.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, ) if err != nil { t.Fatalf("Failed to initialize loggers: %v", err) } defer func() { _ = logger.Sync() }() appLogger := logger.GetAppLogger() // 创建应用 app := fiber.New() // 注册中间件 app.Use(middleware.Recover(appLogger)) app.Use(requestid.New(requestid.Config{ Generator: func() string { return uuid.NewString() }, })) // 创建不同类型的 panic app.Get("/panic-string", func(c *fiber.Ctx) error { panic("string panic message") }) app.Get("/panic-error", func(c *fiber.Ctx) error { panic(fiber.NewError(500, "error panic message")) }) app.Get("/panic-struct", func(c *fiber.Ctx) error { panic(struct{ Message string }{"struct panic message"}) }) tests := []struct { name string path string expectedInLog []string unexpectedInLog []string }{ { name: "string panic logs correctly", path: "/panic-string", expectedInLog: []string{ "Panic 已恢复", "string panic message", "stack", "request_id", "method", "path", }, }, { name: "error panic logs correctly", path: "/panic-error", expectedInLog: []string{ "Panic 已恢复", "error panic message", "stack", }, }, { name: "struct panic logs correctly", path: "/panic-struct", expectedInLog: []string{ "Panic 已恢复", "stack", "Message", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 执行会 panic 的请求 req := httptest.NewRequest("GET", tt.path, nil) resp, err := app.Test(req) if err != nil { t.Fatalf("Failed to execute request: %v", err) } resp.Body.Close() // 刷新日志缓冲区 logger.Sync() time.Sleep(100 * time.Millisecond) // 读取日志内容 logContent, err := os.ReadFile(appLogFile) if err != nil { t.Fatalf("Failed to read app log: %v", err) } content := string(logContent) // 验证日志包含预期内容 for _, expected := range tt.expectedInLog { if !strings.Contains(content, expected) { t.Errorf("Log should contain '%s'", expected) } } // 验证日志不包含意外内容 for _, unexpected := range tt.unexpectedInLog { if strings.Contains(content, unexpected) { t.Errorf("Log should NOT contain '%s'", unexpected) } } // 验证堆栈跟踪包含文件和行号 if !strings.Contains(content, "recover_test.go") { t.Error("Stack trace should contain source file name") } t.Logf("Panic log contains stack trace: %v", strings.Contains(content, "stack")) }) } } // TestSubsequentRequestsAfterPanic 测试 panic 后后续请求正常处理(T054) func TestSubsequentRequestsAfterPanic(t *testing.T) { // 创建临时目录用于日志 tempDir := t.TempDir() // 初始化日志系统 err := logger.InitLoggers("info", false, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "app.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "access.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, ) if err != nil { t.Fatalf("Failed to initialize loggers: %v", err) } defer func() { _ = logger.Sync() }() appLogger := logger.GetAppLogger() // 创建应用 app := fiber.New() // 注册中间件 app.Use(middleware.Recover(appLogger)) app.Use(requestid.New(requestid.Config{ Generator: func() string { return uuid.NewString() }, })) callCount := 0 app.Get("/test", func(c *fiber.Ctx) error { callCount++ // 第 1、3、5 次调用会 panic if callCount%2 == 1 { panic("test panic") } // 第 2、4、6 次调用正常返回 return c.JSON(fiber.Map{ "call_count": callCount, "status": "ok", }) }) // 执行多次请求,验证 panic 不影响后续请求 for i := 1; i <= 6; i++ { req := httptest.NewRequest("GET", "/test", nil) resp, err := app.Test(req) if err != nil { t.Fatalf("Request %d failed: %v", i, err) } body, _ := io.ReadAll(resp.Body) resp.Body.Close() if i%2 == 1 { // 奇数次应该返回 500 if resp.StatusCode != 500 { t.Errorf("Request %d: expected status 500, got %d", i, resp.StatusCode) } } else { // 偶数次应该返回 200 if resp.StatusCode != 200 { t.Errorf("Request %d: expected status 200, got %d", i, resp.StatusCode) } // 验证响应内容 var response map[string]any if err := sonic.Unmarshal(body, &response); err != nil { t.Fatalf("Request %d: failed to unmarshal response: %v", i, err) } if status, ok := response["status"].(string); !ok || status != "ok" { t.Errorf("Request %d: expected status 'ok', got %v", i, response["status"]) } } t.Logf("Request %d completed: status=%d", i, resp.StatusCode) } // 验证所有 6 次调用都执行了 if callCount != 6 { t.Errorf("Expected 6 calls, got %d", callCount) } } // TestPanicWithRequestID 测试 panic 日志包含 Request ID(T053) func TestPanicWithRequestID(t *testing.T) { // 创建临时目录用于日志 tempDir := t.TempDir() appLogFile := filepath.Join(tempDir, "app-panic-reqid.log") // 初始化日志系统 err := logger.InitLoggers("info", false, logger.LogRotationConfig{ Filename: appLogFile, MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "access-panic-reqid.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, ) if err != nil { t.Fatalf("Failed to initialize loggers: %v", err) } defer func() { _ = logger.Sync() }() appLogger := logger.GetAppLogger() // 创建应用 app := fiber.New() // 注册中间件(顺序重要) app.Use(middleware.Recover(appLogger)) app.Use(requestid.New(requestid.Config{ Generator: func() string { return uuid.NewString() }, })) app.Get("/panic", func(c *fiber.Ctx) error { panic("test panic with request id") }) // 执行请求 req := httptest.NewRequest("GET", "/panic", nil) resp, err := app.Test(req) if err != nil { t.Fatalf("Failed to execute request: %v", err) } resp.Body.Close() // 获取 Request ID requestID := resp.Header.Get("X-Request-ID") if requestID == "" { t.Error("X-Request-ID header should be set even after panic") } // 刷新日志缓冲区 logger.Sync() time.Sleep(100 * time.Millisecond) // 读取日志内容 logContent, err := os.ReadFile(appLogFile) if err != nil { t.Fatalf("Failed to read app log: %v", err) } content := string(logContent) // 验证日志包含 Request ID if !strings.Contains(content, requestID) { t.Errorf("Panic log should contain request ID '%s'", requestID) } // 验证日志包含关键字段 requiredFields := []string{ "request_id", "method", "path", "panic", "stack", } for _, field := range requiredFields { if !strings.Contains(content, field) { t.Errorf("Panic log should contain field '%s'", field) } } t.Logf("Panic log successfully includes Request ID: %s", requestID) } // TestConcurrentPanics 测试并发 panic 处理(T054) func TestConcurrentPanics(t *testing.T) { // 创建临时目录用于日志 tempDir := t.TempDir() // 初始化日志系统 err := logger.InitLoggers("info", false, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "app.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "access.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, ) if err != nil { t.Fatalf("Failed to initialize loggers: %v", err) } defer func() { _ = logger.Sync() }() appLogger := logger.GetAppLogger() // 创建应用 app := fiber.New() // 注册中间件 app.Use(middleware.Recover(appLogger)) app.Use(requestid.New(requestid.Config{ Generator: func() string { return uuid.NewString() }, })) app.Get("/panic", func(c *fiber.Ctx) error { panic("concurrent panic test") }) // 并发发送多个会 panic 的请求 const numRequests = 20 errors := make(chan error, numRequests) statuses := make(chan int, numRequests) for i := 0; i < numRequests; i++ { go func() { req := httptest.NewRequest("GET", "/panic", nil) resp, err := app.Test(req) if err != nil { errors <- err statuses <- 0 return } defer resp.Body.Close() statuses <- resp.StatusCode errors <- nil }() } // 收集所有结果 for i := 0; i < numRequests; i++ { if err := <-errors; err != nil { t.Fatalf("Request failed: %v", err) } status := <-statuses if status != 500 { t.Errorf("Expected status 500, got %d", status) } } t.Logf("Successfully handled %d concurrent panics", numRequests) } // TestRecoverMiddlewareOrder 测试 Recover 中间件必须在第一个(T052) func TestRecoverMiddlewareOrder(t *testing.T) { // 创建临时目录用于日志 tempDir := t.TempDir() // 初始化日志系统 err := logger.InitLoggers("info", false, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "app.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "access.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, ) if err != nil { t.Fatalf("Failed to initialize loggers: %v", err) } defer func() { _ = logger.Sync() }() appLogger := logger.GetAppLogger() // 创建应用 app := fiber.New() // 正确的顺序:Recover → RequestID → Logger app.Use(middleware.Recover(appLogger)) app.Use(requestid.New(requestid.Config{ Generator: func() string { return uuid.NewString() }, })) app.Use(logger.Middleware()) app.Get("/panic", func(c *fiber.Ctx) error { panic("test panic") }) // 执行请求 req := httptest.NewRequest("GET", "/panic", nil) resp, err := app.Test(req) if err != nil { t.Fatalf("Failed to execute request: %v", err) } defer resp.Body.Close() // 验证请求被正确处理(返回 500 而不是崩溃) if resp.StatusCode != 500 { t.Errorf("Expected status 500, got %d", resp.StatusCode) } // 验证仍然有 Request ID(说明 RequestID 中间件在 Recover 之后执行) requestID := resp.Header.Get("X-Request-ID") if requestID == "" { t.Error("X-Request-ID should be set even after panic") } // 解析响应,验证返回了统一错误格式 body, _ := io.ReadAll(resp.Body) var response response.Response if err := sonic.Unmarshal(body, &response); err != nil { t.Fatalf("Failed to unmarshal response: %v", err) } if response.Code != errors.CodeInternalError { t.Errorf("Expected code %d, got %d", errors.CodeInternalError, response.Code) } t.Logf("Recover middleware correctly placed first, handled panic gracefully") }