Files
junhong_cmp_fiber/tests/integration/middleware_test.go
huang 23eb0307bb
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 5m30s
feat: 实现门店套餐分配功能并统一测试基础设施
新增功能:
- 门店套餐分配管理(shop_package_allocation):支持门店套餐库存管理
- 门店套餐系列分配管理(shop_series_allocation):支持套餐系列分配和佣金层级设置
- 我的套餐查询(my_package):支持门店查询自己的套餐分配情况

测试改进:
- 统一集成测试基础设施,新增 testutils.NewIntegrationTestEnv
- 重构所有集成测试使用新的测试环境设置
- 移除旧的测试辅助函数和冗余测试文件
- 新增 test_helpers_test.go 统一任务测试辅助

技术细节:
- 新增数据库迁移 000025_create_shop_allocation_tables
- 新增 3 个 Handler、Service、Store 和对应的单元测试
- 更新 OpenAPI 文档和文档生成器
- 测试覆盖率:Service 层 > 90%

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-28 10:45:16 +08:00

525 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package integration
import (
"io"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/google/uuid"
)
// TestRequestIDMiddleware 测试 RequestID 中间件生成 UUID v4T043
func TestRequestIDMiddleware(t *testing.T) {
app := fiber.New()
// 配置 requestid 中间件使用 UUID v4
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/test", func(c *fiber.Ctx) error {
requestID := c.Locals(constants.ContextKeyRequestID)
return c.JSON(fiber.Map{
"request_id": requestID,
})
})
tests := []struct {
name string
}{
{name: "request 1"},
{name: "request 2"},
{name: "request 3"},
}
seenIDs := make(map[string]bool)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证响应头包含 X-Request-ID
requestID := resp.Header.Get("X-Request-ID")
if requestID == "" {
t.Error("X-Request-ID header should not be empty")
}
// 验证是 UUID v4 格式
if _, err := uuid.Parse(requestID); err != nil {
t.Errorf("X-Request-ID is not a valid UUID: %s, error: %v", requestID, err)
}
// 验证 UUID 是唯一的
if seenIDs[requestID] {
t.Errorf("Request ID %s is not unique", requestID)
}
seenIDs[requestID] = true
t.Logf("Request ID: %s", requestID)
})
}
// 验证生成了多个不同的 ID
if len(seenIDs) != len(tests) {
t.Errorf("Expected %d unique request IDs, got %d", len(tests), len(seenIDs))
}
}
// TestLoggerMiddleware 测试 Logger 中间件记录访问日志T044
func TestLoggerMiddleware(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
accessLogFile := filepath.Join(tempDir, "access.log")
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: accessLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer func() { _ = logger.Sync() }()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Use(logger.Middleware())
app.Get("/test", func(c *fiber.Ctx) error {
return c.SendString("ok")
})
app.Post("/test", func(c *fiber.Ctx) error {
return c.SendStatus(201)
})
tests := []struct {
name string
method string
path string
expectedStatus int
}{
{
name: "GET request",
method: "GET",
path: "/test",
expectedStatus: 200,
},
{
name: "POST request",
method: "POST",
path: "/test",
expectedStatus: 201,
},
{
name: "GET with query params",
method: "GET",
path: "/test?foo=bar",
expectedStatus: 200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, tt.path, nil)
req.Header.Set("User-Agent", "test-agent/1.0")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
if resp.StatusCode != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
}
})
}
// 刷新日志缓冲区
_ = logger.Sync()
time.Sleep(100 * time.Millisecond)
// 验证访问日志文件存在且有内容
content, err := os.ReadFile(accessLogFile)
if err != nil {
t.Fatalf("Failed to read access log: %v", err)
}
if len(content) == 0 {
t.Error("Access log should not be empty")
}
logContent := string(content)
t.Logf("Access log content:\n%s", logContent)
// 验证日志包含必要的字段
requiredFields := []string{
"method",
"path",
"status",
"duration_ms",
"request_id",
"ip",
"user_agent",
}
for _, field := range requiredFields {
if !strings.Contains(logContent, field) {
t.Errorf("Access log should contain field '%s'", field)
}
}
// 验证记录了所有请求
lines := strings.Split(strings.TrimSpace(logContent), "\n")
if len(lines) < len(tests) {
t.Errorf("Expected at least %d log entries, got %d", len(tests), len(lines))
}
}
// TestRequestIDPropagation 测试 Request ID 在中间件链中传播T045
func TestRequestIDPropagation(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer func() { _ = logger.Sync() }()
// 创建应用
app := fiber.New()
var capturedRequestID string
// 1. RequestID 中间件(第一个)
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
// 2. Logger 中间件(第二个)
app.Use(logger.Middleware())
// 3. 自定义中间件验证 request ID 是否可访问
app.Use(func(c *fiber.Ctx) error {
requestID := c.Locals(constants.ContextKeyRequestID)
if requestID == nil {
t.Error("Request ID should be available in middleware chain")
}
if rid, ok := requestID.(string); ok {
capturedRequestID = rid
}
return c.Next()
})
app.Get("/test", func(c *fiber.Ctx) error {
// 在 handler 中也验证 request ID
requestID := c.Locals(constants.ContextKeyRequestID)
if requestID == nil {
return c.Status(500).SendString("Request ID not found in handler")
}
return c.JSON(fiber.Map{
"request_id": requestID,
"message": "Request ID propagated successfully",
})
})
// 执行请求
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证响应
if resp.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// 验证响应头中的 Request ID
headerRequestID := resp.Header.Get("X-Request-ID")
if headerRequestID == "" {
t.Error("X-Request-ID header should be set")
}
// 验证中间件捕获的 Request ID 与响应头一致
if capturedRequestID != headerRequestID {
t.Errorf("Request ID mismatch: middleware=%s, header=%s", capturedRequestID, headerRequestID)
}
// 验证响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
if !strings.Contains(string(body), headerRequestID) {
t.Errorf("Response body should contain request ID %s", headerRequestID)
}
t.Logf("Request ID successfully propagated: %s", headerRequestID)
}
// TestMiddlewareOrder 测试中间件执行顺序T045
func TestMiddlewareOrder(t *testing.T) {
app := fiber.New()
executionOrder := []string{}
// 中间件 1: RequestID
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "requestid-start")
c.Locals(constants.ContextKeyRequestID, uuid.NewString())
err := c.Next()
executionOrder = append(executionOrder, "requestid-end")
return err
})
// 中间件 2: Logger
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "logger-start")
// 验证 Request ID 已经设置
if c.Locals(constants.ContextKeyRequestID) == nil {
t.Error("Request ID should be set before logger middleware")
}
err := c.Next()
executionOrder = append(executionOrder, "logger-end")
return err
})
// 中间件 3: Custom
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "custom-start")
err := c.Next()
executionOrder = append(executionOrder, "custom-end")
return err
})
app.Get("/test", func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "handler")
return c.SendString("ok")
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
// 验证执行顺序
expectedOrder := []string{
"requestid-start",
"logger-start",
"custom-start",
"handler",
"custom-end",
"logger-end",
"requestid-end",
}
if len(executionOrder) != len(expectedOrder) {
t.Errorf("Expected %d execution steps, got %d", len(expectedOrder), len(executionOrder))
}
for i, expected := range expectedOrder {
if i >= len(executionOrder) {
t.Errorf("Missing execution step at index %d: expected '%s'", i, expected)
continue
}
if executionOrder[i] != expected {
t.Errorf("Execution order mismatch at index %d: expected '%s', got '%s'", i, expected, executionOrder[i])
}
}
t.Logf("Middleware execution order: %v", executionOrder)
}
func TestLoggerMiddlewareWithUserID(t *testing.T) {
tempDir := t.TempDir()
accessLogFile := filepath.Join(tempDir, "access-userid.log")
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: accessLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer func() { _ = logger.Sync() }()
app := fiber.New()
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Use(func(c *fiber.Ctx) error {
c.Locals(constants.ContextKeyUserID, uint(12345))
return c.Next()
})
app.Use(logger.Middleware())
app.Get("/test", func(c *fiber.Ctx) error {
return c.SendString("ok")
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
_ = logger.Sync()
time.Sleep(100 * time.Millisecond)
content, err := os.ReadFile(accessLogFile)
if err != nil {
t.Fatalf("Failed to read access log: %v", err)
}
logContent := string(content)
if !strings.Contains(logContent, "12345") {
t.Error("Access log should contain user_id '12345'")
}
t.Logf("Access log with user_id:\n%s", logContent)
}
// TestConcurrentRequests 测试并发请求的 Request ID 唯一性T043
func TestConcurrentRequests(t *testing.T) {
app := fiber.New()
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/test", func(c *fiber.Ctx) error {
// 模拟一些处理时间
time.Sleep(10 * time.Millisecond)
requestID := c.Locals(constants.ContextKeyRequestID)
return c.JSON(fiber.Map{
"request_id": requestID,
})
})
// 并发发送多个请求
const numRequests = 50
requestIDs := make(chan string, numRequests)
errors := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
go func() {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
errors <- err
return
}
defer resp.Body.Close()
requestID := resp.Header.Get("X-Request-ID")
requestIDs <- requestID
errors <- nil
}()
}
// 收集所有结果
seenIDs := make(map[string]bool)
for i := 0; i < numRequests; i++ {
if err := <-errors; err != nil {
t.Fatalf("Request failed: %v", err)
}
requestID := <-requestIDs
if requestID == "" {
t.Error("Request ID should not be empty")
}
if seenIDs[requestID] {
t.Errorf("Duplicate request ID found: %s", requestID)
}
seenIDs[requestID] = true
}
// 验证所有 ID 都是唯一的
if len(seenIDs) != numRequests {
t.Errorf("Expected %d unique request IDs, got %d", numRequests, len(seenIDs))
}
t.Logf("Successfully generated %d unique request IDs concurrently", len(seenIDs))
}