feat: 添加环境变量管理工具和部署配置改版
All checks were successful
构建并部署到测试环境(无 SSH) / build-and-deploy (push) Successful in 5m33s

主要改动:
- 新增交互式环境配置脚本 (scripts/setup-env.sh)
- 新增本地启动快捷脚本 (scripts/run-local.sh)
- 新增环境变量模板文件 (.env.example)
- 部署模式改版:使用嵌入式配置 + 环境变量覆盖
- 添加对象存储功能支持
- 改进 IoT 卡片导入任务
- 优化 OpenAPI 文档生成
- 删除旧的配置文件,改用嵌入式默认配置
This commit is contained in:
2026-01-26 10:28:29 +08:00
parent 194078674a
commit 45aa7deb87
94 changed files with 6532 additions and 1967 deletions

View File

@@ -0,0 +1,69 @@
package bootstrap
import (
"fmt"
"os"
"path/filepath"
"github.com/break/junhong_cmp_fiber/pkg/config"
"go.uber.org/zap"
)
type DirectoryResult struct {
TempDir string
AppLogDir string
AccessLogDir string
Fallbacks []string
}
func EnsureDirectories(cfg *config.Config, logger *zap.Logger) (*DirectoryResult, error) {
result := &DirectoryResult{}
directories := []struct {
path string
configKey string
resultPtr *string
}{
{cfg.Storage.TempDir, "storage.temp_dir", &result.TempDir},
{filepath.Dir(cfg.Logging.AppLog.Filename), "logging.app_log.filename", &result.AppLogDir},
{filepath.Dir(cfg.Logging.AccessLog.Filename), "logging.access_log.filename", &result.AccessLogDir},
}
for _, dir := range directories {
if dir.path == "" || dir.path == "." {
continue
}
actualPath, fallback, err := ensureDirectory(dir.path, logger)
if err != nil {
return nil, fmt.Errorf("创建目录 %s (%s) 失败: %w", dir.path, dir.configKey, err)
}
*dir.resultPtr = actualPath
if fallback {
result.Fallbacks = append(result.Fallbacks, actualPath)
}
}
return result, nil
}
func ensureDirectory(path string, logger *zap.Logger) (actualPath string, fallback bool, err error) {
if err := os.MkdirAll(path, 0755); err != nil {
if os.IsPermission(err) {
fallbackPath := filepath.Join(os.TempDir(), "junhong", filepath.Base(path))
if mkErr := os.MkdirAll(fallbackPath, 0755); mkErr != nil {
return "", false, fmt.Errorf("原路径 %s 权限不足,降级路径 %s 也创建失败: %w", path, fallbackPath, mkErr)
}
if logger != nil {
logger.Warn("目录权限不足,使用降级路径",
zap.String("original", path),
zap.String("fallback", fallbackPath),
)
}
return fallbackPath, true, nil
}
return "", false, err
}
return path, false, nil
}

View File

@@ -0,0 +1,100 @@
package bootstrap
import (
"os"
"path/filepath"
"testing"
"github.com/break/junhong_cmp_fiber/pkg/config"
)
func TestEnsureDirectories_Success(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Storage: config.StorageConfig{
TempDir: filepath.Join(tmpDir, "storage"),
},
Logging: config.LoggingConfig{
AppLog: config.LogRotationConfig{Filename: filepath.Join(tmpDir, "logs", "app.log")},
AccessLog: config.LogRotationConfig{Filename: filepath.Join(tmpDir, "logs", "access.log")},
},
}
result, err := EnsureDirectories(cfg, nil)
if err != nil {
t.Fatalf("EnsureDirectories() 失败: %v", err)
}
if result.TempDir != cfg.Storage.TempDir {
t.Errorf("TempDir 期望 %s, 实际 %s", cfg.Storage.TempDir, result.TempDir)
}
if result.AppLogDir != filepath.Join(tmpDir, "logs") {
t.Errorf("AppLogDir 期望 %s, 实际 %s", filepath.Join(tmpDir, "logs"), result.AppLogDir)
}
if _, err := os.Stat(result.TempDir); os.IsNotExist(err) {
t.Error("TempDir 目录未创建")
}
if _, err := os.Stat(result.AppLogDir); os.IsNotExist(err) {
t.Error("AppLogDir 目录未创建")
}
}
func TestEnsureDirectories_ExistingDirs(t *testing.T) {
tmpDir := t.TempDir()
storageDir := filepath.Join(tmpDir, "storage")
os.MkdirAll(storageDir, 0755)
cfg := &config.Config{
Storage: config.StorageConfig{TempDir: storageDir},
Logging: config.LoggingConfig{
AppLog: config.LogRotationConfig{Filename: filepath.Join(tmpDir, "logs", "app.log")},
AccessLog: config.LogRotationConfig{Filename: filepath.Join(tmpDir, "logs", "access.log")},
},
}
result, err := EnsureDirectories(cfg, nil)
if err != nil {
t.Fatalf("EnsureDirectories() 失败: %v", err)
}
if result.TempDir != storageDir {
t.Errorf("已存在目录应返回原路径")
}
}
func TestEnsureDirectories_EmptyPaths(t *testing.T) {
cfg := &config.Config{
Storage: config.StorageConfig{TempDir: ""},
Logging: config.LoggingConfig{
AppLog: config.LogRotationConfig{Filename: ""},
AccessLog: config.LogRotationConfig{Filename: ""},
},
}
result, err := EnsureDirectories(cfg, nil)
if err != nil {
t.Fatalf("EnsureDirectories() 空路径时不应失败: %v", err)
}
if len(result.Fallbacks) != 0 {
t.Error("空路径不应产生降级")
}
}
func TestEnsureDirectory_Fallback(t *testing.T) {
path, fallback, err := ensureDirectory("/root/no_permission_dir_test_"+t.Name(), nil)
if err != nil {
if os.Getuid() == 0 {
t.Skip("以 root 身份运行,跳过权限测试")
}
t.Skip("无法测试权限降级场景")
}
if fallback {
if !filepath.HasPrefix(path, os.TempDir()) {
t.Errorf("降级路径应在临时目录下,实际: %s", path)
}
os.RemoveAll(path)
}
}

View File

@@ -3,6 +3,7 @@ package config
import (
"errors"
"fmt"
"strings"
"sync/atomic"
"time"
)
@@ -21,6 +22,7 @@ type Config struct {
SMS SMSConfig `mapstructure:"sms"`
JWT JWTConfig `mapstructure:"jwt"`
DefaultAdmin DefaultAdminConfig `mapstructure:"default_admin"`
Storage StorageConfig `mapstructure:"storage"`
}
// ServerConfig HTTP 服务器配置
@@ -120,7 +122,61 @@ type DefaultAdminConfig struct {
Phone string `mapstructure:"phone"`
}
// Validate 验证配置
// StorageConfig 对象存储配置
type StorageConfig struct {
Provider string `mapstructure:"provider"` // 存储提供商s3
S3 S3Config `mapstructure:"s3"` // S3 兼容存储配置
Presign PresignConfig `mapstructure:"presign"` // 预签名 URL 配置
TempDir string `mapstructure:"temp_dir"` // 临时文件目录
}
// S3Config S3 兼容存储配置
type S3Config struct {
Endpoint string `mapstructure:"endpoint"` // 服务端点http://obs-helf.cucloud.cn
Region string `mapstructure:"region"` // 区域cn-langfang-2
Bucket string `mapstructure:"bucket"` // 存储桶名称
AccessKeyID string `mapstructure:"access_key_id"` // 访问密钥 ID
SecretAccessKey string `mapstructure:"secret_access_key"` // 访问密钥
UseSSL bool `mapstructure:"use_ssl"` // 是否使用 SSL
PathStyle bool `mapstructure:"path_style"` // 是否使用路径风格(兼容性)
}
// PresignConfig 预签名 URL 配置
type PresignConfig struct {
UploadExpires time.Duration `mapstructure:"upload_expires"` // 上传 URL 有效期默认15m
DownloadExpires time.Duration `mapstructure:"download_expires"` // 下载 URL 有效期默认24h
}
type requiredField struct {
value string
name string
envName string
}
func (c *Config) ValidateRequired() error {
fields := []requiredField{
{c.Database.Host, "database.host", "JUNHONG_DATABASE_HOST"},
{c.Database.User, "database.user", "JUNHONG_DATABASE_USER"},
{c.Database.Password, "database.password", "JUNHONG_DATABASE_PASSWORD"},
{c.Database.DBName, "database.dbname", "JUNHONG_DATABASE_DBNAME"},
{c.Redis.Address, "redis.address", "JUNHONG_REDIS_ADDRESS"},
{c.JWT.SecretKey, "jwt.secret_key", "JUNHONG_JWT_SECRET_KEY"},
}
var missing []string
for _, f := range fields {
if f.value == "" {
missing = append(missing, fmt.Sprintf(" - %s (环境变量: %s)", f.name, f.envName))
}
}
if len(missing) > 0 {
return fmt.Errorf("缺少必填配置项:\n%s", strings.Join(missing, "\n"))
}
return nil
}
func (c *Config) Validate() error {
// 服务器验证
if c.Server.Address == "" {
@@ -184,28 +240,24 @@ func (c *Config) Validate() error {
return fmt.Errorf("invalid configuration: middleware.rate_limiter.storage: invalid storage type (current value: %s, expected: memory or redis)", c.Middleware.RateLimiter.Storage)
}
// 短信服务验证
if c.SMS.GatewayURL == "" {
return fmt.Errorf("invalid configuration: sms.gateway_url: must be non-empty (current value: empty)")
}
if c.SMS.Username == "" {
return fmt.Errorf("invalid configuration: sms.username: must be non-empty (current value: empty)")
}
if c.SMS.Password == "" {
return fmt.Errorf("invalid configuration: sms.password: must be non-empty (current value: empty)")
}
if c.SMS.Signature == "" {
return fmt.Errorf("invalid configuration: sms.signature: must be non-empty (current value: empty)")
}
if c.SMS.Timeout < 5*time.Second || c.SMS.Timeout > 60*time.Second {
return fmt.Errorf("invalid configuration: sms.timeout: duration out of range (current value: %s, expected: 5s-60s)", c.SMS.Timeout)
// 短信服务验证(可选,配置 GatewayURL 时才验证其他字段)
if c.SMS.GatewayURL != "" {
if c.SMS.Username == "" {
return fmt.Errorf("invalid configuration: sms.username: must be non-empty when gateway_url is configured")
}
if c.SMS.Password == "" {
return fmt.Errorf("invalid configuration: sms.password: must be non-empty when gateway_url is configured")
}
if c.SMS.Signature == "" {
return fmt.Errorf("invalid configuration: sms.signature: must be non-empty when gateway_url is configured")
}
if c.SMS.Timeout > 0 && (c.SMS.Timeout < 5*time.Second || c.SMS.Timeout > 60*time.Second) {
return fmt.Errorf("invalid configuration: sms.timeout: duration out of range (current value: %s, expected: 5s-60s)", c.SMS.Timeout)
}
}
// JWT 验证
if c.JWT.SecretKey == "" {
return fmt.Errorf("invalid configuration: jwt.secret_key: must be non-empty (current value: empty)")
}
if len(c.JWT.SecretKey) < 32 {
// JWT 验证SecretKey 必填验证在 ValidateRequired 中处理)
if len(c.JWT.SecretKey) > 0 && len(c.JWT.SecretKey) < 32 {
return fmt.Errorf("invalid configuration: jwt.secret_key: secret key too short (current length: %d, expected: >= 32)", len(c.JWT.SecretKey))
}
if c.JWT.TokenDuration < 1*time.Hour || c.JWT.TokenDuration > 720*time.Hour {

View File

@@ -51,6 +51,11 @@ func TestConfig_Validate(t *testing.T) {
Storage: "memory",
},
},
JWT: JWTConfig{
TokenDuration: 24 * time.Hour,
AccessTokenTTL: 24 * time.Hour,
RefreshTokenTTL: 168 * time.Hour,
},
},
wantErr: false,
},
@@ -582,6 +587,11 @@ func TestSet(t *testing.T) {
MaxSize: 500,
},
},
JWT: JWTConfig{
TokenDuration: 24 * time.Hour,
AccessTokenTTL: 24 * time.Hour,
RefreshTokenTTL: 168 * time.Hour,
},
}
err := Set(validCfg)

View File

@@ -0,0 +1,106 @@
# 默认配置文件(嵌入二进制)
# 敏感配置和必填配置为空,必须通过环境变量设置
# 环境变量格式JUNHONG_{SECTION}_{KEY}
server:
address: ":3000"
read_timeout: "30s"
write_timeout: "30s"
shutdown_timeout: "30s"
prefork: false
# 数据库配置(必填项需通过环境变量设置)
database:
host: "" # 必填JUNHONG_DATABASE_HOST
port: 5432
user: "" # 必填JUNHONG_DATABASE_USER
password: "" # 必填JUNHONG_DATABASE_PASSWORD敏感
dbname: "" # 必填JUNHONG_DATABASE_DBNAME
sslmode: "disable"
max_open_conns: 25
max_idle_conns: 10
conn_max_lifetime: "5m"
# Redis 配置(必填项需通过环境变量设置)
redis:
address: "" # 必填JUNHONG_REDIS_ADDRESS
port: 6379
password: "" # 可选JUNHONG_REDIS_PASSWORD敏感
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
# 对象存储配置
storage:
provider: "s3"
temp_dir: "/tmp/junhong-storage"
s3:
endpoint: "" # 可选JUNHONG_STORAGE_S3_ENDPOINT
region: "" # 可选JUNHONG_STORAGE_S3_REGION
bucket: "" # 可选JUNHONG_STORAGE_S3_BUCKET
access_key_id: "" # 可选JUNHONG_STORAGE_S3_ACCESS_KEY_ID敏感
secret_access_key: "" # 可选JUNHONG_STORAGE_S3_SECRET_ACCESS_KEY敏感
use_ssl: false
path_style: true
presign:
upload_expires: "15m"
download_expires: "24h"
# 日志配置
logging:
level: "info"
development: false
app_log:
filename: "/app/logs/app.log"
max_size: 100
max_backups: 3
max_age: 7
compress: true
access_log:
filename: "/app/logs/access.log"
max_size: 100
max_backups: 3
max_age: 7
compress: true
# 任务队列配置
queue:
concurrency: 10
queues:
critical: 6
default: 3
low: 1
retry_max: 5
timeout: "10m"
# JWT 配置(必填项需通过环境变量设置)
jwt:
secret_key: "" # 必填JUNHONG_JWT_SECRET_KEY敏感
token_duration: "24h"
access_token_ttl: "24h"
refresh_token_ttl: "168h"
# 中间件配置
middleware:
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
# 短信服务配置
sms:
gateway_url: "" # 可选JUNHONG_SMS_GATEWAY_URL
username: "" # 可选JUNHONG_SMS_USERNAME
password: "" # 可选JUNHONG_SMS_PASSWORD敏感
signature: "" # 可选JUNHONG_SMS_SIGNATURE
timeout: "10s"
# 默认超级管理员配置(可选)
default_admin:
username: ""
password: ""
phone: ""

17
pkg/config/embedded.go Normal file
View File

@@ -0,0 +1,17 @@
package config
import (
"bytes"
"embed"
)
//go:embed defaults/config.yaml
var defaultConfigFS embed.FS
func getEmbeddedConfig() (*bytes.Reader, error) {
data, err := defaultConfigFS.ReadFile("defaults/config.yaml")
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}

View File

@@ -2,92 +2,120 @@ package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/spf13/viper"
)
// Load 从文件和环境变量加载配置
const envPrefix = "JUNHONG"
func Load() (*Config, error) {
// 确定配置路径
configPath := os.Getenv(constants.EnvConfigPath)
if configPath == "" {
configPath = constants.DefaultConfigPath
}
v := viper.New()
// 检查环境特定配置dev, staging, prod
configEnv := os.Getenv(constants.EnvConfigEnv)
if configEnv != "" {
// 优先尝试环境特定配置
envConfigPath := fmt.Sprintf("configs/config.%s.yaml", configEnv)
if _, err := os.Stat(envConfigPath); err == nil {
configPath = envConfigPath
}
}
// 设置 Viper
viper.SetConfigFile(configPath)
viper.SetConfigType("yaml")
// 启用环境变量覆盖
viper.AutomaticEnv()
viper.SetEnvPrefix("APP")
// 读取配置文件
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
// 反序列化到 Config 结构体
cfg := &Config{}
if err := viper.Unmarshal(cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// 验证配置
if err := cfg.Validate(); err != nil {
return nil, err
}
// 设为全局配置
globalConfig.Store(cfg)
return cfg, nil
}
// Reload 重新加载当前配置文件
func Reload() (*Config, error) {
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to reload config: %w", err)
}
cfg := &Config{}
if err := viper.Unmarshal(cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// 设置前验证
if err := cfg.Validate(); err != nil {
return nil, err
}
// 原子交换
globalConfig.Store(cfg)
return cfg, nil
}
// GetConfigPath 返回当前已加载配置文件的绝对路径
func GetConfigPath() string {
configFile := viper.ConfigFileUsed()
if configFile == "" {
return ""
}
absPath, err := filepath.Abs(configFile)
embeddedReader, err := getEmbeddedConfig()
if err != nil {
return configFile
return nil, fmt.Errorf("读取嵌入配置失败: %w", err)
}
v.SetConfigType("yaml")
if err := v.ReadConfig(embeddedReader); err != nil {
return nil, fmt.Errorf("解析嵌入配置失败: %w", err)
}
v.SetEnvPrefix(envPrefix)
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
bindEnvVariables(v)
cfg := &Config{}
if err := v.Unmarshal(cfg); err != nil {
return nil, fmt.Errorf("反序列化配置失败: %w", err)
}
if err := cfg.ValidateRequired(); err != nil {
return nil, err
}
if err := cfg.Validate(); err != nil {
return nil, err
}
globalConfig.Store(cfg)
return cfg, nil
}
func bindEnvVariables(v *viper.Viper) {
bindings := []string{
"server.address",
"server.read_timeout",
"server.write_timeout",
"server.shutdown_timeout",
"server.prefork",
"database.host",
"database.port",
"database.user",
"database.password",
"database.dbname",
"database.sslmode",
"database.max_open_conns",
"database.max_idle_conns",
"database.conn_max_lifetime",
"redis.address",
"redis.port",
"redis.password",
"redis.db",
"redis.pool_size",
"redis.min_idle_conns",
"redis.dial_timeout",
"redis.read_timeout",
"redis.write_timeout",
"storage.provider",
"storage.temp_dir",
"storage.s3.endpoint",
"storage.s3.region",
"storage.s3.bucket",
"storage.s3.access_key_id",
"storage.s3.secret_access_key",
"storage.s3.use_ssl",
"storage.s3.path_style",
"storage.presign.upload_expires",
"storage.presign.download_expires",
"logging.level",
"logging.development",
"logging.app_log.filename",
"logging.app_log.max_size",
"logging.app_log.max_backups",
"logging.app_log.max_age",
"logging.app_log.compress",
"logging.access_log.filename",
"logging.access_log.max_size",
"logging.access_log.max_backups",
"logging.access_log.max_age",
"logging.access_log.compress",
"queue.concurrency",
"queue.retry_max",
"queue.timeout",
"jwt.secret_key",
"jwt.token_duration",
"jwt.access_token_ttl",
"jwt.refresh_token_ttl",
"middleware.enable_rate_limiter",
"middleware.rate_limiter.max",
"middleware.rate_limiter.expiration",
"middleware.rate_limiter.storage",
"sms.gateway_url",
"sms.username",
"sms.password",
"sms.signature",
"sms.timeout",
"default_admin.username",
"default_admin.password",
"default_admin.phone",
}
for _, key := range bindings {
_ = v.BindEnv(key)
}
return absPath
}

View File

@@ -2,650 +2,219 @@ package config
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/spf13/viper"
)
// TestLoad tests the config loading functionality
func TestLoad(t *testing.T) {
func TestLoad_EmbeddedConfig(t *testing.T) {
clearEnvVars(t)
setRequiredEnvVars(t)
defer clearEnvVars(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() 失败: %v", err)
}
if cfg.Server.Address != ":3000" {
t.Errorf("server.address 期望 :3000, 实际 %s", cfg.Server.Address)
}
if cfg.Server.ReadTimeout != 30*time.Second {
t.Errorf("server.read_timeout 期望 30s, 实际 %v", cfg.Server.ReadTimeout)
}
if cfg.Logging.Level != "info" {
t.Errorf("logging.level 期望 info, 实际 %s", cfg.Logging.Level)
}
}
func TestLoad_EnvOverride(t *testing.T) {
clearEnvVars(t)
setRequiredEnvVars(t)
defer clearEnvVars(t)
os.Setenv("JUNHONG_SERVER_ADDRESS", ":8080")
os.Setenv("JUNHONG_LOGGING_LEVEL", "debug")
defer func() {
os.Unsetenv("JUNHONG_SERVER_ADDRESS")
os.Unsetenv("JUNHONG_LOGGING_LEVEL")
}()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() 失败: %v", err)
}
if cfg.Server.Address != ":8080" {
t.Errorf("server.address 期望 :8080, 实际 %s", cfg.Server.Address)
}
if cfg.Logging.Level != "debug" {
t.Errorf("logging.level 期望 debug, 实际 %s", cfg.Logging.Level)
}
}
func TestLoad_MissingRequired(t *testing.T) {
clearEnvVars(t)
defer clearEnvVars(t)
_, err := Load()
if err == nil {
t.Fatal("Load() 缺少必填配置时应返回错误")
}
expectedFields := []string{"database.host", "database.user", "database.password", "database.dbname", "redis.address", "jwt.secret_key"}
for _, field := range expectedFields {
if !containsString(err.Error(), field) {
t.Errorf("错误信息应包含 %q, 实际: %s", field, err.Error())
}
}
}
func TestLoad_PartialRequired(t *testing.T) {
clearEnvVars(t)
defer clearEnvVars(t)
os.Setenv("JUNHONG_DATABASE_HOST", "localhost")
os.Setenv("JUNHONG_DATABASE_USER", "user")
_, err := Load()
if err == nil {
t.Fatal("Load() 部分必填配置缺失时应返回错误")
}
if containsString(err.Error(), "database.host") {
t.Error("database.host 已设置,不应在错误信息中")
}
if containsString(err.Error(), "database.user") {
t.Error("database.user 已设置,不应在错误信息中")
}
if !containsString(err.Error(), "database.password") {
t.Error("database.password 未设置,应在错误信息中")
}
}
func TestLoad_GlobalConfig(t *testing.T) {
clearEnvVars(t)
setRequiredEnvVars(t)
defer clearEnvVars(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() 失败: %v", err)
}
globalCfg := Get()
if globalCfg == nil {
t.Fatal("Get() 返回 nil")
}
if globalCfg.Server.Address != cfg.Server.Address {
t.Errorf("全局配置与返回配置不一致")
}
}
func TestValidateRequired(t *testing.T) {
tests := []struct {
name string
setupEnv func()
cleanupEnv func()
createConfig func(t *testing.T) string
wantErr bool
validateFunc func(t *testing.T, cfg *Config)
name string
cfg *Config
wantErr bool
}{
{
name: "valid default config",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
_ = os.Setenv(constants.EnvConfigEnv, "")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
_ = os.Unsetenv(constants.EnvConfigEnv)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set as default config path
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
name: "all required set",
cfg: &Config{
Database: DatabaseConfig{
Host: "localhost",
User: "user",
Password: "pass",
DBName: "db",
},
Redis: RedisConfig{Address: "localhost"},
JWT: JWTConfig{SecretKey: "12345678901234567890123456789012"},
},
wantErr: false,
validateFunc: func(t *testing.T, cfg *Config) {
if cfg.Server.Address != ":3000" {
t.Errorf("expected server.address :3000, got %s", cfg.Server.Address)
}
if cfg.Server.ReadTimeout != 10*time.Second {
t.Errorf("expected read_timeout 10s, got %v", cfg.Server.ReadTimeout)
}
if cfg.Redis.Address != "localhost" {
t.Errorf("expected redis.address localhost, got %s", cfg.Redis.Address)
}
if cfg.Redis.Port != 6379 {
t.Errorf("expected redis.port 6379, got %d", cfg.Redis.Port)
}
if cfg.Redis.PoolSize != 10 {
t.Errorf("expected redis.pool_size 10, got %d", cfg.Redis.PoolSize)
}
if cfg.Logging.Level != "info" {
t.Errorf("expected logging.level info, got %s", cfg.Logging.Level)
}
},
},
{
name: "environment-specific config (dev)",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigEnv, "dev")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigEnv)
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
// Create configs directory in temp
tmpDir := t.TempDir()
configsDir := filepath.Join(tmpDir, "configs")
if err := os.MkdirAll(configsDir, 0755); err != nil {
t.Fatalf("failed to create configs dir: %v", err)
}
// Create dev config
devConfigFile := filepath.Join(configsDir, "config.dev.yaml")
content := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 1
pool_size: 5
min_idle_conns: 2
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 50
max_backups: 10
max_age: 7
compress: false
access_log:
filename: "logs/access.log"
max_size: 100
max_backups: 30
max_age: 30
compress: false
middleware:
enable_rate_limiter: false
rate_limiter:
max: 50
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(devConfigFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create dev config file: %v", err)
}
// Change to tmpDir so relative path works
originalWd, _ := os.Getwd()
_ = os.Chdir(tmpDir)
t.Cleanup(func() { _ = os.Chdir(originalWd) })
return devConfigFile
},
wantErr: false,
validateFunc: func(t *testing.T, cfg *Config) {
if cfg.Server.Address != ":8080" {
t.Errorf("expected server.address :8080, got %s", cfg.Server.Address)
}
if cfg.Redis.DB != 1 {
t.Errorf("expected redis.db 1, got %d", cfg.Redis.DB)
}
if cfg.Logging.Level != "debug" {
t.Errorf("expected logging.level debug, got %s", cfg.Logging.Level)
}
name: "missing database host",
cfg: &Config{
Database: DatabaseConfig{
User: "user",
Password: "pass",
DBName: "db",
},
Redis: RedisConfig{Address: "localhost"},
JWT: JWTConfig{SecretKey: "12345678901234567890123456789012"},
},
wantErr: true,
},
{
name: "invalid YAML syntax",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
_ = os.Setenv(constants.EnvConfigEnv, "")
name: "missing redis address",
cfg: &Config{
Database: DatabaseConfig{
Host: "localhost",
User: "user",
Password: "pass",
DBName: "db",
},
Redis: RedisConfig{},
JWT: JWTConfig{SecretKey: "12345678901234567890123456789012"},
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
_ = os.Unsetenv(constants.EnvConfigEnv)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
invalid yaml syntax here!!!
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
wantErr: true,
},
{
name: "validation error - invalid server address",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
name: "missing jwt secret",
cfg: &Config{
Database: DatabaseConfig{
Host: "localhost",
User: "user",
Password: "pass",
DBName: "db",
},
Redis: RedisConfig{Address: "localhost"},
JWT: JWTConfig{},
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ""
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - timeout out of range",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "1s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - invalid redis port",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 99999
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset viper for each test
viper.Reset()
// Setup environment
if tt.setupEnv != nil {
tt.setupEnv()
}
// Create config file
if tt.createConfig != nil {
tt.createConfig(t)
}
// Cleanup after test
if tt.cleanupEnv != nil {
defer tt.cleanupEnv()
}
// Load config
cfg, err := Load()
// Check error expectation
err := tt.cfg.ValidateRequired()
if (err != nil) != tt.wantErr {
t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr)
return
}
// Validate config if no error expected
if !tt.wantErr && tt.validateFunc != nil {
tt.validateFunc(t, cfg)
t.Errorf("ValidateRequired() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
// TestReload tests the config reload functionality
func TestReload(t *testing.T) {
// Reset viper
viper.Reset()
func setRequiredEnvVars(t *testing.T) {
t.Helper()
os.Setenv("JUNHONG_DATABASE_HOST", "localhost")
os.Setenv("JUNHONG_DATABASE_USER", "testuser")
os.Setenv("JUNHONG_DATABASE_PASSWORD", "testpass")
os.Setenv("JUNHONG_DATABASE_DBNAME", "testdb")
os.Setenv("JUNHONG_REDIS_ADDRESS", "localhost")
os.Setenv("JUNHONG_JWT_SECRET_KEY", "12345678901234567890123456789012")
}
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial config
initialContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
func clearEnvVars(t *testing.T) {
t.Helper()
envVars := []string{
"JUNHONG_DATABASE_HOST",
"JUNHONG_DATABASE_PORT",
"JUNHONG_DATABASE_USER",
"JUNHONG_DATABASE_PASSWORD",
"JUNHONG_DATABASE_DBNAME",
"JUNHONG_REDIS_ADDRESS",
"JUNHONG_REDIS_PORT",
"JUNHONG_REDIS_PASSWORD",
"JUNHONG_JWT_SECRET_KEY",
"JUNHONG_SERVER_ADDRESS",
"JUNHONG_LOGGING_LEVEL",
}
// Set config path
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
// Verify initial values
if cfg.Logging.Level != "info" {
t.Errorf("expected initial logging.level info, got %s", cfg.Logging.Level)
}
if cfg.Server.Address != ":3000" {
t.Errorf("expected initial server.address :3000, got %s", cfg.Server.Address)
}
// Modify config file
updatedContent := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 20
min_idle_conns: 10
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: false
enable_rate_limiter: true
rate_limiter:
max: 200
expiration: "2m"
storage: "redis"
`
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
t.Fatalf("failed to update config file: %v", err)
}
// Reload config
newCfg, err := Reload()
if err != nil {
t.Fatalf("failed to reload config: %v", err)
}
// Verify updated values
if newCfg.Logging.Level != "debug" {
t.Errorf("expected updated logging.level debug, got %s", newCfg.Logging.Level)
}
if newCfg.Server.Address != ":8080" {
t.Errorf("expected updated server.address :8080, got %s", newCfg.Server.Address)
}
if newCfg.Redis.PoolSize != 20 {
t.Errorf("expected updated redis.pool_size 20, got %d", newCfg.Redis.PoolSize)
}
if newCfg.Middleware.EnableRateLimiter != true {
t.Errorf("expected updated enable_rate_limiter true, got %v", newCfg.Middleware.EnableRateLimiter)
}
// Verify global config was updated
globalCfg := Get()
if globalCfg.Logging.Level != "debug" {
t.Errorf("expected global config updated, got logging.level %s", globalCfg.Logging.Level)
for _, v := range envVars {
os.Unsetenv(v)
}
}
// TestGetConfigPath tests the GetConfigPath function
func TestGetConfigPath(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load config
_, err := Load()
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
// Get config path
path := GetConfigPath()
if path == "" {
t.Error("expected non-empty config path")
}
// Verify it's an absolute path
if !filepath.IsAbs(path) {
t.Errorf("expected absolute path, got %s", path)
}
func containsString(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && (s[:len(substr)] == substr || containsString(s[1:], substr)))
}

View File

@@ -1,43 +0,0 @@
package config
import (
"context"
"github.com/fsnotify/fsnotify"
"github.com/spf13/viper"
"go.uber.org/zap"
)
// Watch 监听配置文件变化
// 运行直到上下文被取消
func Watch(ctx context.Context, logger *zap.Logger) {
viper.WatchConfig()
viper.OnConfigChange(func(e fsnotify.Event) {
select {
case <-ctx.Done():
return // 如果上下文被取消则停止处理
default:
logger.Info("配置文件已更改", zap.String("file", e.Name))
// 尝试重新加载
newConfig, err := Reload()
if err != nil {
logger.Error("重新加载配置失败,保留先前配置",
zap.Error(err),
zap.String("file", e.Name),
)
return
}
logger.Info("配置重新加载成功",
zap.String("file", e.Name),
zap.String("server_address", newConfig.Server.Address),
zap.String("log_level", newConfig.Logging.Level),
)
}
})
// 阻塞直到上下文被取消
<-ctx.Done()
logger.Info("配置监听器已停止")
}

View File

@@ -1,422 +0,0 @@
package config
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/spf13/viper"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
)
// TestWatch tests the config hot reload watcher
func TestWatch(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial config
initialContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set config path
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
// Verify initial values
if cfg.Logging.Level != "info" {
t.Fatalf("expected initial logging.level info, got %s", cfg.Logging.Level)
}
// Create logger for testing
logger := zaptest.NewLogger(t)
// Start watcher in goroutine with context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go Watch(ctx, logger)
// Give watcher time to initialize
time.Sleep(100 * time.Millisecond)
// Modify config file to trigger hot reload
updatedContent := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 20
min_idle_conns: 10
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: false
enable_rate_limiter: true
rate_limiter:
max: 200
expiration: "2m"
storage: "redis"
`
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
t.Fatalf("failed to update config file: %v", err)
}
// Wait for watcher to detect and process changes (spec requires detection within 5 seconds)
// We use a more aggressive timeout for testing
time.Sleep(2 * time.Second)
// Verify config was reloaded
reloadedCfg := Get()
if reloadedCfg.Logging.Level != "debug" {
t.Errorf("expected config hot reload, got logging.level %s instead of debug", reloadedCfg.Logging.Level)
}
if reloadedCfg.Server.Address != ":8080" {
t.Errorf("expected config hot reload, got server.address %s instead of :8080", reloadedCfg.Server.Address)
}
if reloadedCfg.Redis.PoolSize != 20 {
t.Errorf("expected config hot reload, got redis.pool_size %d instead of 20", reloadedCfg.Redis.PoolSize)
}
// Cancel context to stop watcher
cancel()
// Give watcher time to shut down gracefully
time.Sleep(100 * time.Millisecond)
}
// TestWatch_InvalidConfigRejected tests that invalid config changes are rejected
func TestWatch_InvalidConfigRejected(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial valid config
validContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set config path
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
initialLevel := cfg.Logging.Level
if initialLevel != "info" {
t.Fatalf("expected initial logging.level info, got %s", initialLevel)
}
// Create logger for testing
logger := zaptest.NewLogger(t)
// Start watcher
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go Watch(ctx, logger)
// Give watcher time to initialize
time.Sleep(100 * time.Millisecond)
// Write INVALID config (malformed YAML)
invalidContent := `
server:
address: ":3000"
invalid yaml syntax here!!!
`
if err := os.WriteFile(configFile, []byte(invalidContent), 0644); err != nil {
t.Fatalf("failed to write invalid config: %v", err)
}
// Wait for watcher to detect changes
time.Sleep(2 * time.Second)
// Verify config was NOT changed (should keep previous valid config)
currentCfg := Get()
if currentCfg.Logging.Level != initialLevel {
t.Errorf("expected config to remain unchanged after invalid update, got logging.level %s instead of %s", currentCfg.Logging.Level, initialLevel)
}
// Restore valid config
if err := os.WriteFile(configFile, []byte(validContent), 0644); err != nil {
t.Fatalf("failed to restore valid config: %v", err)
}
time.Sleep(500 * time.Millisecond)
// Now write config with validation error (timeout out of range)
invalidValidationContent := `
server:
address: ":3000"
read_timeout: "1s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "debug"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(invalidValidationContent), 0644); err != nil {
t.Fatalf("failed to write config with validation error: %v", err)
}
// Wait for watcher to detect changes
time.Sleep(2 * time.Second)
// Verify config was NOT changed (validation should have failed)
finalCfg := Get()
if finalCfg.Logging.Level != initialLevel {
t.Errorf("expected config to remain unchanged after validation error, got logging.level %s instead of %s", finalCfg.Logging.Level, initialLevel)
}
if finalCfg.Server.ReadTimeout == 1*time.Second {
t.Error("expected config to remain unchanged, but read_timeout was updated to invalid value")
}
// Cancel context
cancel()
time.Sleep(100 * time.Millisecond)
}
// TestWatch_ContextCancellation tests graceful shutdown on context cancellation
func TestWatch_ContextCancellation(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load config
_, err := Load()
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
// Create logger
logger := zap.NewNop() // Use no-op logger for this test
// Start watcher with context
ctx, cancel := context.WithCancel(context.Background())
done := make(chan bool)
go func() {
Watch(ctx, logger)
done <- true
}()
// Give watcher time to start
time.Sleep(100 * time.Millisecond)
// Cancel context (simulate graceful shutdown)
cancel()
// Wait for watcher to stop (should happen quickly)
select {
case <-done:
// Watcher stopped successfully
case <-time.After(2 * time.Second):
t.Error("watcher did not stop within timeout after context cancellation")
}
}

View File

@@ -68,6 +68,14 @@ const (
CodeCannotAllocateToSelf = 1075 // 不能分配给自己
CodeCannotRecallFromSelf = 1076 // 不能从自己回收
// 对象存储相关错误 (1090-1099)
CodeStorageNotConfigured = 1090 // 对象存储服务未配置
CodeStorageUploadFailed = 1091 // 文件上传失败
CodeStorageDownloadFailed = 1092 // 文件下载失败
CodeStorageFileNotFound = 1093 // 文件不存在
CodeStorageInvalidPurpose = 1094 // 不支持的文件用途
CodeStorageInvalidFileType = 1095 // 不支持的文件类型
// 服务端错误 (2000-2999) -> 5xx HTTP 状态码
CodeInternalError = 2001 // 内部服务器错误
CodeDatabaseError = 2002 // 数据库错误
@@ -130,6 +138,12 @@ var allErrorCodes = []int{
CodeNotDirectSubordinate,
CodeCannotAllocateToSelf,
CodeCannotRecallFromSelf,
CodeStorageNotConfigured,
CodeStorageUploadFailed,
CodeStorageDownloadFailed,
CodeStorageFileNotFound,
CodeStorageInvalidPurpose,
CodeStorageInvalidFileType,
CodeInternalError,
CodeDatabaseError,
CodeRedisError,
@@ -194,6 +208,12 @@ var errorMessages = map[int]string{
CodeNotDirectSubordinate: "只能操作直属下级店铺",
CodeCannotAllocateToSelf: "不能分配给自己",
CodeCannotRecallFromSelf: "不能从自己回收",
CodeStorageNotConfigured: "对象存储服务未配置",
CodeStorageUploadFailed: "文件上传失败",
CodeStorageDownloadFailed: "文件下载失败",
CodeStorageFileNotFound: "文件不存在",
CodeStorageInvalidPurpose: "不支持的文件用途",
CodeStorageInvalidFileType: "不支持的文件类型",
CodeInvalidCredentials: "用户名或密码错误",
CodeAccountLocked: "账号已锁定",
CodePasswordExpired: "密码已过期",

View File

@@ -4,6 +4,9 @@ import (
"encoding/json"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"github.com/swaggest/openapi-go/openapi3"
"gopkg.in/yaml.v3"
@@ -85,26 +88,37 @@ func (g *Generator) addErrorResponseSchema() {
g.Reflector.Spec.ComponentsEns().SchemasEns().WithMapOfSchemaOrRefValuesItem("ErrorResponse", errorSchema)
}
// ptrString 返回字符串指针
func ptrString(s string) *string {
return &s
}
// FileUploadField 定义文件上传字段
type FileUploadField struct {
Name string
Description string
Required bool
}
// AddOperation 向 OpenAPI 规范中添加一个操作
// 参数:
// - method: HTTP 方法GET, POST, PUT, DELETE 等)
// - path: API 路径
// - summary: 操作摘要
// - description: 详细说明,支持 Markdown 语法(可为空)
// - input: 请求参数结构体(可为 nil
// - output: 响应结构体(可为 nil
// - tags: 标签列表
// - requiresAuth: 是否需要认证
func (g *Generator) AddOperation(method, path, summary string, input interface{}, output interface{}, requiresAuth bool, tags ...string) {
func (g *Generator) AddOperation(method, path, summary, description string, input interface{}, output interface{}, requiresAuth bool, tags ...string) {
op := openapi3.Operation{
Summary: &summary,
Tags: tags,
}
if description != "" {
op.Description = &description
}
// 反射输入 (请求参数/Body)
if input != nil {
// SetRequest 根据结构体标签自动检测 Body、Query 或 Path 参数
@@ -134,6 +148,166 @@ func (g *Generator) AddOperation(method, path, summary string, input interface{}
}
}
// AddMultipartOperation 添加支持文件上传的 multipart/form-data 操作
func (g *Generator) AddMultipartOperation(method, path, summary, description string, input interface{}, output interface{}, requiresAuth bool, fileFields []FileUploadField, tags ...string) {
op := openapi3.Operation{
Summary: &summary,
Tags: tags,
}
if description != "" {
op.Description = &description
}
objectType := openapi3.SchemaType("object")
stringType := openapi3.SchemaType("string")
integerType := openapi3.SchemaType("integer")
binaryFormat := "binary"
properties := make(map[string]openapi3.SchemaOrRef)
var requiredFields []string
for _, f := range fileFields {
properties[f.Name] = openapi3.SchemaOrRef{
Schema: &openapi3.Schema{
Type: &stringType,
Format: &binaryFormat,
Description: ptrString(f.Description),
},
}
if f.Required {
requiredFields = append(requiredFields, f.Name)
}
}
if input != nil {
formFields := parseFormFields(input)
for _, field := range formFields {
var schemaType *openapi3.SchemaType
switch field.Type {
case "integer":
schemaType = &integerType
default:
schemaType = &stringType
}
schema := &openapi3.Schema{
Type: schemaType,
Description: ptrString(field.Description),
}
if field.Min != nil {
schema.Minimum = field.Min
}
if field.MaxLength != nil {
schema.MaxLength = field.MaxLength
}
properties[field.Name] = openapi3.SchemaOrRef{Schema: schema}
if field.Required {
requiredFields = append(requiredFields, field.Name)
}
}
}
op.RequestBody = &openapi3.RequestBodyOrRef{
RequestBody: &openapi3.RequestBody{
Required: ptrBool(true),
Content: map[string]openapi3.MediaType{
"multipart/form-data": {
Schema: &openapi3.SchemaOrRef{
Schema: &openapi3.Schema{
Type: &objectType,
Properties: properties,
Required: requiredFields,
},
},
},
},
},
}
if output != nil {
if err := g.Reflector.SetJSONResponse(&op, output, 200); err != nil {
panic(err)
}
}
if requiresAuth {
g.addSecurityRequirement(&op)
}
g.addStandardErrorResponses(&op, requiresAuth)
if err := g.Reflector.Spec.AddOperation(method, path, op); err != nil {
panic(err)
}
}
func ptrBool(b bool) *bool {
return &b
}
type formFieldInfo struct {
Name string
Type string
Description string
Required bool
Min *float64
MaxLength *int64
}
func parseFormFields(input interface{}) []formFieldInfo {
var fields []formFieldInfo
t := reflect.TypeOf(input)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return fields
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
formTag := field.Tag.Get("form")
if formTag == "" || formTag == "-" {
continue
}
info := formFieldInfo{
Name: formTag,
Description: field.Tag.Get("description"),
}
switch field.Type.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
info.Type = "integer"
default:
info.Type = "string"
}
validateTag := field.Tag.Get("validate")
if strings.Contains(validateTag, "required") {
info.Required = true
}
if minStr := field.Tag.Get("minimum"); minStr != "" {
if min, err := strconv.ParseFloat(minStr, 64); err == nil {
info.Min = &min
}
}
if maxLenStr := field.Tag.Get("maxLength"); maxLenStr != "" {
if maxLen, err := strconv.ParseInt(maxLenStr, 10, 64); err == nil {
info.MaxLength = &maxLen
}
}
fields = append(fields, info)
}
return fields
}
// addSecurityRequirement 为操作添加认证要求
func (g *Generator) addSecurityRequirement(op *openapi3.Operation) {
op.Security = []map[string][]string{

View File

@@ -9,23 +9,24 @@ import (
"github.com/break/junhong_cmp_fiber/internal/store/postgres"
"github.com/break/junhong_cmp_fiber/internal/task"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/storage"
)
// Handler 任务处理器注册
type Handler struct {
mux *asynq.ServeMux
logger *zap.Logger
db *gorm.DB
redis *redis.Client
mux *asynq.ServeMux
logger *zap.Logger
db *gorm.DB
redis *redis.Client
storage *storage.Service
}
// NewHandler 创建任务处理器
func NewHandler(db *gorm.DB, redis *redis.Client, logger *zap.Logger) *Handler {
func NewHandler(db *gorm.DB, redis *redis.Client, storageSvc *storage.Service, logger *zap.Logger) *Handler {
return &Handler{
mux: asynq.NewServeMux(),
logger: logger,
db: db,
redis: redis,
mux: asynq.NewServeMux(),
logger: logger,
db: db,
redis: redis,
storage: storageSvc,
}
}
@@ -53,7 +54,7 @@ func (h *Handler) RegisterHandlers() *asynq.ServeMux {
func (h *Handler) registerIotCardImportHandler() {
importTaskStore := postgres.NewIotCardImportTaskStore(h.db, h.redis)
iotCardStore := postgres.NewIotCardStore(h.db, h.redis)
iotCardImportHandler := task.NewIotCardImportHandler(h.db, h.redis, importTaskStore, iotCardStore, h.logger)
iotCardImportHandler := task.NewIotCardImportHandler(h.db, h.redis, importTaskStore, iotCardStore, h.storage, h.logger)
h.mux.HandleFunc(constants.TaskTypeIotCardImport, iotCardImportHandler.HandleIotCardImport)
h.logger.Info("注册 IoT 卡导入任务处理器", zap.String("task_type", constants.TaskTypeIotCardImport))

184
pkg/storage/s3.go Normal file
View File

@@ -0,0 +1,184 @@
package storage
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/break/junhong_cmp_fiber/pkg/config"
)
type S3Provider struct {
client *s3.S3
uploader *s3manager.Uploader
bucket string
tempDir string
}
func NewS3Provider(cfg *config.StorageConfig) (*S3Provider, error) {
if cfg.S3.Endpoint == "" || cfg.S3.Bucket == "" {
return nil, fmt.Errorf("S3 配置不完整endpoint 和 bucket 必填")
}
if cfg.S3.AccessKeyID == "" || cfg.S3.SecretAccessKey == "" {
return nil, fmt.Errorf("S3 凭证未配置access_key_id 和 secret_access_key 必填")
}
sess, err := session.NewSession(&aws.Config{
Endpoint: aws.String(cfg.S3.Endpoint),
Region: aws.String(cfg.S3.Region),
Credentials: credentials.NewStaticCredentials(cfg.S3.AccessKeyID, cfg.S3.SecretAccessKey, ""),
DisableSSL: aws.Bool(!cfg.S3.UseSSL),
S3ForcePathStyle: aws.Bool(cfg.S3.PathStyle),
})
if err != nil {
return nil, fmt.Errorf("创建 S3 session 失败: %w", err)
}
tempDir := cfg.TempDir
if tempDir == "" {
tempDir = "/tmp/junhong-storage"
}
return &S3Provider{
client: s3.New(sess),
uploader: s3manager.NewUploader(sess),
bucket: cfg.S3.Bucket,
tempDir: tempDir,
}, nil
}
func (p *S3Provider) Upload(ctx context.Context, key string, reader io.Reader, contentType string) error {
input := &s3manager.UploadInput{
Bucket: aws.String(p.bucket),
Key: aws.String(key),
Body: reader,
ContentType: aws.String(contentType),
}
_, err := p.uploader.UploadWithContext(ctx, input)
if err != nil {
return fmt.Errorf("上传文件失败: %w", err)
}
return nil
}
func (p *S3Provider) Download(ctx context.Context, key string, writer io.Writer) error {
input := &s3.GetObjectInput{
Bucket: aws.String(p.bucket),
Key: aws.String(key),
}
result, err := p.client.GetObjectWithContext(ctx, input)
if err != nil {
if strings.Contains(err.Error(), "NoSuchKey") {
return fmt.Errorf("文件不存在: %s", key)
}
return fmt.Errorf("下载文件失败: %w", err)
}
defer result.Body.Close()
_, err = io.Copy(writer, result.Body)
if err != nil {
return fmt.Errorf("写入文件内容失败: %w", err)
}
return nil
}
func (p *S3Provider) DownloadToTemp(ctx context.Context, key string) (string, func(), error) {
ext := filepath.Ext(key)
if ext == "" {
ext = ".tmp"
}
tempFile, err := os.CreateTemp(p.tempDir, "download-*"+ext)
if err != nil {
return "", nil, fmt.Errorf("创建临时文件失败: %w", err)
}
tempPath := tempFile.Name()
cleanup := func() {
_ = os.Remove(tempPath)
}
if err := p.Download(ctx, key, tempFile); err != nil {
tempFile.Close()
cleanup()
return "", nil, err
}
if err := tempFile.Close(); err != nil {
cleanup()
return "", nil, fmt.Errorf("关闭临时文件失败: %w", err)
}
return tempPath, cleanup, nil
}
func (p *S3Provider) Delete(ctx context.Context, key string) error {
input := &s3.DeleteObjectInput{
Bucket: aws.String(p.bucket),
Key: aws.String(key),
}
_, err := p.client.DeleteObjectWithContext(ctx, input)
if err != nil {
return fmt.Errorf("删除文件失败: %w", err)
}
return nil
}
func (p *S3Provider) Exists(ctx context.Context, key string) (bool, error) {
input := &s3.HeadObjectInput{
Bucket: aws.String(p.bucket),
Key: aws.String(key),
}
_, err := p.client.HeadObjectWithContext(ctx, input)
if err != nil {
if strings.Contains(err.Error(), "NotFound") || strings.Contains(err.Error(), "404") {
return false, nil
}
return false, fmt.Errorf("检查文件存在性失败: %w", err)
}
return true, nil
}
func (p *S3Provider) GetUploadURL(ctx context.Context, key string, contentType string, expires time.Duration) (string, error) {
input := &s3.PutObjectInput{
Bucket: aws.String(p.bucket),
Key: aws.String(key),
ContentType: aws.String(contentType),
}
req, _ := p.client.PutObjectRequest(input)
url, err := req.Presign(expires)
if err != nil {
return "", fmt.Errorf("生成上传预签名 URL 失败: %w", err)
}
return url, nil
}
func (p *S3Provider) GetDownloadURL(ctx context.Context, key string, expires time.Duration) (string, error) {
input := &s3.GetObjectInput{
Bucket: aws.String(p.bucket),
Key: aws.String(key),
}
req, _ := p.client.GetObjectRequest(input)
url, err := req.Presign(expires)
if err != nil {
return "", fmt.Errorf("生成下载预签名 URL 失败: %w", err)
}
return url, nil
}

112
pkg/storage/service.go Normal file
View File

@@ -0,0 +1,112 @@
package storage
import (
"context"
"fmt"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/break/junhong_cmp_fiber/pkg/config"
)
type Service struct {
provider Provider
config *config.StorageConfig
}
func NewService(provider Provider, cfg *config.StorageConfig) *Service {
return &Service{
provider: provider,
config: cfg,
}
}
func (s *Service) GenerateFileKey(purpose, fileName string) (string, error) {
mapping, ok := PurposeMappings[purpose]
if !ok {
return "", fmt.Errorf("不支持的文件用途: %s", purpose)
}
ext := filepath.Ext(fileName)
if ext == "" {
ext = ".bin"
}
now := time.Now()
id := uuid.New().String()
key := fmt.Sprintf("%s/%04d/%02d/%02d/%s%s",
mapping.Prefix,
now.Year(),
now.Month(),
now.Day(),
id,
strings.ToLower(ext),
)
return key, nil
}
func (s *Service) GetUploadURL(ctx context.Context, purpose, fileName, contentType string) (*PresignResult, error) {
fileKey, err := s.GenerateFileKey(purpose, fileName)
if err != nil {
return nil, err
}
if contentType == "" {
if mapping, ok := PurposeMappings[purpose]; ok && mapping.ContentType != "" {
contentType = mapping.ContentType
} else {
contentType = "application/octet-stream"
}
}
expires := s.config.Presign.UploadExpires
if expires == 0 {
expires = 15 * time.Minute
}
url, err := s.provider.GetUploadURL(ctx, fileKey, contentType, expires)
if err != nil {
return nil, err
}
return &PresignResult{
URL: url,
FileKey: fileKey,
ExpiresIn: int(expires.Seconds()),
}, nil
}
func (s *Service) GetDownloadURL(ctx context.Context, fileKey string) (*PresignResult, error) {
expires := s.config.Presign.DownloadExpires
if expires == 0 {
expires = 24 * time.Hour
}
url, err := s.provider.GetDownloadURL(ctx, fileKey, expires)
if err != nil {
return nil, err
}
return &PresignResult{
URL: url,
FileKey: fileKey,
ExpiresIn: int(expires.Seconds()),
}, nil
}
func (s *Service) DownloadToTemp(ctx context.Context, fileKey string) (string, func(), error) {
return s.provider.DownloadToTemp(ctx, fileKey)
}
func (s *Service) Provider() Provider {
return s.provider
}
func (s *Service) Bucket() string {
return s.config.S3.Bucket
}

17
pkg/storage/storage.go Normal file
View File

@@ -0,0 +1,17 @@
package storage
import (
"context"
"io"
"time"
)
type Provider interface {
Upload(ctx context.Context, key string, reader io.Reader, contentType string) error
Download(ctx context.Context, key string, writer io.Writer) error
DownloadToTemp(ctx context.Context, key string) (localPath string, cleanup func(), err error)
Delete(ctx context.Context, key string) error
Exists(ctx context.Context, key string) (bool, error)
GetUploadURL(ctx context.Context, key string, contentType string, expires time.Duration) (string, error)
GetDownloadURL(ctx context.Context, key string, expires time.Duration) (string, error)
}

18
pkg/storage/types.go Normal file
View File

@@ -0,0 +1,18 @@
package storage
type PresignResult struct {
URL string `json:"url"`
FileKey string `json:"file_key"`
ExpiresIn int `json:"expires_in"`
}
type PurposeMapping struct {
Prefix string
ContentType string
}
var PurposeMappings = map[string]PurposeMapping{
"iot_import": {Prefix: "imports", ContentType: "text/csv"},
"export": {Prefix: "exports", ContentType: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"},
"attachment": {Prefix: "attachments", ContentType: ""},
}

View File

@@ -2,33 +2,49 @@ package utils
import (
"encoding/csv"
"errors"
"io"
"strings"
)
// CardInfo 卡信息ICCID + MSISDN
type CardInfo struct {
ICCID string
MSISDN string
}
// CSVParseResult CSV 解析结果
type CSVParseResult struct {
ICCIDs []string
Cards []CardInfo
TotalCount int
ParseErrors []CSVParseError
}
// CSVParseError CSV 解析错误
type CSVParseError struct {
Line int
ICCID string
MSISDN string
Reason string
}
func ParseICCIDFromCSV(reader io.Reader) (*CSVParseResult, error) {
// ErrInvalidCSVFormat CSV 格式错误
var ErrInvalidCSVFormat = errors.New("CSV 文件格式错误:缺少 MSISDN 列,文件必须包含 ICCID 和 MSISDN 两列")
// ParseCardCSV 解析包含 ICCID 和 MSISDN 两列的 CSV 文件
func ParseCardCSV(reader io.Reader) (*CSVParseResult, error) {
csvReader := csv.NewReader(reader)
csvReader.FieldsPerRecord = -1
csvReader.TrimLeadingSpace = true
result := &CSVParseResult{
ICCIDs: make([]string, 0),
Cards: make([]CardInfo, 0),
ParseErrors: make([]CSVParseError, 0),
}
lineNum := 0
headerSkipped := false
for {
record, err := csvReader.Read()
if err == io.EOF {
@@ -48,23 +64,69 @@ func ParseICCIDFromCSV(reader io.Reader) (*CSVParseResult, error) {
continue
}
iccid := strings.TrimSpace(record[0])
if iccid == "" {
if len(record) < 2 {
if lineNum == 1 && !headerSkipped {
firstCol := strings.TrimSpace(record[0])
if isICCIDHeader(firstCol) {
return nil, ErrInvalidCSVFormat
}
}
result.ParseErrors = append(result.ParseErrors, CSVParseError{
Line: lineNum,
ICCID: strings.TrimSpace(record[0]),
Reason: "列数不足:缺少 MSISDN 列",
})
result.TotalCount++
continue
}
if lineNum == 1 && isHeader(iccid) {
iccid := strings.TrimSpace(record[0])
msisdn := strings.TrimSpace(record[1])
if lineNum == 1 && !headerSkipped && isHeader(iccid, msisdn) {
headerSkipped = true
continue
}
result.TotalCount++
result.ICCIDs = append(result.ICCIDs, iccid)
if iccid == "" {
result.ParseErrors = append(result.ParseErrors, CSVParseError{
Line: lineNum,
MSISDN: msisdn,
Reason: "ICCID 不能为空",
})
continue
}
if msisdn == "" {
result.ParseErrors = append(result.ParseErrors, CSVParseError{
Line: lineNum,
ICCID: iccid,
Reason: "MSISDN 不能为空",
})
continue
}
result.Cards = append(result.Cards, CardInfo{
ICCID: iccid,
MSISDN: msisdn,
})
}
return result, nil
}
func isHeader(value string) bool {
func isICCIDHeader(value string) bool {
lower := strings.ToLower(value)
return lower == "iccid" || lower == "卡号" || lower == "号码"
}
func isMSISDNHeader(value string) bool {
lower := strings.ToLower(value)
return lower == "msisdn" || lower == "接入号" || lower == "手机号" || lower == "电话" || lower == "号码"
}
func isHeader(col1, col2 string) bool {
return isICCIDHeader(col1) && isMSISDNHeader(col2)
}

View File

@@ -8,89 +8,133 @@ import (
"github.com/stretchr/testify/require"
)
func TestParseICCIDFromCSV(t *testing.T) {
func TestParseCardCSV(t *testing.T) {
tests := []struct {
name string
csvContent string
wantICCIDs []string
wantCards []CardInfo
wantTotalCount int
wantErrorCount int
wantError error
}{
{
name: "单列ICCID无表头",
csvContent: "89860012345678901234\n89860012345678901235\n89860012345678901236",
wantICCIDs: []string{"89860012345678901234", "89860012345678901235", "89860012345678901236"},
wantTotalCount: 3,
wantErrorCount: 0,
},
{
name: "单列ICCID有表头-iccid",
csvContent: "iccid\n89860012345678901234\n89860012345678901235",
wantICCIDs: []string{"89860012345678901234", "89860012345678901235"},
name: "标准双列无表头",
csvContent: "89860012345678901234,13800000001\n89860012345678901235,13800000002",
wantCards: []CardInfo{
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
{ICCID: "89860012345678901235", MSISDN: "13800000002"},
},
wantTotalCount: 2,
wantErrorCount: 0,
},
{
name: "单列ICCID有表头-ICCID大写",
csvContent: "ICCID\n89860012345678901234",
wantICCIDs: []string{"89860012345678901234"},
name: "标准双列有表头-英文",
csvContent: "iccid,msisdn\n89860012345678901234,13800000001\n89860012345678901235,13800000002",
wantCards: []CardInfo{
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
{ICCID: "89860012345678901235", MSISDN: "13800000002"},
},
wantTotalCount: 2,
wantErrorCount: 0,
},
{
name: "标准双列有表头-中文",
csvContent: "卡号,接入号\n89860012345678901234,13800000001",
wantCards: []CardInfo{
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
},
wantTotalCount: 1,
wantErrorCount: 0,
},
{
name: "单列ICCID有表头-号",
csvContent: "卡号\n89860012345678901234",
wantICCIDs: []string{"89860012345678901234"},
name: "标准双列有表头-手机号",
csvContent: "ICCID,手机号\n89860012345678901234,13800000001",
wantCards: []CardInfo{
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
},
wantTotalCount: 1,
wantErrorCount: 0,
},
{
name: "单列ICCID有表头-号码",
csvContent: "号码\n89860012345678901234",
wantICCIDs: []string{"89860012345678901234"},
wantTotalCount: 1,
name: "单列CSV格式拒绝-有表头",
csvContent: "iccid\n89860012345678901234",
wantCards: nil,
wantTotalCount: 0,
wantErrorCount: 0,
wantError: ErrInvalidCSVFormat,
},
{
name: "单列CSV格式-无表头记录错误",
csvContent: "89860012345678901234\n89860012345678901235",
wantCards: []CardInfo{},
wantTotalCount: 2,
wantErrorCount: 2,
},
{
name: "MSISDN为空记录失败",
csvContent: "iccid,msisdn\n89860012345678901234,13800000001\n89860012345678901235,",
wantCards: []CardInfo{{ICCID: "89860012345678901234", MSISDN: "13800000001"}},
wantTotalCount: 2,
wantErrorCount: 1,
},
{
name: "ICCID为空记录失败",
csvContent: "iccid,msisdn\n89860012345678901234,13800000001\n,13800000002",
wantCards: []CardInfo{{ICCID: "89860012345678901234", MSISDN: "13800000001"}},
wantTotalCount: 2,
wantErrorCount: 1,
},
{
name: "空文件",
csvContent: "",
wantICCIDs: []string{},
wantCards: []CardInfo{},
wantTotalCount: 0,
wantErrorCount: 0,
},
{
name: "只有表头",
csvContent: "iccid",
wantICCIDs: []string{},
csvContent: "iccid,msisdn",
wantCards: []CardInfo{},
wantTotalCount: 0,
wantErrorCount: 0,
},
{
name: "包含空行",
csvContent: "89860012345678901234\n\n89860012345678901235\n \n89860012345678901236",
wantICCIDs: []string{"89860012345678901234", "89860012345678901235", "89860012345678901236"},
wantTotalCount: 3,
wantErrorCount: 0,
},
{
name: "ICCID前后有空格",
csvContent: " 89860012345678901234 \n89860012345678901235",
wantICCIDs: []string{"89860012345678901234", "89860012345678901235"},
name: "包含空行",
csvContent: "89860012345678901234,13800000001\n\n89860012345678901235,13800000002",
wantCards: []CardInfo{
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
{ICCID: "89860012345678901235", MSISDN: "13800000002"},
},
wantTotalCount: 2,
wantErrorCount: 0,
},
{
name: "多列CSV只取第一列",
csvContent: "89860012345678901234,额外数据,更多数据\n89860012345678901235,忽略,忽略",
wantICCIDs: []string{"89860012345678901234", "89860012345678901235"},
name: "ICCID和MSISDN前后有空格",
csvContent: " 89860012345678901234 , 13800000001 ",
wantCards: []CardInfo{
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
},
wantTotalCount: 1,
wantErrorCount: 0,
},
{
name: "多于两列只取前两列",
csvContent: "89860012345678901234,13800000001,额外数据\n89860012345678901235,13800000002,忽略",
wantCards: []CardInfo{
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
{ICCID: "89860012345678901235", MSISDN: "13800000002"},
},
wantTotalCount: 2,
wantErrorCount: 0,
},
{
name: "Windows换行符CRLF",
csvContent: "89860012345678901234\r\n89860012345678901235\r\n89860012345678901236",
wantICCIDs: []string{"89860012345678901234", "89860012345678901235", "89860012345678901236"},
wantTotalCount: 3,
name: "Windows换行符CRLF",
csvContent: "89860012345678901234,13800000001\r\n89860012345678901235,13800000002",
wantCards: []CardInfo{
{ICCID: "89860012345678901234", MSISDN: "13800000001"},
{ICCID: "89860012345678901235", MSISDN: "13800000002"},
},
wantTotalCount: 2,
wantErrorCount: 0,
},
}
@@ -98,34 +142,78 @@ func TestParseICCIDFromCSV(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := strings.NewReader(tt.csvContent)
result, err := ParseICCIDFromCSV(reader)
result, err := ParseCardCSV(reader)
if tt.wantError != nil {
require.ErrorIs(t, err, tt.wantError)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantICCIDs, result.ICCIDs, "ICCIDs 不匹配")
assert.Equal(t, tt.wantCards, result.Cards, "Cards 不匹配")
assert.Equal(t, tt.wantTotalCount, result.TotalCount, "TotalCount 不匹配")
assert.Equal(t, tt.wantErrorCount, len(result.ParseErrors), "ParseErrors 数量不匹配")
})
}
}
func TestParseCardCSV_ErrorDetails(t *testing.T) {
t.Run("MSISDN为空时记录详细错误", func(t *testing.T) {
csvContent := "iccid,msisdn\n89860012345678901234,"
reader := strings.NewReader(csvContent)
result, err := ParseCardCSV(reader)
require.NoError(t, err)
require.Len(t, result.ParseErrors, 1)
assert.Equal(t, 2, result.ParseErrors[0].Line)
assert.Equal(t, "89860012345678901234", result.ParseErrors[0].ICCID)
assert.Equal(t, "MSISDN 不能为空", result.ParseErrors[0].Reason)
})
t.Run("ICCID为空时记录详细错误", func(t *testing.T) {
csvContent := "iccid,msisdn\n,13800000001"
reader := strings.NewReader(csvContent)
result, err := ParseCardCSV(reader)
require.NoError(t, err)
require.Len(t, result.ParseErrors, 1)
assert.Equal(t, 2, result.ParseErrors[0].Line)
assert.Equal(t, "13800000001", result.ParseErrors[0].MSISDN)
assert.Equal(t, "ICCID 不能为空", result.ParseErrors[0].Reason)
})
t.Run("列数不足时记录详细错误", func(t *testing.T) {
csvContent := "89860012345678901234"
reader := strings.NewReader(csvContent)
result, err := ParseCardCSV(reader)
require.NoError(t, err)
require.Len(t, result.ParseErrors, 1)
assert.Equal(t, 1, result.ParseErrors[0].Line)
assert.Equal(t, "89860012345678901234", result.ParseErrors[0].ICCID)
assert.Contains(t, result.ParseErrors[0].Reason, "列数不足")
})
}
func TestIsHeader(t *testing.T) {
tests := []struct {
value string
col1 string
col2 string
expected bool
}{
{"iccid", true},
{"ICCID", true},
{"Iccid", true},
{"号", true},
{"号码", true},
{"89860012345678901234", false},
{"", false},
{"id", false},
{"card", false},
{"iccid", "msisdn", true},
{"ICCID", "MSISDN", true},
{"卡号", "接入号", true},
{"号码", "手机号", true},
{"iccid", "电话", true},
{"89860012345678901234", "13800000001", false},
{"iccid", "", false},
{"", "msisdn", false},
}
for _, tt := range tests {
t.Run(tt.value, func(t *testing.T) {
result := isHeader(tt.value)
t.Run(tt.col1+"_"+tt.col2, func(t *testing.T) {
result := isHeader(tt.col1, tt.col2)
assert.Equal(t, tt.expected, result)
})
}