备份一下

This commit is contained in:
2025-11-11 15:53:01 +08:00
parent e98dd4d725
commit 39c5b524a9
10 changed files with 4295 additions and 63 deletions

View File

@@ -0,0 +1,533 @@
package integration
import (
"io"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/google/uuid"
)
// TestRequestIDMiddleware 测试 RequestID 中间件生成 UUID v4T043
func TestRequestIDMiddleware(t *testing.T) {
app := fiber.New()
// 配置 requestid 中间件使用 UUID v4
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/test", func(c *fiber.Ctx) error {
requestID := c.Locals(constants.ContextKeyRequestID)
return c.JSON(fiber.Map{
"request_id": requestID,
})
})
tests := []struct {
name string
}{
{name: "request 1"},
{name: "request 2"},
{name: "request 3"},
}
seenIDs := make(map[string]bool)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证响应头包含 X-Request-ID
requestID := resp.Header.Get("X-Request-ID")
if requestID == "" {
t.Error("X-Request-ID header should not be empty")
}
// 验证是 UUID v4 格式
if _, err := uuid.Parse(requestID); err != nil {
t.Errorf("X-Request-ID is not a valid UUID: %s, error: %v", requestID, err)
}
// 验证 UUID 是唯一的
if seenIDs[requestID] {
t.Errorf("Request ID %s is not unique", requestID)
}
seenIDs[requestID] = true
t.Logf("Request ID: %s", requestID)
})
}
// 验证生成了多个不同的 ID
if len(seenIDs) != len(tests) {
t.Errorf("Expected %d unique request IDs, got %d", len(tests), len(seenIDs))
}
}
// TestLoggerMiddleware 测试 Logger 中间件记录访问日志T044
func TestLoggerMiddleware(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
accessLogFile := filepath.Join(tempDir, "access.log")
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: accessLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Use(logger.Middleware())
app.Get("/test", func(c *fiber.Ctx) error {
return c.SendString("ok")
})
app.Post("/test", func(c *fiber.Ctx) error {
return c.SendStatus(201)
})
tests := []struct {
name string
method string
path string
expectedStatus int
}{
{
name: "GET request",
method: "GET",
path: "/test",
expectedStatus: 200,
},
{
name: "POST request",
method: "POST",
path: "/test",
expectedStatus: 201,
},
{
name: "GET with query params",
method: "GET",
path: "/test?foo=bar",
expectedStatus: 200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, tt.path, nil)
req.Header.Set("User-Agent", "test-agent/1.0")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
if resp.StatusCode != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
}
})
}
// 刷新日志缓冲区
logger.Sync()
time.Sleep(100 * time.Millisecond)
// 验证访问日志文件存在且有内容
content, err := os.ReadFile(accessLogFile)
if err != nil {
t.Fatalf("Failed to read access log: %v", err)
}
if len(content) == 0 {
t.Error("Access log should not be empty")
}
logContent := string(content)
t.Logf("Access log content:\n%s", logContent)
// 验证日志包含必要的字段
requiredFields := []string{
"method",
"path",
"status",
"duration_ms",
"request_id",
"ip",
"user_agent",
}
for _, field := range requiredFields {
if !strings.Contains(logContent, field) {
t.Errorf("Access log should contain field '%s'", field)
}
}
// 验证记录了所有请求
lines := strings.Split(strings.TrimSpace(logContent), "\n")
if len(lines) < len(tests) {
t.Errorf("Expected at least %d log entries, got %d", len(tests), len(lines))
}
}
// TestRequestIDPropagation 测试 Request ID 在中间件链中传播T045
func TestRequestIDPropagation(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
// 创建应用
app := fiber.New()
var capturedRequestID string
// 1. RequestID 中间件(第一个)
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
// 2. Logger 中间件(第二个)
app.Use(logger.Middleware())
// 3. 自定义中间件验证 request ID 是否可访问
app.Use(func(c *fiber.Ctx) error {
requestID := c.Locals(constants.ContextKeyRequestID)
if requestID == nil {
t.Error("Request ID should be available in middleware chain")
}
if rid, ok := requestID.(string); ok {
capturedRequestID = rid
}
return c.Next()
})
app.Get("/test", func(c *fiber.Ctx) error {
// 在 handler 中也验证 request ID
requestID := c.Locals(constants.ContextKeyRequestID)
if requestID == nil {
return c.Status(500).SendString("Request ID not found in handler")
}
return c.JSON(fiber.Map{
"request_id": requestID,
"message": "Request ID propagated successfully",
})
})
// 执行请求
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证响应
if resp.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// 验证响应头中的 Request ID
headerRequestID := resp.Header.Get("X-Request-ID")
if headerRequestID == "" {
t.Error("X-Request-ID header should be set")
}
// 验证中间件捕获的 Request ID 与响应头一致
if capturedRequestID != headerRequestID {
t.Errorf("Request ID mismatch: middleware=%s, header=%s", capturedRequestID, headerRequestID)
}
// 验证响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
if !strings.Contains(string(body), headerRequestID) {
t.Errorf("Response body should contain request ID %s", headerRequestID)
}
t.Logf("Request ID successfully propagated: %s", headerRequestID)
}
// TestMiddlewareOrder 测试中间件执行顺序T045
func TestMiddlewareOrder(t *testing.T) {
app := fiber.New()
executionOrder := []string{}
// 中间件 1: RequestID
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "requestid-start")
c.Locals(constants.ContextKeyRequestID, uuid.NewString())
err := c.Next()
executionOrder = append(executionOrder, "requestid-end")
return err
})
// 中间件 2: Logger
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "logger-start")
// 验证 Request ID 已经设置
if c.Locals(constants.ContextKeyRequestID) == nil {
t.Error("Request ID should be set before logger middleware")
}
err := c.Next()
executionOrder = append(executionOrder, "logger-end")
return err
})
// 中间件 3: Custom
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "custom-start")
err := c.Next()
executionOrder = append(executionOrder, "custom-end")
return err
})
app.Get("/test", func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "handler")
return c.SendString("ok")
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
// 验证执行顺序
expectedOrder := []string{
"requestid-start",
"logger-start",
"custom-start",
"handler",
"custom-end",
"logger-end",
"requestid-end",
}
if len(executionOrder) != len(expectedOrder) {
t.Errorf("Expected %d execution steps, got %d", len(expectedOrder), len(executionOrder))
}
for i, expected := range expectedOrder {
if i >= len(executionOrder) {
t.Errorf("Missing execution step at index %d: expected '%s'", i, expected)
continue
}
if executionOrder[i] != expected {
t.Errorf("Execution order mismatch at index %d: expected '%s', got '%s'", i, expected, executionOrder[i])
}
}
t.Logf("Middleware execution order: %v", executionOrder)
}
// TestLoggerMiddlewareWithUserID 测试 Logger 中间件记录用户 IDT044
func TestLoggerMiddlewareWithUserID(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
accessLogFile := filepath.Join(tempDir, "access-userid.log")
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: accessLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
// 模拟 auth 中间件设置 user_id
app.Use(func(c *fiber.Ctx) error {
c.Locals(constants.ContextKeyUserID, "user_12345")
return c.Next()
})
app.Use(logger.Middleware())
app.Get("/test", func(c *fiber.Ctx) error {
return c.SendString("ok")
})
// 执行请求
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
// 刷新日志缓冲区
logger.Sync()
time.Sleep(100 * time.Millisecond)
// 验证访问日志包含 user_id
content, err := os.ReadFile(accessLogFile)
if err != nil {
t.Fatalf("Failed to read access log: %v", err)
}
logContent := string(content)
if !strings.Contains(logContent, "user_12345") {
t.Error("Access log should contain user_id 'user_12345'")
}
t.Logf("Access log with user_id:\n%s", logContent)
}
// TestConcurrentRequests 测试并发请求的 Request ID 唯一性T043
func TestConcurrentRequests(t *testing.T) {
app := fiber.New()
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/test", func(c *fiber.Ctx) error {
// 模拟一些处理时间
time.Sleep(10 * time.Millisecond)
requestID := c.Locals(constants.ContextKeyRequestID)
return c.JSON(fiber.Map{
"request_id": requestID,
})
})
// 并发发送多个请求
const numRequests = 50
requestIDs := make(chan string, numRequests)
errors := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
go func() {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
errors <- err
return
}
defer resp.Body.Close()
requestID := resp.Header.Get("X-Request-ID")
requestIDs <- requestID
errors <- nil
}()
}
// 收集所有结果
seenIDs := make(map[string]bool)
for i := 0; i < numRequests; i++ {
if err := <-errors; err != nil {
t.Fatalf("Request failed: %v", err)
}
requestID := <-requestIDs
if requestID == "" {
t.Error("Request ID should not be empty")
}
if seenIDs[requestID] {
t.Errorf("Duplicate request ID found: %s", requestID)
}
seenIDs[requestID] = true
}
// 验证所有 ID 都是唯一的
if len(seenIDs) != numRequests {
t.Errorf("Expected %d unique request IDs, got %d", numRequests, len(seenIDs))
}
t.Logf("Successfully generated %d unique request IDs concurrently", len(seenIDs))
}