备份一下
This commit is contained in:
@@ -3,7 +3,6 @@ package config
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
@@ -93,8 +92,9 @@ func (c *Config) Validate() error {
|
||||
if c.Redis.Address == "" {
|
||||
return fmt.Errorf("invalid configuration: redis.address: must be non-empty (current value: empty)")
|
||||
}
|
||||
if !strings.Contains(c.Redis.Address, ":") {
|
||||
return fmt.Errorf("invalid configuration: redis.address: invalid format (current value: %s, expected: HOST:PORT)", c.Redis.Address)
|
||||
// Port 验证(独立字段)
|
||||
if c.Redis.Port <= 0 || c.Redis.Port > 65535 {
|
||||
return fmt.Errorf("invalid configuration: redis.port: port number out of range (current value: %d, expected: 1-65535)", c.Redis.Port)
|
||||
}
|
||||
if c.Redis.DB < 0 || c.Redis.DB > 15 {
|
||||
return fmt.Errorf("invalid configuration: redis.db: database number out of range (current value: %d, expected: 0-15)", c.Redis.DB)
|
||||
|
||||
615
pkg/config/config_test.go
Normal file
615
pkg/config/config_test.go
Normal file
@@ -0,0 +1,615 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestConfig_Validate tests configuration validation rules
|
||||
func TestConfig_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
DB: 0,
|
||||
PoolSize: 10,
|
||||
MinIdleConns: 5,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
MaxBackups: 30,
|
||||
MaxAge: 30,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
MaxBackups: 90,
|
||||
MaxAge: 90,
|
||||
},
|
||||
},
|
||||
Middleware: MiddlewareConfig{
|
||||
RateLimiter: RateLimiterConfig{
|
||||
Max: 100,
|
||||
Expiration: 1 * time.Minute,
|
||||
Storage: "memory",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty server address",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: "",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "server.address",
|
||||
},
|
||||
{
|
||||
name: "read timeout too short",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "read_timeout",
|
||||
},
|
||||
{
|
||||
name: "read timeout too long",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 400 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "read_timeout",
|
||||
},
|
||||
{
|
||||
name: "write timeout out of range",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "write_timeout",
|
||||
},
|
||||
{
|
||||
name: "shutdown timeout too short",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "shutdown_timeout",
|
||||
},
|
||||
{
|
||||
name: "empty redis address",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "redis.address",
|
||||
},
|
||||
{
|
||||
name: "invalid redis port - too high",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 99999,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "redis.port",
|
||||
},
|
||||
{
|
||||
name: "invalid redis port - zero",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 0,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "redis.port",
|
||||
},
|
||||
{
|
||||
name: "redis db out of range",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
DB: 20,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "redis.db",
|
||||
},
|
||||
{
|
||||
name: "redis pool size too large",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 2000,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "pool_size",
|
||||
},
|
||||
{
|
||||
name: "min idle conns exceeds pool size",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
MinIdleConns: 20,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "min_idle_conns",
|
||||
},
|
||||
{
|
||||
name: "invalid log level",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "invalid",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "logging.level",
|
||||
},
|
||||
{
|
||||
name: "empty app log filename",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "app_log.filename",
|
||||
},
|
||||
{
|
||||
name: "app log max size out of range",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 2000,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "app_log.max_size",
|
||||
},
|
||||
{
|
||||
name: "invalid rate limiter storage",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
Middleware: MiddlewareConfig{
|
||||
RateLimiter: RateLimiterConfig{
|
||||
Max: 100,
|
||||
Storage: "invalid",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "rate_limiter.storage",
|
||||
},
|
||||
{
|
||||
name: "rate limiter max too high",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
Middleware: MiddlewareConfig{
|
||||
RateLimiter: RateLimiterConfig{
|
||||
Max: 20000,
|
||||
Storage: "memory",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "rate_limiter.max",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr && tt.errMsg != "" {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.errMsg)
|
||||
} else if err.Error() == "" {
|
||||
t.Errorf("expected error containing %q, got empty error", tt.errMsg)
|
||||
}
|
||||
// Note: We check that error message exists, not exact match
|
||||
// This is because error messages might change slightly
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSet tests the Set function
|
||||
func TestSet(t *testing.T) {
|
||||
// Valid config
|
||||
validCfg := &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := Set(validCfg)
|
||||
if err != nil {
|
||||
t.Errorf("Set() with valid config failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was set
|
||||
got := Get()
|
||||
if got.Server.Address != ":3000" {
|
||||
t.Errorf("Get() after Set() returned wrong address: got %s, want :3000", got.Server.Address)
|
||||
}
|
||||
|
||||
// Test with nil config
|
||||
err = Set(nil)
|
||||
if err == nil {
|
||||
t.Error("Set(nil) should return error")
|
||||
}
|
||||
|
||||
// Test with invalid config
|
||||
invalidCfg := &Config{
|
||||
Server: ServerConfig{
|
||||
Address: "", // Empty address is invalid
|
||||
},
|
||||
}
|
||||
|
||||
err = Set(invalidCfg)
|
||||
if err == nil {
|
||||
t.Error("Set() with invalid config should return error")
|
||||
}
|
||||
}
|
||||
661
pkg/config/loader_test.go
Normal file
661
pkg/config/loader_test.go
Normal file
@@ -0,0 +1,661 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// TestLoad tests the config loading functionality
|
||||
func TestLoad(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupEnv func()
|
||||
cleanupEnv func()
|
||||
createConfig func(t *testing.T) string
|
||||
wantErr bool
|
||||
validateFunc func(t *testing.T, cfg *Config)
|
||||
}{
|
||||
{
|
||||
name: "valid default config",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigPath, "")
|
||||
os.Setenv(constants.EnvConfigEnv, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
os.Unsetenv(constants.EnvConfigEnv)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
// Set as default config path
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T, cfg *Config) {
|
||||
if cfg.Server.Address != ":3000" {
|
||||
t.Errorf("expected server.address :3000, got %s", cfg.Server.Address)
|
||||
}
|
||||
if cfg.Server.ReadTimeout != 10*time.Second {
|
||||
t.Errorf("expected read_timeout 10s, got %v", cfg.Server.ReadTimeout)
|
||||
}
|
||||
if cfg.Redis.Address != "localhost" {
|
||||
t.Errorf("expected redis.address localhost, got %s", cfg.Redis.Address)
|
||||
}
|
||||
if cfg.Redis.Port != 6379 {
|
||||
t.Errorf("expected redis.port 6379, got %d", cfg.Redis.Port)
|
||||
}
|
||||
if cfg.Redis.PoolSize != 10 {
|
||||
t.Errorf("expected redis.pool_size 10, got %d", cfg.Redis.PoolSize)
|
||||
}
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Errorf("expected logging.level info, got %s", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Middleware.EnableAuth != true {
|
||||
t.Errorf("expected enable_auth true, got %v", cfg.Middleware.EnableAuth)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "environment-specific config (dev)",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigEnv, "dev")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigEnv)
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
// Create configs directory in temp
|
||||
tmpDir := t.TempDir()
|
||||
configsDir := filepath.Join(tmpDir, "configs")
|
||||
if err := os.MkdirAll(configsDir, 0755); err != nil {
|
||||
t.Fatalf("failed to create configs dir: %v", err)
|
||||
}
|
||||
|
||||
// Create dev config
|
||||
devConfigFile := filepath.Join(configsDir, "config.dev.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":8080"
|
||||
read_timeout: "15s"
|
||||
write_timeout: "15s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 1
|
||||
pool_size: 5
|
||||
min_idle_conns: 2
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
development: true
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 50
|
||||
max_backups: 10
|
||||
max_age: 7
|
||||
compress: false
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: false
|
||||
|
||||
middleware:
|
||||
enable_auth: false
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 50
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(devConfigFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create dev config file: %v", err)
|
||||
}
|
||||
|
||||
// Change to tmpDir so relative path works
|
||||
originalWd, _ := os.Getwd()
|
||||
os.Chdir(tmpDir)
|
||||
t.Cleanup(func() { os.Chdir(originalWd) })
|
||||
|
||||
return devConfigFile
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T, cfg *Config) {
|
||||
if cfg.Server.Address != ":8080" {
|
||||
t.Errorf("expected server.address :8080, got %s", cfg.Server.Address)
|
||||
}
|
||||
if cfg.Redis.DB != 1 {
|
||||
t.Errorf("expected redis.db 1, got %d", cfg.Redis.DB)
|
||||
}
|
||||
if cfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected logging.level debug, got %s", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Middleware.EnableAuth != false {
|
||||
t.Errorf("expected enable_auth false, got %v", cfg.Middleware.EnableAuth)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid YAML syntax",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigPath, "")
|
||||
os.Setenv(constants.EnvConfigEnv, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
os.Unsetenv(constants.EnvConfigEnv)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
invalid yaml syntax here!!!
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "validation error - invalid server address",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigPath, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ""
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "validation error - timeout out of range",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigPath, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "1s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "validation error - invalid redis port",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigPath, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 99999
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset viper for each test
|
||||
viper.Reset()
|
||||
|
||||
// Setup environment
|
||||
if tt.setupEnv != nil {
|
||||
tt.setupEnv()
|
||||
}
|
||||
|
||||
// Create config file
|
||||
if tt.createConfig != nil {
|
||||
tt.createConfig(t)
|
||||
}
|
||||
|
||||
// Cleanup after test
|
||||
if tt.cleanupEnv != nil {
|
||||
defer tt.cleanupEnv()
|
||||
}
|
||||
|
||||
// Load config
|
||||
cfg, err := Load()
|
||||
|
||||
// Check error expectation
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate config if no error expected
|
||||
if !tt.wantErr && tt.validateFunc != nil {
|
||||
tt.validateFunc(t, cfg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReload tests the config reload functionality
|
||||
func TestReload(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Initial config
|
||||
initialContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
// Set config path
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer os.Unsetenv(constants.EnvConfigPath)
|
||||
|
||||
// Load initial config
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load initial config: %v", err)
|
||||
}
|
||||
|
||||
// Verify initial values
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Errorf("expected initial logging.level info, got %s", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Server.Address != ":3000" {
|
||||
t.Errorf("expected initial server.address :3000, got %s", cfg.Server.Address)
|
||||
}
|
||||
|
||||
// Modify config file
|
||||
updatedContent := `
|
||||
server:
|
||||
address: ":8080"
|
||||
read_timeout: "15s"
|
||||
write_timeout: "15s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 20
|
||||
min_idle_conns: 10
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
development: true
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: false
|
||||
enable_rate_limiter: true
|
||||
rate_limiter:
|
||||
max: 200
|
||||
expiration: "2m"
|
||||
storage: "redis"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
|
||||
t.Fatalf("failed to update config file: %v", err)
|
||||
}
|
||||
|
||||
// Reload config
|
||||
newCfg, err := Reload()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload config: %v", err)
|
||||
}
|
||||
|
||||
// Verify updated values
|
||||
if newCfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected updated logging.level debug, got %s", newCfg.Logging.Level)
|
||||
}
|
||||
if newCfg.Server.Address != ":8080" {
|
||||
t.Errorf("expected updated server.address :8080, got %s", newCfg.Server.Address)
|
||||
}
|
||||
if newCfg.Redis.PoolSize != 20 {
|
||||
t.Errorf("expected updated redis.pool_size 20, got %d", newCfg.Redis.PoolSize)
|
||||
}
|
||||
if newCfg.Middleware.EnableAuth != false {
|
||||
t.Errorf("expected updated enable_auth false, got %v", newCfg.Middleware.EnableAuth)
|
||||
}
|
||||
if newCfg.Middleware.EnableRateLimiter != true {
|
||||
t.Errorf("expected updated enable_rate_limiter true, got %v", newCfg.Middleware.EnableRateLimiter)
|
||||
}
|
||||
|
||||
// Verify global config was updated
|
||||
globalCfg := Get()
|
||||
if globalCfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected global config updated, got logging.level %s", globalCfg.Logging.Level)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetConfigPath tests the GetConfigPath function
|
||||
func TestGetConfigPath(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer os.Unsetenv(constants.EnvConfigPath)
|
||||
|
||||
// Load config
|
||||
_, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Get config path
|
||||
path := GetConfigPath()
|
||||
if path == "" {
|
||||
t.Error("expected non-empty config path")
|
||||
}
|
||||
|
||||
// Verify it's an absolute path
|
||||
if !filepath.IsAbs(path) {
|
||||
t.Errorf("expected absolute path, got %s", path)
|
||||
}
|
||||
}
|
||||
422
pkg/config/watcher_test.go
Normal file
422
pkg/config/watcher_test.go
Normal file
@@ -0,0 +1,422 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestWatch tests the config hot reload watcher
|
||||
func TestWatch(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Initial config
|
||||
initialContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
// Set config path
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer os.Unsetenv(constants.EnvConfigPath)
|
||||
|
||||
// Load initial config
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load initial config: %v", err)
|
||||
}
|
||||
|
||||
// Verify initial values
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Fatalf("expected initial logging.level info, got %s", cfg.Logging.Level)
|
||||
}
|
||||
|
||||
// Create logger for testing
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// Start watcher in goroutine with context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go Watch(ctx, logger)
|
||||
|
||||
// Give watcher time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Modify config file to trigger hot reload
|
||||
updatedContent := `
|
||||
server:
|
||||
address: ":8080"
|
||||
read_timeout: "15s"
|
||||
write_timeout: "15s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 20
|
||||
min_idle_conns: 10
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
development: true
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: false
|
||||
enable_rate_limiter: true
|
||||
rate_limiter:
|
||||
max: 200
|
||||
expiration: "2m"
|
||||
storage: "redis"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
|
||||
t.Fatalf("failed to update config file: %v", err)
|
||||
}
|
||||
|
||||
// Wait for watcher to detect and process changes (spec requires detection within 5 seconds)
|
||||
// We use a more aggressive timeout for testing
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify config was reloaded
|
||||
reloadedCfg := Get()
|
||||
if reloadedCfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected config hot reload, got logging.level %s instead of debug", reloadedCfg.Logging.Level)
|
||||
}
|
||||
if reloadedCfg.Server.Address != ":8080" {
|
||||
t.Errorf("expected config hot reload, got server.address %s instead of :8080", reloadedCfg.Server.Address)
|
||||
}
|
||||
if reloadedCfg.Redis.PoolSize != 20 {
|
||||
t.Errorf("expected config hot reload, got redis.pool_size %d instead of 20", reloadedCfg.Redis.PoolSize)
|
||||
}
|
||||
|
||||
// Cancel context to stop watcher
|
||||
cancel()
|
||||
|
||||
// Give watcher time to shut down gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestWatch_InvalidConfigRejected tests that invalid config changes are rejected
|
||||
func TestWatch_InvalidConfigRejected(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Initial valid config
|
||||
validContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
// Set config path
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer os.Unsetenv(constants.EnvConfigPath)
|
||||
|
||||
// Load initial config
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load initial config: %v", err)
|
||||
}
|
||||
|
||||
initialLevel := cfg.Logging.Level
|
||||
if initialLevel != "info" {
|
||||
t.Fatalf("expected initial logging.level info, got %s", initialLevel)
|
||||
}
|
||||
|
||||
// Create logger for testing
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// Start watcher
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go Watch(ctx, logger)
|
||||
|
||||
// Give watcher time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Write INVALID config (malformed YAML)
|
||||
invalidContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
invalid yaml syntax here!!!
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(invalidContent), 0644); err != nil {
|
||||
t.Fatalf("failed to write invalid config: %v", err)
|
||||
}
|
||||
|
||||
// Wait for watcher to detect changes
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify config was NOT changed (should keep previous valid config)
|
||||
currentCfg := Get()
|
||||
if currentCfg.Logging.Level != initialLevel {
|
||||
t.Errorf("expected config to remain unchanged after invalid update, got logging.level %s instead of %s", currentCfg.Logging.Level, initialLevel)
|
||||
}
|
||||
|
||||
// Restore valid config
|
||||
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
|
||||
t.Fatalf("failed to restore valid config: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Now write config with validation error (timeout out of range)
|
||||
invalidValidationContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "1s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(invalidValidationContent), 0644); err != nil {
|
||||
t.Fatalf("failed to write config with validation error: %v", err)
|
||||
}
|
||||
|
||||
// Wait for watcher to detect changes
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify config was NOT changed (validation should have failed)
|
||||
finalCfg := Get()
|
||||
if finalCfg.Logging.Level != initialLevel {
|
||||
t.Errorf("expected config to remain unchanged after validation error, got logging.level %s instead of %s", finalCfg.Logging.Level, initialLevel)
|
||||
}
|
||||
if finalCfg.Server.ReadTimeout == 1*time.Second {
|
||||
t.Error("expected config to remain unchanged, but read_timeout was updated to invalid value")
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestWatch_ContextCancellation tests graceful shutdown on context cancellation
|
||||
func TestWatch_ContextCancellation(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer os.Unsetenv(constants.EnvConfigPath)
|
||||
|
||||
// Load config
|
||||
_, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Create logger
|
||||
logger := zap.NewNop() // Use no-op logger for this test
|
||||
|
||||
// Start watcher with context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
Watch(ctx, logger)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Give watcher time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Cancel context (simulate graceful shutdown)
|
||||
cancel()
|
||||
|
||||
// Wait for watcher to stop (should happen quickly)
|
||||
select {
|
||||
case <-done:
|
||||
// Watcher stopped successfully
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("watcher did not stop within timeout after context cancellation")
|
||||
}
|
||||
}
|
||||
518
pkg/logger/logger_test.go
Normal file
518
pkg/logger/logger_test.go
Normal file
@@ -0,0 +1,518 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// TestInitLoggers 测试日志初始化(T026)
|
||||
func TestInitLoggers(t *testing.T) {
|
||||
// 创建临时目录用于日志文件
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
level string
|
||||
development bool
|
||||
appLogConfig LogRotationConfig
|
||||
accessLogConfig LogRotationConfig
|
||||
wantErr bool
|
||||
validateFunc func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "production mode with info level",
|
||||
level: "info",
|
||||
development: false,
|
||||
appLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-prod.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
accessLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-prod.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T) {
|
||||
if appLogger == nil {
|
||||
t.Error("appLogger should not be nil")
|
||||
}
|
||||
if accessLogger == nil {
|
||||
t.Error("accessLogger should not be nil")
|
||||
}
|
||||
// 写入一条日志以触发文件创建
|
||||
GetAppLogger().Info("test log creation")
|
||||
Sync()
|
||||
// 验证日志文件创建
|
||||
if _, err := os.Stat(filepath.Join(tempDir, "app-prod.log")); os.IsNotExist(err) {
|
||||
t.Error("app log file should be created after writing")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "development mode with debug level",
|
||||
level: "debug",
|
||||
development: true,
|
||||
appLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-dev.log"),
|
||||
MaxSize: 5,
|
||||
MaxBackups: 2,
|
||||
MaxAge: 3,
|
||||
Compress: false,
|
||||
},
|
||||
accessLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-dev.log"),
|
||||
MaxSize: 5,
|
||||
MaxBackups: 2,
|
||||
MaxAge: 3,
|
||||
Compress: false,
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T) {
|
||||
if appLogger == nil {
|
||||
t.Error("appLogger should not be nil in dev mode")
|
||||
}
|
||||
if accessLogger == nil {
|
||||
t.Error("accessLogger should not be nil in dev mode")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "warn level logging",
|
||||
level: "warn",
|
||||
development: false,
|
||||
appLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-warn.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
accessLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-warn.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T) {
|
||||
if appLogger == nil {
|
||||
t.Error("appLogger should not be nil")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error level logging",
|
||||
level: "error",
|
||||
development: false,
|
||||
appLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-error.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
accessLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-error.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T) {
|
||||
if appLogger == nil {
|
||||
t.Error("appLogger should not be nil")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid level defaults to info",
|
||||
level: "invalid",
|
||||
development: false,
|
||||
appLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-invalid.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
accessLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-invalid.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T) {
|
||||
if appLogger == nil {
|
||||
t.Error("appLogger should not be nil even with invalid level")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := InitLoggers(tt.level, tt.development, tt.appLogConfig, tt.accessLogConfig)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("InitLoggers() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.validateFunc != nil {
|
||||
tt.validateFunc(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAppLogger 测试获取应用日志记录器(T026)
|
||||
func TestGetAppLogger(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func()
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "after initialization",
|
||||
setupFunc: func() {
|
||||
InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-get.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-get.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
)
|
||||
},
|
||||
wantNil: false,
|
||||
},
|
||||
{
|
||||
name: "before initialization returns nop logger",
|
||||
setupFunc: func() {
|
||||
// 重置全局变量
|
||||
appLogger = nil
|
||||
},
|
||||
wantNil: false, // GetAppLogger 应该返回 nop logger,不是 nil
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupFunc()
|
||||
logger := GetAppLogger()
|
||||
if logger == nil {
|
||||
t.Error("GetAppLogger() should never return nil, should return nop logger instead")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAccessLogger 测试获取访问日志记录器(T028)
|
||||
func TestGetAccessLogger(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func()
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "after initialization",
|
||||
setupFunc: func() {
|
||||
InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-access.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-access.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
)
|
||||
},
|
||||
wantNil: false,
|
||||
},
|
||||
{
|
||||
name: "before initialization returns nop logger",
|
||||
setupFunc: func() {
|
||||
// 重置全局变量
|
||||
accessLogger = nil
|
||||
},
|
||||
wantNil: false, // GetAccessLogger 应该返回 nop logger,不是 nil
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupFunc()
|
||||
logger := GetAccessLogger()
|
||||
if logger == nil {
|
||||
t.Error("GetAccessLogger() should never return nil, should return nop logger instead")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSync 测试日志缓冲区刷新(T028)
|
||||
func TestSync(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func()
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "sync after initialization",
|
||||
setupFunc: func() {
|
||||
InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-sync.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-sync.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
)
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "sync before initialization",
|
||||
setupFunc: func() {
|
||||
appLogger = nil
|
||||
accessLogger = nil
|
||||
},
|
||||
wantErr: false, // 应该优雅地处理 nil 情况
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupFunc()
|
||||
err := Sync()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Sync() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseLevel 测试日志级别解析(T026)
|
||||
func TestParseLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
level string
|
||||
want zapcore.Level
|
||||
}{
|
||||
{
|
||||
name: "debug level",
|
||||
level: "debug",
|
||||
want: zapcore.DebugLevel,
|
||||
},
|
||||
{
|
||||
name: "info level",
|
||||
level: "info",
|
||||
want: zapcore.InfoLevel,
|
||||
},
|
||||
{
|
||||
name: "warn level",
|
||||
level: "warn",
|
||||
want: zapcore.WarnLevel,
|
||||
},
|
||||
{
|
||||
name: "error level",
|
||||
level: "error",
|
||||
want: zapcore.ErrorLevel,
|
||||
},
|
||||
{
|
||||
name: "invalid level defaults to info",
|
||||
level: "invalid",
|
||||
want: zapcore.InfoLevel,
|
||||
},
|
||||
{
|
||||
name: "empty level defaults to info",
|
||||
level: "",
|
||||
want: zapcore.InfoLevel,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := parseLevel(tt.level)
|
||||
if got != tt.want {
|
||||
t.Errorf("parseLevel() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDualLoggerSystem 测试双日志系统(T028)
|
||||
func TestDualLoggerSystem(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
appLogFile := filepath.Join(tempDir, "app-dual.log")
|
||||
accessLogFile := filepath.Join(tempDir, "access-dual.log")
|
||||
|
||||
// 初始化双日志系统
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: appLogFile,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false, // 不压缩以便检查内容
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: accessLogFile,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
// 写入应用日志
|
||||
appLog := GetAppLogger()
|
||||
appLog.Info("test app log message",
|
||||
zap.String("module", "test"),
|
||||
zap.Int("code", 200),
|
||||
)
|
||||
|
||||
// 写入访问日志
|
||||
accessLog := GetAccessLogger()
|
||||
accessLog.Info("test access log message",
|
||||
zap.String("method", "GET"),
|
||||
zap.String("path", "/api/test"),
|
||||
zap.Int("status", 200),
|
||||
zap.Duration("latency", 100),
|
||||
)
|
||||
|
||||
// 刷新缓冲区
|
||||
if err := Sync(); err != nil {
|
||||
t.Fatalf("Sync failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证应用日志文件存在并有内容
|
||||
appLogContent, err := os.ReadFile(appLogFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read app log file: %v", err)
|
||||
}
|
||||
if len(appLogContent) == 0 {
|
||||
t.Error("App log file should not be empty")
|
||||
}
|
||||
|
||||
// 验证访问日志文件存在并有内容
|
||||
accessLogContent, err := os.ReadFile(accessLogFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read access log file: %v", err)
|
||||
}
|
||||
if len(accessLogContent) == 0 {
|
||||
t.Error("Access log file should not be empty")
|
||||
}
|
||||
|
||||
// 验证两个日志文件是独立的
|
||||
if string(appLogContent) == string(accessLogContent) {
|
||||
t.Error("App log and access log should have different content")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoggerReinitialization 测试日志重新初始化(T026)
|
||||
func TestLoggerReinitialization(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 第一次初始化
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-reinit1.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-reinit1.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("First InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
firstAppLogger := GetAppLogger()
|
||||
|
||||
// 第二次初始化(重新初始化)
|
||||
err = InitLoggers("debug", true,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-reinit2.log"),
|
||||
MaxSize: 5,
|
||||
MaxBackups: 2,
|
||||
MaxAge: 3,
|
||||
Compress: false,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-reinit2.log"),
|
||||
MaxSize: 5,
|
||||
MaxBackups: 2,
|
||||
MaxAge: 3,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Second InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
secondAppLogger := GetAppLogger()
|
||||
|
||||
// 验证重新初始化后日志记录器已更新
|
||||
if firstAppLogger == secondAppLogger {
|
||||
t.Error("Logger should be replaced after reinitialization")
|
||||
}
|
||||
}
|
||||
388
pkg/logger/rotation_test.go
Normal file
388
pkg/logger/rotation_test.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TestLogRotation 测试日志轮转功能(T027)
|
||||
func TestLogRotation(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
appLogFile := filepath.Join(tempDir, "app-rotation.log")
|
||||
|
||||
// 初始化日志系统,设置较小的 MaxSize 以便测试
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: appLogFile,
|
||||
MaxSize: 1, // 1MB,写入足够数据后会触发轮转
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false, // 不压缩以便检查
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-rotation.log"),
|
||||
MaxSize: 1,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
logger := GetAppLogger()
|
||||
|
||||
// 写入大量日志数据以触发轮转(每条约100字节,写入15000条约1.5MB)
|
||||
largeMessage := strings.Repeat("a", 100)
|
||||
for i := 0; i < 15000; i++ {
|
||||
logger.Info(largeMessage,
|
||||
zap.Int("iteration", i),
|
||||
zap.String("data", largeMessage),
|
||||
)
|
||||
}
|
||||
|
||||
// 刷新缓冲区
|
||||
if err := Sync(); err != nil {
|
||||
t.Fatalf("Sync failed: %v", err)
|
||||
}
|
||||
|
||||
// 等待一小段时间确保文件写入完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证主日志文件存在
|
||||
if _, err := os.Stat(appLogFile); os.IsNotExist(err) {
|
||||
t.Error("Main log file should exist")
|
||||
}
|
||||
|
||||
// 检查是否有备份文件(轮转后的文件)
|
||||
files, err := filepath.Glob(filepath.Join(tempDir, "app-rotation-*.log"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to glob backup files: %v", err)
|
||||
}
|
||||
|
||||
// 由于写入了超过1MB的数据,应该触发至少一次轮转
|
||||
if len(files) == 0 {
|
||||
// 可能系统写入速度或lumberjack行为导致未立即轮转,检查主文件大小
|
||||
info, err := os.Stat(appLogFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to stat main log file: %v", err)
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
t.Error("Log file should have content")
|
||||
}
|
||||
// 不强制要求必须轮转,因为取决于具体实现
|
||||
t.Logf("No rotation occurred, but main log file size: %d bytes", info.Size())
|
||||
} else {
|
||||
t.Logf("Found %d rotated backup file(s)", len(files))
|
||||
}
|
||||
}
|
||||
|
||||
// TestMaxBackups 测试最大备份数限制(T027)
|
||||
func TestMaxBackups(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
appLogFile := filepath.Join(tempDir, "app-backups.log")
|
||||
|
||||
// 初始化日志系统,设置 MaxBackups=2
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: appLogFile,
|
||||
MaxSize: 1, // 1MB
|
||||
MaxBackups: 2, // 最多保留2个备份
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-backups.log"),
|
||||
MaxSize: 1,
|
||||
MaxBackups: 2,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
logger := GetAppLogger()
|
||||
|
||||
// 写入足够的数据触发多次轮转(每次1.5MB,共4.5MB应该触发3次轮转)
|
||||
largeMessage := strings.Repeat("b", 100)
|
||||
for round := 0; round < 3; round++ {
|
||||
for i := 0; i < 15000; i++ {
|
||||
logger.Info(largeMessage,
|
||||
zap.Int("round", round),
|
||||
zap.Int("iteration", i),
|
||||
)
|
||||
}
|
||||
Sync()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// 等待轮转完成
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// 检查备份文件数量
|
||||
files, err := filepath.Glob(filepath.Join(tempDir, "app-backups-*.log"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to glob backup files: %v", err)
|
||||
}
|
||||
|
||||
// 由于 MaxBackups=2,即使触发了多次轮转,也只应保留最多2个备份文件
|
||||
// (实际行为取决于 lumberjack 的实现细节,可能小于等于2)
|
||||
if len(files) > 2 {
|
||||
t.Errorf("Expected at most 2 backup files due to MaxBackups=2, got %d", len(files))
|
||||
}
|
||||
t.Logf("Found %d backup file(s) with MaxBackups=2", len(files))
|
||||
}
|
||||
|
||||
// TestCompressionConfig 测试压缩配置(T027)
|
||||
func TestCompressionConfig(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
compress bool
|
||||
}{
|
||||
{
|
||||
name: "compression enabled",
|
||||
compress: true,
|
||||
},
|
||||
{
|
||||
name: "compression disabled",
|
||||
compress: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logFile := filepath.Join(tempDir, "app-"+tt.name+".log")
|
||||
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: logFile,
|
||||
MaxSize: 1,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: tt.compress,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-"+tt.name+".log"),
|
||||
MaxSize: 1,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: tt.compress,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
logger := GetAppLogger()
|
||||
|
||||
// 写入一些日志
|
||||
for i := 0; i < 1000; i++ {
|
||||
logger.Info("test compression",
|
||||
zap.Int("id", i),
|
||||
zap.String("data", strings.Repeat("c", 50)),
|
||||
)
|
||||
}
|
||||
|
||||
Sync()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证日志文件存在
|
||||
if _, err := os.Stat(logFile); os.IsNotExist(err) {
|
||||
t.Error("Log file should exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMaxAge 测试日志文件保留时间(T027)
|
||||
func TestMaxAge(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 初始化日志系统,设置 MaxAge=1 天
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-maxage.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 1, // 1天
|
||||
Compress: false,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-maxage.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 1,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
logger := GetAppLogger()
|
||||
|
||||
// 写入日志
|
||||
logger.Info("test max age", zap.String("config", "maxage=1"))
|
||||
Sync()
|
||||
|
||||
// 验证配置已应用(无法在单元测试中验证实际的清理行为,因为需要等待1天)
|
||||
// 这里只验证初始化没有错误
|
||||
if logger == nil {
|
||||
t.Error("Logger should be initialized with MaxAge config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewLumberjackLogger 测试 Lumberjack logger 创建(T027)
|
||||
func TestNewLumberjackLogger(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config LogRotationConfig
|
||||
}{
|
||||
{
|
||||
name: "standard config",
|
||||
config: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "test1.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "minimal config",
|
||||
config: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "test2.log"),
|
||||
MaxSize: 1,
|
||||
MaxBackups: 1,
|
||||
MaxAge: 1,
|
||||
Compress: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "large config",
|
||||
config: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "test3.log"),
|
||||
MaxSize: 100,
|
||||
MaxBackups: 10,
|
||||
MaxAge: 30,
|
||||
Compress: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := newLumberjackLogger(tt.config)
|
||||
if logger == nil {
|
||||
t.Error("newLumberjackLogger should not return nil")
|
||||
}
|
||||
|
||||
// 验证配置已正确设置
|
||||
if logger.Filename != tt.config.Filename {
|
||||
t.Errorf("Filename = %v, want %v", logger.Filename, tt.config.Filename)
|
||||
}
|
||||
if logger.MaxSize != tt.config.MaxSize {
|
||||
t.Errorf("MaxSize = %v, want %v", logger.MaxSize, tt.config.MaxSize)
|
||||
}
|
||||
if logger.MaxBackups != tt.config.MaxBackups {
|
||||
t.Errorf("MaxBackups = %v, want %v", logger.MaxBackups, tt.config.MaxBackups)
|
||||
}
|
||||
if logger.MaxAge != tt.config.MaxAge {
|
||||
t.Errorf("MaxAge = %v, want %v", logger.MaxAge, tt.config.MaxAge)
|
||||
}
|
||||
if logger.Compress != tt.config.Compress {
|
||||
t.Errorf("Compress = %v, want %v", logger.Compress, tt.config.Compress)
|
||||
}
|
||||
if !logger.LocalTime {
|
||||
t.Error("LocalTime should be true")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentLogging 测试并发日志写入(T027)
|
||||
func TestConcurrentLogging(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 初始化日志系统
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-concurrent.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-concurrent.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
logger := GetAppLogger()
|
||||
|
||||
// 启动多个 goroutine 并发写入日志
|
||||
done := make(chan bool)
|
||||
goroutines := 10
|
||||
messagesPerGoroutine := 100
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(id int) {
|
||||
for j := 0; j < messagesPerGoroutine; j++ {
|
||||
logger.Info("concurrent log message",
|
||||
zap.Int("goroutine", id),
|
||||
zap.Int("message", j),
|
||||
)
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 等待所有 goroutine 完成
|
||||
for i := 0; i < goroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// 刷新缓冲区
|
||||
if err := Sync(); err != nil {
|
||||
t.Fatalf("Sync failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证日志文件存在且有内容
|
||||
logFile := filepath.Join(tempDir, "app-concurrent.log")
|
||||
info, err := os.Stat(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to stat log file: %v", err)
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
t.Error("Log file should have content after concurrent writes")
|
||||
}
|
||||
|
||||
t.Logf("Concurrent logging test completed, log file size: %d bytes", info.Size())
|
||||
}
|
||||
477
pkg/response/response_test.go
Normal file
477
pkg/response/response_test.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/errors"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// TestSuccess 测试成功响应(T034)
|
||||
func TestSuccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
}{
|
||||
{
|
||||
name: "success with string data",
|
||||
data: "test data",
|
||||
},
|
||||
{
|
||||
name: "success with map data",
|
||||
data: map[string]any{
|
||||
"id": 123,
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with slice data",
|
||||
data: []string{"item1", "item2", "item3"},
|
||||
},
|
||||
{
|
||||
name: "success with struct data",
|
||||
data: struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}{
|
||||
ID: 456,
|
||||
Name: "test struct",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with nil data",
|
||||
data: nil,
|
||||
},
|
||||
{
|
||||
name: "success with empty map",
|
||||
data: map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return Success(c, tt.data)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 验证 HTTP 状态码
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("Expected status code 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 验证响应头
|
||||
if resp.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
// 解析响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// 验证响应结构
|
||||
if response.Code != errors.CodeSuccess {
|
||||
t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code)
|
||||
}
|
||||
|
||||
if response.Message != "success" {
|
||||
t.Errorf("Expected message 'success', got '%s'", response.Message)
|
||||
}
|
||||
|
||||
// 验证时间戳格式 RFC3339
|
||||
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
|
||||
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
|
||||
}
|
||||
|
||||
// 验证数据字段(如果不是 nil)
|
||||
if tt.data != nil {
|
||||
if response.Data == nil {
|
||||
t.Error("Expected data field to be non-nil")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestError 测试错误响应(T035)
|
||||
func TestError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
httpStatus int
|
||||
code int
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "internal server error",
|
||||
httpStatus: 500,
|
||||
code: errors.CodeInternalError,
|
||||
message: "Internal server error occurred",
|
||||
},
|
||||
{
|
||||
name: "missing token error",
|
||||
httpStatus: 401,
|
||||
code: errors.CodeMissingToken,
|
||||
message: "Authentication token is missing",
|
||||
},
|
||||
{
|
||||
name: "invalid token error",
|
||||
httpStatus: 401,
|
||||
code: errors.CodeInvalidToken,
|
||||
message: "Token is invalid or expired",
|
||||
},
|
||||
{
|
||||
name: "rate limit error",
|
||||
httpStatus: 429,
|
||||
code: errors.CodeTooManyRequests,
|
||||
message: "Too many requests, please try again later",
|
||||
},
|
||||
{
|
||||
name: "service unavailable error",
|
||||
httpStatus: 503,
|
||||
code: errors.CodeAuthServiceUnavailable,
|
||||
message: "Authentication service is currently unavailable",
|
||||
},
|
||||
{
|
||||
name: "bad request error",
|
||||
httpStatus: 400,
|
||||
code: 2000,
|
||||
message: "Invalid request parameters",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return Error(c, tt.httpStatus, tt.code, tt.message)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 验证 HTTP 状态码
|
||||
if resp.StatusCode != tt.httpStatus {
|
||||
t.Errorf("Expected status code %d, got %d", tt.httpStatus, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 验证响应头
|
||||
if resp.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
// 解析响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// 验证响应结构
|
||||
if response.Code != tt.code {
|
||||
t.Errorf("Expected code %d, got %d", tt.code, response.Code)
|
||||
}
|
||||
|
||||
if response.Message != tt.message {
|
||||
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message)
|
||||
}
|
||||
|
||||
if response.Data != nil {
|
||||
t.Errorf("Expected data to be nil in error response, got %v", response.Data)
|
||||
}
|
||||
|
||||
// 验证时间戳格式 RFC3339
|
||||
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
|
||||
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSuccessWithMessage 测试带自定义消息的成功响应(T034)
|
||||
func TestSuccessWithMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "custom success message",
|
||||
data: map[string]any{
|
||||
"user_id": 123,
|
||||
},
|
||||
message: "User created successfully",
|
||||
},
|
||||
{
|
||||
name: "empty custom message",
|
||||
data: "test data",
|
||||
message: "",
|
||||
},
|
||||
{
|
||||
name: "chinese message",
|
||||
data: map[string]string{
|
||||
"status": "ok",
|
||||
},
|
||||
message: "操作成功",
|
||||
},
|
||||
{
|
||||
name: "long message",
|
||||
data: nil,
|
||||
message: "This is a very long success message that describes in detail what happened during the operation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return SuccessWithMessage(c, tt.data, tt.message)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 验证 HTTP 状态码(默认 200)
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("Expected status code 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 解析响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// 验证响应结构
|
||||
if response.Code != errors.CodeSuccess {
|
||||
t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code)
|
||||
}
|
||||
|
||||
if response.Message != tt.message {
|
||||
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message)
|
||||
}
|
||||
|
||||
// 验证时间戳格式 RFC3339
|
||||
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
|
||||
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponseSerialization 测试响应序列化(T036)
|
||||
func TestResponseSerialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response Response
|
||||
}{
|
||||
{
|
||||
name: "complete response",
|
||||
response: Response{
|
||||
Code: 0,
|
||||
Data: map[string]any{"key": "value"},
|
||||
Message: "success",
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with nil data",
|
||||
response: Response{
|
||||
Code: 1000,
|
||||
Data: nil,
|
||||
Message: "error",
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with nested data",
|
||||
response: Response{
|
||||
Code: 0,
|
||||
Data: map[string]any{
|
||||
"user": map[string]any{
|
||||
"id": 123,
|
||||
"name": "test",
|
||||
"tags": []string{"tag1", "tag2"},
|
||||
},
|
||||
},
|
||||
Message: "success",
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 序列化
|
||||
data, err := json.Marshal(tt.response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal response: %v", err)
|
||||
}
|
||||
|
||||
// 反序列化
|
||||
var deserialized Response
|
||||
if err := json.Unmarshal(data, &deserialized); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// 验证字段
|
||||
if deserialized.Code != tt.response.Code {
|
||||
t.Errorf("Code mismatch: expected %d, got %d", tt.response.Code, deserialized.Code)
|
||||
}
|
||||
|
||||
if deserialized.Message != tt.response.Message {
|
||||
t.Errorf("Message mismatch: expected '%s', got '%s'", tt.response.Message, deserialized.Message)
|
||||
}
|
||||
|
||||
if deserialized.Timestamp != tt.response.Timestamp {
|
||||
t.Errorf("Timestamp mismatch: expected '%s', got '%s'", tt.response.Timestamp, deserialized.Timestamp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponseStructFields 测试响应结构字段(T036)
|
||||
func TestResponseStructFields(t *testing.T) {
|
||||
response := Response{
|
||||
Code: 0,
|
||||
Data: "test",
|
||||
Message: "success",
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal response: %v", err)
|
||||
}
|
||||
|
||||
// 解析为 map 以检查 JSON 键
|
||||
var jsonMap map[string]any
|
||||
if err := json.Unmarshal(data, &jsonMap); err != nil {
|
||||
t.Fatalf("Failed to unmarshal to map: %v", err)
|
||||
}
|
||||
|
||||
// 验证所有必需字段都存在
|
||||
requiredFields := []string{"code", "data", "msg", "timestamp"}
|
||||
for _, field := range requiredFields {
|
||||
if _, exists := jsonMap[field]; !exists {
|
||||
t.Errorf("Required field '%s' is missing in JSON response", field)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证字段类型
|
||||
if _, ok := jsonMap["code"].(float64); !ok {
|
||||
t.Error("Field 'code' should be a number")
|
||||
}
|
||||
|
||||
if _, ok := jsonMap["msg"].(string); !ok {
|
||||
t.Error("Field 'msg' should be a string")
|
||||
}
|
||||
|
||||
if _, ok := jsonMap["timestamp"].(string); !ok {
|
||||
t.Error("Field 'timestamp' should be a string")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipleResponses 测试多个连续响应(T036)
|
||||
func TestMultipleResponses(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
callCount := 0
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
callCount++
|
||||
if callCount%2 == 0 {
|
||||
return Success(c, map[string]int{"count": callCount})
|
||||
}
|
||||
return Error(c, 500, errors.CodeInternalError, "error occurred")
|
||||
})
|
||||
|
||||
// 发送多个请求
|
||||
for i := 1; i <= 5; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request %d failed: %v", i, err)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
var response Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Request %d: failed to unmarshal response: %v", i, err)
|
||||
}
|
||||
|
||||
// 验证每个响应都有时间戳
|
||||
if response.Timestamp == "" {
|
||||
t.Errorf("Request %d: timestamp should not be empty", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTimestampFormat 测试时间戳格式(T036)
|
||||
func TestTimestampFormat(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return Success(c, nil)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var response Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// 验证是 RFC3339 格式
|
||||
parsedTime, err := time.Parse(time.RFC3339, response.Timestamp)
|
||||
if err != nil {
|
||||
t.Fatalf("Timestamp is not in RFC3339 format: %s, error: %v", response.Timestamp, err)
|
||||
}
|
||||
|
||||
// 验证时间戳是最近的(应该在最近 1 秒内)
|
||||
now := time.Now()
|
||||
diff := now.Sub(parsedTime)
|
||||
if diff < 0 || diff > time.Second {
|
||||
t.Errorf("Timestamp seems incorrect: %s (diff from now: %v)", response.Timestamp, diff)
|
||||
}
|
||||
}
|
||||
@@ -68,17 +68,17 @@
|
||||
|
||||
### Unit Tests for User Story 1
|
||||
|
||||
- [ ] T018 [P] [US1] Unit test for config loading and validation in pkg/config/loader_test.go
|
||||
- [ ] T019 [P] [US1] Unit test for config hot reload mechanism in pkg/config/watcher_test.go
|
||||
- [ ] T020 [P] [US1] Test invalid config handling (malformed YAML, validation errors) in pkg/config/config_test.go
|
||||
- [X] T018 [P] [US1] Unit test for config loading and validation in pkg/config/loader_test.go
|
||||
- [X] T019 [P] [US1] Unit test for config hot reload mechanism in pkg/config/watcher_test.go
|
||||
- [X] T020 [P] [US1] Test invalid config handling (malformed YAML, validation errors) in pkg/config/config_test.go
|
||||
|
||||
### Implementation for User Story 1
|
||||
|
||||
- [ ] T021 [US1] Implement atomic config pointer swap in pkg/config/config.go (sync/atomic usage)
|
||||
- [ ] T022 [US1] Implement config change callback with validation in pkg/config/watcher.go
|
||||
- [ ] T023 [US1] Add config reload logging with Zap in pkg/config/watcher.go
|
||||
- [ ] T024 [US1] Integrate config watcher with context cancellation in cmd/api/main.go
|
||||
- [ ] T025 [US1] Add graceful shutdown for config watcher in cmd/api/main.go
|
||||
- [X] T021 [US1] Implement atomic config pointer swap in pkg/config/config.go (sync/atomic usage)
|
||||
- [X] T022 [US1] Implement config change callback with validation in pkg/config/watcher.go
|
||||
- [X] T023 [US1] Add config reload logging with Zap in pkg/config/watcher.go
|
||||
- [X] T024 [US1] Integrate config watcher with context cancellation in cmd/api/main.go
|
||||
- [X] T025 [US1] Add graceful shutdown for config watcher in cmd/api/main.go
|
||||
|
||||
**Checkpoint**: Config hot reload should work independently - modify config and see changes applied
|
||||
|
||||
@@ -92,17 +92,17 @@
|
||||
|
||||
### Unit Tests for User Story 2
|
||||
|
||||
- [ ] T026 [P] [US2] Unit test for logger initialization in pkg/logger/logger_test.go
|
||||
- [ ] T027 [P] [US2] Unit test for log rotation configuration in pkg/logger/rotation_test.go
|
||||
- [ ] T028 [P] [US2] Test structured logging with fields in pkg/logger/logger_test.go
|
||||
- [X] T026 [P] [US2] Unit test for logger initialization in pkg/logger/logger_test.go
|
||||
- [X] T027 [P] [US2] Unit test for log rotation configuration in pkg/logger/rotation_test.go
|
||||
- [X] T028 [P] [US2] Test structured logging with fields in pkg/logger/logger_test.go
|
||||
|
||||
### Implementation for User Story 2
|
||||
|
||||
- [ ] T029 [P] [US2] Create appLogger instance with Lumberjack writer in pkg/logger/logger.go
|
||||
- [ ] T030 [P] [US2] Create accessLogger instance with separate Lumberjack writer in pkg/logger/logger.go
|
||||
- [ ] T031 [US2] Configure JSON encoder with RFC3339 timestamps in pkg/logger/logger.go
|
||||
- [ ] T032 [US2] Export GetAppLogger() and GetAccessLogger() functions in pkg/logger/logger.go
|
||||
- [ ] T033 [US2] Add logger.Sync() call in graceful shutdown in cmd/api/main.go
|
||||
- [X] T029 [P] [US2] Create appLogger instance with Lumberjack writer in pkg/logger/logger.go
|
||||
- [X] T030 [P] [US2] Create accessLogger instance with separate Lumberjack writer in pkg/logger/logger.go
|
||||
- [X] T031 [US2] Configure JSON encoder with RFC3339 timestamps in pkg/logger/logger.go
|
||||
- [X] T032 [US2] Export GetAppLogger() and GetAccessLogger() functions in pkg/logger/logger.go
|
||||
- [X] T033 [US2] Add logger.Sync() call in graceful shutdown in cmd/api/main.go
|
||||
|
||||
**Checkpoint**: Both app.log and access.log should exist with JSON entries, rotate at configured size
|
||||
|
||||
@@ -116,18 +116,18 @@
|
||||
|
||||
### Unit Tests for User Story 3
|
||||
|
||||
- [ ] T034 [P] [US3] Unit test for Success() response helper in pkg/response/response_test.go
|
||||
- [ ] T035 [P] [US3] Unit test for Error() response helper in pkg/response/response_test.go
|
||||
- [ ] T036 [P] [US3] Test response serialization with sonic JSON in pkg/response/response_test.go
|
||||
- [X] T034 [P] [US3] Unit test for Success() response helper in pkg/response/response_test.go
|
||||
- [X] T035 [P] [US3] Unit test for Error() response helper in pkg/response/response_test.go
|
||||
- [X] T036 [P] [US3] Test response serialization with sonic JSON in pkg/response/response_test.go
|
||||
|
||||
### Implementation for User Story 3
|
||||
|
||||
- [ ] T037 [P] [US3] Implement Success() helper function in pkg/response/response.go
|
||||
- [ ] T038 [P] [US3] Implement Error() helper function in pkg/response/response.go
|
||||
- [ ] T039 [P] [US3] Implement SuccessWithMessage() helper function in pkg/response/response.go
|
||||
- [ ] T040 [US3] Configure Fiber to use sonic as JSON serializer in cmd/api/main.go
|
||||
- [ ] T041 [US3] Create example health check endpoint using response helpers in internal/handler/health.go
|
||||
- [ ] T042 [US3] Register health check route in cmd/api/main.go
|
||||
- [X] T037 [P] [US3] Implement Success() helper function in pkg/response/response.go
|
||||
- [X] T038 [P] [US3] Implement Error() helper function in pkg/response/response.go
|
||||
- [X] T039 [P] [US3] Implement SuccessWithMessage() helper function in pkg/response/response.go
|
||||
- [X] T040 [US3] Configure Fiber to use sonic as JSON serializer in cmd/api/main.go
|
||||
- [X] T041 [US3] Create example health check endpoint using response helpers in internal/handler/health.go
|
||||
- [X] T042 [US3] Register health check route in cmd/api/main.go
|
||||
|
||||
**Checkpoint**: Health check endpoint returns unified response format with proper structure
|
||||
|
||||
@@ -141,18 +141,18 @@
|
||||
|
||||
### Integration Tests for User Story 4
|
||||
|
||||
- [ ] T043 [P] [US4] Integration test for requestid middleware (UUID v4 generation) in tests/integration/middleware_test.go
|
||||
- [ ] T044 [P] [US4] Integration test for logger middleware (access log entries) in tests/integration/middleware_test.go
|
||||
- [ ] T045 [P] [US4] Test request ID propagation through middleware chain in tests/integration/middleware_test.go
|
||||
- [X] T043 [P] [US4] Integration test for requestid middleware (UUID v4 generation) in tests/integration/middleware_test.go
|
||||
- [X] T044 [P] [US4] Integration test for logger middleware (access log entries) in tests/integration/middleware_test.go
|
||||
- [X] T045 [P] [US4] Test request ID propagation through middleware chain in tests/integration/middleware_test.go
|
||||
|
||||
### Implementation for User Story 4
|
||||
|
||||
- [ ] T046 [P] [US4] Configure Fiber requestid middleware with google/uuid in cmd/api/main.go
|
||||
- [ ] T047 [US4] Implement custom logger middleware writing to accessLogger in internal/middleware/logger.go
|
||||
- [ ] T048 [US4] Add request ID to Fiber Locals in logger middleware in internal/middleware/logger.go
|
||||
- [ ] T049 [US4] Add X-Request-ID response header in logger middleware in internal/middleware/logger.go
|
||||
- [ ] T050 [US4] Log request details (method, path, status, duration, IP, user_agent) to access.log in internal/middleware/logger.go
|
||||
- [ ] T051 [US4] Register requestid and logger middleware in correct order in cmd/api/main.go
|
||||
- [X] T046 [P] [US4] Configure Fiber requestid middleware with google/uuid in cmd/api/main.go
|
||||
- [X] T047 [US4] Implement custom logger middleware writing to accessLogger in internal/middleware/logger.go
|
||||
- [X] T048 [US4] Add request ID to Fiber Locals in logger middleware in internal/middleware/logger.go
|
||||
- [X] T049 [US4] Add X-Request-ID response header in logger middleware in internal/middleware/logger.go
|
||||
- [X] T050 [US4] Log request details (method, path, status, duration, IP, user_agent) to access.log in internal/middleware/logger.go
|
||||
- [X] T051 [US4] Register requestid and logger middleware in correct order in cmd/api/main.go
|
||||
|
||||
**Checkpoint**: Every request should have unique UUID v4 in header and access.log, with full request details
|
||||
|
||||
@@ -166,17 +166,17 @@
|
||||
|
||||
### Integration Tests for User Story 5
|
||||
|
||||
- [ ] T052 [P] [US5] Integration test for panic recovery in tests/integration/middleware_test.go
|
||||
- [ ] T053 [P] [US5] Test panic logging with stack trace in tests/integration/middleware_test.go
|
||||
- [ ] T054 [P] [US5] Test subsequent requests after panic recovery in tests/integration/middleware_test.go
|
||||
- [X] T052 [P] [US5] Integration test for panic recovery in tests/integration/recover_test.go
|
||||
- [X] T053 [P] [US5] Test panic logging with stack trace in tests/integration/recover_test.go
|
||||
- [X] T054 [P] [US5] Test subsequent requests after panic recovery in tests/integration/recover_test.go
|
||||
|
||||
### Implementation for User Story 5
|
||||
|
||||
- [ ] T055 [US5] Implement custom recover middleware with Zap logging in internal/middleware/recover.go
|
||||
- [ ] T056 [US5] Add stack trace capture to recover middleware in internal/middleware/recover.go
|
||||
- [ ] T057 [US5] Add request ID to panic logs in internal/middleware/recover.go
|
||||
- [ ] T058 [US5] Return unified error response (500, code 1000) on panic in internal/middleware/recover.go
|
||||
- [ ] T059 [US5] Register recover middleware as FIRST middleware in cmd/api/main.go
|
||||
- [X] T055 [US5] Implement custom recover middleware with Zap logging in internal/middleware/recover.go
|
||||
- [X] T056 [US5] Add stack trace capture to recover middleware in internal/middleware/recover.go
|
||||
- [X] T057 [US5] Add request ID to panic logs in internal/middleware/recover.go
|
||||
- [X] T058 [US5] Return unified error response (500, code 1000) on panic in internal/middleware/recover.go
|
||||
- [X] T059 [US5] Register recover middleware as FIRST middleware in cmd/api/main.go
|
||||
- [ ] T060 [US5] Create test panic endpoint for testing in internal/handler/test.go (optional, for quickstart validation)
|
||||
|
||||
**Checkpoint**: Panic in handler should be caught, logged with stack trace, return 500, server continues running
|
||||
@@ -205,19 +205,19 @@
|
||||
|
||||
### Implementation for User Story 6
|
||||
|
||||
- [ ] T069 [US6] Create TokenValidator struct with Redis client in pkg/validator/token.go
|
||||
- [ ] T070 [US6] Implement TokenValidator.Validate() with Redis GET operation in pkg/validator/token.go
|
||||
- [ ] T071 [US6] Add context timeout (50ms) for Redis operations in pkg/validator/token.go
|
||||
- [ ] T072 [US6] Implement Redis availability check (Ping) with fail-closed behavior in pkg/validator/token.go
|
||||
- [ ] T073 [US6] Implement custom keyauth middleware wrapper in internal/middleware/auth.go
|
||||
- [ ] T074 [US6] Configure keyauth with header lookup "token" in internal/middleware/auth.go
|
||||
- [ ] T075 [US6] Add validator callback to keyauth config in internal/middleware/auth.go
|
||||
- [ ] T076 [US6] Store user_id in Fiber Locals after successful validation in internal/middleware/auth.go
|
||||
- [ ] T077 [US6] Implement custom ErrorHandler mapping errors to response codes in internal/middleware/auth.go
|
||||
- [ ] T078 [US6] Add auth failure logging with request ID in internal/middleware/auth.go
|
||||
- [ ] T079 [US6] Register keyauth middleware after logger in cmd/api/main.go
|
||||
- [ ] T080 [US6] Create protected example endpoint (/api/v1/users) in internal/handler/user.go
|
||||
- [ ] T081 [US6] Register protected routes with middleware in cmd/api/main.go
|
||||
- [X] T069 [US6] Create TokenValidator struct with Redis client in pkg/validator/token.go
|
||||
- [X] T070 [US6] Implement TokenValidator.Validate() with Redis GET operation in pkg/validator/token.go
|
||||
- [X] T071 [US6] Add context timeout (50ms) for Redis operations in pkg/validator/token.go
|
||||
- [X] T072 [US6] Implement Redis availability check (Ping) with fail-closed behavior in pkg/validator/token.go
|
||||
- [X] T073 [US6] Implement custom keyauth middleware wrapper in internal/middleware/auth.go
|
||||
- [X] T074 [US6] Configure keyauth with header lookup "token" in internal/middleware/auth.go
|
||||
- [X] T075 [US6] Add validator callback to keyauth config in internal/middleware/auth.go
|
||||
- [X] T076 [US6] Store user_id in Fiber Locals after successful validation in internal/middleware/auth.go
|
||||
- [X] T077 [US6] Implement custom ErrorHandler mapping errors to response codes in internal/middleware/auth.go
|
||||
- [X] T078 [US6] Add auth failure logging with request ID in internal/middleware/auth.go
|
||||
- [X] T079 [US6] Register keyauth middleware after logger in cmd/api/main.go
|
||||
- [X] T080 [US6] Create protected example endpoint (/api/v1/users) in internal/handler/user.go
|
||||
- [X] T081 [US6] Register protected routes with middleware in cmd/api/main.go
|
||||
|
||||
**Checkpoint**: Protected endpoints require valid token, reject invalid/missing tokens with correct error codes
|
||||
|
||||
@@ -237,11 +237,11 @@
|
||||
|
||||
### Implementation for User Story 7
|
||||
|
||||
- [ ] T085 [US7] Implement rate limiter middleware wrapper (COMMENTED by default) in internal/middleware/ratelimit.go
|
||||
- [ ] T086 [US7] Configure limiter with IP-based key generator (c.IP()) in internal/middleware/ratelimit.go
|
||||
- [ ] T087 [US7] Configure limiter with config values (Max, Expiration) in internal/middleware/ratelimit.go
|
||||
- [ ] T088 [US7] Add custom LimitReached handler returning unified error response in internal/middleware/ratelimit.go
|
||||
- [ ] T089 [US7] Add commented middleware registration example in cmd/api/main.go
|
||||
- [X] T085 [US7] Implement rate limiter middleware wrapper (COMMENTED by default) in internal/middleware/ratelimit.go
|
||||
- [X] T086 [US7] Configure limiter with IP-based key generator (c.IP()) in internal/middleware/ratelimit.go
|
||||
- [X] T087 [US7] Configure limiter with config values (Max, Expiration) in internal/middleware/ratelimit.go
|
||||
- [X] T088 [US7] Add custom LimitReached handler returning unified error response in internal/middleware/ratelimit.go
|
||||
- [X] T089 [US7] Add commented middleware registration example in cmd/api/main.go
|
||||
- [ ] T090 [US7] Document rate limiter usage in quickstart.md (how to enable, configure)
|
||||
- [ ] T091 [US7] Add rate limiter configuration examples to config files
|
||||
|
||||
|
||||
533
tests/integration/middleware_test.go
Normal file
533
tests/integration/middleware_test.go
Normal file
@@ -0,0 +1,533 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/logger"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestRequestIDMiddleware 测试 RequestID 中间件生成 UUID v4(T043)
|
||||
func TestRequestIDMiddleware(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
// 配置 requestid 中间件使用 UUID v4
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return uuid.NewString()
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
requestID := c.Locals(constants.ContextKeyRequestID)
|
||||
return c.JSON(fiber.Map{
|
||||
"request_id": requestID,
|
||||
})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
}{
|
||||
{name: "request 1"},
|
||||
{name: "request 2"},
|
||||
{name: "request 3"},
|
||||
}
|
||||
|
||||
seenIDs := make(map[string]bool)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 验证响应头包含 X-Request-ID
|
||||
requestID := resp.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
t.Error("X-Request-ID header should not be empty")
|
||||
}
|
||||
|
||||
// 验证是 UUID v4 格式
|
||||
if _, err := uuid.Parse(requestID); err != nil {
|
||||
t.Errorf("X-Request-ID is not a valid UUID: %s, error: %v", requestID, err)
|
||||
}
|
||||
|
||||
// 验证 UUID 是唯一的
|
||||
if seenIDs[requestID] {
|
||||
t.Errorf("Request ID %s is not unique", requestID)
|
||||
}
|
||||
seenIDs[requestID] = true
|
||||
|
||||
t.Logf("Request ID: %s", requestID)
|
||||
})
|
||||
}
|
||||
|
||||
// 验证生成了多个不同的 ID
|
||||
if len(seenIDs) != len(tests) {
|
||||
t.Errorf("Expected %d unique request IDs, got %d", len(tests), len(seenIDs))
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoggerMiddleware 测试 Logger 中间件记录访问日志(T044)
|
||||
func TestLoggerMiddleware(t *testing.T) {
|
||||
// 创建临时目录用于日志
|
||||
tempDir := t.TempDir()
|
||||
accessLogFile := filepath.Join(tempDir, "access.log")
|
||||
|
||||
// 初始化日志系统
|
||||
err := logger.InitLoggers("info", false,
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
logger.LogRotationConfig{
|
||||
Filename: accessLogFile,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
// 创建应用
|
||||
app := fiber.New()
|
||||
|
||||
// 注册中间件
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return uuid.NewString()
|
||||
},
|
||||
}))
|
||||
app.Use(logger.Middleware())
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
})
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(201)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "GET request",
|
||||
method: "GET",
|
||||
path: "/test",
|
||||
expectedStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "POST request",
|
||||
method: "POST",
|
||||
path: "/test",
|
||||
expectedStatus: 201,
|
||||
},
|
||||
{
|
||||
name: "GET with query params",
|
||||
method: "GET",
|
||||
path: "/test?foo=bar",
|
||||
expectedStatus: 200,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tt.method, tt.path, nil)
|
||||
req.Header.Set("User-Agent", "test-agent/1.0")
|
||||
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 刷新日志缓冲区
|
||||
logger.Sync()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证访问日志文件存在且有内容
|
||||
content, err := os.ReadFile(accessLogFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read access log: %v", err)
|
||||
}
|
||||
|
||||
if len(content) == 0 {
|
||||
t.Error("Access log should not be empty")
|
||||
}
|
||||
|
||||
logContent := string(content)
|
||||
t.Logf("Access log content:\n%s", logContent)
|
||||
|
||||
// 验证日志包含必要的字段
|
||||
requiredFields := []string{
|
||||
"method",
|
||||
"path",
|
||||
"status",
|
||||
"duration_ms",
|
||||
"request_id",
|
||||
"ip",
|
||||
"user_agent",
|
||||
}
|
||||
|
||||
for _, field := range requiredFields {
|
||||
if !strings.Contains(logContent, field) {
|
||||
t.Errorf("Access log should contain field '%s'", field)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证记录了所有请求
|
||||
lines := strings.Split(strings.TrimSpace(logContent), "\n")
|
||||
if len(lines) < len(tests) {
|
||||
t.Errorf("Expected at least %d log entries, got %d", len(tests), len(lines))
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestIDPropagation 测试 Request ID 在中间件链中传播(T045)
|
||||
func TestRequestIDPropagation(t *testing.T) {
|
||||
// 创建临时目录用于日志
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 初始化日志系统
|
||||
err := logger.InitLoggers("info", false,
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
// 创建应用
|
||||
app := fiber.New()
|
||||
|
||||
var capturedRequestID string
|
||||
|
||||
// 1. RequestID 中间件(第一个)
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return uuid.NewString()
|
||||
},
|
||||
}))
|
||||
|
||||
// 2. Logger 中间件(第二个)
|
||||
app.Use(logger.Middleware())
|
||||
|
||||
// 3. 自定义中间件验证 request ID 是否可访问
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
requestID := c.Locals(constants.ContextKeyRequestID)
|
||||
if requestID == nil {
|
||||
t.Error("Request ID should be available in middleware chain")
|
||||
}
|
||||
if rid, ok := requestID.(string); ok {
|
||||
capturedRequestID = rid
|
||||
}
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
// 在 handler 中也验证 request ID
|
||||
requestID := c.Locals(constants.ContextKeyRequestID)
|
||||
if requestID == nil {
|
||||
return c.Status(500).SendString("Request ID not found in handler")
|
||||
}
|
||||
|
||||
return c.JSON(fiber.Map{
|
||||
"request_id": requestID,
|
||||
"message": "Request ID propagated successfully",
|
||||
})
|
||||
})
|
||||
|
||||
// 执行请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 验证响应
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 验证响应头中的 Request ID
|
||||
headerRequestID := resp.Header.Get("X-Request-ID")
|
||||
if headerRequestID == "" {
|
||||
t.Error("X-Request-ID header should be set")
|
||||
}
|
||||
|
||||
// 验证中间件捕获的 Request ID 与响应头一致
|
||||
if capturedRequestID != headerRequestID {
|
||||
t.Errorf("Request ID mismatch: middleware=%s, header=%s", capturedRequestID, headerRequestID)
|
||||
}
|
||||
|
||||
// 验证响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), headerRequestID) {
|
||||
t.Errorf("Response body should contain request ID %s", headerRequestID)
|
||||
}
|
||||
|
||||
t.Logf("Request ID successfully propagated: %s", headerRequestID)
|
||||
}
|
||||
|
||||
// TestMiddlewareOrder 测试中间件执行顺序(T045)
|
||||
func TestMiddlewareOrder(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
executionOrder := []string{}
|
||||
|
||||
// 中间件 1: RequestID
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
executionOrder = append(executionOrder, "requestid-start")
|
||||
c.Locals(constants.ContextKeyRequestID, uuid.NewString())
|
||||
err := c.Next()
|
||||
executionOrder = append(executionOrder, "requestid-end")
|
||||
return err
|
||||
})
|
||||
|
||||
// 中间件 2: Logger
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
executionOrder = append(executionOrder, "logger-start")
|
||||
// 验证 Request ID 已经设置
|
||||
if c.Locals(constants.ContextKeyRequestID) == nil {
|
||||
t.Error("Request ID should be set before logger middleware")
|
||||
}
|
||||
err := c.Next()
|
||||
executionOrder = append(executionOrder, "logger-end")
|
||||
return err
|
||||
})
|
||||
|
||||
// 中间件 3: Custom
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
executionOrder = append(executionOrder, "custom-start")
|
||||
err := c.Next()
|
||||
executionOrder = append(executionOrder, "custom-end")
|
||||
return err
|
||||
})
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
executionOrder = append(executionOrder, "handler")
|
||||
return c.SendString("ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// 验证执行顺序
|
||||
expectedOrder := []string{
|
||||
"requestid-start",
|
||||
"logger-start",
|
||||
"custom-start",
|
||||
"handler",
|
||||
"custom-end",
|
||||
"logger-end",
|
||||
"requestid-end",
|
||||
}
|
||||
|
||||
if len(executionOrder) != len(expectedOrder) {
|
||||
t.Errorf("Expected %d execution steps, got %d", len(expectedOrder), len(executionOrder))
|
||||
}
|
||||
|
||||
for i, expected := range expectedOrder {
|
||||
if i >= len(executionOrder) {
|
||||
t.Errorf("Missing execution step at index %d: expected '%s'", i, expected)
|
||||
continue
|
||||
}
|
||||
if executionOrder[i] != expected {
|
||||
t.Errorf("Execution order mismatch at index %d: expected '%s', got '%s'", i, expected, executionOrder[i])
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Middleware execution order: %v", executionOrder)
|
||||
}
|
||||
|
||||
// TestLoggerMiddlewareWithUserID 测试 Logger 中间件记录用户 ID(T044)
|
||||
func TestLoggerMiddlewareWithUserID(t *testing.T) {
|
||||
// 创建临时目录用于日志
|
||||
tempDir := t.TempDir()
|
||||
accessLogFile := filepath.Join(tempDir, "access-userid.log")
|
||||
|
||||
// 初始化日志系统
|
||||
err := logger.InitLoggers("info", false,
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
logger.LogRotationConfig{
|
||||
Filename: accessLogFile,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
// 创建应用
|
||||
app := fiber.New()
|
||||
|
||||
// 注册中间件
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return uuid.NewString()
|
||||
},
|
||||
}))
|
||||
|
||||
// 模拟 auth 中间件设置 user_id
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Locals(constants.ContextKeyUserID, "user_12345")
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
app.Use(logger.Middleware())
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
})
|
||||
|
||||
// 执行请求
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// 刷新日志缓冲区
|
||||
logger.Sync()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证访问日志包含 user_id
|
||||
content, err := os.ReadFile(accessLogFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read access log: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(content)
|
||||
if !strings.Contains(logContent, "user_12345") {
|
||||
t.Error("Access log should contain user_id 'user_12345'")
|
||||
}
|
||||
|
||||
t.Logf("Access log with user_id:\n%s", logContent)
|
||||
}
|
||||
|
||||
// TestConcurrentRequests 测试并发请求的 Request ID 唯一性(T043)
|
||||
func TestConcurrentRequests(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return uuid.NewString()
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
// 模拟一些处理时间
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
requestID := c.Locals(constants.ContextKeyRequestID)
|
||||
return c.JSON(fiber.Map{
|
||||
"request_id": requestID,
|
||||
})
|
||||
})
|
||||
|
||||
// 并发发送多个请求
|
||||
const numRequests = 50
|
||||
requestIDs := make(chan string, numRequests)
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func() {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
requestID := resp.Header.Get("X-Request-ID")
|
||||
requestIDs <- requestID
|
||||
errors <- nil
|
||||
}()
|
||||
}
|
||||
|
||||
// 收集所有结果
|
||||
seenIDs := make(map[string]bool)
|
||||
for i := 0; i < numRequests; i++ {
|
||||
if err := <-errors; err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
requestID := <-requestIDs
|
||||
|
||||
if requestID == "" {
|
||||
t.Error("Request ID should not be empty")
|
||||
}
|
||||
|
||||
if seenIDs[requestID] {
|
||||
t.Errorf("Duplicate request ID found: %s", requestID)
|
||||
}
|
||||
seenIDs[requestID] = true
|
||||
}
|
||||
|
||||
// 验证所有 ID 都是唯一的
|
||||
if len(seenIDs) != numRequests {
|
||||
t.Errorf("Expected %d unique request IDs, got %d", numRequests, len(seenIDs))
|
||||
}
|
||||
|
||||
t.Logf("Successfully generated %d unique request IDs concurrently", len(seenIDs))
|
||||
}
|
||||
618
tests/integration/recover_test.go
Normal file
618
tests/integration/recover_test.go
Normal file
@@ -0,0 +1,618 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/internal/middleware"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/errors"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/logger"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/response"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestPanicRecovery 测试 panic 恢复功能(T052)
|
||||
func TestPanicRecovery(t *testing.T) {
|
||||
// 创建临时目录用于日志
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 初始化日志系统
|
||||
err := logger.InitLoggers("info", false,
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-panic.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-panic.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
appLogger := logger.GetAppLogger()
|
||||
|
||||
// 创建应用
|
||||
app := fiber.New()
|
||||
|
||||
// 注册中间件(recover 必须第一个)
|
||||
app.Use(middleware.Recover(appLogger))
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return uuid.NewString()
|
||||
},
|
||||
}))
|
||||
|
||||
// 创建会 panic 的 handler
|
||||
app.Get("/panic", func(c *fiber.Ctx) error {
|
||||
panic("intentional panic for testing")
|
||||
})
|
||||
|
||||
// 创建正常的 handler
|
||||
app.Get("/ok", func(c *fiber.Ctx) error {
|
||||
return c.SendString("ok")
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
shouldPanic bool
|
||||
expectedStatus int
|
||||
expectedCode int
|
||||
}{
|
||||
{
|
||||
name: "panic endpoint returns 500",
|
||||
path: "/panic",
|
||||
shouldPanic: true,
|
||||
expectedStatus: 500,
|
||||
expectedCode: errors.CodeInternalError,
|
||||
},
|
||||
{
|
||||
name: "normal endpoint works after panic",
|
||||
path: "/ok",
|
||||
shouldPanic: false,
|
||||
expectedStatus: 200,
|
||||
expectedCode: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", tt.path, nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 验证 HTTP 状态码
|
||||
if resp.StatusCode != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
if tt.shouldPanic {
|
||||
// panic 应该返回统一错误响应
|
||||
var response response.Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Code != tt.expectedCode {
|
||||
t.Errorf("Expected code %d, got %d", tt.expectedCode, response.Code)
|
||||
}
|
||||
|
||||
if response.Data != nil {
|
||||
t.Error("Error response data should be nil")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPanicLogging 测试 panic 日志记录和堆栈跟踪(T053)
|
||||
func TestPanicLogging(t *testing.T) {
|
||||
// 创建临时目录用于日志
|
||||
tempDir := t.TempDir()
|
||||
appLogFile := filepath.Join(tempDir, "app-panic-log.log")
|
||||
|
||||
// 初始化日志系统
|
||||
err := logger.InitLoggers("info", false,
|
||||
logger.LogRotationConfig{
|
||||
Filename: appLogFile,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-panic-log.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
appLogger := logger.GetAppLogger()
|
||||
|
||||
// 创建应用
|
||||
app := fiber.New()
|
||||
|
||||
// 注册中间件
|
||||
app.Use(middleware.Recover(appLogger))
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return uuid.NewString()
|
||||
},
|
||||
}))
|
||||
|
||||
// 创建不同类型的 panic
|
||||
app.Get("/panic-string", func(c *fiber.Ctx) error {
|
||||
panic("string panic message")
|
||||
})
|
||||
|
||||
app.Get("/panic-error", func(c *fiber.Ctx) error {
|
||||
panic(fiber.NewError(500, "error panic message"))
|
||||
})
|
||||
|
||||
app.Get("/panic-struct", func(c *fiber.Ctx) error {
|
||||
panic(struct{ Message string }{"struct panic message"})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expectedInLog []string
|
||||
unexpectedInLog []string
|
||||
}{
|
||||
{
|
||||
name: "string panic logs correctly",
|
||||
path: "/panic-string",
|
||||
expectedInLog: []string{
|
||||
"Panic 已恢复",
|
||||
"string panic message",
|
||||
"stack",
|
||||
"request_id",
|
||||
"method",
|
||||
"path",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error panic logs correctly",
|
||||
path: "/panic-error",
|
||||
expectedInLog: []string{
|
||||
"Panic 已恢复",
|
||||
"error panic message",
|
||||
"stack",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "struct panic logs correctly",
|
||||
path: "/panic-struct",
|
||||
expectedInLog: []string{
|
||||
"Panic 已恢复",
|
||||
"stack",
|
||||
"Message",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 执行会 panic 的请求
|
||||
req := httptest.NewRequest("GET", tt.path, nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// 刷新日志缓冲区
|
||||
logger.Sync()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 读取日志内容
|
||||
logContent, err := os.ReadFile(appLogFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read app log: %v", err)
|
||||
}
|
||||
|
||||
content := string(logContent)
|
||||
|
||||
// 验证日志包含预期内容
|
||||
for _, expected := range tt.expectedInLog {
|
||||
if !strings.Contains(content, expected) {
|
||||
t.Errorf("Log should contain '%s'", expected)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证日志不包含意外内容
|
||||
for _, unexpected := range tt.unexpectedInLog {
|
||||
if strings.Contains(content, unexpected) {
|
||||
t.Errorf("Log should NOT contain '%s'", unexpected)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证堆栈跟踪包含文件和行号
|
||||
if !strings.Contains(content, "recover_test.go") {
|
||||
t.Error("Stack trace should contain source file name")
|
||||
}
|
||||
|
||||
t.Logf("Panic log contains stack trace: %v", strings.Contains(content, "stack"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubsequentRequestsAfterPanic 测试 panic 后后续请求正常处理(T054)
|
||||
func TestSubsequentRequestsAfterPanic(t *testing.T) {
|
||||
// 创建临时目录用于日志
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 初始化日志系统
|
||||
err := logger.InitLoggers("info", false,
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
appLogger := logger.GetAppLogger()
|
||||
|
||||
// 创建应用
|
||||
app := fiber.New()
|
||||
|
||||
// 注册中间件
|
||||
app.Use(middleware.Recover(appLogger))
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return uuid.NewString()
|
||||
},
|
||||
}))
|
||||
|
||||
callCount := 0
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
callCount++
|
||||
// 第 1、3、5 次调用会 panic
|
||||
if callCount%2 == 1 {
|
||||
panic("test panic")
|
||||
}
|
||||
// 第 2、4、6 次调用正常返回
|
||||
return c.JSON(fiber.Map{
|
||||
"call_count": callCount,
|
||||
"status": "ok",
|
||||
})
|
||||
})
|
||||
|
||||
// 执行多次请求,验证 panic 不影响后续请求
|
||||
for i := 1; i <= 6; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request %d failed: %v", i, err)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if i%2 == 1 {
|
||||
// 奇数次应该返回 500
|
||||
if resp.StatusCode != 500 {
|
||||
t.Errorf("Request %d: expected status 500, got %d", i, resp.StatusCode)
|
||||
}
|
||||
} else {
|
||||
// 偶数次应该返回 200
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("Request %d: expected status 200, got %d", i, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 验证响应内容
|
||||
var response map[string]any
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Request %d: failed to unmarshal response: %v", i, err)
|
||||
}
|
||||
|
||||
if status, ok := response["status"].(string); !ok || status != "ok" {
|
||||
t.Errorf("Request %d: expected status 'ok', got %v", i, response["status"])
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Request %d completed: status=%d", i, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 验证所有 6 次调用都执行了
|
||||
if callCount != 6 {
|
||||
t.Errorf("Expected 6 calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPanicWithRequestID 测试 panic 日志包含 Request ID(T053)
|
||||
func TestPanicWithRequestID(t *testing.T) {
|
||||
// 创建临时目录用于日志
|
||||
tempDir := t.TempDir()
|
||||
appLogFile := filepath.Join(tempDir, "app-panic-reqid.log")
|
||||
|
||||
// 初始化日志系统
|
||||
err := logger.InitLoggers("info", false,
|
||||
logger.LogRotationConfig{
|
||||
Filename: appLogFile,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-panic-reqid.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
appLogger := logger.GetAppLogger()
|
||||
|
||||
// 创建应用
|
||||
app := fiber.New()
|
||||
|
||||
// 注册中间件(顺序重要)
|
||||
app.Use(middleware.Recover(appLogger))
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return uuid.NewString()
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/panic", func(c *fiber.Ctx) error {
|
||||
panic("test panic with request id")
|
||||
})
|
||||
|
||||
// 执行请求
|
||||
req := httptest.NewRequest("GET", "/panic", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// 获取 Request ID
|
||||
requestID := resp.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
t.Error("X-Request-ID header should be set even after panic")
|
||||
}
|
||||
|
||||
// 刷新日志缓冲区
|
||||
logger.Sync()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 读取日志内容
|
||||
logContent, err := os.ReadFile(appLogFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read app log: %v", err)
|
||||
}
|
||||
|
||||
content := string(logContent)
|
||||
|
||||
// 验证日志包含 Request ID
|
||||
if !strings.Contains(content, requestID) {
|
||||
t.Errorf("Panic log should contain request ID '%s'", requestID)
|
||||
}
|
||||
|
||||
// 验证日志包含关键字段
|
||||
requiredFields := []string{
|
||||
"request_id",
|
||||
"method",
|
||||
"path",
|
||||
"panic",
|
||||
"stack",
|
||||
}
|
||||
|
||||
for _, field := range requiredFields {
|
||||
if !strings.Contains(content, field) {
|
||||
t.Errorf("Panic log should contain field '%s'", field)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Panic log successfully includes Request ID: %s", requestID)
|
||||
}
|
||||
|
||||
// TestConcurrentPanics 测试并发 panic 处理(T054)
|
||||
func TestConcurrentPanics(t *testing.T) {
|
||||
// 创建临时目录用于日志
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 初始化日志系统
|
||||
err := logger.InitLoggers("info", false,
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
appLogger := logger.GetAppLogger()
|
||||
|
||||
// 创建应用
|
||||
app := fiber.New()
|
||||
|
||||
// 注册中间件
|
||||
app.Use(middleware.Recover(appLogger))
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return uuid.NewString()
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/panic", func(c *fiber.Ctx) error {
|
||||
panic("concurrent panic test")
|
||||
})
|
||||
|
||||
// 并发发送多个会 panic 的请求
|
||||
const numRequests = 20
|
||||
errors := make(chan error, numRequests)
|
||||
statuses := make(chan int, numRequests)
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func() {
|
||||
req := httptest.NewRequest("GET", "/panic", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
statuses <- 0
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
statuses <- resp.StatusCode
|
||||
errors <- nil
|
||||
}()
|
||||
}
|
||||
|
||||
// 收集所有结果
|
||||
for i := 0; i < numRequests; i++ {
|
||||
if err := <-errors; err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
status := <-statuses
|
||||
if status != 500 {
|
||||
t.Errorf("Expected status 500, got %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Successfully handled %d concurrent panics", numRequests)
|
||||
}
|
||||
|
||||
// TestRecoverMiddlewareOrder 测试 Recover 中间件必须在第一个(T052)
|
||||
func TestRecoverMiddlewareOrder(t *testing.T) {
|
||||
// 创建临时目录用于日志
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 初始化日志系统
|
||||
err := logger.InitLoggers("info", false,
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
logger.LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize loggers: %v", err)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
appLogger := logger.GetAppLogger()
|
||||
|
||||
// 创建应用
|
||||
app := fiber.New()
|
||||
|
||||
// 正确的顺序:Recover → RequestID → Logger
|
||||
app.Use(middleware.Recover(appLogger))
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
return uuid.NewString()
|
||||
},
|
||||
}))
|
||||
app.Use(logger.Middleware())
|
||||
|
||||
app.Get("/panic", func(c *fiber.Ctx) error {
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
// 执行请求
|
||||
req := httptest.NewRequest("GET", "/panic", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 验证请求被正确处理(返回 500 而不是崩溃)
|
||||
if resp.StatusCode != 500 {
|
||||
t.Errorf("Expected status 500, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 验证仍然有 Request ID(说明 RequestID 中间件在 Recover 之后执行)
|
||||
requestID := resp.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
t.Error("X-Request-ID should be set even after panic")
|
||||
}
|
||||
|
||||
// 解析响应,验证返回了统一错误格式
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var response response.Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Code != errors.CodeInternalError {
|
||||
t.Errorf("Expected code %d, got %d", errors.CodeInternalError, response.Code)
|
||||
}
|
||||
|
||||
t.Logf("Recover middleware correctly placed first, handled panic gracefully")
|
||||
}
|
||||
Reference in New Issue
Block a user