备份一下
This commit is contained in:
@@ -3,7 +3,6 @@ package config
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
@@ -93,8 +92,9 @@ func (c *Config) Validate() error {
|
||||
if c.Redis.Address == "" {
|
||||
return fmt.Errorf("invalid configuration: redis.address: must be non-empty (current value: empty)")
|
||||
}
|
||||
if !strings.Contains(c.Redis.Address, ":") {
|
||||
return fmt.Errorf("invalid configuration: redis.address: invalid format (current value: %s, expected: HOST:PORT)", c.Redis.Address)
|
||||
// Port 验证(独立字段)
|
||||
if c.Redis.Port <= 0 || c.Redis.Port > 65535 {
|
||||
return fmt.Errorf("invalid configuration: redis.port: port number out of range (current value: %d, expected: 1-65535)", c.Redis.Port)
|
||||
}
|
||||
if c.Redis.DB < 0 || c.Redis.DB > 15 {
|
||||
return fmt.Errorf("invalid configuration: redis.db: database number out of range (current value: %d, expected: 0-15)", c.Redis.DB)
|
||||
|
||||
615
pkg/config/config_test.go
Normal file
615
pkg/config/config_test.go
Normal file
@@ -0,0 +1,615 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestConfig_Validate tests configuration validation rules
|
||||
func TestConfig_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
DB: 0,
|
||||
PoolSize: 10,
|
||||
MinIdleConns: 5,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
MaxBackups: 30,
|
||||
MaxAge: 30,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
MaxBackups: 90,
|
||||
MaxAge: 90,
|
||||
},
|
||||
},
|
||||
Middleware: MiddlewareConfig{
|
||||
RateLimiter: RateLimiterConfig{
|
||||
Max: 100,
|
||||
Expiration: 1 * time.Minute,
|
||||
Storage: "memory",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty server address",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: "",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "server.address",
|
||||
},
|
||||
{
|
||||
name: "read timeout too short",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "read_timeout",
|
||||
},
|
||||
{
|
||||
name: "read timeout too long",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 400 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "read_timeout",
|
||||
},
|
||||
{
|
||||
name: "write timeout out of range",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "write_timeout",
|
||||
},
|
||||
{
|
||||
name: "shutdown timeout too short",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "shutdown_timeout",
|
||||
},
|
||||
{
|
||||
name: "empty redis address",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "redis.address",
|
||||
},
|
||||
{
|
||||
name: "invalid redis port - too high",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 99999,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "redis.port",
|
||||
},
|
||||
{
|
||||
name: "invalid redis port - zero",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 0,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "redis.port",
|
||||
},
|
||||
{
|
||||
name: "redis db out of range",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
DB: 20,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "redis.db",
|
||||
},
|
||||
{
|
||||
name: "redis pool size too large",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 2000,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "pool_size",
|
||||
},
|
||||
{
|
||||
name: "min idle conns exceeds pool size",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
MinIdleConns: 20,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "min_idle_conns",
|
||||
},
|
||||
{
|
||||
name: "invalid log level",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "invalid",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "logging.level",
|
||||
},
|
||||
{
|
||||
name: "empty app log filename",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "app_log.filename",
|
||||
},
|
||||
{
|
||||
name: "app log max size out of range",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 2000,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "app_log.max_size",
|
||||
},
|
||||
{
|
||||
name: "invalid rate limiter storage",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
Middleware: MiddlewareConfig{
|
||||
RateLimiter: RateLimiterConfig{
|
||||
Max: 100,
|
||||
Storage: "invalid",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "rate_limiter.storage",
|
||||
},
|
||||
{
|
||||
name: "rate limiter max too high",
|
||||
config: &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
Middleware: MiddlewareConfig{
|
||||
RateLimiter: RateLimiterConfig{
|
||||
Max: 20000,
|
||||
Storage: "memory",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "rate_limiter.max",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr && tt.errMsg != "" {
|
||||
if err == nil {
|
||||
t.Errorf("expected error containing %q, got nil", tt.errMsg)
|
||||
} else if err.Error() == "" {
|
||||
t.Errorf("expected error containing %q, got empty error", tt.errMsg)
|
||||
}
|
||||
// Note: We check that error message exists, not exact match
|
||||
// This is because error messages might change slightly
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSet tests the Set function
|
||||
func TestSet(t *testing.T) {
|
||||
// Valid config
|
||||
validCfg := &Config{
|
||||
Server: ServerConfig{
|
||||
Address: ":3000",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ShutdownTimeout: 30 * time.Second,
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Address: "localhost",
|
||||
Port: 6379,
|
||||
PoolSize: 10,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
AppLog: LogRotationConfig{
|
||||
Filename: "logs/app.log",
|
||||
MaxSize: 100,
|
||||
},
|
||||
AccessLog: LogRotationConfig{
|
||||
Filename: "logs/access.log",
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := Set(validCfg)
|
||||
if err != nil {
|
||||
t.Errorf("Set() with valid config failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was set
|
||||
got := Get()
|
||||
if got.Server.Address != ":3000" {
|
||||
t.Errorf("Get() after Set() returned wrong address: got %s, want :3000", got.Server.Address)
|
||||
}
|
||||
|
||||
// Test with nil config
|
||||
err = Set(nil)
|
||||
if err == nil {
|
||||
t.Error("Set(nil) should return error")
|
||||
}
|
||||
|
||||
// Test with invalid config
|
||||
invalidCfg := &Config{
|
||||
Server: ServerConfig{
|
||||
Address: "", // Empty address is invalid
|
||||
},
|
||||
}
|
||||
|
||||
err = Set(invalidCfg)
|
||||
if err == nil {
|
||||
t.Error("Set() with invalid config should return error")
|
||||
}
|
||||
}
|
||||
661
pkg/config/loader_test.go
Normal file
661
pkg/config/loader_test.go
Normal file
@@ -0,0 +1,661 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// TestLoad tests the config loading functionality
|
||||
func TestLoad(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupEnv func()
|
||||
cleanupEnv func()
|
||||
createConfig func(t *testing.T) string
|
||||
wantErr bool
|
||||
validateFunc func(t *testing.T, cfg *Config)
|
||||
}{
|
||||
{
|
||||
name: "valid default config",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigPath, "")
|
||||
os.Setenv(constants.EnvConfigEnv, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
os.Unsetenv(constants.EnvConfigEnv)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
// Set as default config path
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T, cfg *Config) {
|
||||
if cfg.Server.Address != ":3000" {
|
||||
t.Errorf("expected server.address :3000, got %s", cfg.Server.Address)
|
||||
}
|
||||
if cfg.Server.ReadTimeout != 10*time.Second {
|
||||
t.Errorf("expected read_timeout 10s, got %v", cfg.Server.ReadTimeout)
|
||||
}
|
||||
if cfg.Redis.Address != "localhost" {
|
||||
t.Errorf("expected redis.address localhost, got %s", cfg.Redis.Address)
|
||||
}
|
||||
if cfg.Redis.Port != 6379 {
|
||||
t.Errorf("expected redis.port 6379, got %d", cfg.Redis.Port)
|
||||
}
|
||||
if cfg.Redis.PoolSize != 10 {
|
||||
t.Errorf("expected redis.pool_size 10, got %d", cfg.Redis.PoolSize)
|
||||
}
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Errorf("expected logging.level info, got %s", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Middleware.EnableAuth != true {
|
||||
t.Errorf("expected enable_auth true, got %v", cfg.Middleware.EnableAuth)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "environment-specific config (dev)",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigEnv, "dev")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigEnv)
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
// Create configs directory in temp
|
||||
tmpDir := t.TempDir()
|
||||
configsDir := filepath.Join(tmpDir, "configs")
|
||||
if err := os.MkdirAll(configsDir, 0755); err != nil {
|
||||
t.Fatalf("failed to create configs dir: %v", err)
|
||||
}
|
||||
|
||||
// Create dev config
|
||||
devConfigFile := filepath.Join(configsDir, "config.dev.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":8080"
|
||||
read_timeout: "15s"
|
||||
write_timeout: "15s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 1
|
||||
pool_size: 5
|
||||
min_idle_conns: 2
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
development: true
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 50
|
||||
max_backups: 10
|
||||
max_age: 7
|
||||
compress: false
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: false
|
||||
|
||||
middleware:
|
||||
enable_auth: false
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 50
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(devConfigFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create dev config file: %v", err)
|
||||
}
|
||||
|
||||
// Change to tmpDir so relative path works
|
||||
originalWd, _ := os.Getwd()
|
||||
os.Chdir(tmpDir)
|
||||
t.Cleanup(func() { os.Chdir(originalWd) })
|
||||
|
||||
return devConfigFile
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T, cfg *Config) {
|
||||
if cfg.Server.Address != ":8080" {
|
||||
t.Errorf("expected server.address :8080, got %s", cfg.Server.Address)
|
||||
}
|
||||
if cfg.Redis.DB != 1 {
|
||||
t.Errorf("expected redis.db 1, got %d", cfg.Redis.DB)
|
||||
}
|
||||
if cfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected logging.level debug, got %s", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Middleware.EnableAuth != false {
|
||||
t.Errorf("expected enable_auth false, got %v", cfg.Middleware.EnableAuth)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid YAML syntax",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigPath, "")
|
||||
os.Setenv(constants.EnvConfigEnv, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
os.Unsetenv(constants.EnvConfigEnv)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
invalid yaml syntax here!!!
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "validation error - invalid server address",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigPath, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ""
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "validation error - timeout out of range",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigPath, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "1s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "validation error - invalid redis port",
|
||||
setupEnv: func() {
|
||||
os.Setenv(constants.EnvConfigPath, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 99999
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset viper for each test
|
||||
viper.Reset()
|
||||
|
||||
// Setup environment
|
||||
if tt.setupEnv != nil {
|
||||
tt.setupEnv()
|
||||
}
|
||||
|
||||
// Create config file
|
||||
if tt.createConfig != nil {
|
||||
tt.createConfig(t)
|
||||
}
|
||||
|
||||
// Cleanup after test
|
||||
if tt.cleanupEnv != nil {
|
||||
defer tt.cleanupEnv()
|
||||
}
|
||||
|
||||
// Load config
|
||||
cfg, err := Load()
|
||||
|
||||
// Check error expectation
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate config if no error expected
|
||||
if !tt.wantErr && tt.validateFunc != nil {
|
||||
tt.validateFunc(t, cfg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReload tests the config reload functionality
|
||||
func TestReload(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Initial config
|
||||
initialContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
// Set config path
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer os.Unsetenv(constants.EnvConfigPath)
|
||||
|
||||
// Load initial config
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load initial config: %v", err)
|
||||
}
|
||||
|
||||
// Verify initial values
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Errorf("expected initial logging.level info, got %s", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Server.Address != ":3000" {
|
||||
t.Errorf("expected initial server.address :3000, got %s", cfg.Server.Address)
|
||||
}
|
||||
|
||||
// Modify config file
|
||||
updatedContent := `
|
||||
server:
|
||||
address: ":8080"
|
||||
read_timeout: "15s"
|
||||
write_timeout: "15s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 20
|
||||
min_idle_conns: 10
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
development: true
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: false
|
||||
enable_rate_limiter: true
|
||||
rate_limiter:
|
||||
max: 200
|
||||
expiration: "2m"
|
||||
storage: "redis"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
|
||||
t.Fatalf("failed to update config file: %v", err)
|
||||
}
|
||||
|
||||
// Reload config
|
||||
newCfg, err := Reload()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload config: %v", err)
|
||||
}
|
||||
|
||||
// Verify updated values
|
||||
if newCfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected updated logging.level debug, got %s", newCfg.Logging.Level)
|
||||
}
|
||||
if newCfg.Server.Address != ":8080" {
|
||||
t.Errorf("expected updated server.address :8080, got %s", newCfg.Server.Address)
|
||||
}
|
||||
if newCfg.Redis.PoolSize != 20 {
|
||||
t.Errorf("expected updated redis.pool_size 20, got %d", newCfg.Redis.PoolSize)
|
||||
}
|
||||
if newCfg.Middleware.EnableAuth != false {
|
||||
t.Errorf("expected updated enable_auth false, got %v", newCfg.Middleware.EnableAuth)
|
||||
}
|
||||
if newCfg.Middleware.EnableRateLimiter != true {
|
||||
t.Errorf("expected updated enable_rate_limiter true, got %v", newCfg.Middleware.EnableRateLimiter)
|
||||
}
|
||||
|
||||
// Verify global config was updated
|
||||
globalCfg := Get()
|
||||
if globalCfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected global config updated, got logging.level %s", globalCfg.Logging.Level)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetConfigPath tests the GetConfigPath function
|
||||
func TestGetConfigPath(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer os.Unsetenv(constants.EnvConfigPath)
|
||||
|
||||
// Load config
|
||||
_, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Get config path
|
||||
path := GetConfigPath()
|
||||
if path == "" {
|
||||
t.Error("expected non-empty config path")
|
||||
}
|
||||
|
||||
// Verify it's an absolute path
|
||||
if !filepath.IsAbs(path) {
|
||||
t.Errorf("expected absolute path, got %s", path)
|
||||
}
|
||||
}
|
||||
422
pkg/config/watcher_test.go
Normal file
422
pkg/config/watcher_test.go
Normal file
@@ -0,0 +1,422 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestWatch tests the config hot reload watcher
|
||||
func TestWatch(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Initial config
|
||||
initialContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
// Set config path
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer os.Unsetenv(constants.EnvConfigPath)
|
||||
|
||||
// Load initial config
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load initial config: %v", err)
|
||||
}
|
||||
|
||||
// Verify initial values
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Fatalf("expected initial logging.level info, got %s", cfg.Logging.Level)
|
||||
}
|
||||
|
||||
// Create logger for testing
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// Start watcher in goroutine with context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go Watch(ctx, logger)
|
||||
|
||||
// Give watcher time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Modify config file to trigger hot reload
|
||||
updatedContent := `
|
||||
server:
|
||||
address: ":8080"
|
||||
read_timeout: "15s"
|
||||
write_timeout: "15s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 20
|
||||
min_idle_conns: 10
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
development: true
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: false
|
||||
enable_rate_limiter: true
|
||||
rate_limiter:
|
||||
max: 200
|
||||
expiration: "2m"
|
||||
storage: "redis"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
|
||||
t.Fatalf("failed to update config file: %v", err)
|
||||
}
|
||||
|
||||
// Wait for watcher to detect and process changes (spec requires detection within 5 seconds)
|
||||
// We use a more aggressive timeout for testing
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify config was reloaded
|
||||
reloadedCfg := Get()
|
||||
if reloadedCfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected config hot reload, got logging.level %s instead of debug", reloadedCfg.Logging.Level)
|
||||
}
|
||||
if reloadedCfg.Server.Address != ":8080" {
|
||||
t.Errorf("expected config hot reload, got server.address %s instead of :8080", reloadedCfg.Server.Address)
|
||||
}
|
||||
if reloadedCfg.Redis.PoolSize != 20 {
|
||||
t.Errorf("expected config hot reload, got redis.pool_size %d instead of 20", reloadedCfg.Redis.PoolSize)
|
||||
}
|
||||
|
||||
// Cancel context to stop watcher
|
||||
cancel()
|
||||
|
||||
// Give watcher time to shut down gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestWatch_InvalidConfigRejected tests that invalid config changes are rejected
|
||||
func TestWatch_InvalidConfigRejected(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Initial valid config
|
||||
validContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
// Set config path
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer os.Unsetenv(constants.EnvConfigPath)
|
||||
|
||||
// Load initial config
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load initial config: %v", err)
|
||||
}
|
||||
|
||||
initialLevel := cfg.Logging.Level
|
||||
if initialLevel != "info" {
|
||||
t.Fatalf("expected initial logging.level info, got %s", initialLevel)
|
||||
}
|
||||
|
||||
// Create logger for testing
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// Start watcher
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go Watch(ctx, logger)
|
||||
|
||||
// Give watcher time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Write INVALID config (malformed YAML)
|
||||
invalidContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
invalid yaml syntax here!!!
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(invalidContent), 0644); err != nil {
|
||||
t.Fatalf("failed to write invalid config: %v", err)
|
||||
}
|
||||
|
||||
// Wait for watcher to detect changes
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify config was NOT changed (should keep previous valid config)
|
||||
currentCfg := Get()
|
||||
if currentCfg.Logging.Level != initialLevel {
|
||||
t.Errorf("expected config to remain unchanged after invalid update, got logging.level %s instead of %s", currentCfg.Logging.Level, initialLevel)
|
||||
}
|
||||
|
||||
// Restore valid config
|
||||
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
|
||||
t.Fatalf("failed to restore valid config: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Now write config with validation error (timeout out of range)
|
||||
invalidValidationContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "1s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(invalidValidationContent), 0644); err != nil {
|
||||
t.Fatalf("failed to write config with validation error: %v", err)
|
||||
}
|
||||
|
||||
// Wait for watcher to detect changes
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify config was NOT changed (validation should have failed)
|
||||
finalCfg := Get()
|
||||
if finalCfg.Logging.Level != initialLevel {
|
||||
t.Errorf("expected config to remain unchanged after validation error, got logging.level %s instead of %s", finalCfg.Logging.Level, initialLevel)
|
||||
}
|
||||
if finalCfg.Server.ReadTimeout == 1*time.Second {
|
||||
t.Error("expected config to remain unchanged, but read_timeout was updated to invalid value")
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestWatch_ContextCancellation tests graceful shutdown on context cancellation
|
||||
func TestWatch_ContextCancellation(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer os.Unsetenv(constants.EnvConfigPath)
|
||||
|
||||
// Load config
|
||||
_, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Create logger
|
||||
logger := zap.NewNop() // Use no-op logger for this test
|
||||
|
||||
// Start watcher with context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
Watch(ctx, logger)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Give watcher time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Cancel context (simulate graceful shutdown)
|
||||
cancel()
|
||||
|
||||
// Wait for watcher to stop (should happen quickly)
|
||||
select {
|
||||
case <-done:
|
||||
// Watcher stopped successfully
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("watcher did not stop within timeout after context cancellation")
|
||||
}
|
||||
}
|
||||
518
pkg/logger/logger_test.go
Normal file
518
pkg/logger/logger_test.go
Normal file
@@ -0,0 +1,518 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// TestInitLoggers 测试日志初始化(T026)
|
||||
func TestInitLoggers(t *testing.T) {
|
||||
// 创建临时目录用于日志文件
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
level string
|
||||
development bool
|
||||
appLogConfig LogRotationConfig
|
||||
accessLogConfig LogRotationConfig
|
||||
wantErr bool
|
||||
validateFunc func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "production mode with info level",
|
||||
level: "info",
|
||||
development: false,
|
||||
appLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-prod.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
accessLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-prod.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T) {
|
||||
if appLogger == nil {
|
||||
t.Error("appLogger should not be nil")
|
||||
}
|
||||
if accessLogger == nil {
|
||||
t.Error("accessLogger should not be nil")
|
||||
}
|
||||
// 写入一条日志以触发文件创建
|
||||
GetAppLogger().Info("test log creation")
|
||||
Sync()
|
||||
// 验证日志文件创建
|
||||
if _, err := os.Stat(filepath.Join(tempDir, "app-prod.log")); os.IsNotExist(err) {
|
||||
t.Error("app log file should be created after writing")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "development mode with debug level",
|
||||
level: "debug",
|
||||
development: true,
|
||||
appLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-dev.log"),
|
||||
MaxSize: 5,
|
||||
MaxBackups: 2,
|
||||
MaxAge: 3,
|
||||
Compress: false,
|
||||
},
|
||||
accessLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-dev.log"),
|
||||
MaxSize: 5,
|
||||
MaxBackups: 2,
|
||||
MaxAge: 3,
|
||||
Compress: false,
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T) {
|
||||
if appLogger == nil {
|
||||
t.Error("appLogger should not be nil in dev mode")
|
||||
}
|
||||
if accessLogger == nil {
|
||||
t.Error("accessLogger should not be nil in dev mode")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "warn level logging",
|
||||
level: "warn",
|
||||
development: false,
|
||||
appLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-warn.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
accessLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-warn.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T) {
|
||||
if appLogger == nil {
|
||||
t.Error("appLogger should not be nil")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error level logging",
|
||||
level: "error",
|
||||
development: false,
|
||||
appLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-error.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
accessLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-error.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T) {
|
||||
if appLogger == nil {
|
||||
t.Error("appLogger should not be nil")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid level defaults to info",
|
||||
level: "invalid",
|
||||
development: false,
|
||||
appLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-invalid.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
accessLogConfig: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-invalid.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T) {
|
||||
if appLogger == nil {
|
||||
t.Error("appLogger should not be nil even with invalid level")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := InitLoggers(tt.level, tt.development, tt.appLogConfig, tt.accessLogConfig)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("InitLoggers() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.validateFunc != nil {
|
||||
tt.validateFunc(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAppLogger 测试获取应用日志记录器(T026)
|
||||
func TestGetAppLogger(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func()
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "after initialization",
|
||||
setupFunc: func() {
|
||||
InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-get.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-get.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
)
|
||||
},
|
||||
wantNil: false,
|
||||
},
|
||||
{
|
||||
name: "before initialization returns nop logger",
|
||||
setupFunc: func() {
|
||||
// 重置全局变量
|
||||
appLogger = nil
|
||||
},
|
||||
wantNil: false, // GetAppLogger 应该返回 nop logger,不是 nil
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupFunc()
|
||||
logger := GetAppLogger()
|
||||
if logger == nil {
|
||||
t.Error("GetAppLogger() should never return nil, should return nop logger instead")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAccessLogger 测试获取访问日志记录器(T028)
|
||||
func TestGetAccessLogger(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func()
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "after initialization",
|
||||
setupFunc: func() {
|
||||
InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-access.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-access.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
)
|
||||
},
|
||||
wantNil: false,
|
||||
},
|
||||
{
|
||||
name: "before initialization returns nop logger",
|
||||
setupFunc: func() {
|
||||
// 重置全局变量
|
||||
accessLogger = nil
|
||||
},
|
||||
wantNil: false, // GetAccessLogger 应该返回 nop logger,不是 nil
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupFunc()
|
||||
logger := GetAccessLogger()
|
||||
if logger == nil {
|
||||
t.Error("GetAccessLogger() should never return nil, should return nop logger instead")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSync 测试日志缓冲区刷新(T028)
|
||||
func TestSync(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func()
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "sync after initialization",
|
||||
setupFunc: func() {
|
||||
InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-sync.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-sync.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
)
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "sync before initialization",
|
||||
setupFunc: func() {
|
||||
appLogger = nil
|
||||
accessLogger = nil
|
||||
},
|
||||
wantErr: false, // 应该优雅地处理 nil 情况
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupFunc()
|
||||
err := Sync()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Sync() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseLevel 测试日志级别解析(T026)
|
||||
func TestParseLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
level string
|
||||
want zapcore.Level
|
||||
}{
|
||||
{
|
||||
name: "debug level",
|
||||
level: "debug",
|
||||
want: zapcore.DebugLevel,
|
||||
},
|
||||
{
|
||||
name: "info level",
|
||||
level: "info",
|
||||
want: zapcore.InfoLevel,
|
||||
},
|
||||
{
|
||||
name: "warn level",
|
||||
level: "warn",
|
||||
want: zapcore.WarnLevel,
|
||||
},
|
||||
{
|
||||
name: "error level",
|
||||
level: "error",
|
||||
want: zapcore.ErrorLevel,
|
||||
},
|
||||
{
|
||||
name: "invalid level defaults to info",
|
||||
level: "invalid",
|
||||
want: zapcore.InfoLevel,
|
||||
},
|
||||
{
|
||||
name: "empty level defaults to info",
|
||||
level: "",
|
||||
want: zapcore.InfoLevel,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := parseLevel(tt.level)
|
||||
if got != tt.want {
|
||||
t.Errorf("parseLevel() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDualLoggerSystem 测试双日志系统(T028)
|
||||
func TestDualLoggerSystem(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
appLogFile := filepath.Join(tempDir, "app-dual.log")
|
||||
accessLogFile := filepath.Join(tempDir, "access-dual.log")
|
||||
|
||||
// 初始化双日志系统
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: appLogFile,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false, // 不压缩以便检查内容
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: accessLogFile,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
// 写入应用日志
|
||||
appLog := GetAppLogger()
|
||||
appLog.Info("test app log message",
|
||||
zap.String("module", "test"),
|
||||
zap.Int("code", 200),
|
||||
)
|
||||
|
||||
// 写入访问日志
|
||||
accessLog := GetAccessLogger()
|
||||
accessLog.Info("test access log message",
|
||||
zap.String("method", "GET"),
|
||||
zap.String("path", "/api/test"),
|
||||
zap.Int("status", 200),
|
||||
zap.Duration("latency", 100),
|
||||
)
|
||||
|
||||
// 刷新缓冲区
|
||||
if err := Sync(); err != nil {
|
||||
t.Fatalf("Sync failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证应用日志文件存在并有内容
|
||||
appLogContent, err := os.ReadFile(appLogFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read app log file: %v", err)
|
||||
}
|
||||
if len(appLogContent) == 0 {
|
||||
t.Error("App log file should not be empty")
|
||||
}
|
||||
|
||||
// 验证访问日志文件存在并有内容
|
||||
accessLogContent, err := os.ReadFile(accessLogFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read access log file: %v", err)
|
||||
}
|
||||
if len(accessLogContent) == 0 {
|
||||
t.Error("Access log file should not be empty")
|
||||
}
|
||||
|
||||
// 验证两个日志文件是独立的
|
||||
if string(appLogContent) == string(accessLogContent) {
|
||||
t.Error("App log and access log should have different content")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoggerReinitialization 测试日志重新初始化(T026)
|
||||
func TestLoggerReinitialization(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 第一次初始化
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-reinit1.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-reinit1.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("First InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
firstAppLogger := GetAppLogger()
|
||||
|
||||
// 第二次初始化(重新初始化)
|
||||
err = InitLoggers("debug", true,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-reinit2.log"),
|
||||
MaxSize: 5,
|
||||
MaxBackups: 2,
|
||||
MaxAge: 3,
|
||||
Compress: false,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-reinit2.log"),
|
||||
MaxSize: 5,
|
||||
MaxBackups: 2,
|
||||
MaxAge: 3,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Second InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
secondAppLogger := GetAppLogger()
|
||||
|
||||
// 验证重新初始化后日志记录器已更新
|
||||
if firstAppLogger == secondAppLogger {
|
||||
t.Error("Logger should be replaced after reinitialization")
|
||||
}
|
||||
}
|
||||
388
pkg/logger/rotation_test.go
Normal file
388
pkg/logger/rotation_test.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TestLogRotation 测试日志轮转功能(T027)
|
||||
func TestLogRotation(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
appLogFile := filepath.Join(tempDir, "app-rotation.log")
|
||||
|
||||
// 初始化日志系统,设置较小的 MaxSize 以便测试
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: appLogFile,
|
||||
MaxSize: 1, // 1MB,写入足够数据后会触发轮转
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false, // 不压缩以便检查
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-rotation.log"),
|
||||
MaxSize: 1,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
logger := GetAppLogger()
|
||||
|
||||
// 写入大量日志数据以触发轮转(每条约100字节,写入15000条约1.5MB)
|
||||
largeMessage := strings.Repeat("a", 100)
|
||||
for i := 0; i < 15000; i++ {
|
||||
logger.Info(largeMessage,
|
||||
zap.Int("iteration", i),
|
||||
zap.String("data", largeMessage),
|
||||
)
|
||||
}
|
||||
|
||||
// 刷新缓冲区
|
||||
if err := Sync(); err != nil {
|
||||
t.Fatalf("Sync failed: %v", err)
|
||||
}
|
||||
|
||||
// 等待一小段时间确保文件写入完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证主日志文件存在
|
||||
if _, err := os.Stat(appLogFile); os.IsNotExist(err) {
|
||||
t.Error("Main log file should exist")
|
||||
}
|
||||
|
||||
// 检查是否有备份文件(轮转后的文件)
|
||||
files, err := filepath.Glob(filepath.Join(tempDir, "app-rotation-*.log"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to glob backup files: %v", err)
|
||||
}
|
||||
|
||||
// 由于写入了超过1MB的数据,应该触发至少一次轮转
|
||||
if len(files) == 0 {
|
||||
// 可能系统写入速度或lumberjack行为导致未立即轮转,检查主文件大小
|
||||
info, err := os.Stat(appLogFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to stat main log file: %v", err)
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
t.Error("Log file should have content")
|
||||
}
|
||||
// 不强制要求必须轮转,因为取决于具体实现
|
||||
t.Logf("No rotation occurred, but main log file size: %d bytes", info.Size())
|
||||
} else {
|
||||
t.Logf("Found %d rotated backup file(s)", len(files))
|
||||
}
|
||||
}
|
||||
|
||||
// TestMaxBackups 测试最大备份数限制(T027)
|
||||
func TestMaxBackups(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
appLogFile := filepath.Join(tempDir, "app-backups.log")
|
||||
|
||||
// 初始化日志系统,设置 MaxBackups=2
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: appLogFile,
|
||||
MaxSize: 1, // 1MB
|
||||
MaxBackups: 2, // 最多保留2个备份
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-backups.log"),
|
||||
MaxSize: 1,
|
||||
MaxBackups: 2,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
logger := GetAppLogger()
|
||||
|
||||
// 写入足够的数据触发多次轮转(每次1.5MB,共4.5MB应该触发3次轮转)
|
||||
largeMessage := strings.Repeat("b", 100)
|
||||
for round := 0; round < 3; round++ {
|
||||
for i := 0; i < 15000; i++ {
|
||||
logger.Info(largeMessage,
|
||||
zap.Int("round", round),
|
||||
zap.Int("iteration", i),
|
||||
)
|
||||
}
|
||||
Sync()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// 等待轮转完成
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// 检查备份文件数量
|
||||
files, err := filepath.Glob(filepath.Join(tempDir, "app-backups-*.log"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to glob backup files: %v", err)
|
||||
}
|
||||
|
||||
// 由于 MaxBackups=2,即使触发了多次轮转,也只应保留最多2个备份文件
|
||||
// (实际行为取决于 lumberjack 的实现细节,可能小于等于2)
|
||||
if len(files) > 2 {
|
||||
t.Errorf("Expected at most 2 backup files due to MaxBackups=2, got %d", len(files))
|
||||
}
|
||||
t.Logf("Found %d backup file(s) with MaxBackups=2", len(files))
|
||||
}
|
||||
|
||||
// TestCompressionConfig 测试压缩配置(T027)
|
||||
func TestCompressionConfig(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
compress bool
|
||||
}{
|
||||
{
|
||||
name: "compression enabled",
|
||||
compress: true,
|
||||
},
|
||||
{
|
||||
name: "compression disabled",
|
||||
compress: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logFile := filepath.Join(tempDir, "app-"+tt.name+".log")
|
||||
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: logFile,
|
||||
MaxSize: 1,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: tt.compress,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-"+tt.name+".log"),
|
||||
MaxSize: 1,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: tt.compress,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
logger := GetAppLogger()
|
||||
|
||||
// 写入一些日志
|
||||
for i := 0; i < 1000; i++ {
|
||||
logger.Info("test compression",
|
||||
zap.Int("id", i),
|
||||
zap.String("data", strings.Repeat("c", 50)),
|
||||
)
|
||||
}
|
||||
|
||||
Sync()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证日志文件存在
|
||||
if _, err := os.Stat(logFile); os.IsNotExist(err) {
|
||||
t.Error("Log file should exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMaxAge 测试日志文件保留时间(T027)
|
||||
func TestMaxAge(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 初始化日志系统,设置 MaxAge=1 天
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-maxage.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 1, // 1天
|
||||
Compress: false,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-maxage.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 1,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
logger := GetAppLogger()
|
||||
|
||||
// 写入日志
|
||||
logger.Info("test max age", zap.String("config", "maxage=1"))
|
||||
Sync()
|
||||
|
||||
// 验证配置已应用(无法在单元测试中验证实际的清理行为,因为需要等待1天)
|
||||
// 这里只验证初始化没有错误
|
||||
if logger == nil {
|
||||
t.Error("Logger should be initialized with MaxAge config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewLumberjackLogger 测试 Lumberjack logger 创建(T027)
|
||||
func TestNewLumberjackLogger(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config LogRotationConfig
|
||||
}{
|
||||
{
|
||||
name: "standard config",
|
||||
config: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "test1.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "minimal config",
|
||||
config: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "test2.log"),
|
||||
MaxSize: 1,
|
||||
MaxBackups: 1,
|
||||
MaxAge: 1,
|
||||
Compress: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "large config",
|
||||
config: LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "test3.log"),
|
||||
MaxSize: 100,
|
||||
MaxBackups: 10,
|
||||
MaxAge: 30,
|
||||
Compress: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := newLumberjackLogger(tt.config)
|
||||
if logger == nil {
|
||||
t.Error("newLumberjackLogger should not return nil")
|
||||
}
|
||||
|
||||
// 验证配置已正确设置
|
||||
if logger.Filename != tt.config.Filename {
|
||||
t.Errorf("Filename = %v, want %v", logger.Filename, tt.config.Filename)
|
||||
}
|
||||
if logger.MaxSize != tt.config.MaxSize {
|
||||
t.Errorf("MaxSize = %v, want %v", logger.MaxSize, tt.config.MaxSize)
|
||||
}
|
||||
if logger.MaxBackups != tt.config.MaxBackups {
|
||||
t.Errorf("MaxBackups = %v, want %v", logger.MaxBackups, tt.config.MaxBackups)
|
||||
}
|
||||
if logger.MaxAge != tt.config.MaxAge {
|
||||
t.Errorf("MaxAge = %v, want %v", logger.MaxAge, tt.config.MaxAge)
|
||||
}
|
||||
if logger.Compress != tt.config.Compress {
|
||||
t.Errorf("Compress = %v, want %v", logger.Compress, tt.config.Compress)
|
||||
}
|
||||
if !logger.LocalTime {
|
||||
t.Error("LocalTime should be true")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentLogging 测试并发日志写入(T027)
|
||||
func TestConcurrentLogging(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 初始化日志系统
|
||||
err := InitLoggers("info", false,
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "app-concurrent.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
LogRotationConfig{
|
||||
Filename: filepath.Join(tempDir, "access-concurrent.log"),
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("InitLoggers failed: %v", err)
|
||||
}
|
||||
|
||||
logger := GetAppLogger()
|
||||
|
||||
// 启动多个 goroutine 并发写入日志
|
||||
done := make(chan bool)
|
||||
goroutines := 10
|
||||
messagesPerGoroutine := 100
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(id int) {
|
||||
for j := 0; j < messagesPerGoroutine; j++ {
|
||||
logger.Info("concurrent log message",
|
||||
zap.Int("goroutine", id),
|
||||
zap.Int("message", j),
|
||||
)
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 等待所有 goroutine 完成
|
||||
for i := 0; i < goroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// 刷新缓冲区
|
||||
if err := Sync(); err != nil {
|
||||
t.Fatalf("Sync failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证日志文件存在且有内容
|
||||
logFile := filepath.Join(tempDir, "app-concurrent.log")
|
||||
info, err := os.Stat(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to stat log file: %v", err)
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
t.Error("Log file should have content after concurrent writes")
|
||||
}
|
||||
|
||||
t.Logf("Concurrent logging test completed, log file size: %d bytes", info.Size())
|
||||
}
|
||||
477
pkg/response/response_test.go
Normal file
477
pkg/response/response_test.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/errors"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// TestSuccess 测试成功响应(T034)
|
||||
func TestSuccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
}{
|
||||
{
|
||||
name: "success with string data",
|
||||
data: "test data",
|
||||
},
|
||||
{
|
||||
name: "success with map data",
|
||||
data: map[string]any{
|
||||
"id": 123,
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with slice data",
|
||||
data: []string{"item1", "item2", "item3"},
|
||||
},
|
||||
{
|
||||
name: "success with struct data",
|
||||
data: struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}{
|
||||
ID: 456,
|
||||
Name: "test struct",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with nil data",
|
||||
data: nil,
|
||||
},
|
||||
{
|
||||
name: "success with empty map",
|
||||
data: map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return Success(c, tt.data)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 验证 HTTP 状态码
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("Expected status code 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 验证响应头
|
||||
if resp.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
// 解析响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// 验证响应结构
|
||||
if response.Code != errors.CodeSuccess {
|
||||
t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code)
|
||||
}
|
||||
|
||||
if response.Message != "success" {
|
||||
t.Errorf("Expected message 'success', got '%s'", response.Message)
|
||||
}
|
||||
|
||||
// 验证时间戳格式 RFC3339
|
||||
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
|
||||
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
|
||||
}
|
||||
|
||||
// 验证数据字段(如果不是 nil)
|
||||
if tt.data != nil {
|
||||
if response.Data == nil {
|
||||
t.Error("Expected data field to be non-nil")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestError 测试错误响应(T035)
|
||||
func TestError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
httpStatus int
|
||||
code int
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "internal server error",
|
||||
httpStatus: 500,
|
||||
code: errors.CodeInternalError,
|
||||
message: "Internal server error occurred",
|
||||
},
|
||||
{
|
||||
name: "missing token error",
|
||||
httpStatus: 401,
|
||||
code: errors.CodeMissingToken,
|
||||
message: "Authentication token is missing",
|
||||
},
|
||||
{
|
||||
name: "invalid token error",
|
||||
httpStatus: 401,
|
||||
code: errors.CodeInvalidToken,
|
||||
message: "Token is invalid or expired",
|
||||
},
|
||||
{
|
||||
name: "rate limit error",
|
||||
httpStatus: 429,
|
||||
code: errors.CodeTooManyRequests,
|
||||
message: "Too many requests, please try again later",
|
||||
},
|
||||
{
|
||||
name: "service unavailable error",
|
||||
httpStatus: 503,
|
||||
code: errors.CodeAuthServiceUnavailable,
|
||||
message: "Authentication service is currently unavailable",
|
||||
},
|
||||
{
|
||||
name: "bad request error",
|
||||
httpStatus: 400,
|
||||
code: 2000,
|
||||
message: "Invalid request parameters",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return Error(c, tt.httpStatus, tt.code, tt.message)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 验证 HTTP 状态码
|
||||
if resp.StatusCode != tt.httpStatus {
|
||||
t.Errorf("Expected status code %d, got %d", tt.httpStatus, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 验证响应头
|
||||
if resp.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
// 解析响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// 验证响应结构
|
||||
if response.Code != tt.code {
|
||||
t.Errorf("Expected code %d, got %d", tt.code, response.Code)
|
||||
}
|
||||
|
||||
if response.Message != tt.message {
|
||||
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message)
|
||||
}
|
||||
|
||||
if response.Data != nil {
|
||||
t.Errorf("Expected data to be nil in error response, got %v", response.Data)
|
||||
}
|
||||
|
||||
// 验证时间戳格式 RFC3339
|
||||
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
|
||||
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSuccessWithMessage 测试带自定义消息的成功响应(T034)
|
||||
func TestSuccessWithMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "custom success message",
|
||||
data: map[string]any{
|
||||
"user_id": 123,
|
||||
},
|
||||
message: "User created successfully",
|
||||
},
|
||||
{
|
||||
name: "empty custom message",
|
||||
data: "test data",
|
||||
message: "",
|
||||
},
|
||||
{
|
||||
name: "chinese message",
|
||||
data: map[string]string{
|
||||
"status": "ok",
|
||||
},
|
||||
message: "操作成功",
|
||||
},
|
||||
{
|
||||
name: "long message",
|
||||
data: nil,
|
||||
message: "This is a very long success message that describes in detail what happened during the operation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return SuccessWithMessage(c, tt.data, tt.message)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 验证 HTTP 状态码(默认 200)
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("Expected status code 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 解析响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
var response Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// 验证响应结构
|
||||
if response.Code != errors.CodeSuccess {
|
||||
t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code)
|
||||
}
|
||||
|
||||
if response.Message != tt.message {
|
||||
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message)
|
||||
}
|
||||
|
||||
// 验证时间戳格式 RFC3339
|
||||
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
|
||||
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponseSerialization 测试响应序列化(T036)
|
||||
func TestResponseSerialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response Response
|
||||
}{
|
||||
{
|
||||
name: "complete response",
|
||||
response: Response{
|
||||
Code: 0,
|
||||
Data: map[string]any{"key": "value"},
|
||||
Message: "success",
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with nil data",
|
||||
response: Response{
|
||||
Code: 1000,
|
||||
Data: nil,
|
||||
Message: "error",
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "response with nested data",
|
||||
response: Response{
|
||||
Code: 0,
|
||||
Data: map[string]any{
|
||||
"user": map[string]any{
|
||||
"id": 123,
|
||||
"name": "test",
|
||||
"tags": []string{"tag1", "tag2"},
|
||||
},
|
||||
},
|
||||
Message: "success",
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 序列化
|
||||
data, err := json.Marshal(tt.response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal response: %v", err)
|
||||
}
|
||||
|
||||
// 反序列化
|
||||
var deserialized Response
|
||||
if err := json.Unmarshal(data, &deserialized); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// 验证字段
|
||||
if deserialized.Code != tt.response.Code {
|
||||
t.Errorf("Code mismatch: expected %d, got %d", tt.response.Code, deserialized.Code)
|
||||
}
|
||||
|
||||
if deserialized.Message != tt.response.Message {
|
||||
t.Errorf("Message mismatch: expected '%s', got '%s'", tt.response.Message, deserialized.Message)
|
||||
}
|
||||
|
||||
if deserialized.Timestamp != tt.response.Timestamp {
|
||||
t.Errorf("Timestamp mismatch: expected '%s', got '%s'", tt.response.Timestamp, deserialized.Timestamp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponseStructFields 测试响应结构字段(T036)
|
||||
func TestResponseStructFields(t *testing.T) {
|
||||
response := Response{
|
||||
Code: 0,
|
||||
Data: "test",
|
||||
Message: "success",
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal response: %v", err)
|
||||
}
|
||||
|
||||
// 解析为 map 以检查 JSON 键
|
||||
var jsonMap map[string]any
|
||||
if err := json.Unmarshal(data, &jsonMap); err != nil {
|
||||
t.Fatalf("Failed to unmarshal to map: %v", err)
|
||||
}
|
||||
|
||||
// 验证所有必需字段都存在
|
||||
requiredFields := []string{"code", "data", "msg", "timestamp"}
|
||||
for _, field := range requiredFields {
|
||||
if _, exists := jsonMap[field]; !exists {
|
||||
t.Errorf("Required field '%s' is missing in JSON response", field)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证字段类型
|
||||
if _, ok := jsonMap["code"].(float64); !ok {
|
||||
t.Error("Field 'code' should be a number")
|
||||
}
|
||||
|
||||
if _, ok := jsonMap["msg"].(string); !ok {
|
||||
t.Error("Field 'msg' should be a string")
|
||||
}
|
||||
|
||||
if _, ok := jsonMap["timestamp"].(string); !ok {
|
||||
t.Error("Field 'timestamp' should be a string")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipleResponses 测试多个连续响应(T036)
|
||||
func TestMultipleResponses(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
callCount := 0
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
callCount++
|
||||
if callCount%2 == 0 {
|
||||
return Success(c, map[string]int{"count": callCount})
|
||||
}
|
||||
return Error(c, 500, errors.CodeInternalError, "error occurred")
|
||||
})
|
||||
|
||||
// 发送多个请求
|
||||
for i := 1; i <= 5; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request %d failed: %v", i, err)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
var response Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Request %d: failed to unmarshal response: %v", i, err)
|
||||
}
|
||||
|
||||
// 验证每个响应都有时间戳
|
||||
if response.Timestamp == "" {
|
||||
t.Errorf("Request %d: timestamp should not be empty", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTimestampFormat 测试时间戳格式(T036)
|
||||
func TestTimestampFormat(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return Success(c, nil)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var response Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// 验证是 RFC3339 格式
|
||||
parsedTime, err := time.Parse(time.RFC3339, response.Timestamp)
|
||||
if err != nil {
|
||||
t.Fatalf("Timestamp is not in RFC3339 format: %s, error: %v", response.Timestamp, err)
|
||||
}
|
||||
|
||||
// 验证时间戳是最近的(应该在最近 1 秒内)
|
||||
now := time.Now()
|
||||
diff := now.Sub(parsedTime)
|
||||
if diff < 0 || diff > time.Second {
|
||||
t.Errorf("Timestamp seems incorrect: %s (diff from now: %v)", response.Timestamp, diff)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user