package integration import ( "io" "net/http/httptest" "os" "path/filepath" "strings" "testing" "time" "github.com/break/junhong_cmp_fiber/pkg/constants" "github.com/break/junhong_cmp_fiber/pkg/logger" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/requestid" "github.com/google/uuid" ) // TestRequestIDMiddleware 测试 RequestID 中间件生成 UUID v4(T043) func TestRequestIDMiddleware(t *testing.T) { app := fiber.New() // 配置 requestid 中间件使用 UUID v4 app.Use(requestid.New(requestid.Config{ Generator: func() string { return uuid.NewString() }, })) app.Get("/test", func(c *fiber.Ctx) error { requestID := c.Locals(constants.ContextKeyRequestID) return c.JSON(fiber.Map{ "request_id": requestID, }) }) tests := []struct { name string }{ {name: "request 1"}, {name: "request 2"}, {name: "request 3"}, } seenIDs := make(map[string]bool) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/test", nil) resp, err := app.Test(req) if err != nil { t.Fatalf("Failed to execute request: %v", err) } defer resp.Body.Close() // 验证响应头包含 X-Request-ID requestID := resp.Header.Get("X-Request-ID") if requestID == "" { t.Error("X-Request-ID header should not be empty") } // 验证是 UUID v4 格式 if _, err := uuid.Parse(requestID); err != nil { t.Errorf("X-Request-ID is not a valid UUID: %s, error: %v", requestID, err) } // 验证 UUID 是唯一的 if seenIDs[requestID] { t.Errorf("Request ID %s is not unique", requestID) } seenIDs[requestID] = true t.Logf("Request ID: %s", requestID) }) } // 验证生成了多个不同的 ID if len(seenIDs) != len(tests) { t.Errorf("Expected %d unique request IDs, got %d", len(tests), len(seenIDs)) } } // TestLoggerMiddleware 测试 Logger 中间件记录访问日志(T044) func TestLoggerMiddleware(t *testing.T) { // 创建临时目录用于日志 tempDir := t.TempDir() accessLogFile := filepath.Join(tempDir, "access.log") // 初始化日志系统 err := logger.InitLoggers("info", false, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "app.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, logger.LogRotationConfig{ Filename: accessLogFile, MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, ) if err != nil { t.Fatalf("Failed to initialize loggers: %v", err) } defer logger.Sync() // 创建应用 app := fiber.New() // 注册中间件 app.Use(requestid.New(requestid.Config{ Generator: func() string { return uuid.NewString() }, })) app.Use(logger.Middleware()) app.Get("/test", func(c *fiber.Ctx) error { return c.SendString("ok") }) app.Post("/test", func(c *fiber.Ctx) error { return c.SendStatus(201) }) tests := []struct { name string method string path string expectedStatus int }{ { name: "GET request", method: "GET", path: "/test", expectedStatus: 200, }, { name: "POST request", method: "POST", path: "/test", expectedStatus: 201, }, { name: "GET with query params", method: "GET", path: "/test?foo=bar", expectedStatus: 200, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(tt.method, tt.path, nil) req.Header.Set("User-Agent", "test-agent/1.0") resp, err := app.Test(req) if err != nil { t.Fatalf("Failed to execute request: %v", err) } resp.Body.Close() if resp.StatusCode != tt.expectedStatus { t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode) } }) } // 刷新日志缓冲区 logger.Sync() time.Sleep(100 * time.Millisecond) // 验证访问日志文件存在且有内容 content, err := os.ReadFile(accessLogFile) if err != nil { t.Fatalf("Failed to read access log: %v", err) } if len(content) == 0 { t.Error("Access log should not be empty") } logContent := string(content) t.Logf("Access log content:\n%s", logContent) // 验证日志包含必要的字段 requiredFields := []string{ "method", "path", "status", "duration_ms", "request_id", "ip", "user_agent", } for _, field := range requiredFields { if !strings.Contains(logContent, field) { t.Errorf("Access log should contain field '%s'", field) } } // 验证记录了所有请求 lines := strings.Split(strings.TrimSpace(logContent), "\n") if len(lines) < len(tests) { t.Errorf("Expected at least %d log entries, got %d", len(tests), len(lines)) } } // TestRequestIDPropagation 测试 Request ID 在中间件链中传播(T045) func TestRequestIDPropagation(t *testing.T) { // 创建临时目录用于日志 tempDir := t.TempDir() // 初始化日志系统 err := logger.InitLoggers("info", false, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "app.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "access.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, ) if err != nil { t.Fatalf("Failed to initialize loggers: %v", err) } defer logger.Sync() // 创建应用 app := fiber.New() var capturedRequestID string // 1. RequestID 中间件(第一个) app.Use(requestid.New(requestid.Config{ Generator: func() string { return uuid.NewString() }, })) // 2. Logger 中间件(第二个) app.Use(logger.Middleware()) // 3. 自定义中间件验证 request ID 是否可访问 app.Use(func(c *fiber.Ctx) error { requestID := c.Locals(constants.ContextKeyRequestID) if requestID == nil { t.Error("Request ID should be available in middleware chain") } if rid, ok := requestID.(string); ok { capturedRequestID = rid } return c.Next() }) app.Get("/test", func(c *fiber.Ctx) error { // 在 handler 中也验证 request ID requestID := c.Locals(constants.ContextKeyRequestID) if requestID == nil { return c.Status(500).SendString("Request ID not found in handler") } return c.JSON(fiber.Map{ "request_id": requestID, "message": "Request ID propagated successfully", }) }) // 执行请求 req := httptest.NewRequest("GET", "/test", nil) resp, err := app.Test(req) if err != nil { t.Fatalf("Failed to execute request: %v", err) } defer resp.Body.Close() // 验证响应 if resp.StatusCode != 200 { t.Errorf("Expected status 200, got %d", resp.StatusCode) } // 验证响应头中的 Request ID headerRequestID := resp.Header.Get("X-Request-ID") if headerRequestID == "" { t.Error("X-Request-ID header should be set") } // 验证中间件捕获的 Request ID 与响应头一致 if capturedRequestID != headerRequestID { t.Errorf("Request ID mismatch: middleware=%s, header=%s", capturedRequestID, headerRequestID) } // 验证响应体 body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Failed to read response body: %v", err) } if !strings.Contains(string(body), headerRequestID) { t.Errorf("Response body should contain request ID %s", headerRequestID) } t.Logf("Request ID successfully propagated: %s", headerRequestID) } // TestMiddlewareOrder 测试中间件执行顺序(T045) func TestMiddlewareOrder(t *testing.T) { app := fiber.New() executionOrder := []string{} // 中间件 1: RequestID app.Use(func(c *fiber.Ctx) error { executionOrder = append(executionOrder, "requestid-start") c.Locals(constants.ContextKeyRequestID, uuid.NewString()) err := c.Next() executionOrder = append(executionOrder, "requestid-end") return err }) // 中间件 2: Logger app.Use(func(c *fiber.Ctx) error { executionOrder = append(executionOrder, "logger-start") // 验证 Request ID 已经设置 if c.Locals(constants.ContextKeyRequestID) == nil { t.Error("Request ID should be set before logger middleware") } err := c.Next() executionOrder = append(executionOrder, "logger-end") return err }) // 中间件 3: Custom app.Use(func(c *fiber.Ctx) error { executionOrder = append(executionOrder, "custom-start") err := c.Next() executionOrder = append(executionOrder, "custom-end") return err }) app.Get("/test", func(c *fiber.Ctx) error { executionOrder = append(executionOrder, "handler") return c.SendString("ok") }) req := httptest.NewRequest("GET", "/test", nil) resp, err := app.Test(req) if err != nil { t.Fatalf("Failed to execute request: %v", err) } resp.Body.Close() // 验证执行顺序 expectedOrder := []string{ "requestid-start", "logger-start", "custom-start", "handler", "custom-end", "logger-end", "requestid-end", } if len(executionOrder) != len(expectedOrder) { t.Errorf("Expected %d execution steps, got %d", len(expectedOrder), len(executionOrder)) } for i, expected := range expectedOrder { if i >= len(executionOrder) { t.Errorf("Missing execution step at index %d: expected '%s'", i, expected) continue } if executionOrder[i] != expected { t.Errorf("Execution order mismatch at index %d: expected '%s', got '%s'", i, expected, executionOrder[i]) } } t.Logf("Middleware execution order: %v", executionOrder) } // TestLoggerMiddlewareWithUserID 测试 Logger 中间件记录用户 ID(T044) func TestLoggerMiddlewareWithUserID(t *testing.T) { // 创建临时目录用于日志 tempDir := t.TempDir() accessLogFile := filepath.Join(tempDir, "access-userid.log") // 初始化日志系统 err := logger.InitLoggers("info", false, logger.LogRotationConfig{ Filename: filepath.Join(tempDir, "app.log"), MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, logger.LogRotationConfig{ Filename: accessLogFile, MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, }, ) if err != nil { t.Fatalf("Failed to initialize loggers: %v", err) } defer logger.Sync() // 创建应用 app := fiber.New() // 注册中间件 app.Use(requestid.New(requestid.Config{ Generator: func() string { return uuid.NewString() }, })) // 模拟 auth 中间件设置 user_id app.Use(func(c *fiber.Ctx) error { c.Locals(constants.ContextKeyUserID, "user_12345") return c.Next() }) app.Use(logger.Middleware()) app.Get("/test", func(c *fiber.Ctx) error { return c.SendString("ok") }) // 执行请求 req := httptest.NewRequest("GET", "/test", nil) resp, err := app.Test(req) if err != nil { t.Fatalf("Failed to execute request: %v", err) } resp.Body.Close() // 刷新日志缓冲区 logger.Sync() time.Sleep(100 * time.Millisecond) // 验证访问日志包含 user_id content, err := os.ReadFile(accessLogFile) if err != nil { t.Fatalf("Failed to read access log: %v", err) } logContent := string(content) if !strings.Contains(logContent, "user_12345") { t.Error("Access log should contain user_id 'user_12345'") } t.Logf("Access log with user_id:\n%s", logContent) } // TestConcurrentRequests 测试并发请求的 Request ID 唯一性(T043) func TestConcurrentRequests(t *testing.T) { app := fiber.New() app.Use(requestid.New(requestid.Config{ Generator: func() string { return uuid.NewString() }, })) app.Get("/test", func(c *fiber.Ctx) error { // 模拟一些处理时间 time.Sleep(10 * time.Millisecond) requestID := c.Locals(constants.ContextKeyRequestID) return c.JSON(fiber.Map{ "request_id": requestID, }) }) // 并发发送多个请求 const numRequests = 50 requestIDs := make(chan string, numRequests) errors := make(chan error, numRequests) for i := 0; i < numRequests; i++ { go func() { req := httptest.NewRequest("GET", "/test", nil) resp, err := app.Test(req) if err != nil { errors <- err return } defer resp.Body.Close() requestID := resp.Header.Get("X-Request-ID") requestIDs <- requestID errors <- nil }() } // 收集所有结果 seenIDs := make(map[string]bool) for i := 0; i < numRequests; i++ { if err := <-errors; err != nil { t.Fatalf("Request failed: %v", err) } requestID := <-requestIDs if requestID == "" { t.Error("Request ID should not be empty") } if seenIDs[requestID] { t.Errorf("Duplicate request ID found: %s", requestID) } seenIDs[requestID] = true } // 验证所有 ID 都是唯一的 if len(seenIDs) != numRequests { t.Errorf("Expected %d unique request IDs, got %d", numRequests, len(seenIDs)) } t.Logf("Successfully generated %d unique request IDs concurrently", len(seenIDs)) }