Files
junhong_cmp_fiber/pkg/config/loader_test.go
huang eaa70ac255 feat: 实现 RBAC 权限系统和数据权限控制 (004-rbac-data-permission)
主要功能:
- 实现完整的 RBAC 权限系统(账号、角色、权限的多对多关联)
- 基于 owner_id + shop_id 的自动数据权限过滤
- 使用 PostgreSQL WITH RECURSIVE 查询下级账号
- Redis 缓存优化下级账号查询性能(30分钟过期)
- 支持多租户数据隔离和层级权限管理

技术实现:
- 新增 Account、Role、Permission 模型及关联关系表
- 实现 GORM Scopes 自动应用数据权限过滤
- 添加数据库迁移脚本(000002_rbac_data_permission、000003_add_owner_id_shop_id)
- 完善错误码定义(1010-1027 为 RBAC 相关错误)
- 重构 main.go 采用函数拆分提高可读性

测试覆盖:
- 添加 Account、Role、Permission 的集成测试
- 添加数据权限过滤的单元测试和集成测试
- 添加下级账号查询和缓存的单元测试
- 添加 API 回归测试确保向后兼容

文档更新:
- 更新 README.md 添加 RBAC 功能说明
- 更新 CLAUDE.md 添加技术栈和开发原则
- 添加 docs/004-rbac-data-permission/ 功能总结和使用指南

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 16:44:06 +08:00

662 lines
14 KiB
Go

package config
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/spf13/viper"
)
// TestLoad tests the config loading functionality
func TestLoad(t *testing.T) {
tests := []struct {
name string
setupEnv func()
cleanupEnv func()
createConfig func(t *testing.T) string
wantErr bool
validateFunc func(t *testing.T, cfg *Config)
}{
{
name: "valid default config",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
_ = os.Setenv(constants.EnvConfigEnv, "")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
_ = os.Unsetenv(constants.EnvConfigEnv)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set as default config path
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: false,
validateFunc: func(t *testing.T, cfg *Config) {
if cfg.Server.Address != ":3000" {
t.Errorf("expected server.address :3000, got %s", cfg.Server.Address)
}
if cfg.Server.ReadTimeout != 10*time.Second {
t.Errorf("expected read_timeout 10s, got %v", cfg.Server.ReadTimeout)
}
if cfg.Redis.Address != "localhost" {
t.Errorf("expected redis.address localhost, got %s", cfg.Redis.Address)
}
if cfg.Redis.Port != 6379 {
t.Errorf("expected redis.port 6379, got %d", cfg.Redis.Port)
}
if cfg.Redis.PoolSize != 10 {
t.Errorf("expected redis.pool_size 10, got %d", cfg.Redis.PoolSize)
}
if cfg.Logging.Level != "info" {
t.Errorf("expected logging.level info, got %s", cfg.Logging.Level)
}
if cfg.Middleware.EnableAuth != true {
t.Errorf("expected enable_auth true, got %v", cfg.Middleware.EnableAuth)
}
},
},
{
name: "environment-specific config (dev)",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigEnv, "dev")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigEnv)
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
// Create configs directory in temp
tmpDir := t.TempDir()
configsDir := filepath.Join(tmpDir, "configs")
if err := os.MkdirAll(configsDir, 0755); err != nil {
t.Fatalf("failed to create configs dir: %v", err)
}
// Create dev config
devConfigFile := filepath.Join(configsDir, "config.dev.yaml")
content := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 1
pool_size: 5
min_idle_conns: 2
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 50
max_backups: 10
max_age: 7
compress: false
access_log:
filename: "logs/access.log"
max_size: 100
max_backups: 30
max_age: 30
compress: false
middleware:
enable_auth: false
enable_rate_limiter: false
rate_limiter:
max: 50
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(devConfigFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create dev config file: %v", err)
}
// Change to tmpDir so relative path works
originalWd, _ := os.Getwd()
_ = os.Chdir(tmpDir)
t.Cleanup(func() { _ = os.Chdir(originalWd) })
return devConfigFile
},
wantErr: false,
validateFunc: func(t *testing.T, cfg *Config) {
if cfg.Server.Address != ":8080" {
t.Errorf("expected server.address :8080, got %s", cfg.Server.Address)
}
if cfg.Redis.DB != 1 {
t.Errorf("expected redis.db 1, got %d", cfg.Redis.DB)
}
if cfg.Logging.Level != "debug" {
t.Errorf("expected logging.level debug, got %s", cfg.Logging.Level)
}
if cfg.Middleware.EnableAuth != false {
t.Errorf("expected enable_auth false, got %v", cfg.Middleware.EnableAuth)
}
},
},
{
name: "invalid YAML syntax",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
_ = os.Setenv(constants.EnvConfigEnv, "")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
_ = os.Unsetenv(constants.EnvConfigEnv)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
invalid yaml syntax here!!!
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - invalid server address",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ""
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - timeout out of range",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "1s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
{
name: "validation error - invalid redis port",
setupEnv: func() {
_ = os.Setenv(constants.EnvConfigPath, "")
},
cleanupEnv: func() {
_ = os.Unsetenv(constants.EnvConfigPath)
},
createConfig: func(t *testing.T) string {
t.Helper()
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 99999
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
return configFile
},
wantErr: true,
validateFunc: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset viper for each test
viper.Reset()
// Setup environment
if tt.setupEnv != nil {
tt.setupEnv()
}
// Create config file
if tt.createConfig != nil {
tt.createConfig(t)
}
// Cleanup after test
if tt.cleanupEnv != nil {
defer tt.cleanupEnv()
}
// Load config
cfg, err := Load()
// Check error expectation
if (err != nil) != tt.wantErr {
t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr)
return
}
// Validate config if no error expected
if !tt.wantErr && tt.validateFunc != nil {
tt.validateFunc(t, cfg)
}
})
}
}
// TestReload tests the config reload functionality
func TestReload(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
// Initial config
initialContent := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 10
min_idle_conns: 5
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "info"
development: false
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
enable_rate_limiter: false
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
// Set config path
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load initial config
cfg, err := Load()
if err != nil {
t.Fatalf("failed to load initial config: %v", err)
}
// Verify initial values
if cfg.Logging.Level != "info" {
t.Errorf("expected initial logging.level info, got %s", cfg.Logging.Level)
}
if cfg.Server.Address != ":3000" {
t.Errorf("expected initial server.address :3000, got %s", cfg.Server.Address)
}
// Modify config file
updatedContent := `
server:
address: ":8080"
read_timeout: "15s"
write_timeout: "15s"
shutdown_timeout: "30s"
prefork: false
redis:
address: "localhost"
port: 6379
password: ""
db: 0
pool_size: 20
min_idle_conns: 10
dial_timeout: "5s"
read_timeout: "3s"
write_timeout: "3s"
logging:
level: "debug"
development: true
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: false
enable_rate_limiter: true
rate_limiter:
max: 200
expiration: "2m"
storage: "redis"
`
if err := os.WriteFile(configFile, []byte(updatedContent), 0644); err != nil {
t.Fatalf("failed to update config file: %v", err)
}
// Reload config
newCfg, err := Reload()
if err != nil {
t.Fatalf("failed to reload config: %v", err)
}
// Verify updated values
if newCfg.Logging.Level != "debug" {
t.Errorf("expected updated logging.level debug, got %s", newCfg.Logging.Level)
}
if newCfg.Server.Address != ":8080" {
t.Errorf("expected updated server.address :8080, got %s", newCfg.Server.Address)
}
if newCfg.Redis.PoolSize != 20 {
t.Errorf("expected updated redis.pool_size 20, got %d", newCfg.Redis.PoolSize)
}
if newCfg.Middleware.EnableAuth != false {
t.Errorf("expected updated enable_auth false, got %v", newCfg.Middleware.EnableAuth)
}
if newCfg.Middleware.EnableRateLimiter != true {
t.Errorf("expected updated enable_rate_limiter true, got %v", newCfg.Middleware.EnableRateLimiter)
}
// Verify global config was updated
globalCfg := Get()
if globalCfg.Logging.Level != "debug" {
t.Errorf("expected global config updated, got logging.level %s", globalCfg.Logging.Level)
}
}
// TestGetConfigPath tests the GetConfigPath function
func TestGetConfigPath(t *testing.T) {
// Reset viper
viper.Reset()
// Create temp config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "config.yaml")
content := `
server:
address: ":3000"
read_timeout: "10s"
write_timeout: "10s"
shutdown_timeout: "30s"
redis:
address: "localhost"
port: 6379
db: 0
pool_size: 10
min_idle_conns: 5
logging:
level: "info"
app_log:
filename: "logs/app.log"
max_size: 100
max_backups: 30
max_age: 30
compress: true
access_log:
filename: "logs/access.log"
max_size: 500
max_backups: 90
max_age: 90
compress: true
middleware:
enable_auth: true
rate_limiter:
max: 100
expiration: "1m"
storage: "memory"
`
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
_ = os.Setenv(constants.EnvConfigPath, configFile)
defer func() { _ = os.Unsetenv(constants.EnvConfigPath) }()
// Load config
_, err := Load()
if err != nil {
t.Fatalf("failed to load config: %v", err)
}
// Get config path
path := GetConfigPath()
if path == "" {
t.Error("expected non-empty config path")
}
// Verify it's an absolute path
if !filepath.IsAbs(path) {
t.Errorf("expected absolute path, got %s", path)
}
}