534 lines
12 KiB
Go
534 lines
12 KiB
Go
package integration
|
||
|
||
import (
|
||
"io"
|
||
"net/http/httptest"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||
"github.com/break/junhong_cmp_fiber/pkg/logger"
|
||
"github.com/gofiber/fiber/v2"
|
||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
// TestRequestIDMiddleware 测试 RequestID 中间件生成 UUID v4(T043)
|
||
func TestRequestIDMiddleware(t *testing.T) {
|
||
app := fiber.New()
|
||
|
||
// 配置 requestid 中间件使用 UUID v4
|
||
app.Use(requestid.New(requestid.Config{
|
||
Generator: func() string {
|
||
return uuid.NewString()
|
||
},
|
||
}))
|
||
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
requestID := c.Locals(constants.ContextKeyRequestID)
|
||
return c.JSON(fiber.Map{
|
||
"request_id": requestID,
|
||
})
|
||
})
|
||
|
||
tests := []struct {
|
||
name string
|
||
}{
|
||
{name: "request 1"},
|
||
{name: "request 2"},
|
||
{name: "request 3"},
|
||
}
|
||
|
||
seenIDs := make(map[string]bool)
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
resp, err := app.Test(req)
|
||
if err != nil {
|
||
t.Fatalf("Failed to execute request: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 验证响应头包含 X-Request-ID
|
||
requestID := resp.Header.Get("X-Request-ID")
|
||
if requestID == "" {
|
||
t.Error("X-Request-ID header should not be empty")
|
||
}
|
||
|
||
// 验证是 UUID v4 格式
|
||
if _, err := uuid.Parse(requestID); err != nil {
|
||
t.Errorf("X-Request-ID is not a valid UUID: %s, error: %v", requestID, err)
|
||
}
|
||
|
||
// 验证 UUID 是唯一的
|
||
if seenIDs[requestID] {
|
||
t.Errorf("Request ID %s is not unique", requestID)
|
||
}
|
||
seenIDs[requestID] = true
|
||
|
||
t.Logf("Request ID: %s", requestID)
|
||
})
|
||
}
|
||
|
||
// 验证生成了多个不同的 ID
|
||
if len(seenIDs) != len(tests) {
|
||
t.Errorf("Expected %d unique request IDs, got %d", len(tests), len(seenIDs))
|
||
}
|
||
}
|
||
|
||
// TestLoggerMiddleware 测试 Logger 中间件记录访问日志(T044)
|
||
func TestLoggerMiddleware(t *testing.T) {
|
||
// 创建临时目录用于日志
|
||
tempDir := t.TempDir()
|
||
accessLogFile := filepath.Join(tempDir, "access.log")
|
||
|
||
// 初始化日志系统
|
||
err := logger.InitLoggers("info", false,
|
||
logger.LogRotationConfig{
|
||
Filename: filepath.Join(tempDir, "app.log"),
|
||
MaxSize: 10,
|
||
MaxBackups: 3,
|
||
MaxAge: 7,
|
||
Compress: false,
|
||
},
|
||
logger.LogRotationConfig{
|
||
Filename: accessLogFile,
|
||
MaxSize: 10,
|
||
MaxBackups: 3,
|
||
MaxAge: 7,
|
||
Compress: false,
|
||
},
|
||
)
|
||
if err != nil {
|
||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||
}
|
||
defer logger.Sync()
|
||
|
||
// 创建应用
|
||
app := fiber.New()
|
||
|
||
// 注册中间件
|
||
app.Use(requestid.New(requestid.Config{
|
||
Generator: func() string {
|
||
return uuid.NewString()
|
||
},
|
||
}))
|
||
app.Use(logger.Middleware())
|
||
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
return c.SendString("ok")
|
||
})
|
||
|
||
app.Post("/test", func(c *fiber.Ctx) error {
|
||
return c.SendStatus(201)
|
||
})
|
||
|
||
tests := []struct {
|
||
name string
|
||
method string
|
||
path string
|
||
expectedStatus int
|
||
}{
|
||
{
|
||
name: "GET request",
|
||
method: "GET",
|
||
path: "/test",
|
||
expectedStatus: 200,
|
||
},
|
||
{
|
||
name: "POST request",
|
||
method: "POST",
|
||
path: "/test",
|
||
expectedStatus: 201,
|
||
},
|
||
{
|
||
name: "GET with query params",
|
||
method: "GET",
|
||
path: "/test?foo=bar",
|
||
expectedStatus: 200,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
req := httptest.NewRequest(tt.method, tt.path, nil)
|
||
req.Header.Set("User-Agent", "test-agent/1.0")
|
||
|
||
resp, err := app.Test(req)
|
||
if err != nil {
|
||
t.Fatalf("Failed to execute request: %v", err)
|
||
}
|
||
resp.Body.Close()
|
||
|
||
if resp.StatusCode != tt.expectedStatus {
|
||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
|
||
}
|
||
})
|
||
}
|
||
|
||
// 刷新日志缓冲区
|
||
logger.Sync()
|
||
time.Sleep(100 * time.Millisecond)
|
||
|
||
// 验证访问日志文件存在且有内容
|
||
content, err := os.ReadFile(accessLogFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to read access log: %v", err)
|
||
}
|
||
|
||
if len(content) == 0 {
|
||
t.Error("Access log should not be empty")
|
||
}
|
||
|
||
logContent := string(content)
|
||
t.Logf("Access log content:\n%s", logContent)
|
||
|
||
// 验证日志包含必要的字段
|
||
requiredFields := []string{
|
||
"method",
|
||
"path",
|
||
"status",
|
||
"duration_ms",
|
||
"request_id",
|
||
"ip",
|
||
"user_agent",
|
||
}
|
||
|
||
for _, field := range requiredFields {
|
||
if !strings.Contains(logContent, field) {
|
||
t.Errorf("Access log should contain field '%s'", field)
|
||
}
|
||
}
|
||
|
||
// 验证记录了所有请求
|
||
lines := strings.Split(strings.TrimSpace(logContent), "\n")
|
||
if len(lines) < len(tests) {
|
||
t.Errorf("Expected at least %d log entries, got %d", len(tests), len(lines))
|
||
}
|
||
}
|
||
|
||
// TestRequestIDPropagation 测试 Request ID 在中间件链中传播(T045)
|
||
func TestRequestIDPropagation(t *testing.T) {
|
||
// 创建临时目录用于日志
|
||
tempDir := t.TempDir()
|
||
|
||
// 初始化日志系统
|
||
err := logger.InitLoggers("info", false,
|
||
logger.LogRotationConfig{
|
||
Filename: filepath.Join(tempDir, "app.log"),
|
||
MaxSize: 10,
|
||
MaxBackups: 3,
|
||
MaxAge: 7,
|
||
Compress: false,
|
||
},
|
||
logger.LogRotationConfig{
|
||
Filename: filepath.Join(tempDir, "access.log"),
|
||
MaxSize: 10,
|
||
MaxBackups: 3,
|
||
MaxAge: 7,
|
||
Compress: false,
|
||
},
|
||
)
|
||
if err != nil {
|
||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||
}
|
||
defer logger.Sync()
|
||
|
||
// 创建应用
|
||
app := fiber.New()
|
||
|
||
var capturedRequestID string
|
||
|
||
// 1. RequestID 中间件(第一个)
|
||
app.Use(requestid.New(requestid.Config{
|
||
Generator: func() string {
|
||
return uuid.NewString()
|
||
},
|
||
}))
|
||
|
||
// 2. Logger 中间件(第二个)
|
||
app.Use(logger.Middleware())
|
||
|
||
// 3. 自定义中间件验证 request ID 是否可访问
|
||
app.Use(func(c *fiber.Ctx) error {
|
||
requestID := c.Locals(constants.ContextKeyRequestID)
|
||
if requestID == nil {
|
||
t.Error("Request ID should be available in middleware chain")
|
||
}
|
||
if rid, ok := requestID.(string); ok {
|
||
capturedRequestID = rid
|
||
}
|
||
return c.Next()
|
||
})
|
||
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
// 在 handler 中也验证 request ID
|
||
requestID := c.Locals(constants.ContextKeyRequestID)
|
||
if requestID == nil {
|
||
return c.Status(500).SendString("Request ID not found in handler")
|
||
}
|
||
|
||
return c.JSON(fiber.Map{
|
||
"request_id": requestID,
|
||
"message": "Request ID propagated successfully",
|
||
})
|
||
})
|
||
|
||
// 执行请求
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
resp, err := app.Test(req)
|
||
if err != nil {
|
||
t.Fatalf("Failed to execute request: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 验证响应
|
||
if resp.StatusCode != 200 {
|
||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||
}
|
||
|
||
// 验证响应头中的 Request ID
|
||
headerRequestID := resp.Header.Get("X-Request-ID")
|
||
if headerRequestID == "" {
|
||
t.Error("X-Request-ID header should be set")
|
||
}
|
||
|
||
// 验证中间件捕获的 Request ID 与响应头一致
|
||
if capturedRequestID != headerRequestID {
|
||
t.Errorf("Request ID mismatch: middleware=%s, header=%s", capturedRequestID, headerRequestID)
|
||
}
|
||
|
||
// 验证响应体
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
t.Fatalf("Failed to read response body: %v", err)
|
||
}
|
||
|
||
if !strings.Contains(string(body), headerRequestID) {
|
||
t.Errorf("Response body should contain request ID %s", headerRequestID)
|
||
}
|
||
|
||
t.Logf("Request ID successfully propagated: %s", headerRequestID)
|
||
}
|
||
|
||
// TestMiddlewareOrder 测试中间件执行顺序(T045)
|
||
func TestMiddlewareOrder(t *testing.T) {
|
||
app := fiber.New()
|
||
|
||
executionOrder := []string{}
|
||
|
||
// 中间件 1: RequestID
|
||
app.Use(func(c *fiber.Ctx) error {
|
||
executionOrder = append(executionOrder, "requestid-start")
|
||
c.Locals(constants.ContextKeyRequestID, uuid.NewString())
|
||
err := c.Next()
|
||
executionOrder = append(executionOrder, "requestid-end")
|
||
return err
|
||
})
|
||
|
||
// 中间件 2: Logger
|
||
app.Use(func(c *fiber.Ctx) error {
|
||
executionOrder = append(executionOrder, "logger-start")
|
||
// 验证 Request ID 已经设置
|
||
if c.Locals(constants.ContextKeyRequestID) == nil {
|
||
t.Error("Request ID should be set before logger middleware")
|
||
}
|
||
err := c.Next()
|
||
executionOrder = append(executionOrder, "logger-end")
|
||
return err
|
||
})
|
||
|
||
// 中间件 3: Custom
|
||
app.Use(func(c *fiber.Ctx) error {
|
||
executionOrder = append(executionOrder, "custom-start")
|
||
err := c.Next()
|
||
executionOrder = append(executionOrder, "custom-end")
|
||
return err
|
||
})
|
||
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
executionOrder = append(executionOrder, "handler")
|
||
return c.SendString("ok")
|
||
})
|
||
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
resp, err := app.Test(req)
|
||
if err != nil {
|
||
t.Fatalf("Failed to execute request: %v", err)
|
||
}
|
||
resp.Body.Close()
|
||
|
||
// 验证执行顺序
|
||
expectedOrder := []string{
|
||
"requestid-start",
|
||
"logger-start",
|
||
"custom-start",
|
||
"handler",
|
||
"custom-end",
|
||
"logger-end",
|
||
"requestid-end",
|
||
}
|
||
|
||
if len(executionOrder) != len(expectedOrder) {
|
||
t.Errorf("Expected %d execution steps, got %d", len(expectedOrder), len(executionOrder))
|
||
}
|
||
|
||
for i, expected := range expectedOrder {
|
||
if i >= len(executionOrder) {
|
||
t.Errorf("Missing execution step at index %d: expected '%s'", i, expected)
|
||
continue
|
||
}
|
||
if executionOrder[i] != expected {
|
||
t.Errorf("Execution order mismatch at index %d: expected '%s', got '%s'", i, expected, executionOrder[i])
|
||
}
|
||
}
|
||
|
||
t.Logf("Middleware execution order: %v", executionOrder)
|
||
}
|
||
|
||
// TestLoggerMiddlewareWithUserID 测试 Logger 中间件记录用户 ID(T044)
|
||
func TestLoggerMiddlewareWithUserID(t *testing.T) {
|
||
// 创建临时目录用于日志
|
||
tempDir := t.TempDir()
|
||
accessLogFile := filepath.Join(tempDir, "access-userid.log")
|
||
|
||
// 初始化日志系统
|
||
err := logger.InitLoggers("info", false,
|
||
logger.LogRotationConfig{
|
||
Filename: filepath.Join(tempDir, "app.log"),
|
||
MaxSize: 10,
|
||
MaxBackups: 3,
|
||
MaxAge: 7,
|
||
Compress: false,
|
||
},
|
||
logger.LogRotationConfig{
|
||
Filename: accessLogFile,
|
||
MaxSize: 10,
|
||
MaxBackups: 3,
|
||
MaxAge: 7,
|
||
Compress: false,
|
||
},
|
||
)
|
||
if err != nil {
|
||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||
}
|
||
defer logger.Sync()
|
||
|
||
// 创建应用
|
||
app := fiber.New()
|
||
|
||
// 注册中间件
|
||
app.Use(requestid.New(requestid.Config{
|
||
Generator: func() string {
|
||
return uuid.NewString()
|
||
},
|
||
}))
|
||
|
||
// 模拟 auth 中间件设置 user_id
|
||
app.Use(func(c *fiber.Ctx) error {
|
||
c.Locals(constants.ContextKeyUserID, "user_12345")
|
||
return c.Next()
|
||
})
|
||
|
||
app.Use(logger.Middleware())
|
||
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
return c.SendString("ok")
|
||
})
|
||
|
||
// 执行请求
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
resp, err := app.Test(req)
|
||
if err != nil {
|
||
t.Fatalf("Failed to execute request: %v", err)
|
||
}
|
||
resp.Body.Close()
|
||
|
||
// 刷新日志缓冲区
|
||
logger.Sync()
|
||
time.Sleep(100 * time.Millisecond)
|
||
|
||
// 验证访问日志包含 user_id
|
||
content, err := os.ReadFile(accessLogFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to read access log: %v", err)
|
||
}
|
||
|
||
logContent := string(content)
|
||
if !strings.Contains(logContent, "user_12345") {
|
||
t.Error("Access log should contain user_id 'user_12345'")
|
||
}
|
||
|
||
t.Logf("Access log with user_id:\n%s", logContent)
|
||
}
|
||
|
||
// TestConcurrentRequests 测试并发请求的 Request ID 唯一性(T043)
|
||
func TestConcurrentRequests(t *testing.T) {
|
||
app := fiber.New()
|
||
|
||
app.Use(requestid.New(requestid.Config{
|
||
Generator: func() string {
|
||
return uuid.NewString()
|
||
},
|
||
}))
|
||
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
// 模拟一些处理时间
|
||
time.Sleep(10 * time.Millisecond)
|
||
requestID := c.Locals(constants.ContextKeyRequestID)
|
||
return c.JSON(fiber.Map{
|
||
"request_id": requestID,
|
||
})
|
||
})
|
||
|
||
// 并发发送多个请求
|
||
const numRequests = 50
|
||
requestIDs := make(chan string, numRequests)
|
||
errors := make(chan error, numRequests)
|
||
|
||
for i := 0; i < numRequests; i++ {
|
||
go func() {
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
resp, err := app.Test(req)
|
||
if err != nil {
|
||
errors <- err
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
requestID := resp.Header.Get("X-Request-ID")
|
||
requestIDs <- requestID
|
||
errors <- nil
|
||
}()
|
||
}
|
||
|
||
// 收集所有结果
|
||
seenIDs := make(map[string]bool)
|
||
for i := 0; i < numRequests; i++ {
|
||
if err := <-errors; err != nil {
|
||
t.Fatalf("Request failed: %v", err)
|
||
}
|
||
requestID := <-requestIDs
|
||
|
||
if requestID == "" {
|
||
t.Error("Request ID should not be empty")
|
||
}
|
||
|
||
if seenIDs[requestID] {
|
||
t.Errorf("Duplicate request ID found: %s", requestID)
|
||
}
|
||
seenIDs[requestID] = true
|
||
}
|
||
|
||
// 验证所有 ID 都是唯一的
|
||
if len(seenIDs) != numRequests {
|
||
t.Errorf("Expected %d unique request IDs, got %d", numRequests, len(seenIDs))
|
||
}
|
||
|
||
t.Logf("Successfully generated %d unique request IDs concurrently", len(seenIDs))
|
||
}
|