feat: 添加环境变量管理工具和部署配置改版
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 5m33s

主要改动:
- 新增交互式环境配置脚本 (scripts/setup-env.sh)
- 新增本地启动快捷脚本 (scripts/run-local.sh)
- 新增环境变量模板文件 (.env.example)
- 部署模式改版:使用嵌入式配置 + 环境变量覆盖
- 添加对象存储功能支持
- 改进 IoT 卡片导入任务
- 优化 OpenAPI 文档生成
- 删除旧的配置文件,改用嵌入式默认配置
This commit is contained in:
2026-01-26 10:28:29 +08:00
parent 194078674a
commit 45aa7deb87
94 changed files with 6532 additions and 1967 deletions

View File

@@ -3,6 +3,7 @@ package config
import (
"errors"
"fmt"
"strings"
"sync/atomic"
"time"
)
@@ -21,6 +22,7 @@ type Config struct {
SMS SMSConfig `mapstructure:"sms"`
JWT JWTConfig `mapstructure:"jwt"`
DefaultAdmin DefaultAdminConfig `mapstructure:"default_admin"`
Storage StorageConfig `mapstructure:"storage"`
}
// ServerConfig HTTP 服务器配置
@@ -120,7 +122,61 @@ type DefaultAdminConfig struct {
Phone string `mapstructure:"phone"`
}
// Validate 验证配置
// StorageConfig 对象存储配置
type StorageConfig struct {
Provider string `mapstructure:"provider"` // 存储提供商s3
S3 S3Config `mapstructure:"s3"` // S3 兼容存储配置
Presign PresignConfig `mapstructure:"presign"` // 预签名 URL 配置
TempDir string `mapstructure:"temp_dir"` // 临时文件目录
}
// S3Config S3 兼容存储配置
type S3Config struct {
Endpoint string `mapstructure:"endpoint"` // 服务端点http://obs-helf.cucloud.cn
Region string `mapstructure:"region"` // 区域cn-langfang-2
Bucket string `mapstructure:"bucket"` // 存储桶名称
AccessKeyID string `mapstructure:"access_key_id"` // 访问密钥 ID
SecretAccessKey string `mapstructure:"secret_access_key"` // 访问密钥
UseSSL bool `mapstructure:"use_ssl"` // 是否使用 SSL
PathStyle bool `mapstructure:"path_style"` // 是否使用路径风格(兼容性)
}
// PresignConfig 预签名 URL 配置
type PresignConfig struct {
UploadExpires time.Duration `mapstructure:"upload_expires"` // 上传 URL 有效期默认15m
DownloadExpires time.Duration `mapstructure:"download_expires"` // 下载 URL 有效期默认24h
}
type requiredField struct {
value string
name string
envName string
}
func (c *Config) ValidateRequired() error {
fields := []requiredField{
{c.Database.Host, "database.host", "JUNHONG_DATABASE_HOST"},
{c.Database.User, "database.user", "JUNHONG_DATABASE_USER"},
{c.Database.Password, "database.password", "JUNHONG_DATABASE_PASSWORD"},
{c.Database.DBName, "database.dbname", "JUNHONG_DATABASE_DBNAME"},
{c.Redis.Address, "redis.address", "JUNHONG_REDIS_ADDRESS"},
{c.JWT.SecretKey, "jwt.secret_key", "JUNHONG_JWT_SECRET_KEY"},
}
var missing []string
for _, f := range fields {
if f.value == "" {
missing = append(missing, fmt.Sprintf(" - %s (环境变量: %s)", f.name, f.envName))
}
}
if len(missing) > 0 {
return fmt.Errorf("缺少必填配置项:\n%s", strings.Join(missing, "\n"))
}
return nil
}
func (c *Config) Validate() error {
// 服务器验证
if c.Server.Address == "" {
@@ -184,28 +240,24 @@ func (c *Config) Validate() error {
return fmt.Errorf("invalid configuration: middleware.rate_limiter.storage: invalid storage type (current value: %s, expected: memory or redis)", c.Middleware.RateLimiter.Storage)
}
// 短信服务验证
if c.SMS.GatewayURL == "" {
return fmt.Errorf("invalid configuration: sms.gateway_url: must be non-empty (current value: empty)")
}
if c.SMS.Username == "" {
return fmt.Errorf("invalid configuration: sms.username: must be non-empty (current value: empty)")
}
if c.SMS.Password == "" {
return fmt.Errorf("invalid configuration: sms.password: must be non-empty (current value: empty)")
}
if c.SMS.Signature == "" {
return fmt.Errorf("invalid configuration: sms.signature: must be non-empty (current value: empty)")
}
if c.SMS.Timeout < 5*time.Second || c.SMS.Timeout > 60*time.Second {
return fmt.Errorf("invalid configuration: sms.timeout: duration out of range (current value: %s, expected: 5s-60s)", c.SMS.Timeout)
// 短信服务验证(可选,配置 GatewayURL 时才验证其他字段)
if c.SMS.GatewayURL != "" {
if c.SMS.Username == "" {
return fmt.Errorf("invalid configuration: sms.username: must be non-empty when gateway_url is configured")
}
if c.SMS.Password == "" {
return fmt.Errorf("invalid configuration: sms.password: must be non-empty when gateway_url is configured")
}
if c.SMS.Signature == "" {
return fmt.Errorf("invalid configuration: sms.signature: must be non-empty when gateway_url is configured")
}
if c.SMS.Timeout > 0 && (c.SMS.Timeout < 5*time.Second || c.SMS.Timeout > 60*time.Second) {
return fmt.Errorf("invalid configuration: sms.timeout: duration out of range (current value: %s, expected: 5s-60s)", c.SMS.Timeout)
}
}
// JWT 验证
if c.JWT.SecretKey == "" {
return fmt.Errorf("invalid configuration: jwt.secret_key: must be non-empty (current value: empty)")
}
if len(c.JWT.SecretKey) < 32 {
// JWT 验证SecretKey 必填验证在 ValidateRequired 中处理)
if len(c.JWT.SecretKey) > 0 && len(c.JWT.SecretKey) < 32 {
return fmt.Errorf("invalid configuration: jwt.secret_key: secret key too short (current length: %d, expected: >= 32)", len(c.JWT.SecretKey))
}
if c.JWT.TokenDuration < 1*time.Hour || c.JWT.TokenDuration > 720*time.Hour {

View File

@@ -51,6 +51,11 @@ func TestConfig_Validate(t *testing.T) {
Storage: "memory",
},
},
JWT: JWTConfig{
TokenDuration: 24 * time.Hour,
AccessTokenTTL: 24 * time.Hour,
RefreshTokenTTL: 168 * time.Hour,
},
},
wantErr: false,
},
@@ -582,6 +587,11 @@ func TestSet(t *testing.T) {
MaxSize: 500,
},
},
JWT: JWTConfig{
TokenDuration: 24 * time.Hour,
AccessTokenTTL: 24 * time.Hour,
RefreshTokenTTL: 168 * time.Hour,
},
}
err := Set(validCfg)

View File

@@ -0,0 +1,106 @@
# 默认配置文件(嵌入二进制)
# 敏感配置和必填配置为空,必须通过环境变量设置
# 环境变量格式JUNHONG_{SECTION}_{KEY}
server:
address: ":3000"
read_timeout: "30s"
write_timeout: "30s"
shutdown_timeout: "30s"
prefork: false
# 数据库配置(必填项需通过环境变量设置)
database:
host: "" # 必填JUNHONG_DATABASE_HOST
port: 5432
user: "" # 必填JUNHONG_DATABASE_USER
password: "" # 必填JUNHONG_DATABASE_PASSWORD敏感
dbname: "" # 必填JUNHONG_DATABASE_DBNAME
sslmode: "disable"
max_open_conns: 25
max_idle_conns: 10
conn_max_lifetime: "5m"
# Redis 配置(必填项需通过环境变量设置)
redis:
address: "" # 必填JUNHONG_REDIS_ADDRESS
port: 6379
password: "" # 可选JUNHONG_REDIS_PASSWORD敏感
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
# 对象存储配置
storage:
provider: "s3"
temp_dir: "/tmp/junhong-storage"
s3:
endpoint: "" # 可选JUNHONG_STORAGE_S3_ENDPOINT
region: "" # 可选JUNHONG_STORAGE_S3_REGION
bucket: "" # 可选JUNHONG_STORAGE_S3_BUCKET
access_key_id: "" # 可选JUNHONG_STORAGE_S3_ACCESS_KEY_ID敏感
secret_access_key: "" # 可选JUNHONG_STORAGE_S3_SECRET_ACCESS_KEY敏感
use_ssl: false
path_style: true
presign:
upload_expires: "15m"
download_expires: "24h"
# 日志配置
logging:
level: "info"
development: false
app_log:
filename: "/app/logs/app.log"
max_size: 100
max_backups: 3
max_age: 7
compress: true
access_log:
filename: "/app/logs/access.log"
max_size: 100
max_backups: 3
max_age: 7
compress: true
# 任务队列配置
queue:
concurrency: 10
queues:
critical: 6
default: 3
low: 1
retry_max: 5
timeout: "10m"
# JWT 配置(必填项需通过环境变量设置)
jwt:
secret_key: "" # 必填JUNHONG_JWT_SECRET_KEY敏感
token_duration: "24h"
access_token_ttl: "24h"
refresh_token_ttl: "168h"
# 中间件配置
middleware:
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
# 短信服务配置
sms:
gateway_url: "" # 可选JUNHONG_SMS_GATEWAY_URL
username: "" # 可选JUNHONG_SMS_USERNAME
password: "" # 可选JUNHONG_SMS_PASSWORD敏感
signature: "" # 可选JUNHONG_SMS_SIGNATURE
timeout: "10s"
# 默认超级管理员配置(可选)
default_admin:
username: ""
password: ""
phone: ""

17
pkg/config/embedded.go Normal file
View File

@@ -0,0 +1,17 @@
package config
import (
"bytes"
"embed"
)
//go:embed defaults/config.yaml
var defaultConfigFS embed.FS
func getEmbeddedConfig() (*bytes.Reader, error) {
data, err := defaultConfigFS.ReadFile("defaults/config.yaml")
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}

View File

@@ -2,92 +2,120 @@ package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/spf13/viper"
)
// Load 从文件和环境变量加载配置
const envPrefix = "JUNHONG"
func Load() (*Config, error) {
// 确定配置路径
configPath := os.Getenv(constants.EnvConfigPath)
if configPath == "" {
configPath = constants.DefaultConfigPath
}
v := viper.New()
// 检查环境特定配置dev, staging, prod
configEnv := os.Getenv(constants.EnvConfigEnv)
if configEnv != "" {
// 优先尝试环境特定配置
envConfigPath := fmt.Sprintf("configs/config.%s.yaml", configEnv)
if _, err := os.Stat(envConfigPath); err == nil {
configPath = envConfigPath
}
}
// 设置 Viper
viper.SetConfigFile(configPath)
viper.SetConfigType("yaml")
// 启用环境变量覆盖
viper.AutomaticEnv()
viper.SetEnvPrefix("APP")
// 读取配置文件
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
// 反序列化到 Config 结构体
cfg := &Config{}
if err := viper.Unmarshal(cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// 验证配置
if err := cfg.Validate(); err != nil {
return nil, err
}
// 设为全局配置
globalConfig.Store(cfg)
return cfg, nil
}
// Reload 重新加载当前配置文件
func Reload() (*Config, error) {
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to reload config: %w", err)
}
cfg := &Config{}
if err := viper.Unmarshal(cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// 设置前验证
if err := cfg.Validate(); err != nil {
return nil, err
}
// 原子交换
globalConfig.Store(cfg)
return cfg, nil
}
// GetConfigPath 返回当前已加载配置文件的绝对路径
func GetConfigPath() string {
configFile := viper.ConfigFileUsed()
if configFile == "" {
return ""
}
absPath, err := filepath.Abs(configFile)
embeddedReader, err := getEmbeddedConfig()
if err != nil {
return configFile
return nil, fmt.Errorf("读取嵌入配置失败: %w", err)
}
v.SetConfigType("yaml")
if err := v.ReadConfig(embeddedReader); err != nil {
return nil, fmt.Errorf("解析嵌入配置失败: %w", err)
}
v.SetEnvPrefix(envPrefix)
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
bindEnvVariables(v)
cfg := &Config{}
if err := v.Unmarshal(cfg); err != nil {
return nil, fmt.Errorf("反序列化配置失败: %w", err)
}
if err := cfg.ValidateRequired(); err != nil {
return nil, err
}
if err := cfg.Validate(); err != nil {
return nil, err
}
globalConfig.Store(cfg)
return cfg, nil
}
func bindEnvVariables(v *viper.Viper) {
bindings := []string{
"server.address",
"server.read_timeout",
"server.write_timeout",
"server.shutdown_timeout",
"server.prefork",
"database.host",
"database.port",
"database.user",
"database.password",
"database.dbname",
"database.sslmode",
"database.max_open_conns",
"database.max_idle_conns",
"database.conn_max_lifetime",
"redis.address",
"redis.port",
"redis.password",
"redis.db",
"redis.pool_size",
"redis.min_idle_conns",
"redis.dial_timeout",
"redis.read_timeout",
"redis.write_timeout",
"storage.provider",
"storage.temp_dir",
"storage.s3.endpoint",
"storage.s3.region",
"storage.s3.bucket",
"storage.s3.access_key_id",
"storage.s3.secret_access_key",
"storage.s3.use_ssl",
"storage.s3.path_style",
"storage.presign.upload_expires",
"storage.presign.download_expires",
"logging.level",
"logging.development",
"logging.app_log.filename",
"logging.app_log.max_size",
"logging.app_log.max_backups",
"logging.app_log.max_age",
"logging.app_log.compress",
"logging.access_log.filename",
"logging.access_log.max_size",
"logging.access_log.max_backups",
"logging.access_log.max_age",
"logging.access_log.compress",
"queue.concurrency",
"queue.retry_max",
"queue.timeout",
"jwt.secret_key",
"jwt.token_duration",
"jwt.access_token_ttl",
"jwt.refresh_token_ttl",
"middleware.enable_rate_limiter",
"middleware.rate_limiter.max",
"middleware.rate_limiter.expiration",
"middleware.rate_limiter.storage",
"sms.gateway_url",
"sms.username",
"sms.password",
"sms.signature",
"sms.timeout",
"default_admin.username",
"default_admin.password",
"default_admin.phone",
}
for _, key := range bindings {
_ = v.BindEnv(key)
}
return absPath
}

View File

@@ -2,650 +2,219 @@ package config
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/spf13/viper"
)
// TestLoad tests the config loading functionality
func TestLoad(t *testing.T) {
func TestLoad_EmbeddedConfig(t *testing.T) {
clearEnvVars(t)
setRequiredEnvVars(t)
defer clearEnvVars(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() 失败: %v", err)
}
if cfg.Server.Address != ":3000" {
t.Errorf("server.address 期望 :3000, 实际 %s", cfg.Server.Address)
}
if cfg.Server.ReadTimeout != 30*time.Second {
t.Errorf("server.read_timeout 期望 30s, 实际 %v", cfg.Server.ReadTimeout)
}
if cfg.Logging.Level != "info" {
t.Errorf("logging.level 期望 info, 实际 %s", cfg.Logging.Level)
}
}
func TestLoad_EnvOverride(t *testing.T) {
clearEnvVars(t)
setRequiredEnvVars(t)
defer clearEnvVars(t)
os.Setenv("JUNHONG_SERVER_ADDRESS", ":8080")
os.Setenv("JUNHONG_LOGGING_LEVEL", "debug")
defer func() {
os.Unsetenv("JUNHONG_SERVER_ADDRESS")
os.Unsetenv("JUNHONG_LOGGING_LEVEL")
}()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() 失败: %v", err)
}
if cfg.Server.Address != ":8080" {
t.Errorf("server.address 期望 :8080, 实际 %s", cfg.Server.Address)
}
if cfg.Logging.Level != "debug" {
t.Errorf("logging.level 期望 debug, 实际 %s", cfg.Logging.Level)
}
}
func TestLoad_MissingRequired(t *testing.T) {
clearEnvVars(t)
defer clearEnvVars(t)
_, err := Load()
if err == nil {
t.Fatal("Load() 缺少必填配置时应返回错误")
}
expectedFields := []string{"database.host", "database.user", "database.password", "database.dbname", "redis.address", "jwt.secret_key"}
for _, field := range expectedFields {
if !containsString(err.Error(), field) {
t.Errorf("错误信息应包含 %q, 实际: %s", field, err.Error())
}
}
}
func TestLoad_PartialRequired(t *testing.T) {
clearEnvVars(t)
defer clearEnvVars(t)
os.Setenv("JUNHONG_DATABASE_HOST", "localhost")
os.Setenv("JUNHONG_DATABASE_USER", "user")
_, err := Load()
if err == nil {
t.Fatal("Load() 部分必填配置缺失时应返回错误")
}
if containsString(err.Error(), "database.host") {
t.Error("database.host 已设置,不应在错误信息中")
}
if containsString(err.Error(), "database.user") {
t.Error("database.user 已设置,不应在错误信息中")
}
if !containsString(err.Error(), "database.password") {
t.Error("database.password 未设置,应在错误信息中")
}
}
func TestLoad_GlobalConfig(t *testing.T) {
clearEnvVars(t)
setRequiredEnvVars(t)
defer clearEnvVars(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() 失败: %v", err)
}
globalCfg := Get()
if globalCfg == nil {
t.Fatal("Get() 返回 nil")
}
if globalCfg.Server.Address != cfg.Server.Address {
t.Errorf("全局配置与返回配置不一致")
}
}
func TestValidateRequired(t *testing.T) {
tests := []struct {
name string
setupEnv func()
cleanupEnv func()
createConfig func(t *testing.T) string
wantErr bool
validateFunc func(t *testing.T, cfg *Config)
name string
cfg *Config
wantErr bool
}{
{
name: "valid default config",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
_ = os.Setenv(constants.EnvConfigEnv, "")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
_ = os.Unsetenv(constants.EnvConfigEnv)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set as default config path
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
name: "all required set",
cfg: &Config{
Database: DatabaseConfig{
Host: "localhost",
User: "user",
Password: "pass",
DBName: "db",
},
Redis: RedisConfig{Address: "localhost"},
JWT: JWTConfig{SecretKey: "12345678901234567890123456789012"},
},
wantErr: false,
validateFunc: func(t *testing.T, cfg *Config) {
if cfg.Server.Address != ":3000" {
t.Errorf("expected server.address :3000, got %s", cfg.Server.Address)
}
if cfg.Server.ReadTimeout != 10*time.Second {
t.Errorf("expected read_timeout 10s, got %v", cfg.Server.ReadTimeout)
}
if cfg.Redis.Address != "localhost" {
t.Errorf("expected redis.address localhost, got %s", cfg.Redis.Address)
}
if cfg.Redis.Port != 6379 {
t.Errorf("expected redis.port 6379, got %d", cfg.Redis.Port)
}
if cfg.Redis.PoolSize != 10 {
t.Errorf("expected redis.pool_size 10, got %d", cfg.Redis.PoolSize)
}
if cfg.Logging.Level != "info" {
t.Errorf("expected logging.level info, got %s", cfg.Logging.Level)
}
},
},
{
name: "environment-specific config (dev)",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigEnv, "dev")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigEnv)
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
// Create configs directory in temp
tmpDir := t.TempDir()
configsDir := filepath.Join(tmpDir, "configs")
if err := os.MkdirAll(configsDir, 0755); err != nil {
t.Fatalf("failed to create configs dir: %v", err)
}
// Create dev config
devConfigFile := filepath.Join(configsDir, "config.dev.yaml")
content := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 1
pool_size: 5
min_idle_conns: 2
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 50
max_backups: 10
max_age: 7
compress: false
access_log:
filename: "logs/access.log"
max_size: 100
max_backups: 30
max_age: 30
compress: false
middleware:
enable_rate_limiter: false
rate_limiter:
max: 50
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(devConfigFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create dev config file: %v", err)
}
// Change to tmpDir so relative path works
originalWd, _ := os.Getwd()
_ = os.Chdir(tmpDir)
t.Cleanup(func() { _ = os.Chdir(originalWd) })
return devConfigFile
},
wantErr: false,
validateFunc: func(t *testing.T, cfg *Config) {
if cfg.Server.Address != ":8080" {
t.Errorf("expected server.address :8080, got %s", cfg.Server.Address)
}
if cfg.Redis.DB != 1 {
t.Errorf("expected redis.db 1, got %d", cfg.Redis.DB)
}
if cfg.Logging.Level != "debug" {
t.Errorf("expected logging.level debug, got %s", cfg.Logging.Level)
}
name: "missing database host",
cfg: &Config{
Database: DatabaseConfig{
User: "user",
Password: "pass",
DBName: "db",
},
Redis: RedisConfig{Address: "localhost"},
JWT: JWTConfig{SecretKey: "12345678901234567890123456789012"},
},
wantErr: true,
},
{
name: "invalid YAML syntax",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
_ = os.Setenv(constants.EnvConfigEnv, "")
name: "missing redis address",
cfg: &Config{
Database: DatabaseConfig{
Host: "localhost",
User: "user",
Password: "pass",
DBName: "db",
},
Redis: RedisConfig{},
JWT: JWTConfig{SecretKey: "12345678901234567890123456789012"},
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
_ = os.Unsetenv(constants.EnvConfigEnv)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
invalid yaml syntax here!!!
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
wantErr: true,
},
{
name: "validation error - invalid server address",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
name: "missing jwt secret",
cfg: &Config{
Database: DatabaseConfig{
Host: "localhost",
User: "user",
Password: "pass",
DBName: "db",
},
Redis: RedisConfig{Address: "localhost"},
JWT: JWTConfig{},
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ""
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - timeout out of range",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "1s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - invalid redis port",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 99999
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset viper for each test
viper.Reset()
// Setup environment
if tt.setupEnv != nil {
tt.setupEnv()
}
// Create config file
if tt.createConfig != nil {
tt.createConfig(t)
}
// Cleanup after test
if tt.cleanupEnv != nil {
defer tt.cleanupEnv()
}
// Load config
cfg, err := Load()
// Check error expectation
err := tt.cfg.ValidateRequired()
if (err != nil) != tt.wantErr {
t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr)
return
}
// Validate config if no error expected
if !tt.wantErr && tt.validateFunc != nil {
tt.validateFunc(t, cfg)
t.Errorf("ValidateRequired() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
// TestReload tests the config reload functionality
func TestReload(t *testing.T) {
// Reset viper
viper.Reset()
func setRequiredEnvVars(t *testing.T) {
t.Helper()
os.Setenv("JUNHONG_DATABASE_HOST", "localhost")
os.Setenv("JUNHONG_DATABASE_USER", "testuser")
os.Setenv("JUNHONG_DATABASE_PASSWORD", "testpass")
os.Setenv("JUNHONG_DATABASE_DBNAME", "testdb")
os.Setenv("JUNHONG_REDIS_ADDRESS", "localhost")
os.Setenv("JUNHONG_JWT_SECRET_KEY", "12345678901234567890123456789012")
}
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial config
initialContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
func clearEnvVars(t *testing.T) {
t.Helper()
envVars := []string{
"JUNHONG_DATABASE_HOST",
"JUNHONG_DATABASE_PORT",
"JUNHONG_DATABASE_USER",
"JUNHONG_DATABASE_PASSWORD",
"JUNHONG_DATABASE_DBNAME",
"JUNHONG_REDIS_ADDRESS",
"JUNHONG_REDIS_PORT",
"JUNHONG_REDIS_PASSWORD",
"JUNHONG_JWT_SECRET_KEY",
"JUNHONG_SERVER_ADDRESS",
"JUNHONG_LOGGING_LEVEL",
}
// Set config path
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
// Verify initial values
if cfg.Logging.Level != "info" {
t.Errorf("expected initial logging.level info, got %s", cfg.Logging.Level)
}
if cfg.Server.Address != ":3000" {
t.Errorf("expected initial server.address :3000, got %s", cfg.Server.Address)
}
// Modify config file
updatedContent := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 20
min_idle_conns: 10
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: false
enable_rate_limiter: true
rate_limiter:
max: 200
expiration: "2m"
storage: "redis"
`
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
t.Fatalf("failed to update config file: %v", err)
}
// Reload config
newCfg, err := Reload()
if err != nil {
t.Fatalf("failed to reload config: %v", err)
}
// Verify updated values
if newCfg.Logging.Level != "debug" {
t.Errorf("expected updated logging.level debug, got %s", newCfg.Logging.Level)
}
if newCfg.Server.Address != ":8080" {
t.Errorf("expected updated server.address :8080, got %s", newCfg.Server.Address)
}
if newCfg.Redis.PoolSize != 20 {
t.Errorf("expected updated redis.pool_size 20, got %d", newCfg.Redis.PoolSize)
}
if newCfg.Middleware.EnableRateLimiter != true {
t.Errorf("expected updated enable_rate_limiter true, got %v", newCfg.Middleware.EnableRateLimiter)
}
// Verify global config was updated
globalCfg := Get()
if globalCfg.Logging.Level != "debug" {
t.Errorf("expected global config updated, got logging.level %s", globalCfg.Logging.Level)
for _, v := range envVars {
os.Unsetenv(v)
}
}
// TestGetConfigPath tests the GetConfigPath function
func TestGetConfigPath(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load config
_, err := Load()
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
// Get config path
path := GetConfigPath()
if path == "" {
t.Error("expected non-empty config path")
}
// Verify it's an absolute path
if !filepath.IsAbs(path) {
t.Errorf("expected absolute path, got %s", path)
}
func containsString(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && (s[:len(substr)] == substr || containsString(s[1:], substr)))
}

View File

@@ -1,43 +0,0 @@
package config
import (
"context"
"github.com/fsnotify/fsnotify"
"github.com/spf13/viper"
"go.uber.org/zap"
)
// Watch 监听配置文件变化
// 运行直到上下文被取消
func Watch(ctx context.Context, logger *zap.Logger) {
viper.WatchConfig()
viper.OnConfigChange(func(e fsnotify.Event) {
select {
case <-ctx.Done():
return // 如果上下文被取消则停止处理
default:
logger.Info("配置文件已更改", zap.String("file", e.Name))
// 尝试重新加载
newConfig, err := Reload()
if err != nil {
logger.Error("重新加载配置失败,保留先前配置",
zap.Error(err),
zap.String("file", e.Name),
)
return
}
logger.Info("配置重新加载成功",
zap.String("file", e.Name),
zap.String("server_address", newConfig.Server.Address),
zap.String("log_level", newConfig.Logging.Level),
)
}
})
// 阻塞直到上下文被取消
<-ctx.Done()
logger.Info("配置监听器已停止")
}

View File

@@ -1,422 +0,0 @@
package config
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/spf13/viper"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
)
// TestWatch tests the config hot reload watcher
func TestWatch(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial config
initialContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set config path
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
// Verify initial values
if cfg.Logging.Level != "info" {
t.Fatalf("expected initial logging.level info, got %s", cfg.Logging.Level)
}
// Create logger for testing
logger := zaptest.NewLogger(t)
// Start watcher in goroutine with context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go Watch(ctx, logger)
// Give watcher time to initialize
time.Sleep(100 * time.Millisecond)
// Modify config file to trigger hot reload
updatedContent := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 20
min_idle_conns: 10
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: false
enable_rate_limiter: true
rate_limiter:
max: 200
expiration: "2m"
storage: "redis"
`
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
t.Fatalf("failed to update config file: %v", err)
}
// Wait for watcher to detect and process changes (spec requires detection within 5 seconds)
// We use a more aggressive timeout for testing
time.Sleep(2 * time.Second)
// Verify config was reloaded
reloadedCfg := Get()
if reloadedCfg.Logging.Level != "debug" {
t.Errorf("expected config hot reload, got logging.level %s instead of debug", reloadedCfg.Logging.Level)
}
if reloadedCfg.Server.Address != ":8080" {
t.Errorf("expected config hot reload, got server.address %s instead of :8080", reloadedCfg.Server.Address)
}
if reloadedCfg.Redis.PoolSize != 20 {
t.Errorf("expected config hot reload, got redis.pool_size %d instead of 20", reloadedCfg.Redis.PoolSize)
}
// Cancel context to stop watcher
cancel()
// Give watcher time to shut down gracefully
time.Sleep(100 * time.Millisecond)
}
// TestWatch_InvalidConfigRejected tests that invalid config changes are rejected
func TestWatch_InvalidConfigRejected(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial valid config
validContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set config path
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
initialLevel := cfg.Logging.Level
if initialLevel != "info" {
t.Fatalf("expected initial logging.level info, got %s", initialLevel)
}
// Create logger for testing
logger := zaptest.NewLogger(t)
// Start watcher
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go Watch(ctx, logger)
// Give watcher time to initialize
time.Sleep(100 * time.Millisecond)
// Write INVALID config (malformed YAML)
invalidContent := `
server:
address: ":3000"
invalid yaml syntax here!!!
`
if err := os.WriteFile(configFile, []byte(invalidContent), 0644); err != nil {
t.Fatalf("failed to write invalid config: %v", err)
}
// Wait for watcher to detect changes
time.Sleep(2 * time.Second)
// Verify config was NOT changed (should keep previous valid config)
currentCfg := Get()
if currentCfg.Logging.Level != initialLevel {
t.Errorf("expected config to remain unchanged after invalid update, got logging.level %s instead of %s", currentCfg.Logging.Level, initialLevel)
}
// Restore valid config
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
t.Fatalf("failed to restore valid config: %v", err)
}
time.Sleep(500 * time.Millisecond)
// Now write config with validation error (timeout out of range)
invalidValidationContent := `
server:
address: ":3000"
read_timeout: "1s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "debug"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(invalidValidationContent), 0644); err != nil {
t.Fatalf("failed to write config with validation error: %v", err)
}
// Wait for watcher to detect changes
time.Sleep(2 * time.Second)
// Verify config was NOT changed (validation should have failed)
finalCfg := Get()
if finalCfg.Logging.Level != initialLevel {
t.Errorf("expected config to remain unchanged after validation error, got logging.level %s instead of %s", finalCfg.Logging.Level, initialLevel)
}
if finalCfg.Server.ReadTimeout == 1*time.Second {
t.Error("expected config to remain unchanged, but read_timeout was updated to invalid value")
}
// Cancel context
cancel()
time.Sleep(100 * time.Millisecond)
}
// TestWatch_ContextCancellation tests graceful shutdown on context cancellation
func TestWatch_ContextCancellation(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load config
_, err := Load()
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
// Create logger
logger := zap.NewNop() // Use no-op logger for this test
// Start watcher with context
ctx, cancel := context.WithCancel(context.Background())
done := make(chan bool)
go func() {
Watch(ctx, logger)
done <- true
}()
// Give watcher time to start
time.Sleep(100 * time.Millisecond)
// Cancel context (simulate graceful shutdown)
cancel()
// Wait for watcher to stop (should happen quickly)
select {
case <-done:
// Watcher stopped successfully
case <-time.After(2 * time.Second):
t.Error("watcher did not stop within timeout after context cancellation")
}
}