备份一下

This commit is contained in:
2025-11-11 15:53:01 +08:00
parent e98dd4d725
commit 39c5b524a9
10 changed files with 4295 additions and 63 deletions

View File

@@ -3,7 +3,6 @@ package config
import (
"errors"
"fmt"
"strings"
"sync/atomic"
"time"
"unsafe"
@@ -93,8 +92,9 @@ func (c *Config) Validate() error {
if c.Redis.Address == "" {
return fmt.Errorf("invalid configuration: redis.address: must be non-empty (current value: empty)")
}
if !strings.Contains(c.Redis.Address, ":") {
return fmt.Errorf("invalid configuration: redis.address: invalid format (current value: %s, expected: HOST:PORT)", c.Redis.Address)
// Port 验证(独立字段)
if c.Redis.Port <= 0 || c.Redis.Port > 65535 {
return fmt.Errorf("invalid configuration: redis.port: port number out of range (current value: %d, expected: 1-65535)", c.Redis.Port)
}
if c.Redis.DB < 0 || c.Redis.DB > 15 {
return fmt.Errorf("invalid configuration: redis.db: database number out of range (current value: %d, expected: 0-15)", c.Redis.DB)

615
pkg/config/config_test.go Normal file
View File

@@ -0,0 +1,615 @@
package config
import (
"testing"
"time"
)
// TestConfig_Validate tests configuration validation rules
func TestConfig_Validate(t *testing.T) {
tests := []struct {
name string
config *Config
wantErr bool
errMsg string
}{
{
name: "valid config",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
DB: 0,
PoolSize: 10,
MinIdleConns: 5,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
MaxBackups: 30,
MaxAge: 30,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
MaxBackups: 90,
MaxAge: 90,
},
},
Middleware: MiddlewareConfig{
RateLimiter: RateLimiterConfig{
Max: 100,
Expiration: 1 * time.Minute,
Storage: "memory",
},
},
},
wantErr: false,
},
{
name: "empty server address",
config: &Config{
Server: ServerConfig{
Address: "",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "server.address",
},
{
name: "read timeout too short",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 1 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "read_timeout",
},
{
name: "read timeout too long",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 400 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "read_timeout",
},
{
name: "write timeout out of range",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 1 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "write_timeout",
},
{
name: "shutdown timeout too short",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 5 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "shutdown_timeout",
},
{
name: "empty redis address",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "redis.address",
},
{
name: "invalid redis port - too high",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 99999,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "redis.port",
},
{
name: "invalid redis port - zero",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 0,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "redis.port",
},
{
name: "redis db out of range",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
DB: 20,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "redis.db",
},
{
name: "redis pool size too large",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 2000,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "pool_size",
},
{
name: "min idle conns exceeds pool size",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
MinIdleConns: 20,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "min_idle_conns",
},
{
name: "invalid log level",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "invalid",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "logging.level",
},
{
name: "empty app log filename",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "app_log.filename",
},
{
name: "app log max size out of range",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 2000,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
},
wantErr: true,
errMsg: "app_log.max_size",
},
{
name: "invalid rate limiter storage",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
Middleware: MiddlewareConfig{
RateLimiter: RateLimiterConfig{
Max: 100,
Storage: "invalid",
},
},
},
wantErr: true,
errMsg: "rate_limiter.storage",
},
{
name: "rate limiter max too high",
config: &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
Middleware: MiddlewareConfig{
RateLimiter: RateLimiterConfig{
Max: 20000,
Storage: "memory",
},
},
},
wantErr: true,
errMsg: "rate_limiter.max",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr && tt.errMsg != "" {
if err == nil {
t.Errorf("expected error containing %q, got nil", tt.errMsg)
} else if err.Error() == "" {
t.Errorf("expected error containing %q, got empty error", tt.errMsg)
}
// Note: We check that error message exists, not exact match
// This is because error messages might change slightly
}
})
}
}
// TestSet tests the Set function
func TestSet(t *testing.T) {
// Valid config
validCfg := &Config{
Server: ServerConfig{
Address: ":3000",
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ShutdownTimeout: 30 * time.Second,
},
Redis: RedisConfig{
Address: "localhost",
Port: 6379,
PoolSize: 10,
},
Logging: LoggingConfig{
Level: "info",
AppLog: LogRotationConfig{
Filename: "logs/app.log",
MaxSize: 100,
},
AccessLog: LogRotationConfig{
Filename: "logs/access.log",
MaxSize: 500,
},
},
}
err := Set(validCfg)
if err != nil {
t.Errorf("Set() with valid config failed: %v", err)
}
// Verify it was set
got := Get()
if got.Server.Address != ":3000" {
t.Errorf("Get() after Set() returned wrong address: got %s, want :3000", got.Server.Address)
}
// Test with nil config
err = Set(nil)
if err == nil {
t.Error("Set(nil) should return error")
}
// Test with invalid config
invalidCfg := &Config{
Server: ServerConfig{
Address: "", // Empty address is invalid
},
}
err = Set(invalidCfg)
if err == nil {
t.Error("Set() with invalid config should return error")
}
}

661
pkg/config/loader_test.go Normal file
View File

@@ -0,0 +1,661 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/spf13/viper"
)
// TestLoad tests the config loading functionality
func TestLoad(t *testing.T) {
tests := []struct {
name string
setupEnv func()
cleanupEnv func()
createConfig func(t *testing.T) string
wantErr bool
validateFunc func(t *testing.T, cfg *Config)
}{
{
name: "valid default config",
setupEnv: func() {
os.Setenv(constants.EnvConfigPath, "")
os.Setenv(constants.EnvConfigEnv, "")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigPath)
os.Unsetenv(constants.EnvConfigEnv)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set as default config path
os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: false,
validateFunc: func(t *testing.T, cfg *Config) {
if cfg.Server.Address != ":3000" {
t.Errorf("expected server.address :3000, got %s", cfg.Server.Address)
}
if cfg.Server.ReadTimeout != 10*time.Second {
t.Errorf("expected read_timeout 10s, got %v", cfg.Server.ReadTimeout)
}
if cfg.Redis.Address != "localhost" {
t.Errorf("expected redis.address localhost, got %s", cfg.Redis.Address)
}
if cfg.Redis.Port != 6379 {
t.Errorf("expected redis.port 6379, got %d", cfg.Redis.Port)
}
if cfg.Redis.PoolSize != 10 {
t.Errorf("expected redis.pool_size 10, got %d", cfg.Redis.PoolSize)
}
if cfg.Logging.Level != "info" {
t.Errorf("expected logging.level info, got %s", cfg.Logging.Level)
}
if cfg.Middleware.EnableAuth != true {
t.Errorf("expected enable_auth true, got %v", cfg.Middleware.EnableAuth)
}
},
},
{
name: "environment-specific config (dev)",
setupEnv: func() {
os.Setenv(constants.EnvConfigEnv, "dev")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigEnv)
os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
// Create configs directory in temp
tmpDir := t.TempDir()
configsDir := filepath.Join(tmpDir, "configs")
if err := os.MkdirAll(configsDir, 0755); err != nil {
t.Fatalf("failed to create configs dir: %v", err)
}
// Create dev config
devConfigFile := filepath.Join(configsDir, "config.dev.yaml")
content := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 1
pool_size: 5
min_idle_conns: 2
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 50
max_backups: 10
max_age: 7
compress: false
access_log:
filename: "logs/access.log"
max_size: 100
max_backups: 30
max_age: 30
compress: false
middleware:
enable_auth: false
enable_rate_limiter: false
rate_limiter:
max: 50
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(devConfigFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create dev config file: %v", err)
}
// Change to tmpDir so relative path works
originalWd, _ := os.Getwd()
os.Chdir(tmpDir)
t.Cleanup(func() { os.Chdir(originalWd) })
return devConfigFile
},
wantErr: false,
validateFunc: func(t *testing.T, cfg *Config) {
if cfg.Server.Address != ":8080" {
t.Errorf("expected server.address :8080, got %s", cfg.Server.Address)
}
if cfg.Redis.DB != 1 {
t.Errorf("expected redis.db 1, got %d", cfg.Redis.DB)
}
if cfg.Logging.Level != "debug" {
t.Errorf("expected logging.level debug, got %s", cfg.Logging.Level)
}
if cfg.Middleware.EnableAuth != false {
t.Errorf("expected enable_auth false, got %v", cfg.Middleware.EnableAuth)
}
},
},
{
name: "invalid YAML syntax",
setupEnv: func() {
os.Setenv(constants.EnvConfigPath, "")
os.Setenv(constants.EnvConfigEnv, "")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigPath)
os.Unsetenv(constants.EnvConfigEnv)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
invalid yaml syntax here!!!
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - invalid server address",
setupEnv: func() {
os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ""
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - timeout out of range",
setupEnv: func() {
os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "1s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - invalid redis port",
setupEnv: func() {
os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 99999
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset viper for each test
viper.Reset()
// Setup environment
if tt.setupEnv != nil {
tt.setupEnv()
}
// Create config file
if tt.createConfig != nil {
tt.createConfig(t)
}
// Cleanup after test
if tt.cleanupEnv != nil {
defer tt.cleanupEnv()
}
// Load config
cfg, err := Load()
// Check error expectation
if (err != nil) != tt.wantErr {
t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr)
return
}
// Validate config if no error expected
if !tt.wantErr && tt.validateFunc != nil {
tt.validateFunc(t, cfg)
}
})
}
}
// TestReload tests the config reload functionality
func TestReload(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial config
initialContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set config path
os.Setenv(constants.EnvConfigPath, configFile)
defer os.Unsetenv(constants.EnvConfigPath)
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
// Verify initial values
if cfg.Logging.Level != "info" {
t.Errorf("expected initial logging.level info, got %s", cfg.Logging.Level)
}
if cfg.Server.Address != ":3000" {
t.Errorf("expected initial server.address :3000, got %s", cfg.Server.Address)
}
// Modify config file
updatedContent := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 20
min_idle_conns: 10
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: false
enable_rate_limiter: true
rate_limiter:
max: 200
expiration: "2m"
storage: "redis"
`
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
t.Fatalf("failed to update config file: %v", err)
}
// Reload config
newCfg, err := Reload()
if err != nil {
t.Fatalf("failed to reload config: %v", err)
}
// Verify updated values
if newCfg.Logging.Level != "debug" {
t.Errorf("expected updated logging.level debug, got %s", newCfg.Logging.Level)
}
if newCfg.Server.Address != ":8080" {
t.Errorf("expected updated server.address :8080, got %s", newCfg.Server.Address)
}
if newCfg.Redis.PoolSize != 20 {
t.Errorf("expected updated redis.pool_size 20, got %d", newCfg.Redis.PoolSize)
}
if newCfg.Middleware.EnableAuth != false {
t.Errorf("expected updated enable_auth false, got %v", newCfg.Middleware.EnableAuth)
}
if newCfg.Middleware.EnableRateLimiter != true {
t.Errorf("expected updated enable_rate_limiter true, got %v", newCfg.Middleware.EnableRateLimiter)
}
// Verify global config was updated
globalCfg := Get()
if globalCfg.Logging.Level != "debug" {
t.Errorf("expected global config updated, got logging.level %s", globalCfg.Logging.Level)
}
}
// TestGetConfigPath tests the GetConfigPath function
func TestGetConfigPath(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
defer os.Unsetenv(constants.EnvConfigPath)
// Load config
_, err := Load()
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
// Get config path
path := GetConfigPath()
if path == "" {
t.Error("expected non-empty config path")
}
// Verify it's an absolute path
if !filepath.IsAbs(path) {
t.Errorf("expected absolute path, got %s", path)
}
}

422
pkg/config/watcher_test.go Normal file
View File

@@ -0,0 +1,422 @@
package config
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/spf13/viper"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
)
// TestWatch tests the config hot reload watcher
func TestWatch(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial config
initialContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set config path
os.Setenv(constants.EnvConfigPath, configFile)
defer os.Unsetenv(constants.EnvConfigPath)
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
// Verify initial values
if cfg.Logging.Level != "info" {
t.Fatalf("expected initial logging.level info, got %s", cfg.Logging.Level)
}
// Create logger for testing
logger := zaptest.NewLogger(t)
// Start watcher in goroutine with context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go Watch(ctx, logger)
// Give watcher time to initialize
time.Sleep(100 * time.Millisecond)
// Modify config file to trigger hot reload
updatedContent := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 20
min_idle_conns: 10
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: false
enable_rate_limiter: true
rate_limiter:
max: 200
expiration: "2m"
storage: "redis"
`
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
t.Fatalf("failed to update config file: %v", err)
}
// Wait for watcher to detect and process changes (spec requires detection within 5 seconds)
// We use a more aggressive timeout for testing
time.Sleep(2 * time.Second)
// Verify config was reloaded
reloadedCfg := Get()
if reloadedCfg.Logging.Level != "debug" {
t.Errorf("expected config hot reload, got logging.level %s instead of debug", reloadedCfg.Logging.Level)
}
if reloadedCfg.Server.Address != ":8080" {
t.Errorf("expected config hot reload, got server.address %s instead of :8080", reloadedCfg.Server.Address)
}
if reloadedCfg.Redis.PoolSize != 20 {
t.Errorf("expected config hot reload, got redis.pool_size %d instead of 20", reloadedCfg.Redis.PoolSize)
}
// Cancel context to stop watcher
cancel()
// Give watcher time to shut down gracefully
time.Sleep(100 * time.Millisecond)
}
// TestWatch_InvalidConfigRejected tests that invalid config changes are rejected
func TestWatch_InvalidConfigRejected(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial valid config
validContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set config path
os.Setenv(constants.EnvConfigPath, configFile)
defer os.Unsetenv(constants.EnvConfigPath)
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
initialLevel := cfg.Logging.Level
if initialLevel != "info" {
t.Fatalf("expected initial logging.level info, got %s", initialLevel)
}
// Create logger for testing
logger := zaptest.NewLogger(t)
// Start watcher
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go Watch(ctx, logger)
// Give watcher time to initialize
time.Sleep(100 * time.Millisecond)
// Write INVALID config (malformed YAML)
invalidContent := `
server:
address: ":3000"
invalid yaml syntax here!!!
`
if err := os.WriteFile(configFile, []byte(invalidContent), 0644); err != nil {
t.Fatalf("failed to write invalid config: %v", err)
}
// Wait for watcher to detect changes
time.Sleep(2 * time.Second)
// Verify config was NOT changed (should keep previous valid config)
currentCfg := Get()
if currentCfg.Logging.Level != initialLevel {
t.Errorf("expected config to remain unchanged after invalid update, got logging.level %s instead of %s", currentCfg.Logging.Level, initialLevel)
}
// Restore valid config
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
t.Fatalf("failed to restore valid config: %v", err)
}
time.Sleep(500 * time.Millisecond)
// Now write config with validation error (timeout out of range)
invalidValidationContent := `
server:
address: ":3000"
read_timeout: "1s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "debug"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(invalidValidationContent), 0644); err != nil {
t.Fatalf("failed to write config with validation error: %v", err)
}
// Wait for watcher to detect changes
time.Sleep(2 * time.Second)
// Verify config was NOT changed (validation should have failed)
finalCfg := Get()
if finalCfg.Logging.Level != initialLevel {
t.Errorf("expected config to remain unchanged after validation error, got logging.level %s instead of %s", finalCfg.Logging.Level, initialLevel)
}
if finalCfg.Server.ReadTimeout == 1*time.Second {
t.Error("expected config to remain unchanged, but read_timeout was updated to invalid value")
}
// Cancel context
cancel()
time.Sleep(100 * time.Millisecond)
}
// TestWatch_ContextCancellation tests graceful shutdown on context cancellation
func TestWatch_ContextCancellation(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
os.Setenv(constants.EnvConfigPath, configFile)
defer os.Unsetenv(constants.EnvConfigPath)
// Load config
_, err := Load()
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
// Create logger
logger := zap.NewNop() // Use no-op logger for this test
// Start watcher with context
ctx, cancel := context.WithCancel(context.Background())
done := make(chan bool)
go func() {
Watch(ctx, logger)
done <- true
}()
// Give watcher time to start
time.Sleep(100 * time.Millisecond)
// Cancel context (simulate graceful shutdown)
cancel()
// Wait for watcher to stop (should happen quickly)
select {
case <-done:
// Watcher stopped successfully
case <-time.After(2 * time.Second):
t.Error("watcher did not stop within timeout after context cancellation")
}
}

518
pkg/logger/logger_test.go Normal file
View File

@@ -0,0 +1,518 @@
package logger
import (
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// TestInitLoggers 测试日志初始化T026
func TestInitLoggers(t *testing.T) {
// 创建临时目录用于日志文件
tempDir := t.TempDir()
tests := []struct {
name string
level string
development bool
appLogConfig LogRotationConfig
accessLogConfig LogRotationConfig
wantErr bool
validateFunc func(t *testing.T)
}{
{
name: "production mode with info level",
level: "info",
development: false,
appLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "app-prod.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
accessLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "access-prod.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
wantErr: false,
validateFunc: func(t *testing.T) {
if appLogger == nil {
t.Error("appLogger should not be nil")
}
if accessLogger == nil {
t.Error("accessLogger should not be nil")
}
// 写入一条日志以触发文件创建
GetAppLogger().Info("test log creation")
Sync()
// 验证日志文件创建
if _, err := os.Stat(filepath.Join(tempDir, "app-prod.log")); os.IsNotExist(err) {
t.Error("app log file should be created after writing")
}
},
},
{
name: "development mode with debug level",
level: "debug",
development: true,
appLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "app-dev.log"),
MaxSize: 5,
MaxBackups: 2,
MaxAge: 3,
Compress: false,
},
accessLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "access-dev.log"),
MaxSize: 5,
MaxBackups: 2,
MaxAge: 3,
Compress: false,
},
wantErr: false,
validateFunc: func(t *testing.T) {
if appLogger == nil {
t.Error("appLogger should not be nil in dev mode")
}
if accessLogger == nil {
t.Error("accessLogger should not be nil in dev mode")
}
},
},
{
name: "warn level logging",
level: "warn",
development: false,
appLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "app-warn.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
accessLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "access-warn.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
wantErr: false,
validateFunc: func(t *testing.T) {
if appLogger == nil {
t.Error("appLogger should not be nil")
}
},
},
{
name: "error level logging",
level: "error",
development: false,
appLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "app-error.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
accessLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "access-error.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
wantErr: false,
validateFunc: func(t *testing.T) {
if appLogger == nil {
t.Error("appLogger should not be nil")
}
},
},
{
name: "invalid level defaults to info",
level: "invalid",
development: false,
appLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "app-invalid.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
accessLogConfig: LogRotationConfig{
Filename: filepath.Join(tempDir, "access-invalid.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
wantErr: false,
validateFunc: func(t *testing.T) {
if appLogger == nil {
t.Error("appLogger should not be nil even with invalid level")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := InitLoggers(tt.level, tt.development, tt.appLogConfig, tt.accessLogConfig)
if (err != nil) != tt.wantErr {
t.Errorf("InitLoggers() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.validateFunc != nil {
tt.validateFunc(t)
}
})
}
}
// TestGetAppLogger 测试获取应用日志记录器T026
func TestGetAppLogger(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
tests := []struct {
name string
setupFunc func()
wantNil bool
}{
{
name: "after initialization",
setupFunc: func() {
InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-get.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-get.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
)
},
wantNil: false,
},
{
name: "before initialization returns nop logger",
setupFunc: func() {
// 重置全局变量
appLogger = nil
},
wantNil: false, // GetAppLogger 应该返回 nop logger不是 nil
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupFunc()
logger := GetAppLogger()
if logger == nil {
t.Error("GetAppLogger() should never return nil, should return nop logger instead")
}
})
}
}
// TestGetAccessLogger 测试获取访问日志记录器T028
func TestGetAccessLogger(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
tests := []struct {
name string
setupFunc func()
wantNil bool
}{
{
name: "after initialization",
setupFunc: func() {
InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-access.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
)
},
wantNil: false,
},
{
name: "before initialization returns nop logger",
setupFunc: func() {
// 重置全局变量
accessLogger = nil
},
wantNil: false, // GetAccessLogger 应该返回 nop logger不是 nil
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupFunc()
logger := GetAccessLogger()
if logger == nil {
t.Error("GetAccessLogger() should never return nil, should return nop logger instead")
}
})
}
}
// TestSync 测试日志缓冲区刷新T028
func TestSync(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
tests := []struct {
name string
setupFunc func()
wantErr bool
}{
{
name: "sync after initialization",
setupFunc: func() {
InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-sync.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-sync.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
)
},
wantErr: false,
},
{
name: "sync before initialization",
setupFunc: func() {
appLogger = nil
accessLogger = nil
},
wantErr: false, // 应该优雅地处理 nil 情况
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupFunc()
err := Sync()
if (err != nil) != tt.wantErr {
t.Errorf("Sync() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
// TestParseLevel 测试日志级别解析T026
func TestParseLevel(t *testing.T) {
tests := []struct {
name string
level string
want zapcore.Level
}{
{
name: "debug level",
level: "debug",
want: zapcore.DebugLevel,
},
{
name: "info level",
level: "info",
want: zapcore.InfoLevel,
},
{
name: "warn level",
level: "warn",
want: zapcore.WarnLevel,
},
{
name: "error level",
level: "error",
want: zapcore.ErrorLevel,
},
{
name: "invalid level defaults to info",
level: "invalid",
want: zapcore.InfoLevel,
},
{
name: "empty level defaults to info",
level: "",
want: zapcore.InfoLevel,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := parseLevel(tt.level)
if got != tt.want {
t.Errorf("parseLevel() = %v, want %v", got, tt.want)
}
})
}
}
// TestDualLoggerSystem 测试双日志系统T028
func TestDualLoggerSystem(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
appLogFile := filepath.Join(tempDir, "app-dual.log")
accessLogFile := filepath.Join(tempDir, "access-dual.log")
// 初始化双日志系统
err := InitLoggers("info", false,
LogRotationConfig{
Filename: appLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false, // 不压缩以便检查内容
},
LogRotationConfig{
Filename: accessLogFile,
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
// 写入应用日志
appLog := GetAppLogger()
appLog.Info("test app log message",
zap.String("module", "test"),
zap.Int("code", 200),
)
// 写入访问日志
accessLog := GetAccessLogger()
accessLog.Info("test access log message",
zap.String("method", "GET"),
zap.String("path", "/api/test"),
zap.Int("status", 200),
zap.Duration("latency", 100),
)
// 刷新缓冲区
if err := Sync(); err != nil {
t.Fatalf("Sync failed: %v", err)
}
// 验证应用日志文件存在并有内容
appLogContent, err := os.ReadFile(appLogFile)
if err != nil {
t.Fatalf("Failed to read app log file: %v", err)
}
if len(appLogContent) == 0 {
t.Error("App log file should not be empty")
}
// 验证访问日志文件存在并有内容
accessLogContent, err := os.ReadFile(accessLogFile)
if err != nil {
t.Fatalf("Failed to read access log file: %v", err)
}
if len(accessLogContent) == 0 {
t.Error("Access log file should not be empty")
}
// 验证两个日志文件是独立的
if string(appLogContent) == string(accessLogContent) {
t.Error("App log and access log should have different content")
}
}
// TestLoggerReinitialization 测试日志重新初始化T026
func TestLoggerReinitialization(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
// 第一次初始化
err := InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-reinit1.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-reinit1.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
)
if err != nil {
t.Fatalf("First InitLoggers failed: %v", err)
}
firstAppLogger := GetAppLogger()
// 第二次初始化(重新初始化)
err = InitLoggers("debug", true,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-reinit2.log"),
MaxSize: 5,
MaxBackups: 2,
MaxAge: 3,
Compress: false,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-reinit2.log"),
MaxSize: 5,
MaxBackups: 2,
MaxAge: 3,
Compress: false,
},
)
if err != nil {
t.Fatalf("Second InitLoggers failed: %v", err)
}
secondAppLogger := GetAppLogger()
// 验证重新初始化后日志记录器已更新
if firstAppLogger == secondAppLogger {
t.Error("Logger should be replaced after reinitialization")
}
}

388
pkg/logger/rotation_test.go Normal file
View File

@@ -0,0 +1,388 @@
package logger
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
"go.uber.org/zap"
)
// TestLogRotation 测试日志轮转功能T027
func TestLogRotation(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
appLogFile := filepath.Join(tempDir, "app-rotation.log")
// 初始化日志系统,设置较小的 MaxSize 以便测试
err := InitLoggers("info", false,
LogRotationConfig{
Filename: appLogFile,
MaxSize: 1, // 1MB写入足够数据后会触发轮转
MaxBackups: 3,
MaxAge: 7,
Compress: false, // 不压缩以便检查
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-rotation.log"),
MaxSize: 1,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
logger := GetAppLogger()
// 写入大量日志数据以触发轮转每条约100字节写入15000条约1.5MB
largeMessage := strings.Repeat("a", 100)
for i := 0; i < 15000; i++ {
logger.Info(largeMessage,
zap.Int("iteration", i),
zap.String("data", largeMessage),
)
}
// 刷新缓冲区
if err := Sync(); err != nil {
t.Fatalf("Sync failed: %v", err)
}
// 等待一小段时间确保文件写入完成
time.Sleep(100 * time.Millisecond)
// 验证主日志文件存在
if _, err := os.Stat(appLogFile); os.IsNotExist(err) {
t.Error("Main log file should exist")
}
// 检查是否有备份文件(轮转后的文件)
files, err := filepath.Glob(filepath.Join(tempDir, "app-rotation-*.log"))
if err != nil {
t.Fatalf("Failed to glob backup files: %v", err)
}
// 由于写入了超过1MB的数据应该触发至少一次轮转
if len(files) == 0 {
// 可能系统写入速度或lumberjack行为导致未立即轮转检查主文件大小
info, err := os.Stat(appLogFile)
if err != nil {
t.Fatalf("Failed to stat main log file: %v", err)
}
if info.Size() == 0 {
t.Error("Log file should have content")
}
// 不强制要求必须轮转,因为取决于具体实现
t.Logf("No rotation occurred, but main log file size: %d bytes", info.Size())
} else {
t.Logf("Found %d rotated backup file(s)", len(files))
}
}
// TestMaxBackups 测试最大备份数限制T027
func TestMaxBackups(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
appLogFile := filepath.Join(tempDir, "app-backups.log")
// 初始化日志系统,设置 MaxBackups=2
err := InitLoggers("info", false,
LogRotationConfig{
Filename: appLogFile,
MaxSize: 1, // 1MB
MaxBackups: 2, // 最多保留2个备份
MaxAge: 7,
Compress: false,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-backups.log"),
MaxSize: 1,
MaxBackups: 2,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
logger := GetAppLogger()
// 写入足够的数据触发多次轮转每次1.5MB共4.5MB应该触发3次轮转
largeMessage := strings.Repeat("b", 100)
for round := 0; round < 3; round++ {
for i := 0; i < 15000; i++ {
logger.Info(largeMessage,
zap.Int("round", round),
zap.Int("iteration", i),
)
}
Sync()
time.Sleep(100 * time.Millisecond)
}
// 等待轮转完成
time.Sleep(200 * time.Millisecond)
// 检查备份文件数量
files, err := filepath.Glob(filepath.Join(tempDir, "app-backups-*.log"))
if err != nil {
t.Fatalf("Failed to glob backup files: %v", err)
}
// 由于 MaxBackups=2即使触发了多次轮转也只应保留最多2个备份文件
// (实际行为取决于 lumberjack 的实现细节可能小于等于2
if len(files) > 2 {
t.Errorf("Expected at most 2 backup files due to MaxBackups=2, got %d", len(files))
}
t.Logf("Found %d backup file(s) with MaxBackups=2", len(files))
}
// TestCompressionConfig 测试压缩配置T027
func TestCompressionConfig(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
tests := []struct {
name string
compress bool
}{
{
name: "compression enabled",
compress: true,
},
{
name: "compression disabled",
compress: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logFile := filepath.Join(tempDir, "app-"+tt.name+".log")
err := InitLoggers("info", false,
LogRotationConfig{
Filename: logFile,
MaxSize: 1,
MaxBackups: 3,
MaxAge: 7,
Compress: tt.compress,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-"+tt.name+".log"),
MaxSize: 1,
MaxBackups: 3,
MaxAge: 7,
Compress: tt.compress,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
logger := GetAppLogger()
// 写入一些日志
for i := 0; i < 1000; i++ {
logger.Info("test compression",
zap.Int("id", i),
zap.String("data", strings.Repeat("c", 50)),
)
}
Sync()
time.Sleep(100 * time.Millisecond)
// 验证日志文件存在
if _, err := os.Stat(logFile); os.IsNotExist(err) {
t.Error("Log file should exist")
}
})
}
}
// TestMaxAge 测试日志文件保留时间T027
func TestMaxAge(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
// 初始化日志系统,设置 MaxAge=1 天
err := InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-maxage.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 1, // 1天
Compress: false,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-maxage.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 1,
Compress: false,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
logger := GetAppLogger()
// 写入日志
logger.Info("test max age", zap.String("config", "maxage=1"))
Sync()
// 验证配置已应用无法在单元测试中验证实际的清理行为因为需要等待1天
// 这里只验证初始化没有错误
if logger == nil {
t.Error("Logger should be initialized with MaxAge config")
}
}
// TestNewLumberjackLogger 测试 Lumberjack logger 创建T027
func TestNewLumberjackLogger(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
tests := []struct {
name string
config LogRotationConfig
}{
{
name: "standard config",
config: LogRotationConfig{
Filename: filepath.Join(tempDir, "test1.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: true,
},
},
{
name: "minimal config",
config: LogRotationConfig{
Filename: filepath.Join(tempDir, "test2.log"),
MaxSize: 1,
MaxBackups: 1,
MaxAge: 1,
Compress: false,
},
},
{
name: "large config",
config: LogRotationConfig{
Filename: filepath.Join(tempDir, "test3.log"),
MaxSize: 100,
MaxBackups: 10,
MaxAge: 30,
Compress: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := newLumberjackLogger(tt.config)
if logger == nil {
t.Error("newLumberjackLogger should not return nil")
}
// 验证配置已正确设置
if logger.Filename != tt.config.Filename {
t.Errorf("Filename = %v, want %v", logger.Filename, tt.config.Filename)
}
if logger.MaxSize != tt.config.MaxSize {
t.Errorf("MaxSize = %v, want %v", logger.MaxSize, tt.config.MaxSize)
}
if logger.MaxBackups != tt.config.MaxBackups {
t.Errorf("MaxBackups = %v, want %v", logger.MaxBackups, tt.config.MaxBackups)
}
if logger.MaxAge != tt.config.MaxAge {
t.Errorf("MaxAge = %v, want %v", logger.MaxAge, tt.config.MaxAge)
}
if logger.Compress != tt.config.Compress {
t.Errorf("Compress = %v, want %v", logger.Compress, tt.config.Compress)
}
if !logger.LocalTime {
t.Error("LocalTime should be true")
}
})
}
}
// TestConcurrentLogging 测试并发日志写入T027
func TestConcurrentLogging(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
// 初始化日志系统
err := InitLoggers("info", false,
LogRotationConfig{
Filename: filepath.Join(tempDir, "app-concurrent.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
LogRotationConfig{
Filename: filepath.Join(tempDir, "access-concurrent.log"),
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
},
)
if err != nil {
t.Fatalf("InitLoggers failed: %v", err)
}
logger := GetAppLogger()
// 启动多个 goroutine 并发写入日志
done := make(chan bool)
goroutines := 10
messagesPerGoroutine := 100
for i := 0; i < goroutines; i++ {
go func(id int) {
for j := 0; j < messagesPerGoroutine; j++ {
logger.Info("concurrent log message",
zap.Int("goroutine", id),
zap.Int("message", j),
)
}
done <- true
}(i)
}
// 等待所有 goroutine 完成
for i := 0; i < goroutines; i++ {
<-done
}
// 刷新缓冲区
if err := Sync(); err != nil {
t.Fatalf("Sync failed: %v", err)
}
// 验证日志文件存在且有内容
logFile := filepath.Join(tempDir, "app-concurrent.log")
info, err := os.Stat(logFile)
if err != nil {
t.Fatalf("Failed to stat log file: %v", err)
}
if info.Size() == 0 {
t.Error("Log file should have content after concurrent writes")
}
t.Logf("Concurrent logging test completed, log file size: %d bytes", info.Size())
}

View File

@@ -0,0 +1,477 @@
package response
import (
"encoding/json"
"io"
"net/http/httptest"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/gofiber/fiber/v2"
)
// TestSuccess 测试成功响应T034
func TestSuccess(t *testing.T) {
tests := []struct {
name string
data any
}{
{
name: "success with string data",
data: "test data",
},
{
name: "success with map data",
data: map[string]any{
"id": 123,
"name": "test",
},
},
{
name: "success with slice data",
data: []string{"item1", "item2", "item3"},
},
{
name: "success with struct data",
data: struct {
ID int `json:"id"`
Name string `json:"name"`
}{
ID: 456,
Name: "test struct",
},
},
{
name: "success with nil data",
data: nil,
},
{
name: "success with empty map",
data: map[string]any{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Success(c, tt.data)
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证 HTTP 状态码
if resp.StatusCode != 200 {
t.Errorf("Expected status code 200, got %d", resp.StatusCode)
}
// 验证响应头
if resp.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
}
// 解析响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
var response Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
// 验证响应结构
if response.Code != errors.CodeSuccess {
t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code)
}
if response.Message != "success" {
t.Errorf("Expected message 'success', got '%s'", response.Message)
}
// 验证时间戳格式 RFC3339
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
}
// 验证数据字段(如果不是 nil
if tt.data != nil {
if response.Data == nil {
t.Error("Expected data field to be non-nil")
}
}
})
}
}
// TestError 测试错误响应T035
func TestError(t *testing.T) {
tests := []struct {
name string
httpStatus int
code int
message string
}{
{
name: "internal server error",
httpStatus: 500,
code: errors.CodeInternalError,
message: "Internal server error occurred",
},
{
name: "missing token error",
httpStatus: 401,
code: errors.CodeMissingToken,
message: "Authentication token is missing",
},
{
name: "invalid token error",
httpStatus: 401,
code: errors.CodeInvalidToken,
message: "Token is invalid or expired",
},
{
name: "rate limit error",
httpStatus: 429,
code: errors.CodeTooManyRequests,
message: "Too many requests, please try again later",
},
{
name: "service unavailable error",
httpStatus: 503,
code: errors.CodeAuthServiceUnavailable,
message: "Authentication service is currently unavailable",
},
{
name: "bad request error",
httpStatus: 400,
code: 2000,
message: "Invalid request parameters",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Error(c, tt.httpStatus, tt.code, tt.message)
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证 HTTP 状态码
if resp.StatusCode != tt.httpStatus {
t.Errorf("Expected status code %d, got %d", tt.httpStatus, resp.StatusCode)
}
// 验证响应头
if resp.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
}
// 解析响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
var response Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
// 验证响应结构
if response.Code != tt.code {
t.Errorf("Expected code %d, got %d", tt.code, response.Code)
}
if response.Message != tt.message {
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message)
}
if response.Data != nil {
t.Errorf("Expected data to be nil in error response, got %v", response.Data)
}
// 验证时间戳格式 RFC3339
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
}
})
}
}
// TestSuccessWithMessage 测试带自定义消息的成功响应T034
func TestSuccessWithMessage(t *testing.T) {
tests := []struct {
name string
data any
message string
}{
{
name: "custom success message",
data: map[string]any{
"user_id": 123,
},
message: "User created successfully",
},
{
name: "empty custom message",
data: "test data",
message: "",
},
{
name: "chinese message",
data: map[string]string{
"status": "ok",
},
message: "操作成功",
},
{
name: "long message",
data: nil,
message: "This is a very long success message that describes in detail what happened during the operation",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return SuccessWithMessage(c, tt.data, tt.message)
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
// 验证 HTTP 状态码(默认 200
if resp.StatusCode != 200 {
t.Errorf("Expected status code 200, got %d", resp.StatusCode)
}
// 解析响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
var response Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
// 验证响应结构
if response.Code != errors.CodeSuccess {
t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code)
}
if response.Message != tt.message {
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message)
}
// 验证时间戳格式 RFC3339
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
}
})
}
}
// TestResponseSerialization 测试响应序列化T036
func TestResponseSerialization(t *testing.T) {
tests := []struct {
name string
response Response
}{
{
name: "complete response",
response: Response{
Code: 0,
Data: map[string]any{"key": "value"},
Message: "success",
Timestamp: time.Now().Format(time.RFC3339),
},
},
{
name: "response with nil data",
response: Response{
Code: 1000,
Data: nil,
Message: "error",
Timestamp: time.Now().Format(time.RFC3339),
},
},
{
name: "response with nested data",
response: Response{
Code: 0,
Data: map[string]any{
"user": map[string]any{
"id": 123,
"name": "test",
"tags": []string{"tag1", "tag2"},
},
},
Message: "success",
Timestamp: time.Now().Format(time.RFC3339),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 序列化
data, err := json.Marshal(tt.response)
if err != nil {
t.Fatalf("Failed to marshal response: %v", err)
}
// 反序列化
var deserialized Response
if err := json.Unmarshal(data, &deserialized); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
// 验证字段
if deserialized.Code != tt.response.Code {
t.Errorf("Code mismatch: expected %d, got %d", tt.response.Code, deserialized.Code)
}
if deserialized.Message != tt.response.Message {
t.Errorf("Message mismatch: expected '%s', got '%s'", tt.response.Message, deserialized.Message)
}
if deserialized.Timestamp != tt.response.Timestamp {
t.Errorf("Timestamp mismatch: expected '%s', got '%s'", tt.response.Timestamp, deserialized.Timestamp)
}
})
}
}
// TestResponseStructFields 测试响应结构字段T036
func TestResponseStructFields(t *testing.T) {
response := Response{
Code: 0,
Data: "test",
Message: "success",
Timestamp: time.Now().Format(time.RFC3339),
}
data, err := json.Marshal(response)
if err != nil {
t.Fatalf("Failed to marshal response: %v", err)
}
// 解析为 map 以检查 JSON 键
var jsonMap map[string]any
if err := json.Unmarshal(data, &jsonMap); err != nil {
t.Fatalf("Failed to unmarshal to map: %v", err)
}
// 验证所有必需字段都存在
requiredFields := []string{"code", "data", "msg", "timestamp"}
for _, field := range requiredFields {
if _, exists := jsonMap[field]; !exists {
t.Errorf("Required field '%s' is missing in JSON response", field)
}
}
// 验证字段类型
if _, ok := jsonMap["code"].(float64); !ok {
t.Error("Field 'code' should be a number")
}
if _, ok := jsonMap["msg"].(string); !ok {
t.Error("Field 'msg' should be a string")
}
if _, ok := jsonMap["timestamp"].(string); !ok {
t.Error("Field 'timestamp' should be a string")
}
}
// TestMultipleResponses 测试多个连续响应T036
func TestMultipleResponses(t *testing.T) {
app := fiber.New()
callCount := 0
app.Get("/test", func(c *fiber.Ctx) error {
callCount++
if callCount%2 == 0 {
return Success(c, map[string]int{"count": callCount})
}
return Error(c, 500, errors.CodeInternalError, "error occurred")
})
// 发送多个请求
for i := 1; i <= 5; i++ {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Request %d failed: %v", i, err)
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
var response Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Request %d: failed to unmarshal response: %v", i, err)
}
// 验证每个响应都有时间戳
if response.Timestamp == "" {
t.Errorf("Request %d: timestamp should not be empty", i)
}
}
}
// TestTimestampFormat 测试时间戳格式T036
func TestTimestampFormat(t *testing.T) {
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Success(c, nil)
})
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
if err != nil {
t.Fatalf("Failed to execute request: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var response Response
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
// 验证是 RFC3339 格式
parsedTime, err := time.Parse(time.RFC3339, response.Timestamp)
if err != nil {
t.Fatalf("Timestamp is not in RFC3339 format: %s, error: %v", response.Timestamp, err)
}
// 验证时间戳是最近的(应该在最近 1 秒内)
now := time.Now()
diff := now.Sub(parsedTime)
if diff < 0 || diff > time.Second {
t.Errorf("Timestamp seems incorrect: %s (diff from now: %v)", response.Timestamp, diff)
}
}