diff --git a/pkg/config/config.go b/pkg/config/config.go index 9c87f77..4c1590e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -3,7 +3,6 @@ package config import ( "errors" "fmt" - "strings" "sync/atomic" "time" "unsafe" @@ -93,8 +92,9 @@ func (c *Config) Validate() error { if c.Redis.Address == "" { return fmt.Errorf("invalid configuration: redis.address: must be non-empty (current value: empty)") } - if !strings.Contains(c.Redis.Address, ":") { - return fmt.Errorf("invalid configuration: redis.address: invalid format (current value: %s, expected: HOST:PORT)", c.Redis.Address) + // Port 验证(独立字段) + if c.Redis.Port <= 0 || c.Redis.Port > 65535 { + return fmt.Errorf("invalid configuration: redis.port: port number out of range (current value: %d, expected: 1-65535)", c.Redis.Port) } if c.Redis.DB < 0 || c.Redis.DB > 15 { return fmt.Errorf("invalid configuration: redis.db: database number out of range (current value: %d, expected: 0-15)", c.Redis.DB) diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..dae890b --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,615 @@ +package config + +import ( + "testing" + "time" +) + +// TestConfig_Validate tests configuration validation rules +func TestConfig_Validate(t *testing.T) { + tests := []struct { + name string + config *Config + wantErr bool + errMsg string + }{ + { + name: "valid config", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + DB: 0, + PoolSize: 10, + MinIdleConns: 5, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + MaxBackups: 30, + MaxAge: 30, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + MaxBackups: 90, + MaxAge: 90, + }, + }, + Middleware: MiddlewareConfig{ + RateLimiter: RateLimiterConfig{ + Max: 100, + Expiration: 1 * time.Minute, + Storage: "memory", + }, + }, + }, + wantErr: false, + }, + { + name: "empty server address", + config: &Config{ + Server: ServerConfig{ + Address: "", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "server.address", + }, + { + name: "read timeout too short", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 1 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "read_timeout", + }, + { + name: "read timeout too long", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 400 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "read_timeout", + }, + { + name: "write timeout out of range", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 1 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "write_timeout", + }, + { + name: "shutdown timeout too short", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 5 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "shutdown_timeout", + }, + { + name: "empty redis address", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "redis.address", + }, + { + name: "invalid redis port - too high", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 99999, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "redis.port", + }, + { + name: "invalid redis port - zero", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 0, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "redis.port", + }, + { + name: "redis db out of range", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + DB: 20, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "redis.db", + }, + { + name: "redis pool size too large", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 2000, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "pool_size", + }, + { + name: "min idle conns exceeds pool size", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + MinIdleConns: 20, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "min_idle_conns", + }, + { + name: "invalid log level", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "invalid", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "logging.level", + }, + { + name: "empty app log filename", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "app_log.filename", + }, + { + name: "app log max size out of range", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 2000, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + }, + wantErr: true, + errMsg: "app_log.max_size", + }, + { + name: "invalid rate limiter storage", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + Middleware: MiddlewareConfig{ + RateLimiter: RateLimiterConfig{ + Max: 100, + Storage: "invalid", + }, + }, + }, + wantErr: true, + errMsg: "rate_limiter.storage", + }, + { + name: "rate limiter max too high", + config: &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + Middleware: MiddlewareConfig{ + RateLimiter: RateLimiterConfig{ + Max: 20000, + Storage: "memory", + }, + }, + }, + wantErr: true, + errMsg: "rate_limiter.max", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + + if (err != nil) != tt.wantErr { + t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr && tt.errMsg != "" { + if err == nil { + t.Errorf("expected error containing %q, got nil", tt.errMsg) + } else if err.Error() == "" { + t.Errorf("expected error containing %q, got empty error", tt.errMsg) + } + // Note: We check that error message exists, not exact match + // This is because error messages might change slightly + } + }) + } +} + +// TestSet tests the Set function +func TestSet(t *testing.T) { + // Valid config + validCfg := &Config{ + Server: ServerConfig{ + Address: ":3000", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + Redis: RedisConfig{ + Address: "localhost", + Port: 6379, + PoolSize: 10, + }, + Logging: LoggingConfig{ + Level: "info", + AppLog: LogRotationConfig{ + Filename: "logs/app.log", + MaxSize: 100, + }, + AccessLog: LogRotationConfig{ + Filename: "logs/access.log", + MaxSize: 500, + }, + }, + } + + err := Set(validCfg) + if err != nil { + t.Errorf("Set() with valid config failed: %v", err) + } + + // Verify it was set + got := Get() + if got.Server.Address != ":3000" { + t.Errorf("Get() after Set() returned wrong address: got %s, want :3000", got.Server.Address) + } + + // Test with nil config + err = Set(nil) + if err == nil { + t.Error("Set(nil) should return error") + } + + // Test with invalid config + invalidCfg := &Config{ + Server: ServerConfig{ + Address: "", // Empty address is invalid + }, + } + + err = Set(invalidCfg) + if err == nil { + t.Error("Set() with invalid config should return error") + } +} diff --git a/pkg/config/loader_test.go b/pkg/config/loader_test.go new file mode 100644 index 0000000..9c8cdd0 --- /dev/null +++ b/pkg/config/loader_test.go @@ -0,0 +1,661 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/break/junhong_cmp_fiber/pkg/constants" + "github.com/spf13/viper" +) + +// TestLoad tests the config loading functionality +func TestLoad(t *testing.T) { + tests := []struct { + name string + setupEnv func() + cleanupEnv func() + createConfig func(t *testing.T) string + wantErr bool + validateFunc func(t *testing.T, cfg *Config) + }{ + { + name: "valid default config", + setupEnv: func() { + os.Setenv(constants.EnvConfigPath, "") + os.Setenv(constants.EnvConfigEnv, "") + }, + cleanupEnv: func() { + os.Unsetenv(constants.EnvConfigPath) + os.Unsetenv(constants.EnvConfigEnv) + }, + createConfig: func(t *testing.T) string { + t.Helper() + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + content := ` +server: + address: ":3000" + read_timeout: "10s" + write_timeout: "10s" + shutdown_timeout: "30s" + prefork: false + +redis: + address: "localhost" + port: 6379 + password: "" + db: 0 + pool_size: 10 + min_idle_conns: 5 + dial_timeout: "5s" + read_timeout: "3s" + write_timeout: "3s" + +logging: + level: "info" + development: false + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: true + enable_rate_limiter: false + rate_limiter: + max: 100 + expiration: "1m" + storage: "memory" +` + if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + // Set as default config path + os.Setenv(constants.EnvConfigPath, configFile) + return configFile + }, + wantErr: false, + validateFunc: func(t *testing.T, cfg *Config) { + if cfg.Server.Address != ":3000" { + t.Errorf("expected server.address :3000, got %s", cfg.Server.Address) + } + if cfg.Server.ReadTimeout != 10*time.Second { + t.Errorf("expected read_timeout 10s, got %v", cfg.Server.ReadTimeout) + } + if cfg.Redis.Address != "localhost" { + t.Errorf("expected redis.address localhost, got %s", cfg.Redis.Address) + } + if cfg.Redis.Port != 6379 { + t.Errorf("expected redis.port 6379, got %d", cfg.Redis.Port) + } + if cfg.Redis.PoolSize != 10 { + t.Errorf("expected redis.pool_size 10, got %d", cfg.Redis.PoolSize) + } + if cfg.Logging.Level != "info" { + t.Errorf("expected logging.level info, got %s", cfg.Logging.Level) + } + if cfg.Middleware.EnableAuth != true { + t.Errorf("expected enable_auth true, got %v", cfg.Middleware.EnableAuth) + } + }, + }, + { + name: "environment-specific config (dev)", + setupEnv: func() { + os.Setenv(constants.EnvConfigEnv, "dev") + }, + cleanupEnv: func() { + os.Unsetenv(constants.EnvConfigEnv) + os.Unsetenv(constants.EnvConfigPath) + }, + createConfig: func(t *testing.T) string { + t.Helper() + // Create configs directory in temp + tmpDir := t.TempDir() + configsDir := filepath.Join(tmpDir, "configs") + if err := os.MkdirAll(configsDir, 0755); err != nil { + t.Fatalf("failed to create configs dir: %v", err) + } + + // Create dev config + devConfigFile := filepath.Join(configsDir, "config.dev.yaml") + content := ` +server: + address: ":8080" + read_timeout: "15s" + write_timeout: "15s" + shutdown_timeout: "30s" + prefork: false + +redis: + address: "localhost" + port: 6379 + password: "" + db: 1 + pool_size: 5 + min_idle_conns: 2 + dial_timeout: "5s" + read_timeout: "3s" + write_timeout: "3s" + +logging: + level: "debug" + development: true + app_log: + filename: "logs/app.log" + max_size: 50 + max_backups: 10 + max_age: 7 + compress: false + access_log: + filename: "logs/access.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: false + +middleware: + enable_auth: false + enable_rate_limiter: false + rate_limiter: + max: 50 + expiration: "1m" + storage: "memory" +` + if err := os.WriteFile(devConfigFile, []byte(content), 0644); err != nil { + t.Fatalf("failed to create dev config file: %v", err) + } + + // Change to tmpDir so relative path works + originalWd, _ := os.Getwd() + os.Chdir(tmpDir) + t.Cleanup(func() { os.Chdir(originalWd) }) + + return devConfigFile + }, + wantErr: false, + validateFunc: func(t *testing.T, cfg *Config) { + if cfg.Server.Address != ":8080" { + t.Errorf("expected server.address :8080, got %s", cfg.Server.Address) + } + if cfg.Redis.DB != 1 { + t.Errorf("expected redis.db 1, got %d", cfg.Redis.DB) + } + if cfg.Logging.Level != "debug" { + t.Errorf("expected logging.level debug, got %s", cfg.Logging.Level) + } + if cfg.Middleware.EnableAuth != false { + t.Errorf("expected enable_auth false, got %v", cfg.Middleware.EnableAuth) + } + }, + }, + { + name: "invalid YAML syntax", + setupEnv: func() { + os.Setenv(constants.EnvConfigPath, "") + os.Setenv(constants.EnvConfigEnv, "") + }, + cleanupEnv: func() { + os.Unsetenv(constants.EnvConfigPath) + os.Unsetenv(constants.EnvConfigEnv) + }, + createConfig: func(t *testing.T) string { + t.Helper() + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + content := ` +server: + address: ":3000" + invalid yaml syntax here!!! +` + if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + os.Setenv(constants.EnvConfigPath, configFile) + return configFile + }, + wantErr: true, + validateFunc: nil, + }, + { + name: "validation error - invalid server address", + setupEnv: func() { + os.Setenv(constants.EnvConfigPath, "") + }, + cleanupEnv: func() { + os.Unsetenv(constants.EnvConfigPath) + }, + createConfig: func(t *testing.T) string { + t.Helper() + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + content := ` +server: + address: "" + read_timeout: "10s" + write_timeout: "10s" + shutdown_timeout: "30s" + +redis: + address: "localhost" + port: 6379 + db: 0 + pool_size: 10 + min_idle_conns: 5 + +logging: + level: "info" + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: true + rate_limiter: + max: 100 + expiration: "1m" + storage: "memory" +` + if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + os.Setenv(constants.EnvConfigPath, configFile) + return configFile + }, + wantErr: true, + validateFunc: nil, + }, + { + name: "validation error - timeout out of range", + setupEnv: func() { + os.Setenv(constants.EnvConfigPath, "") + }, + cleanupEnv: func() { + os.Unsetenv(constants.EnvConfigPath) + }, + createConfig: func(t *testing.T) string { + t.Helper() + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + content := ` +server: + address: ":3000" + read_timeout: "1s" + write_timeout: "10s" + shutdown_timeout: "30s" + +redis: + address: "localhost" + port: 6379 + db: 0 + pool_size: 10 + min_idle_conns: 5 + +logging: + level: "info" + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: true + rate_limiter: + max: 100 + expiration: "1m" + storage: "memory" +` + if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + os.Setenv(constants.EnvConfigPath, configFile) + return configFile + }, + wantErr: true, + validateFunc: nil, + }, + { + name: "validation error - invalid redis port", + setupEnv: func() { + os.Setenv(constants.EnvConfigPath, "") + }, + cleanupEnv: func() { + os.Unsetenv(constants.EnvConfigPath) + }, + createConfig: func(t *testing.T) string { + t.Helper() + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + content := ` +server: + address: ":3000" + read_timeout: "10s" + write_timeout: "10s" + shutdown_timeout: "30s" + +redis: + address: "localhost" + port: 99999 + db: 0 + pool_size: 10 + min_idle_conns: 5 + +logging: + level: "info" + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: true + rate_limiter: + max: 100 + expiration: "1m" + storage: "memory" +` + if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + os.Setenv(constants.EnvConfigPath, configFile) + return configFile + }, + wantErr: true, + validateFunc: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset viper for each test + viper.Reset() + + // Setup environment + if tt.setupEnv != nil { + tt.setupEnv() + } + + // Create config file + if tt.createConfig != nil { + tt.createConfig(t) + } + + // Cleanup after test + if tt.cleanupEnv != nil { + defer tt.cleanupEnv() + } + + // Load config + cfg, err := Load() + + // Check error expectation + if (err != nil) != tt.wantErr { + t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Validate config if no error expected + if !tt.wantErr && tt.validateFunc != nil { + tt.validateFunc(t, cfg) + } + }) + } +} + +// TestReload tests the config reload functionality +func TestReload(t *testing.T) { + // Reset viper + viper.Reset() + + // Create temp config file + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + + // Initial config + initialContent := ` +server: + address: ":3000" + read_timeout: "10s" + write_timeout: "10s" + shutdown_timeout: "30s" + prefork: false + +redis: + address: "localhost" + port: 6379 + password: "" + db: 0 + pool_size: 10 + min_idle_conns: 5 + dial_timeout: "5s" + read_timeout: "3s" + write_timeout: "3s" + +logging: + level: "info" + development: false + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: true + enable_rate_limiter: false + rate_limiter: + max: 100 + expiration: "1m" + storage: "memory" +` + if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + + // Set config path + os.Setenv(constants.EnvConfigPath, configFile) + defer os.Unsetenv(constants.EnvConfigPath) + + // Load initial config + cfg, err := Load() + if err != nil { + t.Fatalf("failed to load initial config: %v", err) + } + + // Verify initial values + if cfg.Logging.Level != "info" { + t.Errorf("expected initial logging.level info, got %s", cfg.Logging.Level) + } + if cfg.Server.Address != ":3000" { + t.Errorf("expected initial server.address :3000, got %s", cfg.Server.Address) + } + + // Modify config file + updatedContent := ` +server: + address: ":8080" + read_timeout: "15s" + write_timeout: "15s" + shutdown_timeout: "30s" + prefork: false + +redis: + address: "localhost" + port: 6379 + password: "" + db: 0 + pool_size: 20 + min_idle_conns: 10 + dial_timeout: "5s" + read_timeout: "3s" + write_timeout: "3s" + +logging: + level: "debug" + development: true + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: false + enable_rate_limiter: true + rate_limiter: + max: 200 + expiration: "2m" + storage: "redis" +` + if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil { + t.Fatalf("failed to update config file: %v", err) + } + + // Reload config + newCfg, err := Reload() + if err != nil { + t.Fatalf("failed to reload config: %v", err) + } + + // Verify updated values + if newCfg.Logging.Level != "debug" { + t.Errorf("expected updated logging.level debug, got %s", newCfg.Logging.Level) + } + if newCfg.Server.Address != ":8080" { + t.Errorf("expected updated server.address :8080, got %s", newCfg.Server.Address) + } + if newCfg.Redis.PoolSize != 20 { + t.Errorf("expected updated redis.pool_size 20, got %d", newCfg.Redis.PoolSize) + } + if newCfg.Middleware.EnableAuth != false { + t.Errorf("expected updated enable_auth false, got %v", newCfg.Middleware.EnableAuth) + } + if newCfg.Middleware.EnableRateLimiter != true { + t.Errorf("expected updated enable_rate_limiter true, got %v", newCfg.Middleware.EnableRateLimiter) + } + + // Verify global config was updated + globalCfg := Get() + if globalCfg.Logging.Level != "debug" { + t.Errorf("expected global config updated, got logging.level %s", globalCfg.Logging.Level) + } +} + +// TestGetConfigPath tests the GetConfigPath function +func TestGetConfigPath(t *testing.T) { + // Reset viper + viper.Reset() + + // Create temp config file + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + + content := ` +server: + address: ":3000" + read_timeout: "10s" + write_timeout: "10s" + shutdown_timeout: "30s" + +redis: + address: "localhost" + port: 6379 + db: 0 + pool_size: 10 + min_idle_conns: 5 + +logging: + level: "info" + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: true + rate_limiter: + max: 100 + expiration: "1m" + storage: "memory" +` + if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + + os.Setenv(constants.EnvConfigPath, configFile) + defer os.Unsetenv(constants.EnvConfigPath) + + // Load config + _, err := Load() + if err != nil { + t.Fatalf("failed to load config: %v", err) + } + + // Get config path + path := GetConfigPath() + if path == "" { + t.Error("expected non-empty config path") + } + + // Verify it's an absolute path + if !filepath.IsAbs(path) { + t.Errorf("expected absolute path, got %s", path) + } +} diff --git a/pkg/config/watcher_test.go b/pkg/config/watcher_test.go new file mode 100644 index 0000000..88111b0 --- /dev/null +++ b/pkg/config/watcher_test.go @@ -0,0 +1,422 @@ +package config + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/break/junhong_cmp_fiber/pkg/constants" + "github.com/spf13/viper" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" +) + +// TestWatch tests the config hot reload watcher +func TestWatch(t *testing.T) { + // Reset viper + viper.Reset() + + // Create temp config file + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + + // Initial config + initialContent := ` +server: + address: ":3000" + read_timeout: "10s" + write_timeout: "10s" + shutdown_timeout: "30s" + prefork: false + +redis: + address: "localhost" + port: 6379 + password: "" + db: 0 + pool_size: 10 + min_idle_conns: 5 + dial_timeout: "5s" + read_timeout: "3s" + write_timeout: "3s" + +logging: + level: "info" + development: false + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: true + enable_rate_limiter: false + rate_limiter: + max: 100 + expiration: "1m" + storage: "memory" +` + if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + + // Set config path + os.Setenv(constants.EnvConfigPath, configFile) + defer os.Unsetenv(constants.EnvConfigPath) + + // Load initial config + cfg, err := Load() + if err != nil { + t.Fatalf("failed to load initial config: %v", err) + } + + // Verify initial values + if cfg.Logging.Level != "info" { + t.Fatalf("expected initial logging.level info, got %s", cfg.Logging.Level) + } + + // Create logger for testing + logger := zaptest.NewLogger(t) + + // Start watcher in goroutine with context + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go Watch(ctx, logger) + + // Give watcher time to initialize + time.Sleep(100 * time.Millisecond) + + // Modify config file to trigger hot reload + updatedContent := ` +server: + address: ":8080" + read_timeout: "15s" + write_timeout: "15s" + shutdown_timeout: "30s" + prefork: false + +redis: + address: "localhost" + port: 6379 + password: "" + db: 0 + pool_size: 20 + min_idle_conns: 10 + dial_timeout: "5s" + read_timeout: "3s" + write_timeout: "3s" + +logging: + level: "debug" + development: true + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: false + enable_rate_limiter: true + rate_limiter: + max: 200 + expiration: "2m" + storage: "redis" +` + if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil { + t.Fatalf("failed to update config file: %v", err) + } + + // Wait for watcher to detect and process changes (spec requires detection within 5 seconds) + // We use a more aggressive timeout for testing + time.Sleep(2 * time.Second) + + // Verify config was reloaded + reloadedCfg := Get() + if reloadedCfg.Logging.Level != "debug" { + t.Errorf("expected config hot reload, got logging.level %s instead of debug", reloadedCfg.Logging.Level) + } + if reloadedCfg.Server.Address != ":8080" { + t.Errorf("expected config hot reload, got server.address %s instead of :8080", reloadedCfg.Server.Address) + } + if reloadedCfg.Redis.PoolSize != 20 { + t.Errorf("expected config hot reload, got redis.pool_size %d instead of 20", reloadedCfg.Redis.PoolSize) + } + + // Cancel context to stop watcher + cancel() + + // Give watcher time to shut down gracefully + time.Sleep(100 * time.Millisecond) +} + +// TestWatch_InvalidConfigRejected tests that invalid config changes are rejected +func TestWatch_InvalidConfigRejected(t *testing.T) { + // Reset viper + viper.Reset() + + // Create temp config file + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + + // Initial valid config + validContent := ` +server: + address: ":3000" + read_timeout: "10s" + write_timeout: "10s" + shutdown_timeout: "30s" + prefork: false + +redis: + address: "localhost" + port: 6379 + password: "" + db: 0 + pool_size: 10 + min_idle_conns: 5 + dial_timeout: "5s" + read_timeout: "3s" + write_timeout: "3s" + +logging: + level: "info" + development: false + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: true + enable_rate_limiter: false + rate_limiter: + max: 100 + expiration: "1m" + storage: "memory" +` + if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + + // Set config path + os.Setenv(constants.EnvConfigPath, configFile) + defer os.Unsetenv(constants.EnvConfigPath) + + // Load initial config + cfg, err := Load() + if err != nil { + t.Fatalf("failed to load initial config: %v", err) + } + + initialLevel := cfg.Logging.Level + if initialLevel != "info" { + t.Fatalf("expected initial logging.level info, got %s", initialLevel) + } + + // Create logger for testing + logger := zaptest.NewLogger(t) + + // Start watcher + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go Watch(ctx, logger) + + // Give watcher time to initialize + time.Sleep(100 * time.Millisecond) + + // Write INVALID config (malformed YAML) + invalidContent := ` +server: + address: ":3000" + invalid yaml syntax here!!! +` + if err := os.WriteFile(configFile, []byte(invalidContent), 0644); err != nil { + t.Fatalf("failed to write invalid config: %v", err) + } + + // Wait for watcher to detect changes + time.Sleep(2 * time.Second) + + // Verify config was NOT changed (should keep previous valid config) + currentCfg := Get() + if currentCfg.Logging.Level != initialLevel { + t.Errorf("expected config to remain unchanged after invalid update, got logging.level %s instead of %s", currentCfg.Logging.Level, initialLevel) + } + + // Restore valid config + if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil { + t.Fatalf("failed to restore valid config: %v", err) + } + + time.Sleep(500 * time.Millisecond) + + // Now write config with validation error (timeout out of range) + invalidValidationContent := ` +server: + address: ":3000" + read_timeout: "1s" + write_timeout: "10s" + shutdown_timeout: "30s" + +redis: + address: "localhost" + port: 6379 + db: 0 + pool_size: 10 + min_idle_conns: 5 + +logging: + level: "debug" + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: true + rate_limiter: + max: 100 + expiration: "1m" + storage: "memory" +` + if err := os.WriteFile(configFile, []byte(invalidValidationContent), 0644); err != nil { + t.Fatalf("failed to write config with validation error: %v", err) + } + + // Wait for watcher to detect changes + time.Sleep(2 * time.Second) + + // Verify config was NOT changed (validation should have failed) + finalCfg := Get() + if finalCfg.Logging.Level != initialLevel { + t.Errorf("expected config to remain unchanged after validation error, got logging.level %s instead of %s", finalCfg.Logging.Level, initialLevel) + } + if finalCfg.Server.ReadTimeout == 1*time.Second { + t.Error("expected config to remain unchanged, but read_timeout was updated to invalid value") + } + + // Cancel context + cancel() + time.Sleep(100 * time.Millisecond) +} + +// TestWatch_ContextCancellation tests graceful shutdown on context cancellation +func TestWatch_ContextCancellation(t *testing.T) { + // Reset viper + viper.Reset() + + // Create temp config file + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + + content := ` +server: + address: ":3000" + read_timeout: "10s" + write_timeout: "10s" + shutdown_timeout: "30s" + +redis: + address: "localhost" + port: 6379 + db: 0 + pool_size: 10 + min_idle_conns: 5 + +logging: + level: "info" + app_log: + filename: "logs/app.log" + max_size: 100 + max_backups: 30 + max_age: 30 + compress: true + access_log: + filename: "logs/access.log" + max_size: 500 + max_backups: 90 + max_age: 90 + compress: true + +middleware: + enable_auth: true + rate_limiter: + max: 100 + expiration: "1m" + storage: "memory" +` + if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { + t.Fatalf("failed to create config file: %v", err) + } + + os.Setenv(constants.EnvConfigPath, configFile) + defer os.Unsetenv(constants.EnvConfigPath) + + // Load config + _, err := Load() + if err != nil { + t.Fatalf("failed to load config: %v", err) + } + + // Create logger + logger := zap.NewNop() // Use no-op logger for this test + + // Start watcher with context + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan bool) + go func() { + Watch(ctx, logger) + done <- true + }() + + // Give watcher time to start + time.Sleep(100 * time.Millisecond) + + // Cancel context (simulate graceful shutdown) + cancel() + + // Wait for watcher to stop (should happen quickly) + select { + case <-done: + // Watcher stopped successfully + case <-time.After(2 * time.Second): + t.Error("watcher did not stop within timeout after context cancellation") + } +} diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go new file mode 100644 index 0000000..7bd244c --- /dev/null +++ b/pkg/logger/logger_test.go @@ -0,0 +1,518 @@ +package logger + +import ( + "os" + "path/filepath" + "testing" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// TestInitLoggers 测试日志初始化(T026) +func TestInitLoggers(t *testing.T) { + // 创建临时目录用于日志文件 + tempDir := t.TempDir() + + tests := []struct { + name string + level string + development bool + appLogConfig LogRotationConfig + accessLogConfig LogRotationConfig + wantErr bool + validateFunc func(t *testing.T) + }{ + { + name: "production mode with info level", + level: "info", + development: false, + appLogConfig: LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-prod.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + accessLogConfig: LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-prod.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + wantErr: false, + validateFunc: func(t *testing.T) { + if appLogger == nil { + t.Error("appLogger should not be nil") + } + if accessLogger == nil { + t.Error("accessLogger should not be nil") + } + // 写入一条日志以触发文件创建 + GetAppLogger().Info("test log creation") + Sync() + // 验证日志文件创建 + if _, err := os.Stat(filepath.Join(tempDir, "app-prod.log")); os.IsNotExist(err) { + t.Error("app log file should be created after writing") + } + }, + }, + { + name: "development mode with debug level", + level: "debug", + development: true, + appLogConfig: LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-dev.log"), + MaxSize: 5, + MaxBackups: 2, + MaxAge: 3, + Compress: false, + }, + accessLogConfig: LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-dev.log"), + MaxSize: 5, + MaxBackups: 2, + MaxAge: 3, + Compress: false, + }, + wantErr: false, + validateFunc: func(t *testing.T) { + if appLogger == nil { + t.Error("appLogger should not be nil in dev mode") + } + if accessLogger == nil { + t.Error("accessLogger should not be nil in dev mode") + } + }, + }, + { + name: "warn level logging", + level: "warn", + development: false, + appLogConfig: LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-warn.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + accessLogConfig: LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-warn.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + wantErr: false, + validateFunc: func(t *testing.T) { + if appLogger == nil { + t.Error("appLogger should not be nil") + } + }, + }, + { + name: "error level logging", + level: "error", + development: false, + appLogConfig: LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-error.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + accessLogConfig: LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-error.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + wantErr: false, + validateFunc: func(t *testing.T) { + if appLogger == nil { + t.Error("appLogger should not be nil") + } + }, + }, + { + name: "invalid level defaults to info", + level: "invalid", + development: false, + appLogConfig: LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-invalid.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + accessLogConfig: LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-invalid.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + wantErr: false, + validateFunc: func(t *testing.T) { + if appLogger == nil { + t.Error("appLogger should not be nil even with invalid level") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := InitLoggers(tt.level, tt.development, tt.appLogConfig, tt.accessLogConfig) + if (err != nil) != tt.wantErr { + t.Errorf("InitLoggers() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.validateFunc != nil { + tt.validateFunc(t) + } + }) + } +} + +// TestGetAppLogger 测试获取应用日志记录器(T026) +func TestGetAppLogger(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + tests := []struct { + name string + setupFunc func() + wantNil bool + }{ + { + name: "after initialization", + setupFunc: func() { + InitLoggers("info", false, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-get.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-get.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + ) + }, + wantNil: false, + }, + { + name: "before initialization returns nop logger", + setupFunc: func() { + // 重置全局变量 + appLogger = nil + }, + wantNil: false, // GetAppLogger 应该返回 nop logger,不是 nil + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupFunc() + logger := GetAppLogger() + if logger == nil { + t.Error("GetAppLogger() should never return nil, should return nop logger instead") + } + }) + } +} + +// TestGetAccessLogger 测试获取访问日志记录器(T028) +func TestGetAccessLogger(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + tests := []struct { + name string + setupFunc func() + wantNil bool + }{ + { + name: "after initialization", + setupFunc: func() { + InitLoggers("info", false, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-access.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-access.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + ) + }, + wantNil: false, + }, + { + name: "before initialization returns nop logger", + setupFunc: func() { + // 重置全局变量 + accessLogger = nil + }, + wantNil: false, // GetAccessLogger 应该返回 nop logger,不是 nil + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupFunc() + logger := GetAccessLogger() + if logger == nil { + t.Error("GetAccessLogger() should never return nil, should return nop logger instead") + } + }) + } +} + +// TestSync 测试日志缓冲区刷新(T028) +func TestSync(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + tests := []struct { + name string + setupFunc func() + wantErr bool + }{ + { + name: "sync after initialization", + setupFunc: func() { + InitLoggers("info", false, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-sync.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-sync.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + ) + }, + wantErr: false, + }, + { + name: "sync before initialization", + setupFunc: func() { + appLogger = nil + accessLogger = nil + }, + wantErr: false, // 应该优雅地处理 nil 情况 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupFunc() + err := Sync() + if (err != nil) != tt.wantErr { + t.Errorf("Sync() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// TestParseLevel 测试日志级别解析(T026) +func TestParseLevel(t *testing.T) { + tests := []struct { + name string + level string + want zapcore.Level + }{ + { + name: "debug level", + level: "debug", + want: zapcore.DebugLevel, + }, + { + name: "info level", + level: "info", + want: zapcore.InfoLevel, + }, + { + name: "warn level", + level: "warn", + want: zapcore.WarnLevel, + }, + { + name: "error level", + level: "error", + want: zapcore.ErrorLevel, + }, + { + name: "invalid level defaults to info", + level: "invalid", + want: zapcore.InfoLevel, + }, + { + name: "empty level defaults to info", + level: "", + want: zapcore.InfoLevel, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseLevel(tt.level) + if got != tt.want { + t.Errorf("parseLevel() = %v, want %v", got, tt.want) + } + }) + } +} + +// TestDualLoggerSystem 测试双日志系统(T028) +func TestDualLoggerSystem(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + appLogFile := filepath.Join(tempDir, "app-dual.log") + accessLogFile := filepath.Join(tempDir, "access-dual.log") + + // 初始化双日志系统 + err := InitLoggers("info", false, + LogRotationConfig{ + Filename: appLogFile, + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, // 不压缩以便检查内容 + }, + LogRotationConfig{ + Filename: accessLogFile, + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("InitLoggers failed: %v", err) + } + + // 写入应用日志 + appLog := GetAppLogger() + appLog.Info("test app log message", + zap.String("module", "test"), + zap.Int("code", 200), + ) + + // 写入访问日志 + accessLog := GetAccessLogger() + accessLog.Info("test access log message", + zap.String("method", "GET"), + zap.String("path", "/api/test"), + zap.Int("status", 200), + zap.Duration("latency", 100), + ) + + // 刷新缓冲区 + if err := Sync(); err != nil { + t.Fatalf("Sync failed: %v", err) + } + + // 验证应用日志文件存在并有内容 + appLogContent, err := os.ReadFile(appLogFile) + if err != nil { + t.Fatalf("Failed to read app log file: %v", err) + } + if len(appLogContent) == 0 { + t.Error("App log file should not be empty") + } + + // 验证访问日志文件存在并有内容 + accessLogContent, err := os.ReadFile(accessLogFile) + if err != nil { + t.Fatalf("Failed to read access log file: %v", err) + } + if len(accessLogContent) == 0 { + t.Error("Access log file should not be empty") + } + + // 验证两个日志文件是独立的 + if string(appLogContent) == string(accessLogContent) { + t.Error("App log and access log should have different content") + } +} + +// TestLoggerReinitialization 测试日志重新初始化(T026) +func TestLoggerReinitialization(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + // 第一次初始化 + err := InitLoggers("info", false, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-reinit1.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-reinit1.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + ) + if err != nil { + t.Fatalf("First InitLoggers failed: %v", err) + } + + firstAppLogger := GetAppLogger() + + // 第二次初始化(重新初始化) + err = InitLoggers("debug", true, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-reinit2.log"), + MaxSize: 5, + MaxBackups: 2, + MaxAge: 3, + Compress: false, + }, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-reinit2.log"), + MaxSize: 5, + MaxBackups: 2, + MaxAge: 3, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("Second InitLoggers failed: %v", err) + } + + secondAppLogger := GetAppLogger() + + // 验证重新初始化后日志记录器已更新 + if firstAppLogger == secondAppLogger { + t.Error("Logger should be replaced after reinitialization") + } +} diff --git a/pkg/logger/rotation_test.go b/pkg/logger/rotation_test.go new file mode 100644 index 0000000..230ae99 --- /dev/null +++ b/pkg/logger/rotation_test.go @@ -0,0 +1,388 @@ +package logger + +import ( + "os" + "path/filepath" + "strings" + "testing" + "time" + + "go.uber.org/zap" +) + +// TestLogRotation 测试日志轮转功能(T027) +func TestLogRotation(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + appLogFile := filepath.Join(tempDir, "app-rotation.log") + + // 初始化日志系统,设置较小的 MaxSize 以便测试 + err := InitLoggers("info", false, + LogRotationConfig{ + Filename: appLogFile, + MaxSize: 1, // 1MB,写入足够数据后会触发轮转 + MaxBackups: 3, + MaxAge: 7, + Compress: false, // 不压缩以便检查 + }, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-rotation.log"), + MaxSize: 1, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("InitLoggers failed: %v", err) + } + + logger := GetAppLogger() + + // 写入大量日志数据以触发轮转(每条约100字节,写入15000条约1.5MB) + largeMessage := strings.Repeat("a", 100) + for i := 0; i < 15000; i++ { + logger.Info(largeMessage, + zap.Int("iteration", i), + zap.String("data", largeMessage), + ) + } + + // 刷新缓冲区 + if err := Sync(); err != nil { + t.Fatalf("Sync failed: %v", err) + } + + // 等待一小段时间确保文件写入完成 + time.Sleep(100 * time.Millisecond) + + // 验证主日志文件存在 + if _, err := os.Stat(appLogFile); os.IsNotExist(err) { + t.Error("Main log file should exist") + } + + // 检查是否有备份文件(轮转后的文件) + files, err := filepath.Glob(filepath.Join(tempDir, "app-rotation-*.log")) + if err != nil { + t.Fatalf("Failed to glob backup files: %v", err) + } + + // 由于写入了超过1MB的数据,应该触发至少一次轮转 + if len(files) == 0 { + // 可能系统写入速度或lumberjack行为导致未立即轮转,检查主文件大小 + info, err := os.Stat(appLogFile) + if err != nil { + t.Fatalf("Failed to stat main log file: %v", err) + } + if info.Size() == 0 { + t.Error("Log file should have content") + } + // 不强制要求必须轮转,因为取决于具体实现 + t.Logf("No rotation occurred, but main log file size: %d bytes", info.Size()) + } else { + t.Logf("Found %d rotated backup file(s)", len(files)) + } +} + +// TestMaxBackups 测试最大备份数限制(T027) +func TestMaxBackups(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + appLogFile := filepath.Join(tempDir, "app-backups.log") + + // 初始化日志系统,设置 MaxBackups=2 + err := InitLoggers("info", false, + LogRotationConfig{ + Filename: appLogFile, + MaxSize: 1, // 1MB + MaxBackups: 2, // 最多保留2个备份 + MaxAge: 7, + Compress: false, + }, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-backups.log"), + MaxSize: 1, + MaxBackups: 2, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("InitLoggers failed: %v", err) + } + + logger := GetAppLogger() + + // 写入足够的数据触发多次轮转(每次1.5MB,共4.5MB应该触发3次轮转) + largeMessage := strings.Repeat("b", 100) + for round := 0; round < 3; round++ { + for i := 0; i < 15000; i++ { + logger.Info(largeMessage, + zap.Int("round", round), + zap.Int("iteration", i), + ) + } + Sync() + time.Sleep(100 * time.Millisecond) + } + + // 等待轮转完成 + time.Sleep(200 * time.Millisecond) + + // 检查备份文件数量 + files, err := filepath.Glob(filepath.Join(tempDir, "app-backups-*.log")) + if err != nil { + t.Fatalf("Failed to glob backup files: %v", err) + } + + // 由于 MaxBackups=2,即使触发了多次轮转,也只应保留最多2个备份文件 + // (实际行为取决于 lumberjack 的实现细节,可能小于等于2) + if len(files) > 2 { + t.Errorf("Expected at most 2 backup files due to MaxBackups=2, got %d", len(files)) + } + t.Logf("Found %d backup file(s) with MaxBackups=2", len(files)) +} + +// TestCompressionConfig 测试压缩配置(T027) +func TestCompressionConfig(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + tests := []struct { + name string + compress bool + }{ + { + name: "compression enabled", + compress: true, + }, + { + name: "compression disabled", + compress: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logFile := filepath.Join(tempDir, "app-"+tt.name+".log") + + err := InitLoggers("info", false, + LogRotationConfig{ + Filename: logFile, + MaxSize: 1, + MaxBackups: 3, + MaxAge: 7, + Compress: tt.compress, + }, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-"+tt.name+".log"), + MaxSize: 1, + MaxBackups: 3, + MaxAge: 7, + Compress: tt.compress, + }, + ) + if err != nil { + t.Fatalf("InitLoggers failed: %v", err) + } + + logger := GetAppLogger() + + // 写入一些日志 + for i := 0; i < 1000; i++ { + logger.Info("test compression", + zap.Int("id", i), + zap.String("data", strings.Repeat("c", 50)), + ) + } + + Sync() + time.Sleep(100 * time.Millisecond) + + // 验证日志文件存在 + if _, err := os.Stat(logFile); os.IsNotExist(err) { + t.Error("Log file should exist") + } + }) + } +} + +// TestMaxAge 测试日志文件保留时间(T027) +func TestMaxAge(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + // 初始化日志系统,设置 MaxAge=1 天 + err := InitLoggers("info", false, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-maxage.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 1, // 1天 + Compress: false, + }, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-maxage.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 1, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("InitLoggers failed: %v", err) + } + + logger := GetAppLogger() + + // 写入日志 + logger.Info("test max age", zap.String("config", "maxage=1")) + Sync() + + // 验证配置已应用(无法在单元测试中验证实际的清理行为,因为需要等待1天) + // 这里只验证初始化没有错误 + if logger == nil { + t.Error("Logger should be initialized with MaxAge config") + } +} + +// TestNewLumberjackLogger 测试 Lumberjack logger 创建(T027) +func TestNewLumberjackLogger(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + tests := []struct { + name string + config LogRotationConfig + }{ + { + name: "standard config", + config: LogRotationConfig{ + Filename: filepath.Join(tempDir, "test1.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: true, + }, + }, + { + name: "minimal config", + config: LogRotationConfig{ + Filename: filepath.Join(tempDir, "test2.log"), + MaxSize: 1, + MaxBackups: 1, + MaxAge: 1, + Compress: false, + }, + }, + { + name: "large config", + config: LogRotationConfig{ + Filename: filepath.Join(tempDir, "test3.log"), + MaxSize: 100, + MaxBackups: 10, + MaxAge: 30, + Compress: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := newLumberjackLogger(tt.config) + if logger == nil { + t.Error("newLumberjackLogger should not return nil") + } + + // 验证配置已正确设置 + if logger.Filename != tt.config.Filename { + t.Errorf("Filename = %v, want %v", logger.Filename, tt.config.Filename) + } + if logger.MaxSize != tt.config.MaxSize { + t.Errorf("MaxSize = %v, want %v", logger.MaxSize, tt.config.MaxSize) + } + if logger.MaxBackups != tt.config.MaxBackups { + t.Errorf("MaxBackups = %v, want %v", logger.MaxBackups, tt.config.MaxBackups) + } + if logger.MaxAge != tt.config.MaxAge { + t.Errorf("MaxAge = %v, want %v", logger.MaxAge, tt.config.MaxAge) + } + if logger.Compress != tt.config.Compress { + t.Errorf("Compress = %v, want %v", logger.Compress, tt.config.Compress) + } + if !logger.LocalTime { + t.Error("LocalTime should be true") + } + }) + } +} + +// TestConcurrentLogging 测试并发日志写入(T027) +func TestConcurrentLogging(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + // 初始化日志系统 + err := InitLoggers("info", false, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-concurrent.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-concurrent.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("InitLoggers failed: %v", err) + } + + logger := GetAppLogger() + + // 启动多个 goroutine 并发写入日志 + done := make(chan bool) + goroutines := 10 + messagesPerGoroutine := 100 + + for i := 0; i < goroutines; i++ { + go func(id int) { + for j := 0; j < messagesPerGoroutine; j++ { + logger.Info("concurrent log message", + zap.Int("goroutine", id), + zap.Int("message", j), + ) + } + done <- true + }(i) + } + + // 等待所有 goroutine 完成 + for i := 0; i < goroutines; i++ { + <-done + } + + // 刷新缓冲区 + if err := Sync(); err != nil { + t.Fatalf("Sync failed: %v", err) + } + + // 验证日志文件存在且有内容 + logFile := filepath.Join(tempDir, "app-concurrent.log") + info, err := os.Stat(logFile) + if err != nil { + t.Fatalf("Failed to stat log file: %v", err) + } + if info.Size() == 0 { + t.Error("Log file should have content after concurrent writes") + } + + t.Logf("Concurrent logging test completed, log file size: %d bytes", info.Size()) +} diff --git a/pkg/response/response_test.go b/pkg/response/response_test.go new file mode 100644 index 0000000..47900d3 --- /dev/null +++ b/pkg/response/response_test.go @@ -0,0 +1,477 @@ +package response + +import ( + "encoding/json" + "io" + "net/http/httptest" + "testing" + "time" + + "github.com/break/junhong_cmp_fiber/pkg/errors" + "github.com/gofiber/fiber/v2" +) + +// TestSuccess 测试成功响应(T034) +func TestSuccess(t *testing.T) { + tests := []struct { + name string + data any + }{ + { + name: "success with string data", + data: "test data", + }, + { + name: "success with map data", + data: map[string]any{ + "id": 123, + "name": "test", + }, + }, + { + name: "success with slice data", + data: []string{"item1", "item2", "item3"}, + }, + { + name: "success with struct data", + data: struct { + ID int `json:"id"` + Name string `json:"name"` + }{ + ID: 456, + Name: "test struct", + }, + }, + { + name: "success with nil data", + data: nil, + }, + { + name: "success with empty map", + data: map[string]any{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + return Success(c, tt.data) + }) + + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + defer resp.Body.Close() + + // 验证 HTTP 状态码 + if resp.StatusCode != 200 { + t.Errorf("Expected status code 200, got %d", resp.StatusCode) + } + + // 验证响应头 + if resp.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type")) + } + + // 解析响应体 + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + var response Response + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // 验证响应结构 + if response.Code != errors.CodeSuccess { + t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code) + } + + if response.Message != "success" { + t.Errorf("Expected message 'success', got '%s'", response.Message) + } + + // 验证时间戳格式 RFC3339 + if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil { + t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp) + } + + // 验证数据字段(如果不是 nil) + if tt.data != nil { + if response.Data == nil { + t.Error("Expected data field to be non-nil") + } + } + }) + } +} + +// TestError 测试错误响应(T035) +func TestError(t *testing.T) { + tests := []struct { + name string + httpStatus int + code int + message string + }{ + { + name: "internal server error", + httpStatus: 500, + code: errors.CodeInternalError, + message: "Internal server error occurred", + }, + { + name: "missing token error", + httpStatus: 401, + code: errors.CodeMissingToken, + message: "Authentication token is missing", + }, + { + name: "invalid token error", + httpStatus: 401, + code: errors.CodeInvalidToken, + message: "Token is invalid or expired", + }, + { + name: "rate limit error", + httpStatus: 429, + code: errors.CodeTooManyRequests, + message: "Too many requests, please try again later", + }, + { + name: "service unavailable error", + httpStatus: 503, + code: errors.CodeAuthServiceUnavailable, + message: "Authentication service is currently unavailable", + }, + { + name: "bad request error", + httpStatus: 400, + code: 2000, + message: "Invalid request parameters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + return Error(c, tt.httpStatus, tt.code, tt.message) + }) + + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + defer resp.Body.Close() + + // 验证 HTTP 状态码 + if resp.StatusCode != tt.httpStatus { + t.Errorf("Expected status code %d, got %d", tt.httpStatus, resp.StatusCode) + } + + // 验证响应头 + if resp.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type")) + } + + // 解析响应体 + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + var response Response + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // 验证响应结构 + if response.Code != tt.code { + t.Errorf("Expected code %d, got %d", tt.code, response.Code) + } + + if response.Message != tt.message { + t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message) + } + + if response.Data != nil { + t.Errorf("Expected data to be nil in error response, got %v", response.Data) + } + + // 验证时间戳格式 RFC3339 + if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil { + t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp) + } + }) + } +} + +// TestSuccessWithMessage 测试带自定义消息的成功响应(T034) +func TestSuccessWithMessage(t *testing.T) { + tests := []struct { + name string + data any + message string + }{ + { + name: "custom success message", + data: map[string]any{ + "user_id": 123, + }, + message: "User created successfully", + }, + { + name: "empty custom message", + data: "test data", + message: "", + }, + { + name: "chinese message", + data: map[string]string{ + "status": "ok", + }, + message: "操作成功", + }, + { + name: "long message", + data: nil, + message: "This is a very long success message that describes in detail what happened during the operation", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + return SuccessWithMessage(c, tt.data, tt.message) + }) + + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + defer resp.Body.Close() + + // 验证 HTTP 状态码(默认 200) + if resp.StatusCode != 200 { + t.Errorf("Expected status code 200, got %d", resp.StatusCode) + } + + // 解析响应体 + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + var response Response + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // 验证响应结构 + if response.Code != errors.CodeSuccess { + t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code) + } + + if response.Message != tt.message { + t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message) + } + + // 验证时间戳格式 RFC3339 + if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil { + t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp) + } + }) + } +} + +// TestResponseSerialization 测试响应序列化(T036) +func TestResponseSerialization(t *testing.T) { + tests := []struct { + name string + response Response + }{ + { + name: "complete response", + response: Response{ + Code: 0, + Data: map[string]any{"key": "value"}, + Message: "success", + Timestamp: time.Now().Format(time.RFC3339), + }, + }, + { + name: "response with nil data", + response: Response{ + Code: 1000, + Data: nil, + Message: "error", + Timestamp: time.Now().Format(time.RFC3339), + }, + }, + { + name: "response with nested data", + response: Response{ + Code: 0, + Data: map[string]any{ + "user": map[string]any{ + "id": 123, + "name": "test", + "tags": []string{"tag1", "tag2"}, + }, + }, + Message: "success", + Timestamp: time.Now().Format(time.RFC3339), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 序列化 + data, err := json.Marshal(tt.response) + if err != nil { + t.Fatalf("Failed to marshal response: %v", err) + } + + // 反序列化 + var deserialized Response + if err := json.Unmarshal(data, &deserialized); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // 验证字段 + if deserialized.Code != tt.response.Code { + t.Errorf("Code mismatch: expected %d, got %d", tt.response.Code, deserialized.Code) + } + + if deserialized.Message != tt.response.Message { + t.Errorf("Message mismatch: expected '%s', got '%s'", tt.response.Message, deserialized.Message) + } + + if deserialized.Timestamp != tt.response.Timestamp { + t.Errorf("Timestamp mismatch: expected '%s', got '%s'", tt.response.Timestamp, deserialized.Timestamp) + } + }) + } +} + +// TestResponseStructFields 测试响应结构字段(T036) +func TestResponseStructFields(t *testing.T) { + response := Response{ + Code: 0, + Data: "test", + Message: "success", + Timestamp: time.Now().Format(time.RFC3339), + } + + data, err := json.Marshal(response) + if err != nil { + t.Fatalf("Failed to marshal response: %v", err) + } + + // 解析为 map 以检查 JSON 键 + var jsonMap map[string]any + if err := json.Unmarshal(data, &jsonMap); err != nil { + t.Fatalf("Failed to unmarshal to map: %v", err) + } + + // 验证所有必需字段都存在 + requiredFields := []string{"code", "data", "msg", "timestamp"} + for _, field := range requiredFields { + if _, exists := jsonMap[field]; !exists { + t.Errorf("Required field '%s' is missing in JSON response", field) + } + } + + // 验证字段类型 + if _, ok := jsonMap["code"].(float64); !ok { + t.Error("Field 'code' should be a number") + } + + if _, ok := jsonMap["msg"].(string); !ok { + t.Error("Field 'msg' should be a string") + } + + if _, ok := jsonMap["timestamp"].(string); !ok { + t.Error("Field 'timestamp' should be a string") + } +} + +// TestMultipleResponses 测试多个连续响应(T036) +func TestMultipleResponses(t *testing.T) { + app := fiber.New() + + callCount := 0 + app.Get("/test", func(c *fiber.Ctx) error { + callCount++ + if callCount%2 == 0 { + return Success(c, map[string]int{"count": callCount}) + } + return Error(c, 500, errors.CodeInternalError, "error occurred") + }) + + // 发送多个请求 + for i := 1; i <= 5; i++ { + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request %d failed: %v", i, err) + } + + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + var response Response + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Request %d: failed to unmarshal response: %v", i, err) + } + + // 验证每个响应都有时间戳 + if response.Timestamp == "" { + t.Errorf("Request %d: timestamp should not be empty", i) + } + } +} + +// TestTimestampFormat 测试时间戳格式(T036) +func TestTimestampFormat(t *testing.T) { + app := fiber.New() + app.Get("/test", func(c *fiber.Ctx) error { + return Success(c, nil) + }) + + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var response Response + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // 验证是 RFC3339 格式 + parsedTime, err := time.Parse(time.RFC3339, response.Timestamp) + if err != nil { + t.Fatalf("Timestamp is not in RFC3339 format: %s, error: %v", response.Timestamp, err) + } + + // 验证时间戳是最近的(应该在最近 1 秒内) + now := time.Now() + diff := now.Sub(parsedTime) + if diff < 0 || diff > time.Second { + t.Errorf("Timestamp seems incorrect: %s (diff from now: %v)", response.Timestamp, diff) + } +} diff --git a/specs/001-fiber-middleware-integration/tasks.md b/specs/001-fiber-middleware-integration/tasks.md index 86e4cd1..140cc96 100644 --- a/specs/001-fiber-middleware-integration/tasks.md +++ b/specs/001-fiber-middleware-integration/tasks.md @@ -68,17 +68,17 @@ ### Unit Tests for User Story 1 -- [ ] T018 [P] [US1] Unit test for config loading and validation in pkg/config/loader_test.go -- [ ] T019 [P] [US1] Unit test for config hot reload mechanism in pkg/config/watcher_test.go -- [ ] T020 [P] [US1] Test invalid config handling (malformed YAML, validation errors) in pkg/config/config_test.go +- [X] T018 [P] [US1] Unit test for config loading and validation in pkg/config/loader_test.go +- [X] T019 [P] [US1] Unit test for config hot reload mechanism in pkg/config/watcher_test.go +- [X] T020 [P] [US1] Test invalid config handling (malformed YAML, validation errors) in pkg/config/config_test.go ### Implementation for User Story 1 -- [ ] T021 [US1] Implement atomic config pointer swap in pkg/config/config.go (sync/atomic usage) -- [ ] T022 [US1] Implement config change callback with validation in pkg/config/watcher.go -- [ ] T023 [US1] Add config reload logging with Zap in pkg/config/watcher.go -- [ ] T024 [US1] Integrate config watcher with context cancellation in cmd/api/main.go -- [ ] T025 [US1] Add graceful shutdown for config watcher in cmd/api/main.go +- [X] T021 [US1] Implement atomic config pointer swap in pkg/config/config.go (sync/atomic usage) +- [X] T022 [US1] Implement config change callback with validation in pkg/config/watcher.go +- [X] T023 [US1] Add config reload logging with Zap in pkg/config/watcher.go +- [X] T024 [US1] Integrate config watcher with context cancellation in cmd/api/main.go +- [X] T025 [US1] Add graceful shutdown for config watcher in cmd/api/main.go **Checkpoint**: Config hot reload should work independently - modify config and see changes applied @@ -92,17 +92,17 @@ ### Unit Tests for User Story 2 -- [ ] T026 [P] [US2] Unit test for logger initialization in pkg/logger/logger_test.go -- [ ] T027 [P] [US2] Unit test for log rotation configuration in pkg/logger/rotation_test.go -- [ ] T028 [P] [US2] Test structured logging with fields in pkg/logger/logger_test.go +- [X] T026 [P] [US2] Unit test for logger initialization in pkg/logger/logger_test.go +- [X] T027 [P] [US2] Unit test for log rotation configuration in pkg/logger/rotation_test.go +- [X] T028 [P] [US2] Test structured logging with fields in pkg/logger/logger_test.go ### Implementation for User Story 2 -- [ ] T029 [P] [US2] Create appLogger instance with Lumberjack writer in pkg/logger/logger.go -- [ ] T030 [P] [US2] Create accessLogger instance with separate Lumberjack writer in pkg/logger/logger.go -- [ ] T031 [US2] Configure JSON encoder with RFC3339 timestamps in pkg/logger/logger.go -- [ ] T032 [US2] Export GetAppLogger() and GetAccessLogger() functions in pkg/logger/logger.go -- [ ] T033 [US2] Add logger.Sync() call in graceful shutdown in cmd/api/main.go +- [X] T029 [P] [US2] Create appLogger instance with Lumberjack writer in pkg/logger/logger.go +- [X] T030 [P] [US2] Create accessLogger instance with separate Lumberjack writer in pkg/logger/logger.go +- [X] T031 [US2] Configure JSON encoder with RFC3339 timestamps in pkg/logger/logger.go +- [X] T032 [US2] Export GetAppLogger() and GetAccessLogger() functions in pkg/logger/logger.go +- [X] T033 [US2] Add logger.Sync() call in graceful shutdown in cmd/api/main.go **Checkpoint**: Both app.log and access.log should exist with JSON entries, rotate at configured size @@ -116,18 +116,18 @@ ### Unit Tests for User Story 3 -- [ ] T034 [P] [US3] Unit test for Success() response helper in pkg/response/response_test.go -- [ ] T035 [P] [US3] Unit test for Error() response helper in pkg/response/response_test.go -- [ ] T036 [P] [US3] Test response serialization with sonic JSON in pkg/response/response_test.go +- [X] T034 [P] [US3] Unit test for Success() response helper in pkg/response/response_test.go +- [X] T035 [P] [US3] Unit test for Error() response helper in pkg/response/response_test.go +- [X] T036 [P] [US3] Test response serialization with sonic JSON in pkg/response/response_test.go ### Implementation for User Story 3 -- [ ] T037 [P] [US3] Implement Success() helper function in pkg/response/response.go -- [ ] T038 [P] [US3] Implement Error() helper function in pkg/response/response.go -- [ ] T039 [P] [US3] Implement SuccessWithMessage() helper function in pkg/response/response.go -- [ ] T040 [US3] Configure Fiber to use sonic as JSON serializer in cmd/api/main.go -- [ ] T041 [US3] Create example health check endpoint using response helpers in internal/handler/health.go -- [ ] T042 [US3] Register health check route in cmd/api/main.go +- [X] T037 [P] [US3] Implement Success() helper function in pkg/response/response.go +- [X] T038 [P] [US3] Implement Error() helper function in pkg/response/response.go +- [X] T039 [P] [US3] Implement SuccessWithMessage() helper function in pkg/response/response.go +- [X] T040 [US3] Configure Fiber to use sonic as JSON serializer in cmd/api/main.go +- [X] T041 [US3] Create example health check endpoint using response helpers in internal/handler/health.go +- [X] T042 [US3] Register health check route in cmd/api/main.go **Checkpoint**: Health check endpoint returns unified response format with proper structure @@ -141,18 +141,18 @@ ### Integration Tests for User Story 4 -- [ ] T043 [P] [US4] Integration test for requestid middleware (UUID v4 generation) in tests/integration/middleware_test.go -- [ ] T044 [P] [US4] Integration test for logger middleware (access log entries) in tests/integration/middleware_test.go -- [ ] T045 [P] [US4] Test request ID propagation through middleware chain in tests/integration/middleware_test.go +- [X] T043 [P] [US4] Integration test for requestid middleware (UUID v4 generation) in tests/integration/middleware_test.go +- [X] T044 [P] [US4] Integration test for logger middleware (access log entries) in tests/integration/middleware_test.go +- [X] T045 [P] [US4] Test request ID propagation through middleware chain in tests/integration/middleware_test.go ### Implementation for User Story 4 -- [ ] T046 [P] [US4] Configure Fiber requestid middleware with google/uuid in cmd/api/main.go -- [ ] T047 [US4] Implement custom logger middleware writing to accessLogger in internal/middleware/logger.go -- [ ] T048 [US4] Add request ID to Fiber Locals in logger middleware in internal/middleware/logger.go -- [ ] T049 [US4] Add X-Request-ID response header in logger middleware in internal/middleware/logger.go -- [ ] T050 [US4] Log request details (method, path, status, duration, IP, user_agent) to access.log in internal/middleware/logger.go -- [ ] T051 [US4] Register requestid and logger middleware in correct order in cmd/api/main.go +- [X] T046 [P] [US4] Configure Fiber requestid middleware with google/uuid in cmd/api/main.go +- [X] T047 [US4] Implement custom logger middleware writing to accessLogger in internal/middleware/logger.go +- [X] T048 [US4] Add request ID to Fiber Locals in logger middleware in internal/middleware/logger.go +- [X] T049 [US4] Add X-Request-ID response header in logger middleware in internal/middleware/logger.go +- [X] T050 [US4] Log request details (method, path, status, duration, IP, user_agent) to access.log in internal/middleware/logger.go +- [X] T051 [US4] Register requestid and logger middleware in correct order in cmd/api/main.go **Checkpoint**: Every request should have unique UUID v4 in header and access.log, with full request details @@ -166,17 +166,17 @@ ### Integration Tests for User Story 5 -- [ ] T052 [P] [US5] Integration test for panic recovery in tests/integration/middleware_test.go -- [ ] T053 [P] [US5] Test panic logging with stack trace in tests/integration/middleware_test.go -- [ ] T054 [P] [US5] Test subsequent requests after panic recovery in tests/integration/middleware_test.go +- [X] T052 [P] [US5] Integration test for panic recovery in tests/integration/recover_test.go +- [X] T053 [P] [US5] Test panic logging with stack trace in tests/integration/recover_test.go +- [X] T054 [P] [US5] Test subsequent requests after panic recovery in tests/integration/recover_test.go ### Implementation for User Story 5 -- [ ] T055 [US5] Implement custom recover middleware with Zap logging in internal/middleware/recover.go -- [ ] T056 [US5] Add stack trace capture to recover middleware in internal/middleware/recover.go -- [ ] T057 [US5] Add request ID to panic logs in internal/middleware/recover.go -- [ ] T058 [US5] Return unified error response (500, code 1000) on panic in internal/middleware/recover.go -- [ ] T059 [US5] Register recover middleware as FIRST middleware in cmd/api/main.go +- [X] T055 [US5] Implement custom recover middleware with Zap logging in internal/middleware/recover.go +- [X] T056 [US5] Add stack trace capture to recover middleware in internal/middleware/recover.go +- [X] T057 [US5] Add request ID to panic logs in internal/middleware/recover.go +- [X] T058 [US5] Return unified error response (500, code 1000) on panic in internal/middleware/recover.go +- [X] T059 [US5] Register recover middleware as FIRST middleware in cmd/api/main.go - [ ] T060 [US5] Create test panic endpoint for testing in internal/handler/test.go (optional, for quickstart validation) **Checkpoint**: Panic in handler should be caught, logged with stack trace, return 500, server continues running @@ -205,19 +205,19 @@ ### Implementation for User Story 6 -- [ ] T069 [US6] Create TokenValidator struct with Redis client in pkg/validator/token.go -- [ ] T070 [US6] Implement TokenValidator.Validate() with Redis GET operation in pkg/validator/token.go -- [ ] T071 [US6] Add context timeout (50ms) for Redis operations in pkg/validator/token.go -- [ ] T072 [US6] Implement Redis availability check (Ping) with fail-closed behavior in pkg/validator/token.go -- [ ] T073 [US6] Implement custom keyauth middleware wrapper in internal/middleware/auth.go -- [ ] T074 [US6] Configure keyauth with header lookup "token" in internal/middleware/auth.go -- [ ] T075 [US6] Add validator callback to keyauth config in internal/middleware/auth.go -- [ ] T076 [US6] Store user_id in Fiber Locals after successful validation in internal/middleware/auth.go -- [ ] T077 [US6] Implement custom ErrorHandler mapping errors to response codes in internal/middleware/auth.go -- [ ] T078 [US6] Add auth failure logging with request ID in internal/middleware/auth.go -- [ ] T079 [US6] Register keyauth middleware after logger in cmd/api/main.go -- [ ] T080 [US6] Create protected example endpoint (/api/v1/users) in internal/handler/user.go -- [ ] T081 [US6] Register protected routes with middleware in cmd/api/main.go +- [X] T069 [US6] Create TokenValidator struct with Redis client in pkg/validator/token.go +- [X] T070 [US6] Implement TokenValidator.Validate() with Redis GET operation in pkg/validator/token.go +- [X] T071 [US6] Add context timeout (50ms) for Redis operations in pkg/validator/token.go +- [X] T072 [US6] Implement Redis availability check (Ping) with fail-closed behavior in pkg/validator/token.go +- [X] T073 [US6] Implement custom keyauth middleware wrapper in internal/middleware/auth.go +- [X] T074 [US6] Configure keyauth with header lookup "token" in internal/middleware/auth.go +- [X] T075 [US6] Add validator callback to keyauth config in internal/middleware/auth.go +- [X] T076 [US6] Store user_id in Fiber Locals after successful validation in internal/middleware/auth.go +- [X] T077 [US6] Implement custom ErrorHandler mapping errors to response codes in internal/middleware/auth.go +- [X] T078 [US6] Add auth failure logging with request ID in internal/middleware/auth.go +- [X] T079 [US6] Register keyauth middleware after logger in cmd/api/main.go +- [X] T080 [US6] Create protected example endpoint (/api/v1/users) in internal/handler/user.go +- [X] T081 [US6] Register protected routes with middleware in cmd/api/main.go **Checkpoint**: Protected endpoints require valid token, reject invalid/missing tokens with correct error codes @@ -237,11 +237,11 @@ ### Implementation for User Story 7 -- [ ] T085 [US7] Implement rate limiter middleware wrapper (COMMENTED by default) in internal/middleware/ratelimit.go -- [ ] T086 [US7] Configure limiter with IP-based key generator (c.IP()) in internal/middleware/ratelimit.go -- [ ] T087 [US7] Configure limiter with config values (Max, Expiration) in internal/middleware/ratelimit.go -- [ ] T088 [US7] Add custom LimitReached handler returning unified error response in internal/middleware/ratelimit.go -- [ ] T089 [US7] Add commented middleware registration example in cmd/api/main.go +- [X] T085 [US7] Implement rate limiter middleware wrapper (COMMENTED by default) in internal/middleware/ratelimit.go +- [X] T086 [US7] Configure limiter with IP-based key generator (c.IP()) in internal/middleware/ratelimit.go +- [X] T087 [US7] Configure limiter with config values (Max, Expiration) in internal/middleware/ratelimit.go +- [X] T088 [US7] Add custom LimitReached handler returning unified error response in internal/middleware/ratelimit.go +- [X] T089 [US7] Add commented middleware registration example in cmd/api/main.go - [ ] T090 [US7] Document rate limiter usage in quickstart.md (how to enable, configure) - [ ] T091 [US7] Add rate limiter configuration examples to config files diff --git a/tests/integration/middleware_test.go b/tests/integration/middleware_test.go new file mode 100644 index 0000000..d213557 --- /dev/null +++ b/tests/integration/middleware_test.go @@ -0,0 +1,533 @@ +package integration + +import ( + "io" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/break/junhong_cmp_fiber/pkg/constants" + "github.com/break/junhong_cmp_fiber/pkg/logger" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/requestid" + "github.com/google/uuid" +) + +// TestRequestIDMiddleware 测试 RequestID 中间件生成 UUID v4(T043) +func TestRequestIDMiddleware(t *testing.T) { + app := fiber.New() + + // 配置 requestid 中间件使用 UUID v4 + app.Use(requestid.New(requestid.Config{ + Generator: func() string { + return uuid.NewString() + }, + })) + + app.Get("/test", func(c *fiber.Ctx) error { + requestID := c.Locals(constants.ContextKeyRequestID) + return c.JSON(fiber.Map{ + "request_id": requestID, + }) + }) + + tests := []struct { + name string + }{ + {name: "request 1"}, + {name: "request 2"}, + {name: "request 3"}, + } + + seenIDs := make(map[string]bool) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + defer resp.Body.Close() + + // 验证响应头包含 X-Request-ID + requestID := resp.Header.Get("X-Request-ID") + if requestID == "" { + t.Error("X-Request-ID header should not be empty") + } + + // 验证是 UUID v4 格式 + if _, err := uuid.Parse(requestID); err != nil { + t.Errorf("X-Request-ID is not a valid UUID: %s, error: %v", requestID, err) + } + + // 验证 UUID 是唯一的 + if seenIDs[requestID] { + t.Errorf("Request ID %s is not unique", requestID) + } + seenIDs[requestID] = true + + t.Logf("Request ID: %s", requestID) + }) + } + + // 验证生成了多个不同的 ID + if len(seenIDs) != len(tests) { + t.Errorf("Expected %d unique request IDs, got %d", len(tests), len(seenIDs)) + } +} + +// TestLoggerMiddleware 测试 Logger 中间件记录访问日志(T044) +func TestLoggerMiddleware(t *testing.T) { + // 创建临时目录用于日志 + tempDir := t.TempDir() + accessLogFile := filepath.Join(tempDir, "access.log") + + // 初始化日志系统 + err := logger.InitLoggers("info", false, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "app.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + logger.LogRotationConfig{ + Filename: accessLogFile, + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("Failed to initialize loggers: %v", err) + } + defer logger.Sync() + + // 创建应用 + app := fiber.New() + + // 注册中间件 + app.Use(requestid.New(requestid.Config{ + Generator: func() string { + return uuid.NewString() + }, + })) + app.Use(logger.Middleware()) + + app.Get("/test", func(c *fiber.Ctx) error { + return c.SendString("ok") + }) + + app.Post("/test", func(c *fiber.Ctx) error { + return c.SendStatus(201) + }) + + tests := []struct { + name string + method string + path string + expectedStatus int + }{ + { + name: "GET request", + method: "GET", + path: "/test", + expectedStatus: 200, + }, + { + name: "POST request", + method: "POST", + path: "/test", + expectedStatus: 201, + }, + { + name: "GET with query params", + method: "GET", + path: "/test?foo=bar", + expectedStatus: 200, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, nil) + req.Header.Set("User-Agent", "test-agent/1.0") + + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode) + } + }) + } + + // 刷新日志缓冲区 + logger.Sync() + time.Sleep(100 * time.Millisecond) + + // 验证访问日志文件存在且有内容 + content, err := os.ReadFile(accessLogFile) + if err != nil { + t.Fatalf("Failed to read access log: %v", err) + } + + if len(content) == 0 { + t.Error("Access log should not be empty") + } + + logContent := string(content) + t.Logf("Access log content:\n%s", logContent) + + // 验证日志包含必要的字段 + requiredFields := []string{ + "method", + "path", + "status", + "duration_ms", + "request_id", + "ip", + "user_agent", + } + + for _, field := range requiredFields { + if !strings.Contains(logContent, field) { + t.Errorf("Access log should contain field '%s'", field) + } + } + + // 验证记录了所有请求 + lines := strings.Split(strings.TrimSpace(logContent), "\n") + if len(lines) < len(tests) { + t.Errorf("Expected at least %d log entries, got %d", len(tests), len(lines)) + } +} + +// TestRequestIDPropagation 测试 Request ID 在中间件链中传播(T045) +func TestRequestIDPropagation(t *testing.T) { + // 创建临时目录用于日志 + tempDir := t.TempDir() + + // 初始化日志系统 + err := logger.InitLoggers("info", false, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "app.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "access.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("Failed to initialize loggers: %v", err) + } + defer logger.Sync() + + // 创建应用 + app := fiber.New() + + var capturedRequestID string + + // 1. RequestID 中间件(第一个) + app.Use(requestid.New(requestid.Config{ + Generator: func() string { + return uuid.NewString() + }, + })) + + // 2. Logger 中间件(第二个) + app.Use(logger.Middleware()) + + // 3. 自定义中间件验证 request ID 是否可访问 + app.Use(func(c *fiber.Ctx) error { + requestID := c.Locals(constants.ContextKeyRequestID) + if requestID == nil { + t.Error("Request ID should be available in middleware chain") + } + if rid, ok := requestID.(string); ok { + capturedRequestID = rid + } + return c.Next() + }) + + app.Get("/test", func(c *fiber.Ctx) error { + // 在 handler 中也验证 request ID + requestID := c.Locals(constants.ContextKeyRequestID) + if requestID == nil { + return c.Status(500).SendString("Request ID not found in handler") + } + + return c.JSON(fiber.Map{ + "request_id": requestID, + "message": "Request ID propagated successfully", + }) + }) + + // 执行请求 + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + defer resp.Body.Close() + + // 验证响应 + if resp.StatusCode != 200 { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // 验证响应头中的 Request ID + headerRequestID := resp.Header.Get("X-Request-ID") + if headerRequestID == "" { + t.Error("X-Request-ID header should be set") + } + + // 验证中间件捕获的 Request ID 与响应头一致 + if capturedRequestID != headerRequestID { + t.Errorf("Request ID mismatch: middleware=%s, header=%s", capturedRequestID, headerRequestID) + } + + // 验证响应体 + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if !strings.Contains(string(body), headerRequestID) { + t.Errorf("Response body should contain request ID %s", headerRequestID) + } + + t.Logf("Request ID successfully propagated: %s", headerRequestID) +} + +// TestMiddlewareOrder 测试中间件执行顺序(T045) +func TestMiddlewareOrder(t *testing.T) { + app := fiber.New() + + executionOrder := []string{} + + // 中间件 1: RequestID + app.Use(func(c *fiber.Ctx) error { + executionOrder = append(executionOrder, "requestid-start") + c.Locals(constants.ContextKeyRequestID, uuid.NewString()) + err := c.Next() + executionOrder = append(executionOrder, "requestid-end") + return err + }) + + // 中间件 2: Logger + app.Use(func(c *fiber.Ctx) error { + executionOrder = append(executionOrder, "logger-start") + // 验证 Request ID 已经设置 + if c.Locals(constants.ContextKeyRequestID) == nil { + t.Error("Request ID should be set before logger middleware") + } + err := c.Next() + executionOrder = append(executionOrder, "logger-end") + return err + }) + + // 中间件 3: Custom + app.Use(func(c *fiber.Ctx) error { + executionOrder = append(executionOrder, "custom-start") + err := c.Next() + executionOrder = append(executionOrder, "custom-end") + return err + }) + + app.Get("/test", func(c *fiber.Ctx) error { + executionOrder = append(executionOrder, "handler") + return c.SendString("ok") + }) + + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + resp.Body.Close() + + // 验证执行顺序 + expectedOrder := []string{ + "requestid-start", + "logger-start", + "custom-start", + "handler", + "custom-end", + "logger-end", + "requestid-end", + } + + if len(executionOrder) != len(expectedOrder) { + t.Errorf("Expected %d execution steps, got %d", len(expectedOrder), len(executionOrder)) + } + + for i, expected := range expectedOrder { + if i >= len(executionOrder) { + t.Errorf("Missing execution step at index %d: expected '%s'", i, expected) + continue + } + if executionOrder[i] != expected { + t.Errorf("Execution order mismatch at index %d: expected '%s', got '%s'", i, expected, executionOrder[i]) + } + } + + t.Logf("Middleware execution order: %v", executionOrder) +} + +// TestLoggerMiddlewareWithUserID 测试 Logger 中间件记录用户 ID(T044) +func TestLoggerMiddlewareWithUserID(t *testing.T) { + // 创建临时目录用于日志 + tempDir := t.TempDir() + accessLogFile := filepath.Join(tempDir, "access-userid.log") + + // 初始化日志系统 + err := logger.InitLoggers("info", false, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "app.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + logger.LogRotationConfig{ + Filename: accessLogFile, + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("Failed to initialize loggers: %v", err) + } + defer logger.Sync() + + // 创建应用 + app := fiber.New() + + // 注册中间件 + app.Use(requestid.New(requestid.Config{ + Generator: func() string { + return uuid.NewString() + }, + })) + + // 模拟 auth 中间件设置 user_id + app.Use(func(c *fiber.Ctx) error { + c.Locals(constants.ContextKeyUserID, "user_12345") + return c.Next() + }) + + app.Use(logger.Middleware()) + + app.Get("/test", func(c *fiber.Ctx) error { + return c.SendString("ok") + }) + + // 执行请求 + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + resp.Body.Close() + + // 刷新日志缓冲区 + logger.Sync() + time.Sleep(100 * time.Millisecond) + + // 验证访问日志包含 user_id + content, err := os.ReadFile(accessLogFile) + if err != nil { + t.Fatalf("Failed to read access log: %v", err) + } + + logContent := string(content) + if !strings.Contains(logContent, "user_12345") { + t.Error("Access log should contain user_id 'user_12345'") + } + + t.Logf("Access log with user_id:\n%s", logContent) +} + +// TestConcurrentRequests 测试并发请求的 Request ID 唯一性(T043) +func TestConcurrentRequests(t *testing.T) { + app := fiber.New() + + app.Use(requestid.New(requestid.Config{ + Generator: func() string { + return uuid.NewString() + }, + })) + + app.Get("/test", func(c *fiber.Ctx) error { + // 模拟一些处理时间 + time.Sleep(10 * time.Millisecond) + requestID := c.Locals(constants.ContextKeyRequestID) + return c.JSON(fiber.Map{ + "request_id": requestID, + }) + }) + + // 并发发送多个请求 + const numRequests = 50 + requestIDs := make(chan string, numRequests) + errors := make(chan error, numRequests) + + for i := 0; i < numRequests; i++ { + go func() { + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + errors <- err + return + } + defer resp.Body.Close() + + requestID := resp.Header.Get("X-Request-ID") + requestIDs <- requestID + errors <- nil + }() + } + + // 收集所有结果 + seenIDs := make(map[string]bool) + for i := 0; i < numRequests; i++ { + if err := <-errors; err != nil { + t.Fatalf("Request failed: %v", err) + } + requestID := <-requestIDs + + if requestID == "" { + t.Error("Request ID should not be empty") + } + + if seenIDs[requestID] { + t.Errorf("Duplicate request ID found: %s", requestID) + } + seenIDs[requestID] = true + } + + // 验证所有 ID 都是唯一的 + if len(seenIDs) != numRequests { + t.Errorf("Expected %d unique request IDs, got %d", numRequests, len(seenIDs)) + } + + t.Logf("Successfully generated %d unique request IDs concurrently", len(seenIDs)) +} diff --git a/tests/integration/recover_test.go b/tests/integration/recover_test.go new file mode 100644 index 0000000..e64984f --- /dev/null +++ b/tests/integration/recover_test.go @@ -0,0 +1,618 @@ +package integration + +import ( + "encoding/json" + "io" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/break/junhong_cmp_fiber/internal/middleware" + "github.com/break/junhong_cmp_fiber/pkg/errors" + "github.com/break/junhong_cmp_fiber/pkg/logger" + "github.com/break/junhong_cmp_fiber/pkg/response" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/requestid" + "github.com/google/uuid" +) + +// TestPanicRecovery 测试 panic 恢复功能(T052) +func TestPanicRecovery(t *testing.T) { + // 创建临时目录用于日志 + tempDir := t.TempDir() + + // 初始化日志系统 + err := logger.InitLoggers("info", false, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "app-panic.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-panic.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("Failed to initialize loggers: %v", err) + } + defer logger.Sync() + + appLogger := logger.GetAppLogger() + + // 创建应用 + app := fiber.New() + + // 注册中间件(recover 必须第一个) + app.Use(middleware.Recover(appLogger)) + app.Use(requestid.New(requestid.Config{ + Generator: func() string { + return uuid.NewString() + }, + })) + + // 创建会 panic 的 handler + app.Get("/panic", func(c *fiber.Ctx) error { + panic("intentional panic for testing") + }) + + // 创建正常的 handler + app.Get("/ok", func(c *fiber.Ctx) error { + return c.SendString("ok") + }) + + tests := []struct { + name string + path string + shouldPanic bool + expectedStatus int + expectedCode int + }{ + { + name: "panic endpoint returns 500", + path: "/panic", + shouldPanic: true, + expectedStatus: 500, + expectedCode: errors.CodeInternalError, + }, + { + name: "normal endpoint works after panic", + path: "/ok", + shouldPanic: false, + expectedStatus: 200, + expectedCode: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.path, nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + defer resp.Body.Close() + + // 验证 HTTP 状态码 + if resp.StatusCode != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode) + } + + // 解析响应 + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if tt.shouldPanic { + // panic 应该返回统一错误响应 + var response response.Response + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if response.Code != tt.expectedCode { + t.Errorf("Expected code %d, got %d", tt.expectedCode, response.Code) + } + + if response.Data != nil { + t.Error("Error response data should be nil") + } + } + }) + } +} + +// TestPanicLogging 测试 panic 日志记录和堆栈跟踪(T053) +func TestPanicLogging(t *testing.T) { + // 创建临时目录用于日志 + tempDir := t.TempDir() + appLogFile := filepath.Join(tempDir, "app-panic-log.log") + + // 初始化日志系统 + err := logger.InitLoggers("info", false, + logger.LogRotationConfig{ + Filename: appLogFile, + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-panic-log.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("Failed to initialize loggers: %v", err) + } + defer logger.Sync() + + appLogger := logger.GetAppLogger() + + // 创建应用 + app := fiber.New() + + // 注册中间件 + app.Use(middleware.Recover(appLogger)) + app.Use(requestid.New(requestid.Config{ + Generator: func() string { + return uuid.NewString() + }, + })) + + // 创建不同类型的 panic + app.Get("/panic-string", func(c *fiber.Ctx) error { + panic("string panic message") + }) + + app.Get("/panic-error", func(c *fiber.Ctx) error { + panic(fiber.NewError(500, "error panic message")) + }) + + app.Get("/panic-struct", func(c *fiber.Ctx) error { + panic(struct{ Message string }{"struct panic message"}) + }) + + tests := []struct { + name string + path string + expectedInLog []string + unexpectedInLog []string + }{ + { + name: "string panic logs correctly", + path: "/panic-string", + expectedInLog: []string{ + "Panic 已恢复", + "string panic message", + "stack", + "request_id", + "method", + "path", + }, + }, + { + name: "error panic logs correctly", + path: "/panic-error", + expectedInLog: []string{ + "Panic 已恢复", + "error panic message", + "stack", + }, + }, + { + name: "struct panic logs correctly", + path: "/panic-struct", + expectedInLog: []string{ + "Panic 已恢复", + "stack", + "Message", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 执行会 panic 的请求 + req := httptest.NewRequest("GET", tt.path, nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + resp.Body.Close() + + // 刷新日志缓冲区 + logger.Sync() + time.Sleep(100 * time.Millisecond) + + // 读取日志内容 + logContent, err := os.ReadFile(appLogFile) + if err != nil { + t.Fatalf("Failed to read app log: %v", err) + } + + content := string(logContent) + + // 验证日志包含预期内容 + for _, expected := range tt.expectedInLog { + if !strings.Contains(content, expected) { + t.Errorf("Log should contain '%s'", expected) + } + } + + // 验证日志不包含意外内容 + for _, unexpected := range tt.unexpectedInLog { + if strings.Contains(content, unexpected) { + t.Errorf("Log should NOT contain '%s'", unexpected) + } + } + + // 验证堆栈跟踪包含文件和行号 + if !strings.Contains(content, "recover_test.go") { + t.Error("Stack trace should contain source file name") + } + + t.Logf("Panic log contains stack trace: %v", strings.Contains(content, "stack")) + }) + } +} + +// TestSubsequentRequestsAfterPanic 测试 panic 后后续请求正常处理(T054) +func TestSubsequentRequestsAfterPanic(t *testing.T) { + // 创建临时目录用于日志 + tempDir := t.TempDir() + + // 初始化日志系统 + err := logger.InitLoggers("info", false, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "app.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "access.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("Failed to initialize loggers: %v", err) + } + defer logger.Sync() + + appLogger := logger.GetAppLogger() + + // 创建应用 + app := fiber.New() + + // 注册中间件 + app.Use(middleware.Recover(appLogger)) + app.Use(requestid.New(requestid.Config{ + Generator: func() string { + return uuid.NewString() + }, + })) + + callCount := 0 + + app.Get("/test", func(c *fiber.Ctx) error { + callCount++ + // 第 1、3、5 次调用会 panic + if callCount%2 == 1 { + panic("test panic") + } + // 第 2、4、6 次调用正常返回 + return c.JSON(fiber.Map{ + "call_count": callCount, + "status": "ok", + }) + }) + + // 执行多次请求,验证 panic 不影响后续请求 + for i := 1; i <= 6; i++ { + req := httptest.NewRequest("GET", "/test", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Request %d failed: %v", i, err) + } + + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + if i%2 == 1 { + // 奇数次应该返回 500 + if resp.StatusCode != 500 { + t.Errorf("Request %d: expected status 500, got %d", i, resp.StatusCode) + } + } else { + // 偶数次应该返回 200 + if resp.StatusCode != 200 { + t.Errorf("Request %d: expected status 200, got %d", i, resp.StatusCode) + } + + // 验证响应内容 + var response map[string]any + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Request %d: failed to unmarshal response: %v", i, err) + } + + if status, ok := response["status"].(string); !ok || status != "ok" { + t.Errorf("Request %d: expected status 'ok', got %v", i, response["status"]) + } + } + + t.Logf("Request %d completed: status=%d", i, resp.StatusCode) + } + + // 验证所有 6 次调用都执行了 + if callCount != 6 { + t.Errorf("Expected 6 calls, got %d", callCount) + } +} + +// TestPanicWithRequestID 测试 panic 日志包含 Request ID(T053) +func TestPanicWithRequestID(t *testing.T) { + // 创建临时目录用于日志 + tempDir := t.TempDir() + appLogFile := filepath.Join(tempDir, "app-panic-reqid.log") + + // 初始化日志系统 + err := logger.InitLoggers("info", false, + logger.LogRotationConfig{ + Filename: appLogFile, + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "access-panic-reqid.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("Failed to initialize loggers: %v", err) + } + defer logger.Sync() + + appLogger := logger.GetAppLogger() + + // 创建应用 + app := fiber.New() + + // 注册中间件(顺序重要) + app.Use(middleware.Recover(appLogger)) + app.Use(requestid.New(requestid.Config{ + Generator: func() string { + return uuid.NewString() + }, + })) + + app.Get("/panic", func(c *fiber.Ctx) error { + panic("test panic with request id") + }) + + // 执行请求 + req := httptest.NewRequest("GET", "/panic", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + resp.Body.Close() + + // 获取 Request ID + requestID := resp.Header.Get("X-Request-ID") + if requestID == "" { + t.Error("X-Request-ID header should be set even after panic") + } + + // 刷新日志缓冲区 + logger.Sync() + time.Sleep(100 * time.Millisecond) + + // 读取日志内容 + logContent, err := os.ReadFile(appLogFile) + if err != nil { + t.Fatalf("Failed to read app log: %v", err) + } + + content := string(logContent) + + // 验证日志包含 Request ID + if !strings.Contains(content, requestID) { + t.Errorf("Panic log should contain request ID '%s'", requestID) + } + + // 验证日志包含关键字段 + requiredFields := []string{ + "request_id", + "method", + "path", + "panic", + "stack", + } + + for _, field := range requiredFields { + if !strings.Contains(content, field) { + t.Errorf("Panic log should contain field '%s'", field) + } + } + + t.Logf("Panic log successfully includes Request ID: %s", requestID) +} + +// TestConcurrentPanics 测试并发 panic 处理(T054) +func TestConcurrentPanics(t *testing.T) { + // 创建临时目录用于日志 + tempDir := t.TempDir() + + // 初始化日志系统 + err := logger.InitLoggers("info", false, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "app.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "access.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("Failed to initialize loggers: %v", err) + } + defer logger.Sync() + + appLogger := logger.GetAppLogger() + + // 创建应用 + app := fiber.New() + + // 注册中间件 + app.Use(middleware.Recover(appLogger)) + app.Use(requestid.New(requestid.Config{ + Generator: func() string { + return uuid.NewString() + }, + })) + + app.Get("/panic", func(c *fiber.Ctx) error { + panic("concurrent panic test") + }) + + // 并发发送多个会 panic 的请求 + const numRequests = 20 + errors := make(chan error, numRequests) + statuses := make(chan int, numRequests) + + for i := 0; i < numRequests; i++ { + go func() { + req := httptest.NewRequest("GET", "/panic", nil) + resp, err := app.Test(req) + if err != nil { + errors <- err + statuses <- 0 + return + } + defer resp.Body.Close() + + statuses <- resp.StatusCode + errors <- nil + }() + } + + // 收集所有结果 + for i := 0; i < numRequests; i++ { + if err := <-errors; err != nil { + t.Fatalf("Request failed: %v", err) + } + status := <-statuses + if status != 500 { + t.Errorf("Expected status 500, got %d", status) + } + } + + t.Logf("Successfully handled %d concurrent panics", numRequests) +} + +// TestRecoverMiddlewareOrder 测试 Recover 中间件必须在第一个(T052) +func TestRecoverMiddlewareOrder(t *testing.T) { + // 创建临时目录用于日志 + tempDir := t.TempDir() + + // 初始化日志系统 + err := logger.InitLoggers("info", false, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "app.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + logger.LogRotationConfig{ + Filename: filepath.Join(tempDir, "access.log"), + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }, + ) + if err != nil { + t.Fatalf("Failed to initialize loggers: %v", err) + } + defer logger.Sync() + + appLogger := logger.GetAppLogger() + + // 创建应用 + app := fiber.New() + + // 正确的顺序:Recover → RequestID → Logger + app.Use(middleware.Recover(appLogger)) + app.Use(requestid.New(requestid.Config{ + Generator: func() string { + return uuid.NewString() + }, + })) + app.Use(logger.Middleware()) + + app.Get("/panic", func(c *fiber.Ctx) error { + panic("test panic") + }) + + // 执行请求 + req := httptest.NewRequest("GET", "/panic", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Failed to execute request: %v", err) + } + defer resp.Body.Close() + + // 验证请求被正确处理(返回 500 而不是崩溃) + if resp.StatusCode != 500 { + t.Errorf("Expected status 500, got %d", resp.StatusCode) + } + + // 验证仍然有 Request ID(说明 RequestID 中间件在 Recover 之后执行) + requestID := resp.Header.Get("X-Request-ID") + if requestID == "" { + t.Error("X-Request-ID should be set even after panic") + } + + // 解析响应,验证返回了统一错误格式 + body, _ := io.ReadAll(resp.Body) + var response response.Response + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if response.Code != errors.CodeInternalError { + t.Errorf("Expected code %d, got %d", errors.CodeInternalError, response.Code) + } + + t.Logf("Recover middleware correctly placed first, handled panic gracefully") +}