feat: 添加环境变量管理工具和部署配置改版
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 5m33s
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 5m33s
主要改动: - 新增交互式环境配置脚本 (scripts/setup-env.sh) - 新增本地启动快捷脚本 (scripts/run-local.sh) - 新增环境变量模板文件 (.env.example) - 部署模式改版:使用嵌入式配置 + 环境变量覆盖 - 添加对象存储功能支持 - 改进 IoT 卡片导入任务 - 优化 OpenAPI 文档生成 - 删除旧的配置文件,改用嵌入式默认配置
This commit is contained in:
69
pkg/bootstrap/directories.go
Normal file
69
pkg/bootstrap/directories.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/config"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type DirectoryResult struct {
|
||||
TempDir string
|
||||
AppLogDir string
|
||||
AccessLogDir string
|
||||
Fallbacks []string
|
||||
}
|
||||
|
||||
func EnsureDirectories(cfg *config.Config, logger *zap.Logger) (*DirectoryResult, error) {
|
||||
result := &DirectoryResult{}
|
||||
|
||||
directories := []struct {
|
||||
path string
|
||||
configKey string
|
||||
resultPtr *string
|
||||
}{
|
||||
{cfg.Storage.TempDir, "storage.temp_dir", &result.TempDir},
|
||||
{filepath.Dir(cfg.Logging.AppLog.Filename), "logging.app_log.filename", &result.AppLogDir},
|
||||
{filepath.Dir(cfg.Logging.AccessLog.Filename), "logging.access_log.filename", &result.AccessLogDir},
|
||||
}
|
||||
|
||||
for _, dir := range directories {
|
||||
if dir.path == "" || dir.path == "." {
|
||||
continue
|
||||
}
|
||||
|
||||
actualPath, fallback, err := ensureDirectory(dir.path, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建目录 %s (%s) 失败: %w", dir.path, dir.configKey, err)
|
||||
}
|
||||
|
||||
*dir.resultPtr = actualPath
|
||||
if fallback {
|
||||
result.Fallbacks = append(result.Fallbacks, actualPath)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func ensureDirectory(path string, logger *zap.Logger) (actualPath string, fallback bool, err error) {
|
||||
if err := os.MkdirAll(path, 0755); err != nil {
|
||||
if os.IsPermission(err) {
|
||||
fallbackPath := filepath.Join(os.TempDir(), "junhong", filepath.Base(path))
|
||||
if mkErr := os.MkdirAll(fallbackPath, 0755); mkErr != nil {
|
||||
return "", false, fmt.Errorf("原路径 %s 权限不足,降级路径 %s 也创建失败: %w", path, fallbackPath, mkErr)
|
||||
}
|
||||
if logger != nil {
|
||||
logger.Warn("目录权限不足,使用降级路径",
|
||||
zap.String("original", path),
|
||||
zap.String("fallback", fallbackPath),
|
||||
)
|
||||
}
|
||||
return fallbackPath, true, nil
|
||||
}
|
||||
return "", false, err
|
||||
}
|
||||
return path, false, nil
|
||||
}
|
||||
100
pkg/bootstrap/directories_test.go
Normal file
100
pkg/bootstrap/directories_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/config"
|
||||
)
|
||||
|
||||
func TestEnsureDirectories_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Storage: config.StorageConfig{
|
||||
TempDir: filepath.Join(tmpDir, "storage"),
|
||||
},
|
||||
Logging: config.LoggingConfig{
|
||||
AppLog: config.LogRotationConfig{Filename: filepath.Join(tmpDir, "logs", "app.log")},
|
||||
AccessLog: config.LogRotationConfig{Filename: filepath.Join(tmpDir, "logs", "access.log")},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := EnsureDirectories(cfg, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("EnsureDirectories() 失败: %v", err)
|
||||
}
|
||||
|
||||
if result.TempDir != cfg.Storage.TempDir {
|
||||
t.Errorf("TempDir 期望 %s, 实际 %s", cfg.Storage.TempDir, result.TempDir)
|
||||
}
|
||||
if result.AppLogDir != filepath.Join(tmpDir, "logs") {
|
||||
t.Errorf("AppLogDir 期望 %s, 实际 %s", filepath.Join(tmpDir, "logs"), result.AppLogDir)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(result.TempDir); os.IsNotExist(err) {
|
||||
t.Error("TempDir 目录未创建")
|
||||
}
|
||||
if _, err := os.Stat(result.AppLogDir); os.IsNotExist(err) {
|
||||
t.Error("AppLogDir 目录未创建")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureDirectories_ExistingDirs(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storageDir := filepath.Join(tmpDir, "storage")
|
||||
os.MkdirAll(storageDir, 0755)
|
||||
|
||||
cfg := &config.Config{
|
||||
Storage: config.StorageConfig{TempDir: storageDir},
|
||||
Logging: config.LoggingConfig{
|
||||
AppLog: config.LogRotationConfig{Filename: filepath.Join(tmpDir, "logs", "app.log")},
|
||||
AccessLog: config.LogRotationConfig{Filename: filepath.Join(tmpDir, "logs", "access.log")},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := EnsureDirectories(cfg, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("EnsureDirectories() 失败: %v", err)
|
||||
}
|
||||
|
||||
if result.TempDir != storageDir {
|
||||
t.Errorf("已存在目录应返回原路径")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureDirectories_EmptyPaths(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Storage: config.StorageConfig{TempDir: ""},
|
||||
Logging: config.LoggingConfig{
|
||||
AppLog: config.LogRotationConfig{Filename: ""},
|
||||
AccessLog: config.LogRotationConfig{Filename: ""},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := EnsureDirectories(cfg, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("EnsureDirectories() 空路径时不应失败: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Fallbacks) != 0 {
|
||||
t.Error("空路径不应产生降级")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureDirectory_Fallback(t *testing.T) {
|
||||
path, fallback, err := ensureDirectory("/root/no_permission_dir_test_"+t.Name(), nil)
|
||||
if err != nil {
|
||||
if os.Getuid() == 0 {
|
||||
t.Skip("以 root 身份运行,跳过权限测试")
|
||||
}
|
||||
t.Skip("无法测试权限降级场景")
|
||||
}
|
||||
|
||||
if fallback {
|
||||
if !filepath.HasPrefix(path, os.TempDir()) {
|
||||
t.Errorf("降级路径应在临时目录下,实际: %s", path)
|
||||
}
|
||||
os.RemoveAll(path)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
@@ -21,6 +22,7 @@ type Config struct {
|
||||
SMS SMSConfig `mapstructure:"sms"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
DefaultAdmin DefaultAdminConfig `mapstructure:"default_admin"`
|
||||
Storage StorageConfig `mapstructure:"storage"`
|
||||
}
|
||||
|
||||
// ServerConfig HTTP 服务器配置
|
||||
@@ -120,7 +122,61 @@ type DefaultAdminConfig struct {
|
||||
Phone string `mapstructure:"phone"`
|
||||
}
|
||||
|
||||
// Validate 验证配置值
|
||||
// StorageConfig 对象存储配置
|
||||
type StorageConfig struct {
|
||||
Provider string `mapstructure:"provider"` // 存储提供商:s3
|
||||
S3 S3Config `mapstructure:"s3"` // S3 兼容存储配置
|
||||
Presign PresignConfig `mapstructure:"presign"` // 预签名 URL 配置
|
||||
TempDir string `mapstructure:"temp_dir"` // 临时文件目录
|
||||
}
|
||||
|
||||
// S3Config S3 兼容存储配置
|
||||
type S3Config struct {
|
||||
Endpoint string `mapstructure:"endpoint"` // 服务端点(如:http://obs-helf.cucloud.cn)
|
||||
Region string `mapstructure:"region"` // 区域(如:cn-langfang-2)
|
||||
Bucket string `mapstructure:"bucket"` // 存储桶名称
|
||||
AccessKeyID string `mapstructure:"access_key_id"` // 访问密钥 ID
|
||||
SecretAccessKey string `mapstructure:"secret_access_key"` // 访问密钥
|
||||
UseSSL bool `mapstructure:"use_ssl"` // 是否使用 SSL
|
||||
PathStyle bool `mapstructure:"path_style"` // 是否使用路径风格(兼容性)
|
||||
}
|
||||
|
||||
// PresignConfig 预签名 URL 配置
|
||||
type PresignConfig struct {
|
||||
UploadExpires time.Duration `mapstructure:"upload_expires"` // 上传 URL 有效期(默认:15m)
|
||||
DownloadExpires time.Duration `mapstructure:"download_expires"` // 下载 URL 有效期(默认:24h)
|
||||
}
|
||||
|
||||
type requiredField struct {
|
||||
value string
|
||||
name string
|
||||
envName string
|
||||
}
|
||||
|
||||
func (c *Config) ValidateRequired() error {
|
||||
fields := []requiredField{
|
||||
{c.Database.Host, "database.host", "JUNHONG_DATABASE_HOST"},
|
||||
{c.Database.User, "database.user", "JUNHONG_DATABASE_USER"},
|
||||
{c.Database.Password, "database.password", "JUNHONG_DATABASE_PASSWORD"},
|
||||
{c.Database.DBName, "database.dbname", "JUNHONG_DATABASE_DBNAME"},
|
||||
{c.Redis.Address, "redis.address", "JUNHONG_REDIS_ADDRESS"},
|
||||
{c.JWT.SecretKey, "jwt.secret_key", "JUNHONG_JWT_SECRET_KEY"},
|
||||
}
|
||||
|
||||
var missing []string
|
||||
for _, f := range fields {
|
||||
if f.value == "" {
|
||||
missing = append(missing, fmt.Sprintf(" - %s (环境变量: %s)", f.name, f.envName))
|
||||
}
|
||||
}
|
||||
|
||||
if len(missing) > 0 {
|
||||
return fmt.Errorf("缺少必填配置项:\n%s", strings.Join(missing, "\n"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
// 服务器验证
|
||||
if c.Server.Address == "" {
|
||||
@@ -184,28 +240,24 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("invalid configuration: middleware.rate_limiter.storage: invalid storage type (current value: %s, expected: memory or redis)", c.Middleware.RateLimiter.Storage)
|
||||
}
|
||||
|
||||
// 短信服务验证
|
||||
if c.SMS.GatewayURL == "" {
|
||||
return fmt.Errorf("invalid configuration: sms.gateway_url: must be non-empty (current value: empty)")
|
||||
}
|
||||
if c.SMS.Username == "" {
|
||||
return fmt.Errorf("invalid configuration: sms.username: must be non-empty (current value: empty)")
|
||||
}
|
||||
if c.SMS.Password == "" {
|
||||
return fmt.Errorf("invalid configuration: sms.password: must be non-empty (current value: empty)")
|
||||
}
|
||||
if c.SMS.Signature == "" {
|
||||
return fmt.Errorf("invalid configuration: sms.signature: must be non-empty (current value: empty)")
|
||||
}
|
||||
if c.SMS.Timeout < 5*time.Second || c.SMS.Timeout > 60*time.Second {
|
||||
return fmt.Errorf("invalid configuration: sms.timeout: duration out of range (current value: %s, expected: 5s-60s)", c.SMS.Timeout)
|
||||
// 短信服务验证(可选,配置 GatewayURL 时才验证其他字段)
|
||||
if c.SMS.GatewayURL != "" {
|
||||
if c.SMS.Username == "" {
|
||||
return fmt.Errorf("invalid configuration: sms.username: must be non-empty when gateway_url is configured")
|
||||
}
|
||||
if c.SMS.Password == "" {
|
||||
return fmt.Errorf("invalid configuration: sms.password: must be non-empty when gateway_url is configured")
|
||||
}
|
||||
if c.SMS.Signature == "" {
|
||||
return fmt.Errorf("invalid configuration: sms.signature: must be non-empty when gateway_url is configured")
|
||||
}
|
||||
if c.SMS.Timeout > 0 && (c.SMS.Timeout < 5*time.Second || c.SMS.Timeout > 60*time.Second) {
|
||||
return fmt.Errorf("invalid configuration: sms.timeout: duration out of range (current value: %s, expected: 5s-60s)", c.SMS.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// JWT 验证
|
||||
if c.JWT.SecretKey == "" {
|
||||
return fmt.Errorf("invalid configuration: jwt.secret_key: must be non-empty (current value: empty)")
|
||||
}
|
||||
if len(c.JWT.SecretKey) < 32 {
|
||||
// JWT 验证(SecretKey 必填验证在 ValidateRequired 中处理)
|
||||
if len(c.JWT.SecretKey) > 0 && len(c.JWT.SecretKey) < 32 {
|
||||
return fmt.Errorf("invalid configuration: jwt.secret_key: secret key too short (current length: %d, expected: >= 32)", len(c.JWT.SecretKey))
|
||||
}
|
||||
if c.JWT.TokenDuration < 1*time.Hour || c.JWT.TokenDuration > 720*time.Hour {
|
||||
|
||||
@@ -51,6 +51,11 @@ func TestConfig_Validate(t *testing.T) {
|
||||
Storage: "memory",
|
||||
},
|
||||
},
|
||||
JWT: JWTConfig{
|
||||
TokenDuration: 24 * time.Hour,
|
||||
AccessTokenTTL: 24 * time.Hour,
|
||||
RefreshTokenTTL: 168 * time.Hour,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -582,6 +587,11 @@ func TestSet(t *testing.T) {
|
||||
MaxSize: 500,
|
||||
},
|
||||
},
|
||||
JWT: JWTConfig{
|
||||
TokenDuration: 24 * time.Hour,
|
||||
AccessTokenTTL: 24 * time.Hour,
|
||||
RefreshTokenTTL: 168 * time.Hour,
|
||||
},
|
||||
}
|
||||
|
||||
err := Set(validCfg)
|
||||
|
||||
106
pkg/config/defaults/config.yaml
Normal file
106
pkg/config/defaults/config.yaml
Normal file
@@ -0,0 +1,106 @@
|
||||
# 默认配置文件(嵌入二进制)
|
||||
# 敏感配置和必填配置为空,必须通过环境变量设置
|
||||
# 环境变量格式:JUNHONG_{SECTION}_{KEY}
|
||||
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "30s"
|
||||
write_timeout: "30s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
# 数据库配置(必填项需通过环境变量设置)
|
||||
database:
|
||||
host: "" # 必填:JUNHONG_DATABASE_HOST
|
||||
port: 5432
|
||||
user: "" # 必填:JUNHONG_DATABASE_USER
|
||||
password: "" # 必填:JUNHONG_DATABASE_PASSWORD(敏感)
|
||||
dbname: "" # 必填:JUNHONG_DATABASE_DBNAME
|
||||
sslmode: "disable"
|
||||
max_open_conns: 25
|
||||
max_idle_conns: 10
|
||||
conn_max_lifetime: "5m"
|
||||
|
||||
# Redis 配置(必填项需通过环境变量设置)
|
||||
redis:
|
||||
address: "" # 必填:JUNHONG_REDIS_ADDRESS
|
||||
port: 6379
|
||||
password: "" # 可选:JUNHONG_REDIS_PASSWORD(敏感)
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
# 对象存储配置
|
||||
storage:
|
||||
provider: "s3"
|
||||
temp_dir: "/tmp/junhong-storage"
|
||||
s3:
|
||||
endpoint: "" # 可选:JUNHONG_STORAGE_S3_ENDPOINT
|
||||
region: "" # 可选:JUNHONG_STORAGE_S3_REGION
|
||||
bucket: "" # 可选:JUNHONG_STORAGE_S3_BUCKET
|
||||
access_key_id: "" # 可选:JUNHONG_STORAGE_S3_ACCESS_KEY_ID(敏感)
|
||||
secret_access_key: "" # 可选:JUNHONG_STORAGE_S3_SECRET_ACCESS_KEY(敏感)
|
||||
use_ssl: false
|
||||
path_style: true
|
||||
presign:
|
||||
upload_expires: "15m"
|
||||
download_expires: "24h"
|
||||
|
||||
# 日志配置
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "/app/logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 3
|
||||
max_age: 7
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "/app/logs/access.log"
|
||||
max_size: 100
|
||||
max_backups: 3
|
||||
max_age: 7
|
||||
compress: true
|
||||
|
||||
# 任务队列配置
|
||||
queue:
|
||||
concurrency: 10
|
||||
queues:
|
||||
critical: 6
|
||||
default: 3
|
||||
low: 1
|
||||
retry_max: 5
|
||||
timeout: "10m"
|
||||
|
||||
# JWT 配置(必填项需通过环境变量设置)
|
||||
jwt:
|
||||
secret_key: "" # 必填:JUNHONG_JWT_SECRET_KEY(敏感)
|
||||
token_duration: "24h"
|
||||
access_token_ttl: "24h"
|
||||
refresh_token_ttl: "168h"
|
||||
|
||||
# 中间件配置
|
||||
middleware:
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
|
||||
# 短信服务配置
|
||||
sms:
|
||||
gateway_url: "" # 可选:JUNHONG_SMS_GATEWAY_URL
|
||||
username: "" # 可选:JUNHONG_SMS_USERNAME
|
||||
password: "" # 可选:JUNHONG_SMS_PASSWORD(敏感)
|
||||
signature: "" # 可选:JUNHONG_SMS_SIGNATURE
|
||||
timeout: "10s"
|
||||
|
||||
# 默认超级管理员配置(可选)
|
||||
default_admin:
|
||||
username: ""
|
||||
password: ""
|
||||
phone: ""
|
||||
17
pkg/config/embedded.go
Normal file
17
pkg/config/embedded.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed defaults/config.yaml
|
||||
var defaultConfigFS embed.FS
|
||||
|
||||
func getEmbeddedConfig() (*bytes.Reader, error) {
|
||||
data, err := defaultConfigFS.ReadFile("defaults/config.yaml")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
@@ -2,92 +2,120 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Load 从文件和环境变量加载配置
|
||||
const envPrefix = "JUNHONG"
|
||||
|
||||
func Load() (*Config, error) {
|
||||
// 确定配置路径
|
||||
configPath := os.Getenv(constants.EnvConfigPath)
|
||||
if configPath == "" {
|
||||
configPath = constants.DefaultConfigPath
|
||||
}
|
||||
v := viper.New()
|
||||
|
||||
// 检查环境特定配置(dev, staging, prod)
|
||||
configEnv := os.Getenv(constants.EnvConfigEnv)
|
||||
if configEnv != "" {
|
||||
// 优先尝试环境特定配置
|
||||
envConfigPath := fmt.Sprintf("configs/config.%s.yaml", configEnv)
|
||||
if _, err := os.Stat(envConfigPath); err == nil {
|
||||
configPath = envConfigPath
|
||||
}
|
||||
}
|
||||
|
||||
// 设置 Viper
|
||||
viper.SetConfigFile(configPath)
|
||||
viper.SetConfigType("yaml")
|
||||
|
||||
// 启用环境变量覆盖
|
||||
viper.AutomaticEnv()
|
||||
viper.SetEnvPrefix("APP")
|
||||
|
||||
// 读取配置文件
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
// 反序列化到 Config 结构体
|
||||
cfg := &Config{}
|
||||
if err := viper.Unmarshal(cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// 验证配置
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设为全局配置
|
||||
globalConfig.Store(cfg)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Reload 重新加载当前配置文件
|
||||
func Reload() (*Config, error) {
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
return nil, fmt.Errorf("failed to reload config: %w", err)
|
||||
}
|
||||
|
||||
cfg := &Config{}
|
||||
if err := viper.Unmarshal(cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// 设置前验证
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 原子交换
|
||||
globalConfig.Store(cfg)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// GetConfigPath 返回当前已加载配置文件的绝对路径
|
||||
func GetConfigPath() string {
|
||||
configFile := viper.ConfigFileUsed()
|
||||
if configFile == "" {
|
||||
return ""
|
||||
}
|
||||
absPath, err := filepath.Abs(configFile)
|
||||
embeddedReader, err := getEmbeddedConfig()
|
||||
if err != nil {
|
||||
return configFile
|
||||
return nil, fmt.Errorf("读取嵌入配置失败: %w", err)
|
||||
}
|
||||
|
||||
v.SetConfigType("yaml")
|
||||
if err := v.ReadConfig(embeddedReader); err != nil {
|
||||
return nil, fmt.Errorf("解析嵌入配置失败: %w", err)
|
||||
}
|
||||
|
||||
v.SetEnvPrefix(envPrefix)
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
v.AutomaticEnv()
|
||||
|
||||
bindEnvVariables(v)
|
||||
|
||||
cfg := &Config{}
|
||||
if err := v.Unmarshal(cfg); err != nil {
|
||||
return nil, fmt.Errorf("反序列化配置失败: %w", err)
|
||||
}
|
||||
|
||||
if err := cfg.ValidateRequired(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
globalConfig.Store(cfg)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func bindEnvVariables(v *viper.Viper) {
|
||||
bindings := []string{
|
||||
"server.address",
|
||||
"server.read_timeout",
|
||||
"server.write_timeout",
|
||||
"server.shutdown_timeout",
|
||||
"server.prefork",
|
||||
"database.host",
|
||||
"database.port",
|
||||
"database.user",
|
||||
"database.password",
|
||||
"database.dbname",
|
||||
"database.sslmode",
|
||||
"database.max_open_conns",
|
||||
"database.max_idle_conns",
|
||||
"database.conn_max_lifetime",
|
||||
"redis.address",
|
||||
"redis.port",
|
||||
"redis.password",
|
||||
"redis.db",
|
||||
"redis.pool_size",
|
||||
"redis.min_idle_conns",
|
||||
"redis.dial_timeout",
|
||||
"redis.read_timeout",
|
||||
"redis.write_timeout",
|
||||
"storage.provider",
|
||||
"storage.temp_dir",
|
||||
"storage.s3.endpoint",
|
||||
"storage.s3.region",
|
||||
"storage.s3.bucket",
|
||||
"storage.s3.access_key_id",
|
||||
"storage.s3.secret_access_key",
|
||||
"storage.s3.use_ssl",
|
||||
"storage.s3.path_style",
|
||||
"storage.presign.upload_expires",
|
||||
"storage.presign.download_expires",
|
||||
"logging.level",
|
||||
"logging.development",
|
||||
"logging.app_log.filename",
|
||||
"logging.app_log.max_size",
|
||||
"logging.app_log.max_backups",
|
||||
"logging.app_log.max_age",
|
||||
"logging.app_log.compress",
|
||||
"logging.access_log.filename",
|
||||
"logging.access_log.max_size",
|
||||
"logging.access_log.max_backups",
|
||||
"logging.access_log.max_age",
|
||||
"logging.access_log.compress",
|
||||
"queue.concurrency",
|
||||
"queue.retry_max",
|
||||
"queue.timeout",
|
||||
"jwt.secret_key",
|
||||
"jwt.token_duration",
|
||||
"jwt.access_token_ttl",
|
||||
"jwt.refresh_token_ttl",
|
||||
"middleware.enable_rate_limiter",
|
||||
"middleware.rate_limiter.max",
|
||||
"middleware.rate_limiter.expiration",
|
||||
"middleware.rate_limiter.storage",
|
||||
"sms.gateway_url",
|
||||
"sms.username",
|
||||
"sms.password",
|
||||
"sms.signature",
|
||||
"sms.timeout",
|
||||
"default_admin.username",
|
||||
"default_admin.password",
|
||||
"default_admin.phone",
|
||||
}
|
||||
|
||||
for _, key := range bindings {
|
||||
_ = v.BindEnv(key)
|
||||
}
|
||||
return absPath
|
||||
}
|
||||
|
||||
@@ -2,650 +2,219 @@ package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// TestLoad tests the config loading functionality
|
||||
func TestLoad(t *testing.T) {
|
||||
func TestLoad_EmbeddedConfig(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
setRequiredEnvVars(t)
|
||||
defer clearEnvVars(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() 失败: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Address != ":3000" {
|
||||
t.Errorf("server.address 期望 :3000, 实际 %s", cfg.Server.Address)
|
||||
}
|
||||
if cfg.Server.ReadTimeout != 30*time.Second {
|
||||
t.Errorf("server.read_timeout 期望 30s, 实际 %v", cfg.Server.ReadTimeout)
|
||||
}
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Errorf("logging.level 期望 info, 实际 %s", cfg.Logging.Level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_EnvOverride(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
setRequiredEnvVars(t)
|
||||
defer clearEnvVars(t)
|
||||
|
||||
os.Setenv("JUNHONG_SERVER_ADDRESS", ":8080")
|
||||
os.Setenv("JUNHONG_LOGGING_LEVEL", "debug")
|
||||
defer func() {
|
||||
os.Unsetenv("JUNHONG_SERVER_ADDRESS")
|
||||
os.Unsetenv("JUNHONG_LOGGING_LEVEL")
|
||||
}()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() 失败: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Address != ":8080" {
|
||||
t.Errorf("server.address 期望 :8080, 实际 %s", cfg.Server.Address)
|
||||
}
|
||||
if cfg.Logging.Level != "debug" {
|
||||
t.Errorf("logging.level 期望 debug, 实际 %s", cfg.Logging.Level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_MissingRequired(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
defer clearEnvVars(t)
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("Load() 缺少必填配置时应返回错误")
|
||||
}
|
||||
|
||||
expectedFields := []string{"database.host", "database.user", "database.password", "database.dbname", "redis.address", "jwt.secret_key"}
|
||||
for _, field := range expectedFields {
|
||||
if !containsString(err.Error(), field) {
|
||||
t.Errorf("错误信息应包含 %q, 实际: %s", field, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_PartialRequired(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
defer clearEnvVars(t)
|
||||
|
||||
os.Setenv("JUNHONG_DATABASE_HOST", "localhost")
|
||||
os.Setenv("JUNHONG_DATABASE_USER", "user")
|
||||
|
||||
_, err := Load()
|
||||
if err == nil {
|
||||
t.Fatal("Load() 部分必填配置缺失时应返回错误")
|
||||
}
|
||||
|
||||
if containsString(err.Error(), "database.host") {
|
||||
t.Error("database.host 已设置,不应在错误信息中")
|
||||
}
|
||||
if containsString(err.Error(), "database.user") {
|
||||
t.Error("database.user 已设置,不应在错误信息中")
|
||||
}
|
||||
if !containsString(err.Error(), "database.password") {
|
||||
t.Error("database.password 未设置,应在错误信息中")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_GlobalConfig(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
setRequiredEnvVars(t)
|
||||
defer clearEnvVars(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() 失败: %v", err)
|
||||
}
|
||||
|
||||
globalCfg := Get()
|
||||
if globalCfg == nil {
|
||||
t.Fatal("Get() 返回 nil")
|
||||
}
|
||||
|
||||
if globalCfg.Server.Address != cfg.Server.Address {
|
||||
t.Errorf("全局配置与返回配置不一致")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRequired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupEnv func()
|
||||
cleanupEnv func()
|
||||
createConfig func(t *testing.T) string
|
||||
wantErr bool
|
||||
validateFunc func(t *testing.T, cfg *Config)
|
||||
name string
|
||||
cfg *Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid default config",
|
||||
setupEnv: func() {
|
||||
_ = os.Setenv(constants.EnvConfigPath, "")
|
||||
_ = os.Setenv(constants.EnvConfigEnv, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
_ = os.Unsetenv(constants.EnvConfigPath)
|
||||
_ = os.Unsetenv(constants.EnvConfigEnv)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
// Set as default config path
|
||||
_ = os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
name: "all required set",
|
||||
cfg: &Config{
|
||||
Database: DatabaseConfig{
|
||||
Host: "localhost",
|
||||
User: "user",
|
||||
Password: "pass",
|
||||
DBName: "db",
|
||||
},
|
||||
Redis: RedisConfig{Address: "localhost"},
|
||||
JWT: JWTConfig{SecretKey: "12345678901234567890123456789012"},
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T, cfg *Config) {
|
||||
if cfg.Server.Address != ":3000" {
|
||||
t.Errorf("expected server.address :3000, got %s", cfg.Server.Address)
|
||||
}
|
||||
if cfg.Server.ReadTimeout != 10*time.Second {
|
||||
t.Errorf("expected read_timeout 10s, got %v", cfg.Server.ReadTimeout)
|
||||
}
|
||||
if cfg.Redis.Address != "localhost" {
|
||||
t.Errorf("expected redis.address localhost, got %s", cfg.Redis.Address)
|
||||
}
|
||||
if cfg.Redis.Port != 6379 {
|
||||
t.Errorf("expected redis.port 6379, got %d", cfg.Redis.Port)
|
||||
}
|
||||
if cfg.Redis.PoolSize != 10 {
|
||||
t.Errorf("expected redis.pool_size 10, got %d", cfg.Redis.PoolSize)
|
||||
}
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Errorf("expected logging.level info, got %s", cfg.Logging.Level)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "environment-specific config (dev)",
|
||||
setupEnv: func() {
|
||||
_ = os.Setenv(constants.EnvConfigEnv, "dev")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
_ = os.Unsetenv(constants.EnvConfigEnv)
|
||||
_ = os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
// Create configs directory in temp
|
||||
tmpDir := t.TempDir()
|
||||
configsDir := filepath.Join(tmpDir, "configs")
|
||||
if err := os.MkdirAll(configsDir, 0755); err != nil {
|
||||
t.Fatalf("failed to create configs dir: %v", err)
|
||||
}
|
||||
|
||||
// Create dev config
|
||||
devConfigFile := filepath.Join(configsDir, "config.dev.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":8080"
|
||||
read_timeout: "15s"
|
||||
write_timeout: "15s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 1
|
||||
pool_size: 5
|
||||
min_idle_conns: 2
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
development: true
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 50
|
||||
max_backups: 10
|
||||
max_age: 7
|
||||
compress: false
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: false
|
||||
|
||||
middleware:
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 50
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(devConfigFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create dev config file: %v", err)
|
||||
}
|
||||
|
||||
// Change to tmpDir so relative path works
|
||||
originalWd, _ := os.Getwd()
|
||||
_ = os.Chdir(tmpDir)
|
||||
t.Cleanup(func() { _ = os.Chdir(originalWd) })
|
||||
|
||||
return devConfigFile
|
||||
},
|
||||
wantErr: false,
|
||||
validateFunc: func(t *testing.T, cfg *Config) {
|
||||
if cfg.Server.Address != ":8080" {
|
||||
t.Errorf("expected server.address :8080, got %s", cfg.Server.Address)
|
||||
}
|
||||
if cfg.Redis.DB != 1 {
|
||||
t.Errorf("expected redis.db 1, got %d", cfg.Redis.DB)
|
||||
}
|
||||
if cfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected logging.level debug, got %s", cfg.Logging.Level)
|
||||
}
|
||||
name: "missing database host",
|
||||
cfg: &Config{
|
||||
Database: DatabaseConfig{
|
||||
User: "user",
|
||||
Password: "pass",
|
||||
DBName: "db",
|
||||
},
|
||||
Redis: RedisConfig{Address: "localhost"},
|
||||
JWT: JWTConfig{SecretKey: "12345678901234567890123456789012"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid YAML syntax",
|
||||
setupEnv: func() {
|
||||
_ = os.Setenv(constants.EnvConfigPath, "")
|
||||
_ = os.Setenv(constants.EnvConfigEnv, "")
|
||||
name: "missing redis address",
|
||||
cfg: &Config{
|
||||
Database: DatabaseConfig{
|
||||
Host: "localhost",
|
||||
User: "user",
|
||||
Password: "pass",
|
||||
DBName: "db",
|
||||
},
|
||||
Redis: RedisConfig{},
|
||||
JWT: JWTConfig{SecretKey: "12345678901234567890123456789012"},
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
_ = os.Unsetenv(constants.EnvConfigPath)
|
||||
_ = os.Unsetenv(constants.EnvConfigEnv)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
invalid yaml syntax here!!!
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
_ = os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "validation error - invalid server address",
|
||||
setupEnv: func() {
|
||||
_ = os.Setenv(constants.EnvConfigPath, "")
|
||||
name: "missing jwt secret",
|
||||
cfg: &Config{
|
||||
Database: DatabaseConfig{
|
||||
Host: "localhost",
|
||||
User: "user",
|
||||
Password: "pass",
|
||||
DBName: "db",
|
||||
},
|
||||
Redis: RedisConfig{Address: "localhost"},
|
||||
JWT: JWTConfig{},
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
_ = os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ""
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
_ = os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "validation error - timeout out of range",
|
||||
setupEnv: func() {
|
||||
_ = os.Setenv(constants.EnvConfigPath, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
_ = os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "1s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
_ = os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "validation error - invalid redis port",
|
||||
setupEnv: func() {
|
||||
_ = os.Setenv(constants.EnvConfigPath, "")
|
||||
},
|
||||
cleanupEnv: func() {
|
||||
_ = os.Unsetenv(constants.EnvConfigPath)
|
||||
},
|
||||
createConfig: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 99999
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
_ = os.Setenv(constants.EnvConfigPath, configFile)
|
||||
return configFile
|
||||
},
|
||||
wantErr: true,
|
||||
validateFunc: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset viper for each test
|
||||
viper.Reset()
|
||||
|
||||
// Setup environment
|
||||
if tt.setupEnv != nil {
|
||||
tt.setupEnv()
|
||||
}
|
||||
|
||||
// Create config file
|
||||
if tt.createConfig != nil {
|
||||
tt.createConfig(t)
|
||||
}
|
||||
|
||||
// Cleanup after test
|
||||
if tt.cleanupEnv != nil {
|
||||
defer tt.cleanupEnv()
|
||||
}
|
||||
|
||||
// Load config
|
||||
cfg, err := Load()
|
||||
|
||||
// Check error expectation
|
||||
err := tt.cfg.ValidateRequired()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate config if no error expected
|
||||
if !tt.wantErr && tt.validateFunc != nil {
|
||||
tt.validateFunc(t, cfg)
|
||||
t.Errorf("ValidateRequired() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReload tests the config reload functionality
|
||||
func TestReload(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
func setRequiredEnvVars(t *testing.T) {
|
||||
t.Helper()
|
||||
os.Setenv("JUNHONG_DATABASE_HOST", "localhost")
|
||||
os.Setenv("JUNHONG_DATABASE_USER", "testuser")
|
||||
os.Setenv("JUNHONG_DATABASE_PASSWORD", "testpass")
|
||||
os.Setenv("JUNHONG_DATABASE_DBNAME", "testdb")
|
||||
os.Setenv("JUNHONG_REDIS_ADDRESS", "localhost")
|
||||
os.Setenv("JUNHONG_JWT_SECRET_KEY", "12345678901234567890123456789012")
|
||||
}
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Initial config
|
||||
initialContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
func clearEnvVars(t *testing.T) {
|
||||
t.Helper()
|
||||
envVars := []string{
|
||||
"JUNHONG_DATABASE_HOST",
|
||||
"JUNHONG_DATABASE_PORT",
|
||||
"JUNHONG_DATABASE_USER",
|
||||
"JUNHONG_DATABASE_PASSWORD",
|
||||
"JUNHONG_DATABASE_DBNAME",
|
||||
"JUNHONG_REDIS_ADDRESS",
|
||||
"JUNHONG_REDIS_PORT",
|
||||
"JUNHONG_REDIS_PASSWORD",
|
||||
"JUNHONG_JWT_SECRET_KEY",
|
||||
"JUNHONG_SERVER_ADDRESS",
|
||||
"JUNHONG_LOGGING_LEVEL",
|
||||
}
|
||||
|
||||
// Set config path
|
||||
_ = os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
|
||||
|
||||
// Load initial config
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load initial config: %v", err)
|
||||
}
|
||||
|
||||
// Verify initial values
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Errorf("expected initial logging.level info, got %s", cfg.Logging.Level)
|
||||
}
|
||||
if cfg.Server.Address != ":3000" {
|
||||
t.Errorf("expected initial server.address :3000, got %s", cfg.Server.Address)
|
||||
}
|
||||
|
||||
// Modify config file
|
||||
updatedContent := `
|
||||
server:
|
||||
address: ":8080"
|
||||
read_timeout: "15s"
|
||||
write_timeout: "15s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 20
|
||||
min_idle_conns: 10
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
development: true
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: false
|
||||
enable_rate_limiter: true
|
||||
rate_limiter:
|
||||
max: 200
|
||||
expiration: "2m"
|
||||
storage: "redis"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
|
||||
t.Fatalf("failed to update config file: %v", err)
|
||||
}
|
||||
|
||||
// Reload config
|
||||
newCfg, err := Reload()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reload config: %v", err)
|
||||
}
|
||||
|
||||
// Verify updated values
|
||||
if newCfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected updated logging.level debug, got %s", newCfg.Logging.Level)
|
||||
}
|
||||
if newCfg.Server.Address != ":8080" {
|
||||
t.Errorf("expected updated server.address :8080, got %s", newCfg.Server.Address)
|
||||
}
|
||||
if newCfg.Redis.PoolSize != 20 {
|
||||
t.Errorf("expected updated redis.pool_size 20, got %d", newCfg.Redis.PoolSize)
|
||||
}
|
||||
if newCfg.Middleware.EnableRateLimiter != true {
|
||||
t.Errorf("expected updated enable_rate_limiter true, got %v", newCfg.Middleware.EnableRateLimiter)
|
||||
}
|
||||
|
||||
// Verify global config was updated
|
||||
globalCfg := Get()
|
||||
if globalCfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected global config updated, got logging.level %s", globalCfg.Logging.Level)
|
||||
for _, v := range envVars {
|
||||
os.Unsetenv(v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetConfigPath tests the GetConfigPath function
|
||||
func TestGetConfigPath(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
_ = os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
|
||||
|
||||
// Load config
|
||||
_, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Get config path
|
||||
path := GetConfigPath()
|
||||
if path == "" {
|
||||
t.Error("expected non-empty config path")
|
||||
}
|
||||
|
||||
// Verify it's an absolute path
|
||||
if !filepath.IsAbs(path) {
|
||||
t.Errorf("expected absolute path, got %s", path)
|
||||
}
|
||||
func containsString(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && (s[:len(substr)] == substr || containsString(s[1:], substr)))
|
||||
}
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Watch 监听配置文件变化
|
||||
// 运行直到上下文被取消
|
||||
func Watch(ctx context.Context, logger *zap.Logger) {
|
||||
viper.WatchConfig()
|
||||
viper.OnConfigChange(func(e fsnotify.Event) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return // 如果上下文被取消则停止处理
|
||||
default:
|
||||
logger.Info("配置文件已更改", zap.String("file", e.Name))
|
||||
|
||||
// 尝试重新加载
|
||||
newConfig, err := Reload()
|
||||
if err != nil {
|
||||
logger.Error("重新加载配置失败,保留先前配置",
|
||||
zap.Error(err),
|
||||
zap.String("file", e.Name),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("配置重新加载成功",
|
||||
zap.String("file", e.Name),
|
||||
zap.String("server_address", newConfig.Server.Address),
|
||||
zap.String("log_level", newConfig.Logging.Level),
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
// 阻塞直到上下文被取消
|
||||
<-ctx.Done()
|
||||
logger.Info("配置监听器已停止")
|
||||
}
|
||||
@@ -1,422 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestWatch tests the config hot reload watcher
|
||||
func TestWatch(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Initial config
|
||||
initialContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
// Set config path
|
||||
_ = os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
|
||||
|
||||
// Load initial config
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load initial config: %v", err)
|
||||
}
|
||||
|
||||
// Verify initial values
|
||||
if cfg.Logging.Level != "info" {
|
||||
t.Fatalf("expected initial logging.level info, got %s", cfg.Logging.Level)
|
||||
}
|
||||
|
||||
// Create logger for testing
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// Start watcher in goroutine with context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go Watch(ctx, logger)
|
||||
|
||||
// Give watcher time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Modify config file to trigger hot reload
|
||||
updatedContent := `
|
||||
server:
|
||||
address: ":8080"
|
||||
read_timeout: "15s"
|
||||
write_timeout: "15s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 20
|
||||
min_idle_conns: 10
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
development: true
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: false
|
||||
enable_rate_limiter: true
|
||||
rate_limiter:
|
||||
max: 200
|
||||
expiration: "2m"
|
||||
storage: "redis"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
|
||||
t.Fatalf("failed to update config file: %v", err)
|
||||
}
|
||||
|
||||
// Wait for watcher to detect and process changes (spec requires detection within 5 seconds)
|
||||
// We use a more aggressive timeout for testing
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify config was reloaded
|
||||
reloadedCfg := Get()
|
||||
if reloadedCfg.Logging.Level != "debug" {
|
||||
t.Errorf("expected config hot reload, got logging.level %s instead of debug", reloadedCfg.Logging.Level)
|
||||
}
|
||||
if reloadedCfg.Server.Address != ":8080" {
|
||||
t.Errorf("expected config hot reload, got server.address %s instead of :8080", reloadedCfg.Server.Address)
|
||||
}
|
||||
if reloadedCfg.Redis.PoolSize != 20 {
|
||||
t.Errorf("expected config hot reload, got redis.pool_size %d instead of 20", reloadedCfg.Redis.PoolSize)
|
||||
}
|
||||
|
||||
// Cancel context to stop watcher
|
||||
cancel()
|
||||
|
||||
// Give watcher time to shut down gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestWatch_InvalidConfigRejected tests that invalid config changes are rejected
|
||||
func TestWatch_InvalidConfigRejected(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Initial valid config
|
||||
validContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
prefork: false
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
dial_timeout: "5s"
|
||||
read_timeout: "3s"
|
||||
write_timeout: "3s"
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
development: false
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
enable_rate_limiter: false
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
// Set config path
|
||||
_ = os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
|
||||
|
||||
// Load initial config
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load initial config: %v", err)
|
||||
}
|
||||
|
||||
initialLevel := cfg.Logging.Level
|
||||
if initialLevel != "info" {
|
||||
t.Fatalf("expected initial logging.level info, got %s", initialLevel)
|
||||
}
|
||||
|
||||
// Create logger for testing
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// Start watcher
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go Watch(ctx, logger)
|
||||
|
||||
// Give watcher time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Write INVALID config (malformed YAML)
|
||||
invalidContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
invalid yaml syntax here!!!
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(invalidContent), 0644); err != nil {
|
||||
t.Fatalf("failed to write invalid config: %v", err)
|
||||
}
|
||||
|
||||
// Wait for watcher to detect changes
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify config was NOT changed (should keep previous valid config)
|
||||
currentCfg := Get()
|
||||
if currentCfg.Logging.Level != initialLevel {
|
||||
t.Errorf("expected config to remain unchanged after invalid update, got logging.level %s instead of %s", currentCfg.Logging.Level, initialLevel)
|
||||
}
|
||||
|
||||
// Restore valid config
|
||||
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
|
||||
t.Fatalf("failed to restore valid config: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Now write config with validation error (timeout out of range)
|
||||
invalidValidationContent := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "1s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "debug"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(invalidValidationContent), 0644); err != nil {
|
||||
t.Fatalf("failed to write config with validation error: %v", err)
|
||||
}
|
||||
|
||||
// Wait for watcher to detect changes
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Verify config was NOT changed (validation should have failed)
|
||||
finalCfg := Get()
|
||||
if finalCfg.Logging.Level != initialLevel {
|
||||
t.Errorf("expected config to remain unchanged after validation error, got logging.level %s instead of %s", finalCfg.Logging.Level, initialLevel)
|
||||
}
|
||||
if finalCfg.Server.ReadTimeout == 1*time.Second {
|
||||
t.Error("expected config to remain unchanged, but read_timeout was updated to invalid value")
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestWatch_ContextCancellation tests graceful shutdown on context cancellation
|
||||
func TestWatch_ContextCancellation(t *testing.T) {
|
||||
// Reset viper
|
||||
viper.Reset()
|
||||
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
content := `
|
||||
server:
|
||||
address: ":3000"
|
||||
read_timeout: "10s"
|
||||
write_timeout: "10s"
|
||||
shutdown_timeout: "30s"
|
||||
|
||||
redis:
|
||||
address: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
pool_size: 10
|
||||
min_idle_conns: 5
|
||||
|
||||
logging:
|
||||
level: "info"
|
||||
app_log:
|
||||
filename: "logs/app.log"
|
||||
max_size: 100
|
||||
max_backups: 30
|
||||
max_age: 30
|
||||
compress: true
|
||||
access_log:
|
||||
filename: "logs/access.log"
|
||||
max_size: 500
|
||||
max_backups: 90
|
||||
max_age: 90
|
||||
compress: true
|
||||
|
||||
middleware:
|
||||
enable_auth: true
|
||||
rate_limiter:
|
||||
max: 100
|
||||
expiration: "1m"
|
||||
storage: "memory"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create config file: %v", err)
|
||||
}
|
||||
|
||||
_ = os.Setenv(constants.EnvConfigPath, configFile)
|
||||
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
|
||||
|
||||
// Load config
|
||||
_, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Create logger
|
||||
logger := zap.NewNop() // Use no-op logger for this test
|
||||
|
||||
// Start watcher with context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
Watch(ctx, logger)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Give watcher time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Cancel context (simulate graceful shutdown)
|
||||
cancel()
|
||||
|
||||
// Wait for watcher to stop (should happen quickly)
|
||||
select {
|
||||
case <-done:
|
||||
// Watcher stopped successfully
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("watcher did not stop within timeout after context cancellation")
|
||||
}
|
||||
}
|
||||
@@ -68,6 +68,14 @@ const (
|
||||
CodeCannotAllocateToSelf = 1075 // 不能分配给自己
|
||||
CodeCannotRecallFromSelf = 1076 // 不能从自己回收
|
||||
|
||||
// 对象存储相关错误 (1090-1099)
|
||||
CodeStorageNotConfigured = 1090 // 对象存储服务未配置
|
||||
CodeStorageUploadFailed = 1091 // 文件上传失败
|
||||
CodeStorageDownloadFailed = 1092 // 文件下载失败
|
||||
CodeStorageFileNotFound = 1093 // 文件不存在
|
||||
CodeStorageInvalidPurpose = 1094 // 不支持的文件用途
|
||||
CodeStorageInvalidFileType = 1095 // 不支持的文件类型
|
||||
|
||||
// 服务端错误 (2000-2999) -> 5xx HTTP 状态码
|
||||
CodeInternalError = 2001 // 内部服务器错误
|
||||
CodeDatabaseError = 2002 // 数据库错误
|
||||
@@ -130,6 +138,12 @@ var allErrorCodes = []int{
|
||||
CodeNotDirectSubordinate,
|
||||
CodeCannotAllocateToSelf,
|
||||
CodeCannotRecallFromSelf,
|
||||
CodeStorageNotConfigured,
|
||||
CodeStorageUploadFailed,
|
||||
CodeStorageDownloadFailed,
|
||||
CodeStorageFileNotFound,
|
||||
CodeStorageInvalidPurpose,
|
||||
CodeStorageInvalidFileType,
|
||||
CodeInternalError,
|
||||
CodeDatabaseError,
|
||||
CodeRedisError,
|
||||
@@ -194,6 +208,12 @@ var errorMessages = map[int]string{
|
||||
CodeNotDirectSubordinate: "只能操作直属下级店铺",
|
||||
CodeCannotAllocateToSelf: "不能分配给自己",
|
||||
CodeCannotRecallFromSelf: "不能从自己回收",
|
||||
CodeStorageNotConfigured: "对象存储服务未配置",
|
||||
CodeStorageUploadFailed: "文件上传失败",
|
||||
CodeStorageDownloadFailed: "文件下载失败",
|
||||
CodeStorageFileNotFound: "文件不存在",
|
||||
CodeStorageInvalidPurpose: "不支持的文件用途",
|
||||
CodeStorageInvalidFileType: "不支持的文件类型",
|
||||
CodeInvalidCredentials: "用户名或密码错误",
|
||||
CodeAccountLocked: "账号已锁定",
|
||||
CodePasswordExpired: "密码已过期",
|
||||
|
||||
@@ -4,6 +4,9 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/swaggest/openapi-go/openapi3"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -85,26 +88,37 @@ func (g *Generator) addErrorResponseSchema() {
|
||||
g.Reflector.Spec.ComponentsEns().SchemasEns().WithMapOfSchemaOrRefValuesItem("ErrorResponse", errorSchema)
|
||||
}
|
||||
|
||||
// ptrString 返回字符串指针
|
||||
func ptrString(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// FileUploadField 定义文件上传字段
|
||||
type FileUploadField struct {
|
||||
Name string
|
||||
Description string
|
||||
Required bool
|
||||
}
|
||||
|
||||
// AddOperation 向 OpenAPI 规范中添加一个操作
|
||||
// 参数:
|
||||
// - method: HTTP 方法(GET, POST, PUT, DELETE 等)
|
||||
// - path: API 路径
|
||||
// - summary: 操作摘要
|
||||
// - description: 详细说明,支持 Markdown 语法(可为空)
|
||||
// - input: 请求参数结构体(可为 nil)
|
||||
// - output: 响应结构体(可为 nil)
|
||||
// - tags: 标签列表
|
||||
// - requiresAuth: 是否需要认证
|
||||
func (g *Generator) AddOperation(method, path, summary string, input interface{}, output interface{}, requiresAuth bool, tags ...string) {
|
||||
func (g *Generator) AddOperation(method, path, summary, description string, input interface{}, output interface{}, requiresAuth bool, tags ...string) {
|
||||
op := openapi3.Operation{
|
||||
Summary: &summary,
|
||||
Tags: tags,
|
||||
}
|
||||
|
||||
if description != "" {
|
||||
op.Description = &description
|
||||
}
|
||||
|
||||
// 反射输入 (请求参数/Body)
|
||||
if input != nil {
|
||||
// SetRequest 根据结构体标签自动检测 Body、Query 或 Path 参数
|
||||
@@ -134,6 +148,166 @@ func (g *Generator) AddOperation(method, path, summary string, input interface{}
|
||||
}
|
||||
}
|
||||
|
||||
// AddMultipartOperation 添加支持文件上传的 multipart/form-data 操作
|
||||
func (g *Generator) AddMultipartOperation(method, path, summary, description string, input interface{}, output interface{}, requiresAuth bool, fileFields []FileUploadField, tags ...string) {
|
||||
op := openapi3.Operation{
|
||||
Summary: &summary,
|
||||
Tags: tags,
|
||||
}
|
||||
|
||||
if description != "" {
|
||||
op.Description = &description
|
||||
}
|
||||
|
||||
objectType := openapi3.SchemaType("object")
|
||||
stringType := openapi3.SchemaType("string")
|
||||
integerType := openapi3.SchemaType("integer")
|
||||
binaryFormat := "binary"
|
||||
|
||||
properties := make(map[string]openapi3.SchemaOrRef)
|
||||
var requiredFields []string
|
||||
|
||||
for _, f := range fileFields {
|
||||
properties[f.Name] = openapi3.SchemaOrRef{
|
||||
Schema: &openapi3.Schema{
|
||||
Type: &stringType,
|
||||
Format: &binaryFormat,
|
||||
Description: ptrString(f.Description),
|
||||
},
|
||||
}
|
||||
if f.Required {
|
||||
requiredFields = append(requiredFields, f.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if input != nil {
|
||||
formFields := parseFormFields(input)
|
||||
for _, field := range formFields {
|
||||
var schemaType *openapi3.SchemaType
|
||||
switch field.Type {
|
||||
case "integer":
|
||||
schemaType = &integerType
|
||||
default:
|
||||
schemaType = &stringType
|
||||
}
|
||||
schema := &openapi3.Schema{
|
||||
Type: schemaType,
|
||||
Description: ptrString(field.Description),
|
||||
}
|
||||
if field.Min != nil {
|
||||
schema.Minimum = field.Min
|
||||
}
|
||||
if field.MaxLength != nil {
|
||||
schema.MaxLength = field.MaxLength
|
||||
}
|
||||
properties[field.Name] = openapi3.SchemaOrRef{Schema: schema}
|
||||
if field.Required {
|
||||
requiredFields = append(requiredFields, field.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
op.RequestBody = &openapi3.RequestBodyOrRef{
|
||||
RequestBody: &openapi3.RequestBody{
|
||||
Required: ptrBool(true),
|
||||
Content: map[string]openapi3.MediaType{
|
||||
"multipart/form-data": {
|
||||
Schema: &openapi3.SchemaOrRef{
|
||||
Schema: &openapi3.Schema{
|
||||
Type: &objectType,
|
||||
Properties: properties,
|
||||
Required: requiredFields,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if output != nil {
|
||||
if err := g.Reflector.SetJSONResponse(&op, output, 200); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
if requiresAuth {
|
||||
g.addSecurityRequirement(&op)
|
||||
}
|
||||
|
||||
g.addStandardErrorResponses(&op, requiresAuth)
|
||||
|
||||
if err := g.Reflector.Spec.AddOperation(method, path, op); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ptrBool(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
type formFieldInfo struct {
|
||||
Name string
|
||||
Type string
|
||||
Description string
|
||||
Required bool
|
||||
Min *float64
|
||||
MaxLength *int64
|
||||
}
|
||||
|
||||
func parseFormFields(input interface{}) []formFieldInfo {
|
||||
var fields []formFieldInfo
|
||||
t := reflect.TypeOf(input)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return fields
|
||||
}
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
|
||||
formTag := field.Tag.Get("form")
|
||||
if formTag == "" || formTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
info := formFieldInfo{
|
||||
Name: formTag,
|
||||
Description: field.Tag.Get("description"),
|
||||
}
|
||||
|
||||
switch field.Type.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
info.Type = "integer"
|
||||
default:
|
||||
info.Type = "string"
|
||||
}
|
||||
|
||||
validateTag := field.Tag.Get("validate")
|
||||
if strings.Contains(validateTag, "required") {
|
||||
info.Required = true
|
||||
}
|
||||
|
||||
if minStr := field.Tag.Get("minimum"); minStr != "" {
|
||||
if min, err := strconv.ParseFloat(minStr, 64); err == nil {
|
||||
info.Min = &min
|
||||
}
|
||||
}
|
||||
|
||||
if maxLenStr := field.Tag.Get("maxLength"); maxLenStr != "" {
|
||||
if maxLen, err := strconv.ParseInt(maxLenStr, 10, 64); err == nil {
|
||||
info.MaxLength = &maxLen
|
||||
}
|
||||
}
|
||||
|
||||
fields = append(fields, info)
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
// addSecurityRequirement 为操作添加认证要求
|
||||
func (g *Generator) addSecurityRequirement(op *openapi3.Operation) {
|
||||
op.Security = []map[string][]string{
|
||||
|
||||
@@ -9,23 +9,24 @@ import (
|
||||
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
|
||||
"github.com/break/junhong_cmp_fiber/internal/task"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/constants"
|
||||
"github.com/break/junhong_cmp_fiber/pkg/storage"
|
||||
)
|
||||
|
||||
// Handler 任务处理器注册
|
||||
type Handler struct {
|
||||
mux *asynq.ServeMux
|
||||
logger *zap.Logger
|
||||
db *gorm.DB
|
||||
redis *redis.Client
|
||||
mux *asynq.ServeMux
|
||||
logger *zap.Logger
|
||||
db *gorm.DB
|
||||
redis *redis.Client
|
||||
storage *storage.Service
|
||||
}
|
||||
|
||||
// NewHandler 创建任务处理器
|
||||
func NewHandler(db *gorm.DB, redis *redis.Client, logger *zap.Logger) *Handler {
|
||||
func NewHandler(db *gorm.DB, redis *redis.Client, storageSvc *storage.Service, logger *zap.Logger) *Handler {
|
||||
return &Handler{
|
||||
mux: asynq.NewServeMux(),
|
||||
logger: logger,
|
||||
db: db,
|
||||
redis: redis,
|
||||
mux: asynq.NewServeMux(),
|
||||
logger: logger,
|
||||
db: db,
|
||||
redis: redis,
|
||||
storage: storageSvc,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +54,7 @@ func (h *Handler) RegisterHandlers() *asynq.ServeMux {
|
||||
func (h *Handler) registerIotCardImportHandler() {
|
||||
importTaskStore := postgres.NewIotCardImportTaskStore(h.db, h.redis)
|
||||
iotCardStore := postgres.NewIotCardStore(h.db, h.redis)
|
||||
iotCardImportHandler := task.NewIotCardImportHandler(h.db, h.redis, importTaskStore, iotCardStore, h.logger)
|
||||
iotCardImportHandler := task.NewIotCardImportHandler(h.db, h.redis, importTaskStore, iotCardStore, h.storage, h.logger)
|
||||
|
||||
h.mux.HandleFunc(constants.TaskTypeIotCardImport, iotCardImportHandler.HandleIotCardImport)
|
||||
h.logger.Info("注册 IoT 卡导入任务处理器", zap.String("task_type", constants.TaskTypeIotCardImport))
|
||||
|
||||
184
pkg/storage/s3.go
Normal file
184
pkg/storage/s3.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/config"
|
||||
)
|
||||
|
||||
type S3Provider struct {
|
||||
client *s3.S3
|
||||
uploader *s3manager.Uploader
|
||||
bucket string
|
||||
tempDir string
|
||||
}
|
||||
|
||||
func NewS3Provider(cfg *config.StorageConfig) (*S3Provider, error) {
|
||||
if cfg.S3.Endpoint == "" || cfg.S3.Bucket == "" {
|
||||
return nil, fmt.Errorf("S3 配置不完整:endpoint 和 bucket 必填")
|
||||
}
|
||||
|
||||
if cfg.S3.AccessKeyID == "" || cfg.S3.SecretAccessKey == "" {
|
||||
return nil, fmt.Errorf("S3 凭证未配置:access_key_id 和 secret_access_key 必填")
|
||||
}
|
||||
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Endpoint: aws.String(cfg.S3.Endpoint),
|
||||
Region: aws.String(cfg.S3.Region),
|
||||
Credentials: credentials.NewStaticCredentials(cfg.S3.AccessKeyID, cfg.S3.SecretAccessKey, ""),
|
||||
DisableSSL: aws.Bool(!cfg.S3.UseSSL),
|
||||
S3ForcePathStyle: aws.Bool(cfg.S3.PathStyle),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 S3 session 失败: %w", err)
|
||||
}
|
||||
|
||||
tempDir := cfg.TempDir
|
||||
if tempDir == "" {
|
||||
tempDir = "/tmp/junhong-storage"
|
||||
}
|
||||
|
||||
return &S3Provider{
|
||||
client: s3.New(sess),
|
||||
uploader: s3manager.NewUploader(sess),
|
||||
bucket: cfg.S3.Bucket,
|
||||
tempDir: tempDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *S3Provider) Upload(ctx context.Context, key string, reader io.Reader, contentType string) error {
|
||||
input := &s3manager.UploadInput{
|
||||
Bucket: aws.String(p.bucket),
|
||||
Key: aws.String(key),
|
||||
Body: reader,
|
||||
ContentType: aws.String(contentType),
|
||||
}
|
||||
|
||||
_, err := p.uploader.UploadWithContext(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("上传文件失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *S3Provider) Download(ctx context.Context, key string, writer io.Writer) error {
|
||||
input := &s3.GetObjectInput{
|
||||
Bucket: aws.String(p.bucket),
|
||||
Key: aws.String(key),
|
||||
}
|
||||
|
||||
result, err := p.client.GetObjectWithContext(ctx, input)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "NoSuchKey") {
|
||||
return fmt.Errorf("文件不存在: %s", key)
|
||||
}
|
||||
return fmt.Errorf("下载文件失败: %w", err)
|
||||
}
|
||||
defer result.Body.Close()
|
||||
|
||||
_, err = io.Copy(writer, result.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("写入文件内容失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *S3Provider) DownloadToTemp(ctx context.Context, key string) (string, func(), error) {
|
||||
ext := filepath.Ext(key)
|
||||
if ext == "" {
|
||||
ext = ".tmp"
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(p.tempDir, "download-*"+ext)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("创建临时文件失败: %w", err)
|
||||
}
|
||||
tempPath := tempFile.Name()
|
||||
|
||||
cleanup := func() {
|
||||
_ = os.Remove(tempPath)
|
||||
}
|
||||
|
||||
if err := p.Download(ctx, key, tempFile); err != nil {
|
||||
tempFile.Close()
|
||||
cleanup()
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if err := tempFile.Close(); err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("关闭临时文件失败: %w", err)
|
||||
}
|
||||
|
||||
return tempPath, cleanup, nil
|
||||
}
|
||||
|
||||
func (p *S3Provider) Delete(ctx context.Context, key string) error {
|
||||
input := &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(p.bucket),
|
||||
Key: aws.String(key),
|
||||
}
|
||||
|
||||
_, err := p.client.DeleteObjectWithContext(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除文件失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *S3Provider) Exists(ctx context.Context, key string) (bool, error) {
|
||||
input := &s3.HeadObjectInput{
|
||||
Bucket: aws.String(p.bucket),
|
||||
Key: aws.String(key),
|
||||
}
|
||||
|
||||
_, err := p.client.HeadObjectWithContext(ctx, input)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "NotFound") || strings.Contains(err.Error(), "404") {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("检查文件存在性失败: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *S3Provider) GetUploadURL(ctx context.Context, key string, contentType string, expires time.Duration) (string, error) {
|
||||
input := &s3.PutObjectInput{
|
||||
Bucket: aws.String(p.bucket),
|
||||
Key: aws.String(key),
|
||||
ContentType: aws.String(contentType),
|
||||
}
|
||||
|
||||
req, _ := p.client.PutObjectRequest(input)
|
||||
url, err := req.Presign(expires)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("生成上传预签名 URL 失败: %w", err)
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
func (p *S3Provider) GetDownloadURL(ctx context.Context, key string, expires time.Duration) (string, error) {
|
||||
input := &s3.GetObjectInput{
|
||||
Bucket: aws.String(p.bucket),
|
||||
Key: aws.String(key),
|
||||
}
|
||||
|
||||
req, _ := p.client.GetObjectRequest(input)
|
||||
url, err := req.Presign(expires)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("生成下载预签名 URL 失败: %w", err)
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
112
pkg/storage/service.go
Normal file
112
pkg/storage/service.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/break/junhong_cmp_fiber/pkg/config"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
provider Provider
|
||||
config *config.StorageConfig
|
||||
}
|
||||
|
||||
func NewService(provider Provider, cfg *config.StorageConfig) *Service {
|
||||
return &Service{
|
||||
provider: provider,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) GenerateFileKey(purpose, fileName string) (string, error) {
|
||||
mapping, ok := PurposeMappings[purpose]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("不支持的文件用途: %s", purpose)
|
||||
}
|
||||
|
||||
ext := filepath.Ext(fileName)
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
id := uuid.New().String()
|
||||
|
||||
key := fmt.Sprintf("%s/%04d/%02d/%02d/%s%s",
|
||||
mapping.Prefix,
|
||||
now.Year(),
|
||||
now.Month(),
|
||||
now.Day(),
|
||||
id,
|
||||
strings.ToLower(ext),
|
||||
)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetUploadURL(ctx context.Context, purpose, fileName, contentType string) (*PresignResult, error) {
|
||||
fileKey, err := s.GenerateFileKey(purpose, fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if contentType == "" {
|
||||
if mapping, ok := PurposeMappings[purpose]; ok && mapping.ContentType != "" {
|
||||
contentType = mapping.ContentType
|
||||
} else {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
}
|
||||
|
||||
expires := s.config.Presign.UploadExpires
|
||||
if expires == 0 {
|
||||
expires = 15 * time.Minute
|
||||
}
|
||||
|
||||
url, err := s.provider.GetUploadURL(ctx, fileKey, contentType, expires)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PresignResult{
|
||||
URL: url,
|
||||
FileKey: fileKey,
|
||||
ExpiresIn: int(expires.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetDownloadURL(ctx context.Context, fileKey string) (*PresignResult, error) {
|
||||
expires := s.config.Presign.DownloadExpires
|
||||
if expires == 0 {
|
||||
expires = 24 * time.Hour
|
||||
}
|
||||
|
||||
url, err := s.provider.GetDownloadURL(ctx, fileKey, expires)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PresignResult{
|
||||
URL: url,
|
||||
FileKey: fileKey,
|
||||
ExpiresIn: int(expires.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) DownloadToTemp(ctx context.Context, fileKey string) (string, func(), error) {
|
||||
return s.provider.DownloadToTemp(ctx, fileKey)
|
||||
}
|
||||
|
||||
func (s *Service) Provider() Provider {
|
||||
return s.provider
|
||||
}
|
||||
|
||||
func (s *Service) Bucket() string {
|
||||
return s.config.S3.Bucket
|
||||
}
|
||||
17
pkg/storage/storage.go
Normal file
17
pkg/storage/storage.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
Upload(ctx context.Context, key string, reader io.Reader, contentType string) error
|
||||
Download(ctx context.Context, key string, writer io.Writer) error
|
||||
DownloadToTemp(ctx context.Context, key string) (localPath string, cleanup func(), err error)
|
||||
Delete(ctx context.Context, key string) error
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
GetUploadURL(ctx context.Context, key string, contentType string, expires time.Duration) (string, error)
|
||||
GetDownloadURL(ctx context.Context, key string, expires time.Duration) (string, error)
|
||||
}
|
||||
18
pkg/storage/types.go
Normal file
18
pkg/storage/types.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package storage
|
||||
|
||||
type PresignResult struct {
|
||||
URL string `json:"url"`
|
||||
FileKey string `json:"file_key"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type PurposeMapping struct {
|
||||
Prefix string
|
||||
ContentType string
|
||||
}
|
||||
|
||||
var PurposeMappings = map[string]PurposeMapping{
|
||||
"iot_import": {Prefix: "imports", ContentType: "text/csv"},
|
||||
"export": {Prefix: "exports", ContentType: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"},
|
||||
"attachment": {Prefix: "attachments", ContentType: ""},
|
||||
}
|
||||
@@ -2,33 +2,49 @@ package utils
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CardInfo 卡信息(ICCID + MSISDN)
|
||||
type CardInfo struct {
|
||||
ICCID string
|
||||
MSISDN string
|
||||
}
|
||||
|
||||
// CSVParseResult CSV 解析结果
|
||||
type CSVParseResult struct {
|
||||
ICCIDs []string
|
||||
Cards []CardInfo
|
||||
TotalCount int
|
||||
ParseErrors []CSVParseError
|
||||
}
|
||||
|
||||
// CSVParseError CSV 解析错误
|
||||
type CSVParseError struct {
|
||||
Line int
|
||||
ICCID string
|
||||
MSISDN string
|
||||
Reason string
|
||||
}
|
||||
|
||||
func ParseICCIDFromCSV(reader io.Reader) (*CSVParseResult, error) {
|
||||
// ErrInvalidCSVFormat CSV 格式错误
|
||||
var ErrInvalidCSVFormat = errors.New("CSV 文件格式错误:缺少 MSISDN 列,文件必须包含 ICCID 和 MSISDN 两列")
|
||||
|
||||
// ParseCardCSV 解析包含 ICCID 和 MSISDN 两列的 CSV 文件
|
||||
func ParseCardCSV(reader io.Reader) (*CSVParseResult, error) {
|
||||
csvReader := csv.NewReader(reader)
|
||||
csvReader.FieldsPerRecord = -1
|
||||
csvReader.TrimLeadingSpace = true
|
||||
|
||||
result := &CSVParseResult{
|
||||
ICCIDs: make([]string, 0),
|
||||
Cards: make([]CardInfo, 0),
|
||||
ParseErrors: make([]CSVParseError, 0),
|
||||
}
|
||||
|
||||
lineNum := 0
|
||||
headerSkipped := false
|
||||
|
||||
for {
|
||||
record, err := csvReader.Read()
|
||||
if err == io.EOF {
|
||||
@@ -48,23 +64,69 @@ func ParseICCIDFromCSV(reader io.Reader) (*CSVParseResult, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
iccid := strings.TrimSpace(record[0])
|
||||
if iccid == "" {
|
||||
if len(record) < 2 {
|
||||
if lineNum == 1 && !headerSkipped {
|
||||
firstCol := strings.TrimSpace(record[0])
|
||||
if isICCIDHeader(firstCol) {
|
||||
return nil, ErrInvalidCSVFormat
|
||||
}
|
||||
}
|
||||
result.ParseErrors = append(result.ParseErrors, CSVParseError{
|
||||
Line: lineNum,
|
||||
ICCID: strings.TrimSpace(record[0]),
|
||||
Reason: "列数不足:缺少 MSISDN 列",
|
||||
})
|
||||
result.TotalCount++
|
||||
continue
|
||||
}
|
||||
|
||||
if lineNum == 1 && isHeader(iccid) {
|
||||
iccid := strings.TrimSpace(record[0])
|
||||
msisdn := strings.TrimSpace(record[1])
|
||||
|
||||
if lineNum == 1 && !headerSkipped && isHeader(iccid, msisdn) {
|
||||
headerSkipped = true
|
||||
continue
|
||||
}
|
||||
|
||||
result.TotalCount++
|
||||
result.ICCIDs = append(result.ICCIDs, iccid)
|
||||
|
||||
if iccid == "" {
|
||||
result.ParseErrors = append(result.ParseErrors, CSVParseError{
|
||||
Line: lineNum,
|
||||
MSISDN: msisdn,
|
||||
Reason: "ICCID 不能为空",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if msisdn == "" {
|
||||
result.ParseErrors = append(result.ParseErrors, CSVParseError{
|
||||
Line: lineNum,
|
||||
ICCID: iccid,
|
||||
Reason: "MSISDN 不能为空",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
result.Cards = append(result.Cards, CardInfo{
|
||||
ICCID: iccid,
|
||||
MSISDN: msisdn,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func isHeader(value string) bool {
|
||||
func isICCIDHeader(value string) bool {
|
||||
lower := strings.ToLower(value)
|
||||
return lower == "iccid" || lower == "卡号" || lower == "号码"
|
||||
}
|
||||
|
||||
func isMSISDNHeader(value string) bool {
|
||||
lower := strings.ToLower(value)
|
||||
return lower == "msisdn" || lower == "接入号" || lower == "手机号" || lower == "电话" || lower == "号码"
|
||||
}
|
||||
|
||||
func isHeader(col1, col2 string) bool {
|
||||
return isICCIDHeader(col1) && isMSISDNHeader(col2)
|
||||
}
|
||||
|
||||
@@ -8,89 +8,133 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseICCIDFromCSV(t *testing.T) {
|
||||
func TestParseCardCSV(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
csvContent string
|
||||
wantICCIDs []string
|
||||
wantCards []CardInfo
|
||||
wantTotalCount int
|
||||
wantErrorCount int
|
||||
wantError error
|
||||
}{
|
||||
{
|
||||
name: "单列ICCID无表头",
|
||||
csvContent: "89860012345678901234\n89860012345678901235\n89860012345678901236",
|
||||
wantICCIDs: []string{"89860012345678901234", "89860012345678901235", "89860012345678901236"},
|
||||
wantTotalCount: 3,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
{
|
||||
name: "单列ICCID有表头-iccid",
|
||||
csvContent: "iccid\n89860012345678901234\n89860012345678901235",
|
||||
wantICCIDs: []string{"89860012345678901234", "89860012345678901235"},
|
||||
name: "标准双列无表头",
|
||||
csvContent: "89860012345678901234,13800000001\n89860012345678901235,13800000002",
|
||||
wantCards: []CardInfo{
|
||||
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
|
||||
{ICCID: "89860012345678901235", MSISDN: "13800000002"},
|
||||
},
|
||||
wantTotalCount: 2,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
{
|
||||
name: "单列ICCID有表头-ICCID大写",
|
||||
csvContent: "ICCID\n89860012345678901234",
|
||||
wantICCIDs: []string{"89860012345678901234"},
|
||||
name: "标准双列有表头-英文",
|
||||
csvContent: "iccid,msisdn\n89860012345678901234,13800000001\n89860012345678901235,13800000002",
|
||||
wantCards: []CardInfo{
|
||||
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
|
||||
{ICCID: "89860012345678901235", MSISDN: "13800000002"},
|
||||
},
|
||||
wantTotalCount: 2,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
{
|
||||
name: "标准双列有表头-中文",
|
||||
csvContent: "卡号,接入号\n89860012345678901234,13800000001",
|
||||
wantCards: []CardInfo{
|
||||
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
|
||||
},
|
||||
wantTotalCount: 1,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
{
|
||||
name: "单列ICCID有表头-卡号",
|
||||
csvContent: "卡号\n89860012345678901234",
|
||||
wantICCIDs: []string{"89860012345678901234"},
|
||||
name: "标准双列有表头-手机号",
|
||||
csvContent: "ICCID,手机号\n89860012345678901234,13800000001",
|
||||
wantCards: []CardInfo{
|
||||
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
|
||||
},
|
||||
wantTotalCount: 1,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
{
|
||||
name: "单列ICCID有表头-号码",
|
||||
csvContent: "号码\n89860012345678901234",
|
||||
wantICCIDs: []string{"89860012345678901234"},
|
||||
wantTotalCount: 1,
|
||||
name: "单列CSV格式拒绝-有表头",
|
||||
csvContent: "iccid\n89860012345678901234",
|
||||
wantCards: nil,
|
||||
wantTotalCount: 0,
|
||||
wantErrorCount: 0,
|
||||
wantError: ErrInvalidCSVFormat,
|
||||
},
|
||||
{
|
||||
name: "单列CSV格式-无表头记录错误",
|
||||
csvContent: "89860012345678901234\n89860012345678901235",
|
||||
wantCards: []CardInfo{},
|
||||
wantTotalCount: 2,
|
||||
wantErrorCount: 2,
|
||||
},
|
||||
{
|
||||
name: "MSISDN为空记录失败",
|
||||
csvContent: "iccid,msisdn\n89860012345678901234,13800000001\n89860012345678901235,",
|
||||
wantCards: []CardInfo{{ICCID: "89860012345678901234", MSISDN: "13800000001"}},
|
||||
wantTotalCount: 2,
|
||||
wantErrorCount: 1,
|
||||
},
|
||||
{
|
||||
name: "ICCID为空记录失败",
|
||||
csvContent: "iccid,msisdn\n89860012345678901234,13800000001\n,13800000002",
|
||||
wantCards: []CardInfo{{ICCID: "89860012345678901234", MSISDN: "13800000001"}},
|
||||
wantTotalCount: 2,
|
||||
wantErrorCount: 1,
|
||||
},
|
||||
{
|
||||
name: "空文件",
|
||||
csvContent: "",
|
||||
wantICCIDs: []string{},
|
||||
wantCards: []CardInfo{},
|
||||
wantTotalCount: 0,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
{
|
||||
name: "只有表头",
|
||||
csvContent: "iccid",
|
||||
wantICCIDs: []string{},
|
||||
csvContent: "iccid,msisdn",
|
||||
wantCards: []CardInfo{},
|
||||
wantTotalCount: 0,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
{
|
||||
name: "包含空行",
|
||||
csvContent: "89860012345678901234\n\n89860012345678901235\n \n89860012345678901236",
|
||||
wantICCIDs: []string{"89860012345678901234", "89860012345678901235", "89860012345678901236"},
|
||||
wantTotalCount: 3,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
{
|
||||
name: "ICCID前后有空格",
|
||||
csvContent: " 89860012345678901234 \n89860012345678901235",
|
||||
wantICCIDs: []string{"89860012345678901234", "89860012345678901235"},
|
||||
name: "包含空行",
|
||||
csvContent: "89860012345678901234,13800000001\n\n89860012345678901235,13800000002",
|
||||
wantCards: []CardInfo{
|
||||
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
|
||||
{ICCID: "89860012345678901235", MSISDN: "13800000002"},
|
||||
},
|
||||
wantTotalCount: 2,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
{
|
||||
name: "多列CSV只取第一列",
|
||||
csvContent: "89860012345678901234,额外数据,更多数据\n89860012345678901235,忽略,忽略",
|
||||
wantICCIDs: []string{"89860012345678901234", "89860012345678901235"},
|
||||
name: "ICCID和MSISDN前后有空格",
|
||||
csvContent: " 89860012345678901234 , 13800000001 ",
|
||||
wantCards: []CardInfo{
|
||||
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
|
||||
},
|
||||
wantTotalCount: 1,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
{
|
||||
name: "多于两列只取前两列",
|
||||
csvContent: "89860012345678901234,13800000001,额外数据\n89860012345678901235,13800000002,忽略",
|
||||
wantCards: []CardInfo{
|
||||
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
|
||||
{ICCID: "89860012345678901235", MSISDN: "13800000002"},
|
||||
},
|
||||
wantTotalCount: 2,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
{
|
||||
name: "Windows换行符CRLF",
|
||||
csvContent: "89860012345678901234\r\n89860012345678901235\r\n89860012345678901236",
|
||||
wantICCIDs: []string{"89860012345678901234", "89860012345678901235", "89860012345678901236"},
|
||||
wantTotalCount: 3,
|
||||
name: "Windows换行符CRLF",
|
||||
csvContent: "89860012345678901234,13800000001\r\n89860012345678901235,13800000002",
|
||||
wantCards: []CardInfo{
|
||||
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
|
||||
{ICCID: "89860012345678901235", MSISDN: "13800000002"},
|
||||
},
|
||||
wantTotalCount: 2,
|
||||
wantErrorCount: 0,
|
||||
},
|
||||
}
|
||||
@@ -98,34 +142,78 @@ func TestParseICCIDFromCSV(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := strings.NewReader(tt.csvContent)
|
||||
result, err := ParseICCIDFromCSV(reader)
|
||||
result, err := ParseCardCSV(reader)
|
||||
|
||||
if tt.wantError != nil {
|
||||
require.ErrorIs(t, err, tt.wantError)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantICCIDs, result.ICCIDs, "ICCIDs 不匹配")
|
||||
assert.Equal(t, tt.wantCards, result.Cards, "Cards 不匹配")
|
||||
assert.Equal(t, tt.wantTotalCount, result.TotalCount, "TotalCount 不匹配")
|
||||
assert.Equal(t, tt.wantErrorCount, len(result.ParseErrors), "ParseErrors 数量不匹配")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCardCSV_ErrorDetails(t *testing.T) {
|
||||
t.Run("MSISDN为空时记录详细错误", func(t *testing.T) {
|
||||
csvContent := "iccid,msisdn\n89860012345678901234,"
|
||||
reader := strings.NewReader(csvContent)
|
||||
result, err := ParseCardCSV(reader)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.ParseErrors, 1)
|
||||
assert.Equal(t, 2, result.ParseErrors[0].Line)
|
||||
assert.Equal(t, "89860012345678901234", result.ParseErrors[0].ICCID)
|
||||
assert.Equal(t, "MSISDN 不能为空", result.ParseErrors[0].Reason)
|
||||
})
|
||||
|
||||
t.Run("ICCID为空时记录详细错误", func(t *testing.T) {
|
||||
csvContent := "iccid,msisdn\n,13800000001"
|
||||
reader := strings.NewReader(csvContent)
|
||||
result, err := ParseCardCSV(reader)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.ParseErrors, 1)
|
||||
assert.Equal(t, 2, result.ParseErrors[0].Line)
|
||||
assert.Equal(t, "13800000001", result.ParseErrors[0].MSISDN)
|
||||
assert.Equal(t, "ICCID 不能为空", result.ParseErrors[0].Reason)
|
||||
})
|
||||
|
||||
t.Run("列数不足时记录详细错误", func(t *testing.T) {
|
||||
csvContent := "89860012345678901234"
|
||||
reader := strings.NewReader(csvContent)
|
||||
result, err := ParseCardCSV(reader)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.ParseErrors, 1)
|
||||
assert.Equal(t, 1, result.ParseErrors[0].Line)
|
||||
assert.Equal(t, "89860012345678901234", result.ParseErrors[0].ICCID)
|
||||
assert.Contains(t, result.ParseErrors[0].Reason, "列数不足")
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
value string
|
||||
col1 string
|
||||
col2 string
|
||||
expected bool
|
||||
}{
|
||||
{"iccid", true},
|
||||
{"ICCID", true},
|
||||
{"Iccid", true},
|
||||
{"卡号", true},
|
||||
{"号码", true},
|
||||
{"89860012345678901234", false},
|
||||
{"", false},
|
||||
{"id", false},
|
||||
{"card", false},
|
||||
{"iccid", "msisdn", true},
|
||||
{"ICCID", "MSISDN", true},
|
||||
{"卡号", "接入号", true},
|
||||
{"号码", "手机号", true},
|
||||
{"iccid", "电话", true},
|
||||
{"89860012345678901234", "13800000001", false},
|
||||
{"iccid", "", false},
|
||||
{"", "msisdn", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.value, func(t *testing.T) {
|
||||
result := isHeader(tt.value)
|
||||
t.Run(tt.col1+"_"+tt.col2, func(t *testing.T) {
|
||||
result := isHeader(tt.col1, tt.col2)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user