Files
junhong_cmp_fiber/tests/integration/middleware_test.go
2025-11-11 15:53:01 +08:00

534 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package integration
import (
"io"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/google/uuid"
)
// TestRequestIDMiddleware 测试 RequestID 中间件生成 UUID v4T043
func TestRequestIDMiddleware(t *testing.T) {
app := fiber.New()
// 配置 requestid 中间件使用 UUID v4
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/test", func(c *fiber.Ctx) error {
requestID := c.Locals(constants.ContextKeyRequestID)
return c.JSON(fiber.Map{
"request_id": requestID,
})
})
tests := []struct {
name string
}{
{name: "request 1"},
{name: "request 2"},
{name: "request 3"},
}
seenIDs := make(map[string]bool)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证响应头包含 X-Request-ID
requestID := resp.Header.Get("X-Request-ID")
if requestID == "" {
t.Error("X-Request-ID header should not be empty")
}
// 验证是 UUID v4 格式
if _, err := uuid.Parse(requestID); err != nil {
t.Errorf("X-Request-ID is not a valid UUID: %s, error: %v", requestID, err)
}
// 验证 UUID 是唯一的
if seenIDs[requestID] {
t.Errorf("Request ID %s is not unique", requestID)
}
seenIDs[requestID] = true
t.Logf("Request ID: %s", requestID)
})
}
// 验证生成了多个不同的 ID
if len(seenIDs) != len(tests) {
t.Errorf("Expected %d unique request IDs, got %d", len(tests), len(seenIDs))
}
}
// TestLoggerMiddleware 测试 Logger 中间件记录访问日志T044
func TestLoggerMiddleware(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
accessLogFile := filepath.Join(tempDir, "access.log")
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: accessLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Use(logger.Middleware())
app.Get("/test", func(c *fiber.Ctx) error {
return c.SendString("ok")
})
app.Post("/test", func(c *fiber.Ctx) error {
return c.SendStatus(201)
})
tests := []struct {
name string
method string
path string
expectedStatus int
}{
{
name: "GET request",
method: "GET",
path: "/test",
expectedStatus: 200,
},
{
name: "POST request",
method: "POST",
path: "/test",
expectedStatus: 201,
},
{
name: "GET with query params",
method: "GET",
path: "/test?foo=bar",
expectedStatus: 200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, tt.path, nil)
req.Header.Set("User-Agent", "test-agent/1.0")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
if resp.StatusCode != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
}
})
}
// 刷新日志缓冲区
logger.Sync()
time.Sleep(100 * time.Millisecond)
// 验证访问日志文件存在且有内容
content, err := os.ReadFile(accessLogFile)
if err != nil {
t.Fatalf("Failed to read access log: %v", err)
}
if len(content) == 0 {
t.Error("Access log should not be empty")
}
logContent := string(content)
t.Logf("Access log content:\n%s", logContent)
// 验证日志包含必要的字段
requiredFields := []string{
"method",
"path",
"status",
"duration_ms",
"request_id",
"ip",
"user_agent",
}
for _, field := range requiredFields {
if !strings.Contains(logContent, field) {
t.Errorf("Access log should contain field '%s'", field)
}
}
// 验证记录了所有请求
lines := strings.Split(strings.TrimSpace(logContent), "\n")
if len(lines) < len(tests) {
t.Errorf("Expected at least %d log entries, got %d", len(tests), len(lines))
}
}
// TestRequestIDPropagation 测试 Request ID 在中间件链中传播T045
func TestRequestIDPropagation(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
// 创建应用
app := fiber.New()
var capturedRequestID string
// 1. RequestID 中间件(第一个)
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
// 2. Logger 中间件(第二个)
app.Use(logger.Middleware())
// 3. 自定义中间件验证 request ID 是否可访问
app.Use(func(c *fiber.Ctx) error {
requestID := c.Locals(constants.ContextKeyRequestID)
if requestID == nil {
t.Error("Request ID should be available in middleware chain")
}
if rid, ok := requestID.(string); ok {
capturedRequestID = rid
}
return c.Next()
})
app.Get("/test", func(c *fiber.Ctx) error {
// 在 handler 中也验证 request ID
requestID := c.Locals(constants.ContextKeyRequestID)
if requestID == nil {
return c.Status(500).SendString("Request ID not found in handler")
}
return c.JSON(fiber.Map{
"request_id": requestID,
"message": "Request ID propagated successfully",
})
})
// 执行请求
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证响应
if resp.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// 验证响应头中的 Request ID
headerRequestID := resp.Header.Get("X-Request-ID")
if headerRequestID == "" {
t.Error("X-Request-ID header should be set")
}
// 验证中间件捕获的 Request ID 与响应头一致
if capturedRequestID != headerRequestID {
t.Errorf("Request ID mismatch: middleware=%s, header=%s", capturedRequestID, headerRequestID)
}
// 验证响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
if !strings.Contains(string(body), headerRequestID) {
t.Errorf("Response body should contain request ID %s", headerRequestID)
}
t.Logf("Request ID successfully propagated: %s", headerRequestID)
}
// TestMiddlewareOrder 测试中间件执行顺序T045
func TestMiddlewareOrder(t *testing.T) {
app := fiber.New()
executionOrder := []string{}
// 中间件 1: RequestID
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "requestid-start")
c.Locals(constants.ContextKeyRequestID, uuid.NewString())
err := c.Next()
executionOrder = append(executionOrder, "requestid-end")
return err
})
// 中间件 2: Logger
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "logger-start")
// 验证 Request ID 已经设置
if c.Locals(constants.ContextKeyRequestID) == nil {
t.Error("Request ID should be set before logger middleware")
}
err := c.Next()
executionOrder = append(executionOrder, "logger-end")
return err
})
// 中间件 3: Custom
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "custom-start")
err := c.Next()
executionOrder = append(executionOrder, "custom-end")
return err
})
app.Get("/test", func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "handler")
return c.SendString("ok")
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
// 验证执行顺序
expectedOrder := []string{
"requestid-start",
"logger-start",
"custom-start",
"handler",
"custom-end",
"logger-end",
"requestid-end",
}
if len(executionOrder) != len(expectedOrder) {
t.Errorf("Expected %d execution steps, got %d", len(expectedOrder), len(executionOrder))
}
for i, expected := range expectedOrder {
if i >= len(executionOrder) {
t.Errorf("Missing execution step at index %d: expected '%s'", i, expected)
continue
}
if executionOrder[i] != expected {
t.Errorf("Execution order mismatch at index %d: expected '%s', got '%s'", i, expected, executionOrder[i])
}
}
t.Logf("Middleware execution order: %v", executionOrder)
}
// TestLoggerMiddlewareWithUserID 测试 Logger 中间件记录用户 IDT044
func TestLoggerMiddlewareWithUserID(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
accessLogFile := filepath.Join(tempDir, "access-userid.log")
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: accessLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
// 模拟 auth 中间件设置 user_id
app.Use(func(c *fiber.Ctx) error {
c.Locals(constants.ContextKeyUserID, "user_12345")
return c.Next()
})
app.Use(logger.Middleware())
app.Get("/test", func(c *fiber.Ctx) error {
return c.SendString("ok")
})
// 执行请求
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
// 刷新日志缓冲区
logger.Sync()
time.Sleep(100 * time.Millisecond)
// 验证访问日志包含 user_id
content, err := os.ReadFile(accessLogFile)
if err != nil {
t.Fatalf("Failed to read access log: %v", err)
}
logContent := string(content)
if !strings.Contains(logContent, "user_12345") {
t.Error("Access log should contain user_id 'user_12345'")
}
t.Logf("Access log with user_id:\n%s", logContent)
}
// TestConcurrentRequests 测试并发请求的 Request ID 唯一性T043
func TestConcurrentRequests(t *testing.T) {
app := fiber.New()
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/test", func(c *fiber.Ctx) error {
// 模拟一些处理时间
time.Sleep(10 * time.Millisecond)
requestID := c.Locals(constants.ContextKeyRequestID)
return c.JSON(fiber.Map{
"request_id": requestID,
})
})
// 并发发送多个请求
const numRequests = 50
requestIDs := make(chan string, numRequests)
errors := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
go func() {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
errors <- err
return
}
defer resp.Body.Close()
requestID := resp.Header.Get("X-Request-ID")
requestIDs <- requestID
errors <- nil
}()
}
// 收集所有结果
seenIDs := make(map[string]bool)
for i := 0; i < numRequests; i++ {
if err := <-errors; err != nil {
t.Fatalf("Request failed: %v", err)
}
requestID := <-requestIDs
if requestID == "" {
t.Error("Request ID should not be empty")
}
if seenIDs[requestID] {
t.Errorf("Duplicate request ID found: %s", requestID)
}
seenIDs[requestID] = true
}
// 验证所有 ID 都是唯一的
if len(seenIDs) != numRequests {
t.Errorf("Expected %d unique request IDs, got %d", numRequests, len(seenIDs))
}
t.Logf("Successfully generated %d unique request IDs concurrently", len(seenIDs))
}