备份一下

This commit is contained in:
2025-11-11 15:53:01 +08:00
parent e98dd4d725
commit 39c5b524a9
10 changed files with 4295 additions and 63 deletions

View File

@@ -3,7 +3,6 @@ package config
import ( import (
"errors" "errors"
"fmt" "fmt"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe" "unsafe"
@@ -93,8 +92,9 @@ func (c *Config) Validate() error {
if c.Redis.Address == "" { if c.Redis.Address == "" {
return fmt.Errorf("invalid configuration: redis.address: must be non-empty (current value: empty)") return fmt.Errorf("invalid configuration: redis.address: must be non-empty (current value: empty)")
} }
if !strings.Contains(c.Redis.Address, ":") { // Port 验证(独立字段)
return fmt.Errorf("invalid configuration: redis.address: invalid format (current value: %s, expected: HOST:PORT)", c.Redis.Address) if c.Redis.Port <= 0 || c.Redis.Port > 65535 {
return fmt.Errorf("invalid configuration: redis.port: port number out of range (current value: %d, expected: 1-65535)", c.Redis.Port)
} }
if c.Redis.DB < 0 || c.Redis.DB > 15 { if c.Redis.DB < 0 || c.Redis.DB > 15 {
return fmt.Errorf("invalid configuration: redis.db: database number out of range (current value: %d, expected: 0-15)", c.Redis.DB) return fmt.Errorf("invalid configuration: redis.db: database number out of range (current value: %d, expected: 0-15)", c.Redis.DB)

615
pkg/config/config_test.go Normal file
View File

@@ -0,0 +1,615 @@
package config
import (
"testing"
"time"
)
// TestConfig_Validate tests configuration validation rules
func TestConfig_Validate(t *testing.T) {
tests := []struct {
name string
config *Config
wantErr bool
errMsg string
}{
{
name: "valid config",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
DB: 0,
PoolSize: 10,
MinIdleConns: 5,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
MaxBackups: 30,
MaxAge: 30,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
MaxBackups: 90,
MaxAge: 90,
},
},
Middleware: MiddlewareConfig{
RateLimiter: RateLimiterConfig{
Max: 100,
Expiration: 1 * time.Minute,
Storage: "memory",
},
},
},
wantErr: false,
},
{
name: "empty server address",
config: &Config{
Server: ServerConfig{
Address: "",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "server.address",
},
{
name: "read timeout too short",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 1 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "read_timeout",
},
{
name: "read timeout too long",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 400 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "read_timeout",
},
{
name: "write timeout out of range",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 1 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "write_timeout",
},
{
name: "shutdown timeout too short",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 5 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "shutdown_timeout",
},
{
name: "empty redis address",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "redis.address",
},
{
name: "invalid redis port - too high",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 99999,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "redis.port",
},
{
name: "invalid redis port - zero",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 0,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "redis.port",
},
{
name: "redis db out of range",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
DB: 20,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "redis.db",
},
{
name: "redis pool size too large",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 2000,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "pool_size",
},
{
name: "min idle conns exceeds pool size",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
MinIdleConns: 20,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "min_idle_conns",
},
{
name: "invalid log level",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "invalid",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "logging.level",
},
{
name: "empty app log filename",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "app_log.filename",
},
{
name: "app log max size out of range",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 2000,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "app_log.max_size",
},
{
name: "invalid rate limiter storage",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
Middleware: MiddlewareConfig{
RateLimiter: RateLimiterConfig{
Max: 100,
Storage: "invalid",
},
},
},
wantErr: true,
errMsg: "rate_limiter.storage",
},
{
name: "rate limiter max too high",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
Middleware: MiddlewareConfig{
RateLimiter: RateLimiterConfig{
Max: 20000,
Storage: "memory",
},
},
},
wantErr: true,
errMsg: "rate_limiter.max",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr && tt.errMsg != "" {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.errMsg)
} else if err.Error() == "" {
t.Errorf("expected error containing %q, got empty error", tt.errMsg)
}
// Note: We check that error message exists, not exact match
// This is because error messages might change slightly
}
})
}
}
// TestSet tests the Set function
func TestSet(t *testing.T) {
// Valid config
validCfg := &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
}
err := Set(validCfg)
if err != nil {
t.Errorf("Set() with valid config failed: %v", err)
}
// Verify it was set
got := Get()
if got.Server.Address != ":3000" {
t.Errorf("Get() after Set() returned wrong address: got %s, want :3000", got.Server.Address)
}
// Test with nil config
err = Set(nil)
if err == nil {
t.Error("Set(nil) should return error")
}
// Test with invalid config
invalidCfg := &Config{
Server: ServerConfig{
Address: "", // Empty address is invalid
},
}
err = Set(invalidCfg)
if err == nil {
t.Error("Set() with invalid config should return error")
}
}

661
pkg/config/loader_test.go Normal file
View File

@@ -0,0 +1,661 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/spf13/viper"
)
// TestLoad tests the config loading functionality
func TestLoad(t *testing.T) {
tests := []struct {
name string
setupEnv func()
cleanupEnv func()
createConfig func(t *testing.T) string
wantErr bool
validateFunc func(t *testing.T, cfg *Config)
}{
{
name: "valid default config",
setupEnv: func() {
os.Setenv(constants.EnvConfigPath, "")
os.Setenv(constants.EnvConfigEnv, "")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigPath)
os.Unsetenv(constants.EnvConfigEnv)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set as default config path
os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: false,
validateFunc: func(t *testing.T, cfg *Config) {
if cfg.Server.Address != ":3000" {
t.Errorf("expected server.address :3000, got %s", cfg.Server.Address)
}
if cfg.Server.ReadTimeout != 10*time.Second {
t.Errorf("expected read_timeout 10s, got %v", cfg.Server.ReadTimeout)
}
if cfg.Redis.Address != "localhost" {
t.Errorf("expected redis.address localhost, got %s", cfg.Redis.Address)
}
if cfg.Redis.Port != 6379 {
t.Errorf("expected redis.port 6379, got %d", cfg.Redis.Port)
}
if cfg.Redis.PoolSize != 10 {
t.Errorf("expected redis.pool_size 10, got %d", cfg.Redis.PoolSize)
}
if cfg.Logging.Level != "info" {
t.Errorf("expected logging.level info, got %s", cfg.Logging.Level)
}
if cfg.Middleware.EnableAuth != true {
t.Errorf("expected enable_auth true, got %v", cfg.Middleware.EnableAuth)
}
},
},
{
name: "environment-specific config (dev)",
setupEnv: func() {
os.Setenv(constants.EnvConfigEnv, "dev")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigEnv)
os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
// Create configs directory in temp
tmpDir := t.TempDir()
configsDir := filepath.Join(tmpDir, "configs")
if err := os.MkdirAll(configsDir, 0755); err != nil {
t.Fatalf("failed to create configs dir: %v", err)
}
// Create dev config
devConfigFile := filepath.Join(configsDir, "config.dev.yaml")
content := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 1
pool_size: 5
min_idle_conns: 2
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 50
max_backups: 10
max_age: 7
compress: false
access_log:
filename: "logs/access.log"
max_size: 100
max_backups: 30
max_age: 30
compress: false
middleware:
enable_auth: false
enable_rate_limiter: false
rate_limiter:
max: 50
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(devConfigFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create dev config file: %v", err)
}
// Change to tmpDir so relative path works
originalWd, _ := os.Getwd()
os.Chdir(tmpDir)
t.Cleanup(func() { os.Chdir(originalWd) })
return devConfigFile
},
wantErr: false,
validateFunc: func(t *testing.T, cfg *Config) {
if cfg.Server.Address != ":8080" {
t.Errorf("expected server.address :8080, got %s", cfg.Server.Address)
}
if cfg.Redis.DB != 1 {
t.Errorf("expected redis.db 1, got %d", cfg.Redis.DB)
}
if cfg.Logging.Level != "debug" {
t.Errorf("expected logging.level debug, got %s", cfg.Logging.Level)
}
if cfg.Middleware.EnableAuth != false {
t.Errorf("expected enable_auth false, got %v", cfg.Middleware.EnableAuth)
}
},
},
{
name: "invalid YAML syntax",
setupEnv: func() {
os.Setenv(constants.EnvConfigPath, "")
os.Setenv(constants.EnvConfigEnv, "")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigPath)
os.Unsetenv(constants.EnvConfigEnv)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
invalid yaml syntax here!!!
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - invalid server address",
setupEnv: func() {
os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ""
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - timeout out of range",
setupEnv: func() {
os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "1s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - invalid redis port",
setupEnv: func() {
os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 99999
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset viper for each test
viper.Reset()
// Setup environment
if tt.setupEnv != nil {
tt.setupEnv()
}
// Create config file
if tt.createConfig != nil {
tt.createConfig(t)
}
// Cleanup after test
if tt.cleanupEnv != nil {
defer tt.cleanupEnv()
}
// Load config
cfg, err := Load()
// Check error expectation
if (err != nil) != tt.wantErr {
t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr)
return
}
// Validate config if no error expected
if !tt.wantErr && tt.validateFunc != nil {
tt.validateFunc(t, cfg)
}
})
}
}
// TestReload tests the config reload functionality
func TestReload(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial config
initialContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set config path
os.Setenv(constants.EnvConfigPath, configFile)
defer os.Unsetenv(constants.EnvConfigPath)
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
// Verify initial values
if cfg.Logging.Level != "info" {
t.Errorf("expected initial logging.level info, got %s", cfg.Logging.Level)
}
if cfg.Server.Address != ":3000" {
t.Errorf("expected initial server.address :3000, got %s", cfg.Server.Address)
}
// Modify config file
updatedContent := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 20
min_idle_conns: 10
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: false
enable_rate_limiter: true
rate_limiter:
max: 200
expiration: "2m"
storage: "redis"
`
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
t.Fatalf("failed to update config file: %v", err)
}
// Reload config
newCfg, err := Reload()
if err != nil {
t.Fatalf("failed to reload config: %v", err)
}
// Verify updated values
if newCfg.Logging.Level != "debug" {
t.Errorf("expected updated logging.level debug, got %s", newCfg.Logging.Level)
}
if newCfg.Server.Address != ":8080" {
t.Errorf("expected updated server.address :8080, got %s", newCfg.Server.Address)
}
if newCfg.Redis.PoolSize != 20 {
t.Errorf("expected updated redis.pool_size 20, got %d", newCfg.Redis.PoolSize)
}
if newCfg.Middleware.EnableAuth != false {
t.Errorf("expected updated enable_auth false, got %v", newCfg.Middleware.EnableAuth)
}
if newCfg.Middleware.EnableRateLimiter != true {
t.Errorf("expected updated enable_rate_limiter true, got %v", newCfg.Middleware.EnableRateLimiter)
}
// Verify global config was updated
globalCfg := Get()
if globalCfg.Logging.Level != "debug" {
t.Errorf("expected global config updated, got logging.level %s", globalCfg.Logging.Level)
}
}
// TestGetConfigPath tests the GetConfigPath function
func TestGetConfigPath(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
defer os.Unsetenv(constants.EnvConfigPath)
// Load config
_, err := Load()
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
// Get config path
path := GetConfigPath()
if path == "" {
t.Error("expected non-empty config path")
}
// Verify it's an absolute path
if !filepath.IsAbs(path) {
t.Errorf("expected absolute path, got %s", path)
}
}

422
pkg/config/watcher_test.go Normal file
View File

@@ -0,0 +1,422 @@
package config
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/spf13/viper"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
)
// TestWatch tests the config hot reload watcher
func TestWatch(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial config
initialContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set config path
os.Setenv(constants.EnvConfigPath, configFile)
defer os.Unsetenv(constants.EnvConfigPath)
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
// Verify initial values
if cfg.Logging.Level != "info" {
t.Fatalf("expected initial logging.level info, got %s", cfg.Logging.Level)
}
// Create logger for testing
logger := zaptest.NewLogger(t)
// Start watcher in goroutine with context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go Watch(ctx, logger)
// Give watcher time to initialize
time.Sleep(100 * time.Millisecond)
// Modify config file to trigger hot reload
updatedContent := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 20
min_idle_conns: 10
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: false
enable_rate_limiter: true
rate_limiter:
max: 200
expiration: "2m"
storage: "redis"
`
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
t.Fatalf("failed to update config file: %v", err)
}
// Wait for watcher to detect and process changes (spec requires detection within 5 seconds)
// We use a more aggressive timeout for testing
time.Sleep(2 * time.Second)
// Verify config was reloaded
reloadedCfg := Get()
if reloadedCfg.Logging.Level != "debug" {
t.Errorf("expected config hot reload, got logging.level %s instead of debug", reloadedCfg.Logging.Level)
}
if reloadedCfg.Server.Address != ":8080" {
t.Errorf("expected config hot reload, got server.address %s instead of :8080", reloadedCfg.Server.Address)
}
if reloadedCfg.Redis.PoolSize != 20 {
t.Errorf("expected config hot reload, got redis.pool_size %d instead of 20", reloadedCfg.Redis.PoolSize)
}
// Cancel context to stop watcher
cancel()
// Give watcher time to shut down gracefully
time.Sleep(100 * time.Millisecond)
}
// TestWatch_InvalidConfigRejected tests that invalid config changes are rejected
func TestWatch_InvalidConfigRejected(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial valid config
validContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set config path
os.Setenv(constants.EnvConfigPath, configFile)
defer os.Unsetenv(constants.EnvConfigPath)
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
initialLevel := cfg.Logging.Level
if initialLevel != "info" {
t.Fatalf("expected initial logging.level info, got %s", initialLevel)
}
// Create logger for testing
logger := zaptest.NewLogger(t)
// Start watcher
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go Watch(ctx, logger)
// Give watcher time to initialize
time.Sleep(100 * time.Millisecond)
// Write INVALID config (malformed YAML)
invalidContent := `
server:
address: ":3000"
invalid yaml syntax here!!!
`
if err := os.WriteFile(configFile, []byte(invalidContent), 0644); err != nil {
t.Fatalf("failed to write invalid config: %v", err)
}
// Wait for watcher to detect changes
time.Sleep(2 * time.Second)
// Verify config was NOT changed (should keep previous valid config)
currentCfg := Get()
if currentCfg.Logging.Level != initialLevel {
t.Errorf("expected config to remain unchanged after invalid update, got logging.level %s instead of %s", currentCfg.Logging.Level, initialLevel)
}
// Restore valid config
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
t.Fatalf("failed to restore valid config: %v", err)
}
time.Sleep(500 * time.Millisecond)
// Now write config with validation error (timeout out of range)
invalidValidationContent := `
server:
address: ":3000"
read_timeout: "1s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "debug"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(invalidValidationContent), 0644); err != nil {
t.Fatalf("failed to write config with validation error: %v", err)
}
// Wait for watcher to detect changes
time.Sleep(2 * time.Second)
// Verify config was NOT changed (validation should have failed)
finalCfg := Get()
if finalCfg.Logging.Level != initialLevel {
t.Errorf("expected config to remain unchanged after validation error, got logging.level %s instead of %s", finalCfg.Logging.Level, initialLevel)
}
if finalCfg.Server.ReadTimeout == 1*time.Second {
t.Error("expected config to remain unchanged, but read_timeout was updated to invalid value")
}
// Cancel context
cancel()
time.Sleep(100 * time.Millisecond)
}
// TestWatch_ContextCancellation tests graceful shutdown on context cancellation
func TestWatch_ContextCancellation(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
defer os.Unsetenv(constants.EnvConfigPath)
// Load config
_, err := Load()
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
// Create logger
logger := zap.NewNop() // Use no-op logger for this test
// Start watcher with context
ctx, cancel := context.WithCancel(context.Background())
done := make(chan bool)
go func() {
Watch(ctx, logger)
done <- true
}()
// Give watcher time to start
time.Sleep(100 * time.Millisecond)
// Cancel context (simulate graceful shutdown)
cancel()
// Wait for watcher to stop (should happen quickly)
select {
case <-done:
// Watcher stopped successfully
case <-time.After(2 * time.Second):
t.Error("watcher did not stop within timeout after context cancellation")
}
}

518
pkg/logger/logger_test.go Normal file
View File

@@ -0,0 +1,518 @@
package logger
import (
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// TestInitLoggers 测试日志初始化T026
func TestInitLoggers(t *testing.T) {
// 创建临时目录用于日志文件
tempDir := t.TempDir()
tests := []struct {
name string
level string
development bool
appLogConfig LogRotationConfig
accessLogConfig LogRotationConfig
wantErr bool
validateFunc func(t *testing.T)
}{
{
name: "production mode with info level",
level: "info",
development: false,
appLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "app-prod.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
accessLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "access-prod.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
wantErr: false,
validateFunc: func(t *testing.T) {
if appLogger == nil {
t.Error("appLogger should not be nil")
}
if accessLogger == nil {
t.Error("accessLogger should not be nil")
}
// 写入一条日志以触发文件创建
GetAppLogger().Info("test log creation")
Sync()
// 验证日志文件创建
if _, err := os.Stat(filepath.Join(tempDir, "app-prod.log")); os.IsNotExist(err) {
t.Error("app log file should be created after writing")
}
},
},
{
name: "development mode with debug level",
level: "debug",
development: true,
appLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "app-dev.log"),
MaxSize: 5,
MaxBackups: 2,
MaxAge: 3,
Compress: false,
},
accessLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "access-dev.log"),
MaxSize: 5,
MaxBackups: 2,
MaxAge: 3,
Compress: false,
},
wantErr: false,
validateFunc: func(t *testing.T) {
if appLogger == nil {
t.Error("appLogger should not be nil in dev mode")
}
if accessLogger == nil {
t.Error("accessLogger should not be nil in dev mode")
}
},
},
{
name: "warn level logging",
level: "warn",
development: false,
appLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "app-warn.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
accessLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "access-warn.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
wantErr: false,
validateFunc: func(t *testing.T) {
if appLogger == nil {
t.Error("appLogger should not be nil")
}
},
},
{
name: "error level logging",
level: "error",
development: false,
appLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "app-error.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
accessLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "access-error.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
wantErr: false,
validateFunc: func(t *testing.T) {
if appLogger == nil {
t.Error("appLogger should not be nil")
}
},
},
{
name: "invalid level defaults to info",
level: "invalid",
development: false,
appLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "app-invalid.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
accessLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "access-invalid.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
wantErr: false,
validateFunc: func(t *testing.T) {
if appLogger == nil {
t.Error("appLogger should not be nil even with invalid level")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := InitLoggers(tt.level, tt.development, tt.appLogConfig, tt.accessLogConfig)
if (err != nil) != tt.wantErr {
t.Errorf("InitLoggers() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.validateFunc != nil {
tt.validateFunc(t)
}
})
}
}
// TestGetAppLogger 测试获取应用日志记录器T026
func TestGetAppLogger(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
tests := []struct {
name string
setupFunc func()
wantNil bool
}{
{
name: "after initialization",
setupFunc: func() {
InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-get.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-get.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
)
},
wantNil: false,
},
{
name: "before initialization returns nop logger",
setupFunc: func() {
// 重置全局变量
appLogger = nil
},
wantNil: false, // GetAppLogger 应该返回 nop logger不是 nil
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupFunc()
logger := GetAppLogger()
if logger == nil {
t.Error("GetAppLogger() should never return nil, should return nop logger instead")
}
})
}
}
// TestGetAccessLogger 测试获取访问日志记录器T028
func TestGetAccessLogger(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
tests := []struct {
name string
setupFunc func()
wantNil bool
}{
{
name: "after initialization",
setupFunc: func() {
InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
)
},
wantNil: false,
},
{
name: "before initialization returns nop logger",
setupFunc: func() {
// 重置全局变量
accessLogger = nil
},
wantNil: false, // GetAccessLogger 应该返回 nop logger不是 nil
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupFunc()
logger := GetAccessLogger()
if logger == nil {
t.Error("GetAccessLogger() should never return nil, should return nop logger instead")
}
})
}
}
// TestSync 测试日志缓冲区刷新T028
func TestSync(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
tests := []struct {
name string
setupFunc func()
wantErr bool
}{
{
name: "sync after initialization",
setupFunc: func() {
InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-sync.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-sync.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
)
},
wantErr: false,
},
{
name: "sync before initialization",
setupFunc: func() {
appLogger = nil
accessLogger = nil
},
wantErr: false, // 应该优雅地处理 nil 情况
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupFunc()
err := Sync()
if (err != nil) != tt.wantErr {
t.Errorf("Sync() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
// TestParseLevel 测试日志级别解析T026
func TestParseLevel(t *testing.T) {
tests := []struct {
name string
level string
want zapcore.Level
}{
{
name: "debug level",
level: "debug",
want: zapcore.DebugLevel,
},
{
name: "info level",
level: "info",
want: zapcore.InfoLevel,
},
{
name: "warn level",
level: "warn",
want: zapcore.WarnLevel,
},
{
name: "error level",
level: "error",
want: zapcore.ErrorLevel,
},
{
name: "invalid level defaults to info",
level: "invalid",
want: zapcore.InfoLevel,
},
{
name: "empty level defaults to info",
level: "",
want: zapcore.InfoLevel,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := parseLevel(tt.level)
if got != tt.want {
t.Errorf("parseLevel() = %v, want %v", got, tt.want)
}
})
}
}
// TestDualLoggerSystem 测试双日志系统T028
func TestDualLoggerSystem(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
appLogFile := filepath.Join(tempDir, "app-dual.log")
accessLogFile := filepath.Join(tempDir, "access-dual.log")
// 初始化双日志系统
err := InitLoggers("info", false,
LogRotationConfig{
Filename: appLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false, // 不压缩以便检查内容
},
LogRotationConfig{
Filename: accessLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
// 写入应用日志
appLog := GetAppLogger()
appLog.Info("test app log message",
zap.String("module", "test"),
zap.Int("code", 200),
)
// 写入访问日志
accessLog := GetAccessLogger()
accessLog.Info("test access log message",
zap.String("method", "GET"),
zap.String("path", "/api/test"),
zap.Int("status", 200),
zap.Duration("latency", 100),
)
// 刷新缓冲区
if err := Sync(); err != nil {
t.Fatalf("Sync failed: %v", err)
}
// 验证应用日志文件存在并有内容
appLogContent, err := os.ReadFile(appLogFile)
if err != nil {
t.Fatalf("Failed to read app log file: %v", err)
}
if len(appLogContent) == 0 {
t.Error("App log file should not be empty")
}
// 验证访问日志文件存在并有内容
accessLogContent, err := os.ReadFile(accessLogFile)
if err != nil {
t.Fatalf("Failed to read access log file: %v", err)
}
if len(accessLogContent) == 0 {
t.Error("Access log file should not be empty")
}
// 验证两个日志文件是独立的
if string(appLogContent) == string(accessLogContent) {
t.Error("App log and access log should have different content")
}
}
// TestLoggerReinitialization 测试日志重新初始化T026
func TestLoggerReinitialization(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
// 第一次初始化
err := InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-reinit1.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-reinit1.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
)
if err != nil {
t.Fatalf("First InitLoggers failed: %v", err)
}
firstAppLogger := GetAppLogger()
// 第二次初始化(重新初始化)
err = InitLoggers("debug", true,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-reinit2.log"),
MaxSize: 5,
MaxBackups: 2,
MaxAge: 3,
Compress: false,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-reinit2.log"),
MaxSize: 5,
MaxBackups: 2,
MaxAge: 3,
Compress: false,
},
)
if err != nil {
t.Fatalf("Second InitLoggers failed: %v", err)
}
secondAppLogger := GetAppLogger()
// 验证重新初始化后日志记录器已更新
if firstAppLogger == secondAppLogger {
t.Error("Logger should be replaced after reinitialization")
}
}

388
pkg/logger/rotation_test.go Normal file
View File

@@ -0,0 +1,388 @@
package logger
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
"go.uber.org/zap"
)
// TestLogRotation 测试日志轮转功能T027
func TestLogRotation(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
appLogFile := filepath.Join(tempDir, "app-rotation.log")
// 初始化日志系统,设置较小的 MaxSize 以便测试
err := InitLoggers("info", false,
LogRotationConfig{
Filename: appLogFile,
MaxSize: 1, // 1MB写入足够数据后会触发轮转
MaxBackups: 3,
MaxAge: 7,
Compress: false, // 不压缩以便检查
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-rotation.log"),
MaxSize: 1,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
logger := GetAppLogger()
// 写入大量日志数据以触发轮转每条约100字节写入15000条约1.5MB
largeMessage := strings.Repeat("a", 100)
for i := 0; i < 15000; i++ {
logger.Info(largeMessage,
zap.Int("iteration", i),
zap.String("data", largeMessage),
)
}
// 刷新缓冲区
if err := Sync(); err != nil {
t.Fatalf("Sync failed: %v", err)
}
// 等待一小段时间确保文件写入完成
time.Sleep(100 * time.Millisecond)
// 验证主日志文件存在
if _, err := os.Stat(appLogFile); os.IsNotExist(err) {
t.Error("Main log file should exist")
}
// 检查是否有备份文件(轮转后的文件)
files, err := filepath.Glob(filepath.Join(tempDir, "app-rotation-*.log"))
if err != nil {
t.Fatalf("Failed to glob backup files: %v", err)
}
// 由于写入了超过1MB的数据应该触发至少一次轮转
if len(files) == 0 {
// 可能系统写入速度或lumberjack行为导致未立即轮转检查主文件大小
info, err := os.Stat(appLogFile)
if err != nil {
t.Fatalf("Failed to stat main log file: %v", err)
}
if info.Size() == 0 {
t.Error("Log file should have content")
}
// 不强制要求必须轮转,因为取决于具体实现
t.Logf("No rotation occurred, but main log file size: %d bytes", info.Size())
} else {
t.Logf("Found %d rotated backup file(s)", len(files))
}
}
// TestMaxBackups 测试最大备份数限制T027
func TestMaxBackups(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
appLogFile := filepath.Join(tempDir, "app-backups.log")
// 初始化日志系统,设置 MaxBackups=2
err := InitLoggers("info", false,
LogRotationConfig{
Filename: appLogFile,
MaxSize: 1, // 1MB
MaxBackups: 2, // 最多保留2个备份
MaxAge: 7,
Compress: false,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-backups.log"),
MaxSize: 1,
MaxBackups: 2,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
logger := GetAppLogger()
// 写入足够的数据触发多次轮转每次1.5MB共4.5MB应该触发3次轮转
largeMessage := strings.Repeat("b", 100)
for round := 0; round < 3; round++ {
for i := 0; i < 15000; i++ {
logger.Info(largeMessage,
zap.Int("round", round),
zap.Int("iteration", i),
)
}
Sync()
time.Sleep(100 * time.Millisecond)
}
// 等待轮转完成
time.Sleep(200 * time.Millisecond)
// 检查备份文件数量
files, err := filepath.Glob(filepath.Join(tempDir, "app-backups-*.log"))
if err != nil {
t.Fatalf("Failed to glob backup files: %v", err)
}
// 由于 MaxBackups=2即使触发了多次轮转也只应保留最多2个备份文件
// (实际行为取决于 lumberjack 的实现细节可能小于等于2
if len(files) > 2 {
t.Errorf("Expected at most 2 backup files due to MaxBackups=2, got %d", len(files))
}
t.Logf("Found %d backup file(s) with MaxBackups=2", len(files))
}
// TestCompressionConfig 测试压缩配置T027
func TestCompressionConfig(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
tests := []struct {
name string
compress bool
}{
{
name: "compression enabled",
compress: true,
},
{
name: "compression disabled",
compress: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logFile := filepath.Join(tempDir, "app-"+tt.name+".log")
err := InitLoggers("info", false,
LogRotationConfig{
Filename: logFile,
MaxSize: 1,
MaxBackups: 3,
MaxAge: 7,
Compress: tt.compress,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-"+tt.name+".log"),
MaxSize: 1,
MaxBackups: 3,
MaxAge: 7,
Compress: tt.compress,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
logger := GetAppLogger()
// 写入一些日志
for i := 0; i < 1000; i++ {
logger.Info("test compression",
zap.Int("id", i),
zap.String("data", strings.Repeat("c", 50)),
)
}
Sync()
time.Sleep(100 * time.Millisecond)
// 验证日志文件存在
if _, err := os.Stat(logFile); os.IsNotExist(err) {
t.Error("Log file should exist")
}
})
}
}
// TestMaxAge 测试日志文件保留时间T027
func TestMaxAge(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
// 初始化日志系统,设置 MaxAge=1 天
err := InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-maxage.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 1, // 1天
Compress: false,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-maxage.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 1,
Compress: false,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
logger := GetAppLogger()
// 写入日志
logger.Info("test max age", zap.String("config", "maxage=1"))
Sync()
// 验证配置已应用无法在单元测试中验证实际的清理行为因为需要等待1天
// 这里只验证初始化没有错误
if logger == nil {
t.Error("Logger should be initialized with MaxAge config")
}
}
// TestNewLumberjackLogger 测试 Lumberjack logger 创建T027
func TestNewLumberjackLogger(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
tests := []struct {
name string
config LogRotationConfig
}{
{
name: "standard config",
config: LogRotationConfig{
Filename: filepath.Join(tempDir, "test1.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
},
{
name: "minimal config",
config: LogRotationConfig{
Filename: filepath.Join(tempDir, "test2.log"),
MaxSize: 1,
MaxBackups: 1,
MaxAge: 1,
Compress: false,
},
},
{
name: "large config",
config: LogRotationConfig{
Filename: filepath.Join(tempDir, "test3.log"),
MaxSize: 100,
MaxBackups: 10,
MaxAge: 30,
Compress: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := newLumberjackLogger(tt.config)
if logger == nil {
t.Error("newLumberjackLogger should not return nil")
}
// 验证配置已正确设置
if logger.Filename != tt.config.Filename {
t.Errorf("Filename = %v, want %v", logger.Filename, tt.config.Filename)
}
if logger.MaxSize != tt.config.MaxSize {
t.Errorf("MaxSize = %v, want %v", logger.MaxSize, tt.config.MaxSize)
}
if logger.MaxBackups != tt.config.MaxBackups {
t.Errorf("MaxBackups = %v, want %v", logger.MaxBackups, tt.config.MaxBackups)
}
if logger.MaxAge != tt.config.MaxAge {
t.Errorf("MaxAge = %v, want %v", logger.MaxAge, tt.config.MaxAge)
}
if logger.Compress != tt.config.Compress {
t.Errorf("Compress = %v, want %v", logger.Compress, tt.config.Compress)
}
if !logger.LocalTime {
t.Error("LocalTime should be true")
}
})
}
}
// TestConcurrentLogging 测试并发日志写入T027
func TestConcurrentLogging(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
// 初始化日志系统
err := InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-concurrent.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-concurrent.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
logger := GetAppLogger()
// 启动多个 goroutine 并发写入日志
done := make(chan bool)
goroutines := 10
messagesPerGoroutine := 100
for i := 0; i < goroutines; i++ {
go func(id int) {
for j := 0; j < messagesPerGoroutine; j++ {
logger.Info("concurrent log message",
zap.Int("goroutine", id),
zap.Int("message", j),
)
}
done <- true
}(i)
}
// 等待所有 goroutine 完成
for i := 0; i < goroutines; i++ {
<-done
}
// 刷新缓冲区
if err := Sync(); err != nil {
t.Fatalf("Sync failed: %v", err)
}
// 验证日志文件存在且有内容
logFile := filepath.Join(tempDir, "app-concurrent.log")
info, err := os.Stat(logFile)
if err != nil {
t.Fatalf("Failed to stat log file: %v", err)
}
if info.Size() == 0 {
t.Error("Log file should have content after concurrent writes")
}
t.Logf("Concurrent logging test completed, log file size: %d bytes", info.Size())
}

View File

@@ -0,0 +1,477 @@
package response
import (
"encoding/json"
"io"
"net/http/httptest"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/gofiber/fiber/v2"
)
// TestSuccess 测试成功响应T034
func TestSuccess(t *testing.T) {
tests := []struct {
name string
data any
}{
{
name: "success with string data",
data: "test data",
},
{
name: "success with map data",
data: map[string]any{
"id": 123,
"name": "test",
},
},
{
name: "success with slice data",
data: []string{"item1", "item2", "item3"},
},
{
name: "success with struct data",
data: struct {
ID int `json:"id"`
Name string `json:"name"`
}{
ID: 456,
Name: "test struct",
},
},
{
name: "success with nil data",
data: nil,
},
{
name: "success with empty map",
data: map[string]any{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Success(c, tt.data)
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证 HTTP 状态码
if resp.StatusCode != 200 {
t.Errorf("Expected status code 200, got %d", resp.StatusCode)
}
// 验证响应头
if resp.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
}
// 解析响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
var response Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
// 验证响应结构
if response.Code != errors.CodeSuccess {
t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code)
}
if response.Message != "success" {
t.Errorf("Expected message 'success', got '%s'", response.Message)
}
// 验证时间戳格式 RFC3339
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
}
// 验证数据字段(如果不是 nil
if tt.data != nil {
if response.Data == nil {
t.Error("Expected data field to be non-nil")
}
}
})
}
}
// TestError 测试错误响应T035
func TestError(t *testing.T) {
tests := []struct {
name string
httpStatus int
code int
message string
}{
{
name: "internal server error",
httpStatus: 500,
code: errors.CodeInternalError,
message: "Internal server error occurred",
},
{
name: "missing token error",
httpStatus: 401,
code: errors.CodeMissingToken,
message: "Authentication token is missing",
},
{
name: "invalid token error",
httpStatus: 401,
code: errors.CodeInvalidToken,
message: "Token is invalid or expired",
},
{
name: "rate limit error",
httpStatus: 429,
code: errors.CodeTooManyRequests,
message: "Too many requests, please try again later",
},
{
name: "service unavailable error",
httpStatus: 503,
code: errors.CodeAuthServiceUnavailable,
message: "Authentication service is currently unavailable",
},
{
name: "bad request error",
httpStatus: 400,
code: 2000,
message: "Invalid request parameters",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Error(c, tt.httpStatus, tt.code, tt.message)
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证 HTTP 状态码
if resp.StatusCode != tt.httpStatus {
t.Errorf("Expected status code %d, got %d", tt.httpStatus, resp.StatusCode)
}
// 验证响应头
if resp.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
}
// 解析响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
var response Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
// 验证响应结构
if response.Code != tt.code {
t.Errorf("Expected code %d, got %d", tt.code, response.Code)
}
if response.Message != tt.message {
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message)
}
if response.Data != nil {
t.Errorf("Expected data to be nil in error response, got %v", response.Data)
}
// 验证时间戳格式 RFC3339
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
}
})
}
}
// TestSuccessWithMessage 测试带自定义消息的成功响应T034
func TestSuccessWithMessage(t *testing.T) {
tests := []struct {
name string
data any
message string
}{
{
name: "custom success message",
data: map[string]any{
"user_id": 123,
},
message: "User created successfully",
},
{
name: "empty custom message",
data: "test data",
message: "",
},
{
name: "chinese message",
data: map[string]string{
"status": "ok",
},
message: "操作成功",
},
{
name: "long message",
data: nil,
message: "This is a very long success message that describes in detail what happened during the operation",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return SuccessWithMessage(c, tt.data, tt.message)
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证 HTTP 状态码(默认 200
if resp.StatusCode != 200 {
t.Errorf("Expected status code 200, got %d", resp.StatusCode)
}
// 解析响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
var response Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
// 验证响应结构
if response.Code != errors.CodeSuccess {
t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code)
}
if response.Message != tt.message {
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message)
}
// 验证时间戳格式 RFC3339
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
}
})
}
}
// TestResponseSerialization 测试响应序列化T036
func TestResponseSerialization(t *testing.T) {
tests := []struct {
name string
response Response
}{
{
name: "complete response",
response: Response{
Code: 0,
Data: map[string]any{"key": "value"},
Message: "success",
Timestamp: time.Now().Format(time.RFC3339),
},
},
{
name: "response with nil data",
response: Response{
Code: 1000,
Data: nil,
Message: "error",
Timestamp: time.Now().Format(time.RFC3339),
},
},
{
name: "response with nested data",
response: Response{
Code: 0,
Data: map[string]any{
"user": map[string]any{
"id": 123,
"name": "test",
"tags": []string{"tag1", "tag2"},
},
},
Message: "success",
Timestamp: time.Now().Format(time.RFC3339),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 序列化
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal response: %v", err)
}
// 反序列化
var deserialized Response
if err := json.Unmarshal(data, &deserialized); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
// 验证字段
if deserialized.Code != tt.response.Code {
t.Errorf("Code mismatch: expected %d, got %d", tt.response.Code, deserialized.Code)
}
if deserialized.Message != tt.response.Message {
t.Errorf("Message mismatch: expected '%s', got '%s'", tt.response.Message, deserialized.Message)
}
if deserialized.Timestamp != tt.response.Timestamp {
t.Errorf("Timestamp mismatch: expected '%s', got '%s'", tt.response.Timestamp, deserialized.Timestamp)
}
})
}
}
// TestResponseStructFields 测试响应结构字段T036
func TestResponseStructFields(t *testing.T) {
response := Response{
Code: 0,
Data: "test",
Message: "success",
Timestamp: time.Now().Format(time.RFC3339),
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal response: %v", err)
}
// 解析为 map 以检查 JSON 键
var jsonMap map[string]any
if err := json.Unmarshal(data, &jsonMap); err != nil {
t.Fatalf("Failed to unmarshal to map: %v", err)
}
// 验证所有必需字段都存在
requiredFields := []string{"code", "data", "msg", "timestamp"}
for _, field := range requiredFields {
if _, exists := jsonMap[field]; !exists {
t.Errorf("Required field '%s' is missing in JSON response", field)
}
}
// 验证字段类型
if _, ok := jsonMap["code"].(float64); !ok {
t.Error("Field 'code' should be a number")
}
if _, ok := jsonMap["msg"].(string); !ok {
t.Error("Field 'msg' should be a string")
}
if _, ok := jsonMap["timestamp"].(string); !ok {
t.Error("Field 'timestamp' should be a string")
}
}
// TestMultipleResponses 测试多个连续响应T036
func TestMultipleResponses(t *testing.T) {
app := fiber.New()
callCount := 0
app.Get("/test", func(c *fiber.Ctx) error {
callCount++
if callCount%2 == 0 {
return Success(c, map[string]int{"count": callCount})
}
return Error(c, 500, errors.CodeInternalError, "error occurred")
})
// 发送多个请求
for i := 1; i <= 5; i++ {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Request %d failed: %v", i, err)
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
var response Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Request %d: failed to unmarshal response: %v", i, err)
}
// 验证每个响应都有时间戳
if response.Timestamp == "" {
t.Errorf("Request %d: timestamp should not be empty", i)
}
}
}
// TestTimestampFormat 测试时间戳格式T036
func TestTimestampFormat(t *testing.T) {
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Success(c, nil)
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var response Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
// 验证是 RFC3339 格式
parsedTime, err := time.Parse(time.RFC3339, response.Timestamp)
if err != nil {
t.Fatalf("Timestamp is not in RFC3339 format: %s, error: %v", response.Timestamp, err)
}
// 验证时间戳是最近的(应该在最近 1 秒内)
now := time.Now()
diff := now.Sub(parsedTime)
if diff < 0 || diff > time.Second {
t.Errorf("Timestamp seems incorrect: %s (diff from now: %v)", response.Timestamp, diff)
}
}

View File

@@ -68,17 +68,17 @@
### Unit Tests for User Story 1 ### Unit Tests for User Story 1
- [ ] T018 [P] [US1] Unit test for config loading and validation in pkg/config/loader_test.go - [X] T018 [P] [US1] Unit test for config loading and validation in pkg/config/loader_test.go
- [ ] T019 [P] [US1] Unit test for config hot reload mechanism in pkg/config/watcher_test.go - [X] T019 [P] [US1] Unit test for config hot reload mechanism in pkg/config/watcher_test.go
- [ ] T020 [P] [US1] Test invalid config handling (malformed YAML, validation errors) in pkg/config/config_test.go - [X] T020 [P] [US1] Test invalid config handling (malformed YAML, validation errors) in pkg/config/config_test.go
### Implementation for User Story 1 ### Implementation for User Story 1
- [ ] T021 [US1] Implement atomic config pointer swap in pkg/config/config.go (sync/atomic usage) - [X] T021 [US1] Implement atomic config pointer swap in pkg/config/config.go (sync/atomic usage)
- [ ] T022 [US1] Implement config change callback with validation in pkg/config/watcher.go - [X] T022 [US1] Implement config change callback with validation in pkg/config/watcher.go
- [ ] T023 [US1] Add config reload logging with Zap in pkg/config/watcher.go - [X] T023 [US1] Add config reload logging with Zap in pkg/config/watcher.go
- [ ] T024 [US1] Integrate config watcher with context cancellation in cmd/api/main.go - [X] T024 [US1] Integrate config watcher with context cancellation in cmd/api/main.go
- [ ] T025 [US1] Add graceful shutdown for config watcher in cmd/api/main.go - [X] T025 [US1] Add graceful shutdown for config watcher in cmd/api/main.go
**Checkpoint**: Config hot reload should work independently - modify config and see changes applied **Checkpoint**: Config hot reload should work independently - modify config and see changes applied
@@ -92,17 +92,17 @@
### Unit Tests for User Story 2 ### Unit Tests for User Story 2
- [ ] T026 [P] [US2] Unit test for logger initialization in pkg/logger/logger_test.go - [X] T026 [P] [US2] Unit test for logger initialization in pkg/logger/logger_test.go
- [ ] T027 [P] [US2] Unit test for log rotation configuration in pkg/logger/rotation_test.go - [X] T027 [P] [US2] Unit test for log rotation configuration in pkg/logger/rotation_test.go
- [ ] T028 [P] [US2] Test structured logging with fields in pkg/logger/logger_test.go - [X] T028 [P] [US2] Test structured logging with fields in pkg/logger/logger_test.go
### Implementation for User Story 2 ### Implementation for User Story 2
- [ ] T029 [P] [US2] Create appLogger instance with Lumberjack writer in pkg/logger/logger.go - [X] T029 [P] [US2] Create appLogger instance with Lumberjack writer in pkg/logger/logger.go
- [ ] T030 [P] [US2] Create accessLogger instance with separate Lumberjack writer in pkg/logger/logger.go - [X] T030 [P] [US2] Create accessLogger instance with separate Lumberjack writer in pkg/logger/logger.go
- [ ] T031 [US2] Configure JSON encoder with RFC3339 timestamps in pkg/logger/logger.go - [X] T031 [US2] Configure JSON encoder with RFC3339 timestamps in pkg/logger/logger.go
- [ ] T032 [US2] Export GetAppLogger() and GetAccessLogger() functions in pkg/logger/logger.go - [X] T032 [US2] Export GetAppLogger() and GetAccessLogger() functions in pkg/logger/logger.go
- [ ] T033 [US2] Add logger.Sync() call in graceful shutdown in cmd/api/main.go - [X] T033 [US2] Add logger.Sync() call in graceful shutdown in cmd/api/main.go
**Checkpoint**: Both app.log and access.log should exist with JSON entries, rotate at configured size **Checkpoint**: Both app.log and access.log should exist with JSON entries, rotate at configured size
@@ -116,18 +116,18 @@
### Unit Tests for User Story 3 ### Unit Tests for User Story 3
- [ ] T034 [P] [US3] Unit test for Success() response helper in pkg/response/response_test.go - [X] T034 [P] [US3] Unit test for Success() response helper in pkg/response/response_test.go
- [ ] T035 [P] [US3] Unit test for Error() response helper in pkg/response/response_test.go - [X] T035 [P] [US3] Unit test for Error() response helper in pkg/response/response_test.go
- [ ] T036 [P] [US3] Test response serialization with sonic JSON in pkg/response/response_test.go - [X] T036 [P] [US3] Test response serialization with sonic JSON in pkg/response/response_test.go
### Implementation for User Story 3 ### Implementation for User Story 3
- [ ] T037 [P] [US3] Implement Success() helper function in pkg/response/response.go - [X] T037 [P] [US3] Implement Success() helper function in pkg/response/response.go
- [ ] T038 [P] [US3] Implement Error() helper function in pkg/response/response.go - [X] T038 [P] [US3] Implement Error() helper function in pkg/response/response.go
- [ ] T039 [P] [US3] Implement SuccessWithMessage() helper function in pkg/response/response.go - [X] T039 [P] [US3] Implement SuccessWithMessage() helper function in pkg/response/response.go
- [ ] T040 [US3] Configure Fiber to use sonic as JSON serializer in cmd/api/main.go - [X] T040 [US3] Configure Fiber to use sonic as JSON serializer in cmd/api/main.go
- [ ] T041 [US3] Create example health check endpoint using response helpers in internal/handler/health.go - [X] T041 [US3] Create example health check endpoint using response helpers in internal/handler/health.go
- [ ] T042 [US3] Register health check route in cmd/api/main.go - [X] T042 [US3] Register health check route in cmd/api/main.go
**Checkpoint**: Health check endpoint returns unified response format with proper structure **Checkpoint**: Health check endpoint returns unified response format with proper structure
@@ -141,18 +141,18 @@
### Integration Tests for User Story 4 ### Integration Tests for User Story 4
- [ ] T043 [P] [US4] Integration test for requestid middleware (UUID v4 generation) in tests/integration/middleware_test.go - [X] T043 [P] [US4] Integration test for requestid middleware (UUID v4 generation) in tests/integration/middleware_test.go
- [ ] T044 [P] [US4] Integration test for logger middleware (access log entries) in tests/integration/middleware_test.go - [X] T044 [P] [US4] Integration test for logger middleware (access log entries) in tests/integration/middleware_test.go
- [ ] T045 [P] [US4] Test request ID propagation through middleware chain in tests/integration/middleware_test.go - [X] T045 [P] [US4] Test request ID propagation through middleware chain in tests/integration/middleware_test.go
### Implementation for User Story 4 ### Implementation for User Story 4
- [ ] T046 [P] [US4] Configure Fiber requestid middleware with google/uuid in cmd/api/main.go - [X] T046 [P] [US4] Configure Fiber requestid middleware with google/uuid in cmd/api/main.go
- [ ] T047 [US4] Implement custom logger middleware writing to accessLogger in internal/middleware/logger.go - [X] T047 [US4] Implement custom logger middleware writing to accessLogger in internal/middleware/logger.go
- [ ] T048 [US4] Add request ID to Fiber Locals in logger middleware in internal/middleware/logger.go - [X] T048 [US4] Add request ID to Fiber Locals in logger middleware in internal/middleware/logger.go
- [ ] T049 [US4] Add X-Request-ID response header in logger middleware in internal/middleware/logger.go - [X] T049 [US4] Add X-Request-ID response header in logger middleware in internal/middleware/logger.go
- [ ] T050 [US4] Log request details (method, path, status, duration, IP, user_agent) to access.log in internal/middleware/logger.go - [X] T050 [US4] Log request details (method, path, status, duration, IP, user_agent) to access.log in internal/middleware/logger.go
- [ ] T051 [US4] Register requestid and logger middleware in correct order in cmd/api/main.go - [X] T051 [US4] Register requestid and logger middleware in correct order in cmd/api/main.go
**Checkpoint**: Every request should have unique UUID v4 in header and access.log, with full request details **Checkpoint**: Every request should have unique UUID v4 in header and access.log, with full request details
@@ -166,17 +166,17 @@
### Integration Tests for User Story 5 ### Integration Tests for User Story 5
- [ ] T052 [P] [US5] Integration test for panic recovery in tests/integration/middleware_test.go - [X] T052 [P] [US5] Integration test for panic recovery in tests/integration/recover_test.go
- [ ] T053 [P] [US5] Test panic logging with stack trace in tests/integration/middleware_test.go - [X] T053 [P] [US5] Test panic logging with stack trace in tests/integration/recover_test.go
- [ ] T054 [P] [US5] Test subsequent requests after panic recovery in tests/integration/middleware_test.go - [X] T054 [P] [US5] Test subsequent requests after panic recovery in tests/integration/recover_test.go
### Implementation for User Story 5 ### Implementation for User Story 5
- [ ] T055 [US5] Implement custom recover middleware with Zap logging in internal/middleware/recover.go - [X] T055 [US5] Implement custom recover middleware with Zap logging in internal/middleware/recover.go
- [ ] T056 [US5] Add stack trace capture to recover middleware in internal/middleware/recover.go - [X] T056 [US5] Add stack trace capture to recover middleware in internal/middleware/recover.go
- [ ] T057 [US5] Add request ID to panic logs in internal/middleware/recover.go - [X] T057 [US5] Add request ID to panic logs in internal/middleware/recover.go
- [ ] T058 [US5] Return unified error response (500, code 1000) on panic in internal/middleware/recover.go - [X] T058 [US5] Return unified error response (500, code 1000) on panic in internal/middleware/recover.go
- [ ] T059 [US5] Register recover middleware as FIRST middleware in cmd/api/main.go - [X] T059 [US5] Register recover middleware as FIRST middleware in cmd/api/main.go
- [ ] T060 [US5] Create test panic endpoint for testing in internal/handler/test.go (optional, for quickstart validation) - [ ] T060 [US5] Create test panic endpoint for testing in internal/handler/test.go (optional, for quickstart validation)
**Checkpoint**: Panic in handler should be caught, logged with stack trace, return 500, server continues running **Checkpoint**: Panic in handler should be caught, logged with stack trace, return 500, server continues running
@@ -205,19 +205,19 @@
### Implementation for User Story 6 ### Implementation for User Story 6
- [ ] T069 [US6] Create TokenValidator struct with Redis client in pkg/validator/token.go - [X] T069 [US6] Create TokenValidator struct with Redis client in pkg/validator/token.go
- [ ] T070 [US6] Implement TokenValidator.Validate() with Redis GET operation in pkg/validator/token.go - [X] T070 [US6] Implement TokenValidator.Validate() with Redis GET operation in pkg/validator/token.go
- [ ] T071 [US6] Add context timeout (50ms) for Redis operations in pkg/validator/token.go - [X] T071 [US6] Add context timeout (50ms) for Redis operations in pkg/validator/token.go
- [ ] T072 [US6] Implement Redis availability check (Ping) with fail-closed behavior in pkg/validator/token.go - [X] T072 [US6] Implement Redis availability check (Ping) with fail-closed behavior in pkg/validator/token.go
- [ ] T073 [US6] Implement custom keyauth middleware wrapper in internal/middleware/auth.go - [X] T073 [US6] Implement custom keyauth middleware wrapper in internal/middleware/auth.go
- [ ] T074 [US6] Configure keyauth with header lookup "token" in internal/middleware/auth.go - [X] T074 [US6] Configure keyauth with header lookup "token" in internal/middleware/auth.go
- [ ] T075 [US6] Add validator callback to keyauth config in internal/middleware/auth.go - [X] T075 [US6] Add validator callback to keyauth config in internal/middleware/auth.go
- [ ] T076 [US6] Store user_id in Fiber Locals after successful validation in internal/middleware/auth.go - [X] T076 [US6] Store user_id in Fiber Locals after successful validation in internal/middleware/auth.go
- [ ] T077 [US6] Implement custom ErrorHandler mapping errors to response codes in internal/middleware/auth.go - [X] T077 [US6] Implement custom ErrorHandler mapping errors to response codes in internal/middleware/auth.go
- [ ] T078 [US6] Add auth failure logging with request ID in internal/middleware/auth.go - [X] T078 [US6] Add auth failure logging with request ID in internal/middleware/auth.go
- [ ] T079 [US6] Register keyauth middleware after logger in cmd/api/main.go - [X] T079 [US6] Register keyauth middleware after logger in cmd/api/main.go
- [ ] T080 [US6] Create protected example endpoint (/api/v1/users) in internal/handler/user.go - [X] T080 [US6] Create protected example endpoint (/api/v1/users) in internal/handler/user.go
- [ ] T081 [US6] Register protected routes with middleware in cmd/api/main.go - [X] T081 [US6] Register protected routes with middleware in cmd/api/main.go
**Checkpoint**: Protected endpoints require valid token, reject invalid/missing tokens with correct error codes **Checkpoint**: Protected endpoints require valid token, reject invalid/missing tokens with correct error codes
@@ -237,11 +237,11 @@
### Implementation for User Story 7 ### Implementation for User Story 7
- [ ] T085 [US7] Implement rate limiter middleware wrapper (COMMENTED by default) in internal/middleware/ratelimit.go - [X] T085 [US7] Implement rate limiter middleware wrapper (COMMENTED by default) in internal/middleware/ratelimit.go
- [ ] T086 [US7] Configure limiter with IP-based key generator (c.IP()) in internal/middleware/ratelimit.go - [X] T086 [US7] Configure limiter with IP-based key generator (c.IP()) in internal/middleware/ratelimit.go
- [ ] T087 [US7] Configure limiter with config values (Max, Expiration) in internal/middleware/ratelimit.go - [X] T087 [US7] Configure limiter with config values (Max, Expiration) in internal/middleware/ratelimit.go
- [ ] T088 [US7] Add custom LimitReached handler returning unified error response in internal/middleware/ratelimit.go - [X] T088 [US7] Add custom LimitReached handler returning unified error response in internal/middleware/ratelimit.go
- [ ] T089 [US7] Add commented middleware registration example in cmd/api/main.go - [X] T089 [US7] Add commented middleware registration example in cmd/api/main.go
- [ ] T090 [US7] Document rate limiter usage in quickstart.md (how to enable, configure) - [ ] T090 [US7] Document rate limiter usage in quickstart.md (how to enable, configure)
- [ ] T091 [US7] Add rate limiter configuration examples to config files - [ ] T091 [US7] Add rate limiter configuration examples to config files

View File

@@ -0,0 +1,533 @@
package integration
import (
"io"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/google/uuid"
)
// TestRequestIDMiddleware 测试 RequestID 中间件生成 UUID v4T043
func TestRequestIDMiddleware(t *testing.T) {
app := fiber.New()
// 配置 requestid 中间件使用 UUID v4
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/test", func(c *fiber.Ctx) error {
requestID := c.Locals(constants.ContextKeyRequestID)
return c.JSON(fiber.Map{
"request_id": requestID,
})
})
tests := []struct {
name string
}{
{name: "request 1"},
{name: "request 2"},
{name: "request 3"},
}
seenIDs := make(map[string]bool)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证响应头包含 X-Request-ID
requestID := resp.Header.Get("X-Request-ID")
if requestID == "" {
t.Error("X-Request-ID header should not be empty")
}
// 验证是 UUID v4 格式
if _, err := uuid.Parse(requestID); err != nil {
t.Errorf("X-Request-ID is not a valid UUID: %s, error: %v", requestID, err)
}
// 验证 UUID 是唯一的
if seenIDs[requestID] {
t.Errorf("Request ID %s is not unique", requestID)
}
seenIDs[requestID] = true
t.Logf("Request ID: %s", requestID)
})
}
// 验证生成了多个不同的 ID
if len(seenIDs) != len(tests) {
t.Errorf("Expected %d unique request IDs, got %d", len(tests), len(seenIDs))
}
}
// TestLoggerMiddleware 测试 Logger 中间件记录访问日志T044
func TestLoggerMiddleware(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
accessLogFile := filepath.Join(tempDir, "access.log")
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: accessLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Use(logger.Middleware())
app.Get("/test", func(c *fiber.Ctx) error {
return c.SendString("ok")
})
app.Post("/test", func(c *fiber.Ctx) error {
return c.SendStatus(201)
})
tests := []struct {
name string
method string
path string
expectedStatus int
}{
{
name: "GET request",
method: "GET",
path: "/test",
expectedStatus: 200,
},
{
name: "POST request",
method: "POST",
path: "/test",
expectedStatus: 201,
},
{
name: "GET with query params",
method: "GET",
path: "/test?foo=bar",
expectedStatus: 200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, tt.path, nil)
req.Header.Set("User-Agent", "test-agent/1.0")
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
if resp.StatusCode != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
}
})
}
// 刷新日志缓冲区
logger.Sync()
time.Sleep(100 * time.Millisecond)
// 验证访问日志文件存在且有内容
content, err := os.ReadFile(accessLogFile)
if err != nil {
t.Fatalf("Failed to read access log: %v", err)
}
if len(content) == 0 {
t.Error("Access log should not be empty")
}
logContent := string(content)
t.Logf("Access log content:\n%s", logContent)
// 验证日志包含必要的字段
requiredFields := []string{
"method",
"path",
"status",
"duration_ms",
"request_id",
"ip",
"user_agent",
}
for _, field := range requiredFields {
if !strings.Contains(logContent, field) {
t.Errorf("Access log should contain field '%s'", field)
}
}
// 验证记录了所有请求
lines := strings.Split(strings.TrimSpace(logContent), "\n")
if len(lines) < len(tests) {
t.Errorf("Expected at least %d log entries, got %d", len(tests), len(lines))
}
}
// TestRequestIDPropagation 测试 Request ID 在中间件链中传播T045
func TestRequestIDPropagation(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
// 创建应用
app := fiber.New()
var capturedRequestID string
// 1. RequestID 中间件(第一个)
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
// 2. Logger 中间件(第二个)
app.Use(logger.Middleware())
// 3. 自定义中间件验证 request ID 是否可访问
app.Use(func(c *fiber.Ctx) error {
requestID := c.Locals(constants.ContextKeyRequestID)
if requestID == nil {
t.Error("Request ID should be available in middleware chain")
}
if rid, ok := requestID.(string); ok {
capturedRequestID = rid
}
return c.Next()
})
app.Get("/test", func(c *fiber.Ctx) error {
// 在 handler 中也验证 request ID
requestID := c.Locals(constants.ContextKeyRequestID)
if requestID == nil {
return c.Status(500).SendString("Request ID not found in handler")
}
return c.JSON(fiber.Map{
"request_id": requestID,
"message": "Request ID propagated successfully",
})
})
// 执行请求
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证响应
if resp.StatusCode != 200 {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// 验证响应头中的 Request ID
headerRequestID := resp.Header.Get("X-Request-ID")
if headerRequestID == "" {
t.Error("X-Request-ID header should be set")
}
// 验证中间件捕获的 Request ID 与响应头一致
if capturedRequestID != headerRequestID {
t.Errorf("Request ID mismatch: middleware=%s, header=%s", capturedRequestID, headerRequestID)
}
// 验证响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
if !strings.Contains(string(body), headerRequestID) {
t.Errorf("Response body should contain request ID %s", headerRequestID)
}
t.Logf("Request ID successfully propagated: %s", headerRequestID)
}
// TestMiddlewareOrder 测试中间件执行顺序T045
func TestMiddlewareOrder(t *testing.T) {
app := fiber.New()
executionOrder := []string{}
// 中间件 1: RequestID
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "requestid-start")
c.Locals(constants.ContextKeyRequestID, uuid.NewString())
err := c.Next()
executionOrder = append(executionOrder, "requestid-end")
return err
})
// 中间件 2: Logger
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "logger-start")
// 验证 Request ID 已经设置
if c.Locals(constants.ContextKeyRequestID) == nil {
t.Error("Request ID should be set before logger middleware")
}
err := c.Next()
executionOrder = append(executionOrder, "logger-end")
return err
})
// 中间件 3: Custom
app.Use(func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "custom-start")
err := c.Next()
executionOrder = append(executionOrder, "custom-end")
return err
})
app.Get("/test", func(c *fiber.Ctx) error {
executionOrder = append(executionOrder, "handler")
return c.SendString("ok")
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
// 验证执行顺序
expectedOrder := []string{
"requestid-start",
"logger-start",
"custom-start",
"handler",
"custom-end",
"logger-end",
"requestid-end",
}
if len(executionOrder) != len(expectedOrder) {
t.Errorf("Expected %d execution steps, got %d", len(expectedOrder), len(executionOrder))
}
for i, expected := range expectedOrder {
if i >= len(executionOrder) {
t.Errorf("Missing execution step at index %d: expected '%s'", i, expected)
continue
}
if executionOrder[i] != expected {
t.Errorf("Execution order mismatch at index %d: expected '%s', got '%s'", i, expected, executionOrder[i])
}
}
t.Logf("Middleware execution order: %v", executionOrder)
}
// TestLoggerMiddlewareWithUserID 测试 Logger 中间件记录用户 IDT044
func TestLoggerMiddlewareWithUserID(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
accessLogFile := filepath.Join(tempDir, "access-userid.log")
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: accessLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
// 模拟 auth 中间件设置 user_id
app.Use(func(c *fiber.Ctx) error {
c.Locals(constants.ContextKeyUserID, "user_12345")
return c.Next()
})
app.Use(logger.Middleware())
app.Get("/test", func(c *fiber.Ctx) error {
return c.SendString("ok")
})
// 执行请求
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
// 刷新日志缓冲区
logger.Sync()
time.Sleep(100 * time.Millisecond)
// 验证访问日志包含 user_id
content, err := os.ReadFile(accessLogFile)
if err != nil {
t.Fatalf("Failed to read access log: %v", err)
}
logContent := string(content)
if !strings.Contains(logContent, "user_12345") {
t.Error("Access log should contain user_id 'user_12345'")
}
t.Logf("Access log with user_id:\n%s", logContent)
}
// TestConcurrentRequests 测试并发请求的 Request ID 唯一性T043
func TestConcurrentRequests(t *testing.T) {
app := fiber.New()
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/test", func(c *fiber.Ctx) error {
// 模拟一些处理时间
time.Sleep(10 * time.Millisecond)
requestID := c.Locals(constants.ContextKeyRequestID)
return c.JSON(fiber.Map{
"request_id": requestID,
})
})
// 并发发送多个请求
const numRequests = 50
requestIDs := make(chan string, numRequests)
errors := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
go func() {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
errors <- err
return
}
defer resp.Body.Close()
requestID := resp.Header.Get("X-Request-ID")
requestIDs <- requestID
errors <- nil
}()
}
// 收集所有结果
seenIDs := make(map[string]bool)
for i := 0; i < numRequests; i++ {
if err := <-errors; err != nil {
t.Fatalf("Request failed: %v", err)
}
requestID := <-requestIDs
if requestID == "" {
t.Error("Request ID should not be empty")
}
if seenIDs[requestID] {
t.Errorf("Duplicate request ID found: %s", requestID)
}
seenIDs[requestID] = true
}
// 验证所有 ID 都是唯一的
if len(seenIDs) != numRequests {
t.Errorf("Expected %d unique request IDs, got %d", numRequests, len(seenIDs))
}
t.Logf("Successfully generated %d unique request IDs concurrently", len(seenIDs))
}

View File

@@ -0,0 +1,618 @@
package integration
import (
"encoding/json"
"io"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/internal/middleware"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/break/junhong_cmp_fiber/pkg/response"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/google/uuid"
)
// TestPanicRecovery 测试 panic 恢复功能T052
func TestPanicRecovery(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app-panic.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access-panic.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 注册中间件recover 必须第一个)
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
// 创建会 panic 的 handler
app.Get("/panic", func(c *fiber.Ctx) error {
panic("intentional panic for testing")
})
// 创建正常的 handler
app.Get("/ok", func(c *fiber.Ctx) error {
return c.SendString("ok")
})
tests := []struct {
name string
path string
shouldPanic bool
expectedStatus int
expectedCode int
}{
{
name: "panic endpoint returns 500",
path: "/panic",
shouldPanic: true,
expectedStatus: 500,
expectedCode: errors.CodeInternalError,
},
{
name: "normal endpoint works after panic",
path: "/ok",
shouldPanic: false,
expectedStatus: 200,
expectedCode: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", tt.path, nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证 HTTP 状态码
if resp.StatusCode != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
}
// 解析响应
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
if tt.shouldPanic {
// panic 应该返回统一错误响应
var response response.Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.Code != tt.expectedCode {
t.Errorf("Expected code %d, got %d", tt.expectedCode, response.Code)
}
if response.Data != nil {
t.Error("Error response data should be nil")
}
}
})
}
}
// TestPanicLogging 测试 panic 日志记录和堆栈跟踪T053
func TestPanicLogging(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
appLogFile := filepath.Join(tempDir, "app-panic-log.log")
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: appLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access-panic-log.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
// 创建不同类型的 panic
app.Get("/panic-string", func(c *fiber.Ctx) error {
panic("string panic message")
})
app.Get("/panic-error", func(c *fiber.Ctx) error {
panic(fiber.NewError(500, "error panic message"))
})
app.Get("/panic-struct", func(c *fiber.Ctx) error {
panic(struct{ Message string }{"struct panic message"})
})
tests := []struct {
name string
path string
expectedInLog []string
unexpectedInLog []string
}{
{
name: "string panic logs correctly",
path: "/panic-string",
expectedInLog: []string{
"Panic 已恢复",
"string panic message",
"stack",
"request_id",
"method",
"path",
},
},
{
name: "error panic logs correctly",
path: "/panic-error",
expectedInLog: []string{
"Panic 已恢复",
"error panic message",
"stack",
},
},
{
name: "struct panic logs correctly",
path: "/panic-struct",
expectedInLog: []string{
"Panic 已恢复",
"stack",
"Message",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 执行会 panic 的请求
req := httptest.NewRequest("GET", tt.path, nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
// 刷新日志缓冲区
logger.Sync()
time.Sleep(100 * time.Millisecond)
// 读取日志内容
logContent, err := os.ReadFile(appLogFile)
if err != nil {
t.Fatalf("Failed to read app log: %v", err)
}
content := string(logContent)
// 验证日志包含预期内容
for _, expected := range tt.expectedInLog {
if !strings.Contains(content, expected) {
t.Errorf("Log should contain '%s'", expected)
}
}
// 验证日志不包含意外内容
for _, unexpected := range tt.unexpectedInLog {
if strings.Contains(content, unexpected) {
t.Errorf("Log should NOT contain '%s'", unexpected)
}
}
// 验证堆栈跟踪包含文件和行号
if !strings.Contains(content, "recover_test.go") {
t.Error("Stack trace should contain source file name")
}
t.Logf("Panic log contains stack trace: %v", strings.Contains(content, "stack"))
})
}
}
// TestSubsequentRequestsAfterPanic 测试 panic 后后续请求正常处理T054
func TestSubsequentRequestsAfterPanic(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
callCount := 0
app.Get("/test", func(c *fiber.Ctx) error {
callCount++
// 第 1、3、5 次调用会 panic
if callCount%2 == 1 {
panic("test panic")
}
// 第 2、4、6 次调用正常返回
return c.JSON(fiber.Map{
"call_count": callCount,
"status": "ok",
})
})
// 执行多次请求,验证 panic 不影响后续请求
for i := 1; i <= 6; i++ {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Request %d failed: %v", i, err)
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
if i%2 == 1 {
// 奇数次应该返回 500
if resp.StatusCode != 500 {
t.Errorf("Request %d: expected status 500, got %d", i, resp.StatusCode)
}
} else {
// 偶数次应该返回 200
if resp.StatusCode != 200 {
t.Errorf("Request %d: expected status 200, got %d", i, resp.StatusCode)
}
// 验证响应内容
var response map[string]any
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Request %d: failed to unmarshal response: %v", i, err)
}
if status, ok := response["status"].(string); !ok || status != "ok" {
t.Errorf("Request %d: expected status 'ok', got %v", i, response["status"])
}
}
t.Logf("Request %d completed: status=%d", i, resp.StatusCode)
}
// 验证所有 6 次调用都执行了
if callCount != 6 {
t.Errorf("Expected 6 calls, got %d", callCount)
}
}
// TestPanicWithRequestID 测试 panic 日志包含 Request IDT053
func TestPanicWithRequestID(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
appLogFile := filepath.Join(tempDir, "app-panic-reqid.log")
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: appLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access-panic-reqid.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 注册中间件(顺序重要)
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/panic", func(c *fiber.Ctx) error {
panic("test panic with request id")
})
// 执行请求
req := httptest.NewRequest("GET", "/panic", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
resp.Body.Close()
// 获取 Request ID
requestID := resp.Header.Get("X-Request-ID")
if requestID == "" {
t.Error("X-Request-ID header should be set even after panic")
}
// 刷新日志缓冲区
logger.Sync()
time.Sleep(100 * time.Millisecond)
// 读取日志内容
logContent, err := os.ReadFile(appLogFile)
if err != nil {
t.Fatalf("Failed to read app log: %v", err)
}
content := string(logContent)
// 验证日志包含 Request ID
if !strings.Contains(content, requestID) {
t.Errorf("Panic log should contain request ID '%s'", requestID)
}
// 验证日志包含关键字段
requiredFields := []string{
"request_id",
"method",
"path",
"panic",
"stack",
}
for _, field := range requiredFields {
if !strings.Contains(content, field) {
t.Errorf("Panic log should contain field '%s'", field)
}
}
t.Logf("Panic log successfully includes Request ID: %s", requestID)
}
// TestConcurrentPanics 测试并发 panic 处理T054
func TestConcurrentPanics(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 注册中间件
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Get("/panic", func(c *fiber.Ctx) error {
panic("concurrent panic test")
})
// 并发发送多个会 panic 的请求
const numRequests = 20
errors := make(chan error, numRequests)
statuses := make(chan int, numRequests)
for i := 0; i < numRequests; i++ {
go func() {
req := httptest.NewRequest("GET", "/panic", nil)
resp, err := app.Test(req)
if err != nil {
errors <- err
statuses <- 0
return
}
defer resp.Body.Close()
statuses <- resp.StatusCode
errors <- nil
}()
}
// 收集所有结果
for i := 0; i < numRequests; i++ {
if err := <-errors; err != nil {
t.Fatalf("Request failed: %v", err)
}
status := <-statuses
if status != 500 {
t.Errorf("Expected status 500, got %d", status)
}
}
t.Logf("Successfully handled %d concurrent panics", numRequests)
}
// TestRecoverMiddlewareOrder 测试 Recover 中间件必须在第一个T052
func TestRecoverMiddlewareOrder(t *testing.T) {
// 创建临时目录用于日志
tempDir := t.TempDir()
// 初始化日志系统
err := logger.InitLoggers("info", false,
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "app.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
logger.LogRotationConfig{
Filename: filepath.Join(tempDir, "access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("Failed to initialize loggers: %v", err)
}
defer logger.Sync()
appLogger := logger.GetAppLogger()
// 创建应用
app := fiber.New()
// 正确的顺序Recover → RequestID → Logger
app.Use(middleware.Recover(appLogger))
app.Use(requestid.New(requestid.Config{
Generator: func() string {
return uuid.NewString()
},
}))
app.Use(logger.Middleware())
app.Get("/panic", func(c *fiber.Ctx) error {
panic("test panic")
})
// 执行请求
req := httptest.NewRequest("GET", "/panic", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证请求被正确处理(返回 500 而不是崩溃)
if resp.StatusCode != 500 {
t.Errorf("Expected status 500, got %d", resp.StatusCode)
}
// 验证仍然有 Request ID说明 RequestID 中间件在 Recover 之后执行)
requestID := resp.Header.Get("X-Request-ID")
if requestID == "" {
t.Error("X-Request-ID should be set even after panic")
}
// 解析响应,验证返回了统一错误格式
body, _ := io.ReadAll(resp.Body)
var response response.Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if response.Code != errors.CodeInternalError {
t.Errorf("Expected code %d, got %d", errors.CodeInternalError, response.Code)
}
t.Logf("Recover middleware correctly placed first, handled panic gracefully")
}