主要功能: - 实现完整的 RBAC 权限系统(账号、角色、权限的多对多关联) - 基于 owner_id + shop_id 的自动数据权限过滤 - 使用 PostgreSQL WITH RECURSIVE 查询下级账号 - Redis 缓存优化下级账号查询性能(30分钟过期) - 支持多租户数据隔离和层级权限管理 技术实现: - 新增 Account、Role、Permission 模型及关联关系表 - 实现 GORM Scopes 自动应用数据权限过滤 - 添加数据库迁移脚本(000002_rbac_data_permission、000003_add_owner_id_shop_id) - 完善错误码定义(1010-1027 为 RBAC 相关错误) - 重构 main.go 采用函数拆分提高可读性 测试覆盖: - 添加 Account、Role、Permission 的集成测试 - 添加数据权限过滤的单元测试和集成测试 - 添加下级账号查询和缓存的单元测试 - 添加 API 回归测试确保向后兼容 文档更新: - 更新 README.md 添加 RBAC 功能说明 - 更新 CLAUDE.md 添加技术栈和开发原则 - 添加 docs/004-rbac-data-permission/ 功能总结和使用指南 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
359 lines
9.1 KiB
Go
359 lines
9.1 KiB
Go
package errors
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"testing"
|
||
|
||
"github.com/gofiber/fiber/v2"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// TestSafeErrorHandler 测试 SafeErrorHandler 基本功能
|
||
func TestSafeErrorHandler(t *testing.T) {
|
||
logger, _ := zap.NewProduction()
|
||
defer func() { _ = logger.Sync() }()
|
||
handler := SafeErrorHandler(logger)
|
||
|
||
tests := []struct {
|
||
name string
|
||
err error
|
||
expectedStatus int
|
||
expectedCode int
|
||
}{
|
||
{
|
||
name: "AppError 参数验证失败",
|
||
err: New(CodeInvalidParam, "用户名不能为空"),
|
||
expectedStatus: 400,
|
||
expectedCode: CodeInvalidParam,
|
||
},
|
||
{
|
||
name: "AppError 缺失令牌",
|
||
err: New(CodeMissingToken, ""),
|
||
expectedStatus: 401,
|
||
expectedCode: CodeMissingToken,
|
||
},
|
||
{
|
||
name: "AppError 资源未找到",
|
||
err: New(CodeNotFound, "用户不存在"),
|
||
expectedStatus: 404,
|
||
expectedCode: CodeNotFound,
|
||
},
|
||
{
|
||
name: "AppError 数据库错误",
|
||
err: New(CodeDatabaseError, "连接失败"),
|
||
expectedStatus: 500,
|
||
expectedCode: CodeDatabaseError,
|
||
},
|
||
{
|
||
name: "fiber.Error 400",
|
||
err: fiber.NewError(400, "Bad Request"),
|
||
expectedStatus: 400,
|
||
expectedCode: CodeInvalidParam,
|
||
},
|
||
{
|
||
name: "fiber.Error 404",
|
||
err: fiber.NewError(404, "Not Found"),
|
||
expectedStatus: 404,
|
||
expectedCode: CodeNotFound,
|
||
},
|
||
{
|
||
name: "标准 error",
|
||
err: errors.New("standard error"),
|
||
expectedStatus: 500,
|
||
expectedCode: CodeInternalError,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
app := fiber.New(fiber.Config{
|
||
ErrorHandler: handler,
|
||
})
|
||
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
return tt.err
|
||
})
|
||
|
||
// 不实际发起 HTTP 请求,仅验证 handler 不会 panic
|
||
// 实际的集成测试在 tests/integration/ 中进行
|
||
if handler == nil {
|
||
t.Error("SafeErrorHandler returned nil")
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestAppErrorMethods 测试 AppError 的方法
|
||
func TestAppErrorMethods(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
err *AppError
|
||
expectedError string
|
||
expectedHTTPStatus int
|
||
expectedCode int
|
||
}{
|
||
{
|
||
name: "基本 AppError",
|
||
err: New(CodeInvalidParam, "参数错误"),
|
||
expectedError: "参数错误",
|
||
expectedHTTPStatus: 400,
|
||
expectedCode: CodeInvalidParam,
|
||
},
|
||
{
|
||
name: "带自定义 HTTP 状态码",
|
||
err: New(CodeNotFound, "用户不存在").WithHTTPStatus(404),
|
||
expectedError: "用户不存在",
|
||
expectedHTTPStatus: 404,
|
||
expectedCode: CodeNotFound,
|
||
},
|
||
{
|
||
name: "空消息使用默认",
|
||
err: New(CodeDatabaseError, ""),
|
||
expectedError: "数据库错误",
|
||
expectedHTTPStatus: 500,
|
||
expectedCode: CodeDatabaseError,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
// 测试 Error() 方法
|
||
if tt.err.Error() != tt.expectedError {
|
||
t.Errorf("Error() = %q, expected %q", tt.err.Error(), tt.expectedError)
|
||
}
|
||
|
||
// 测试 Code 字段
|
||
if tt.err.Code != tt.expectedCode {
|
||
t.Errorf("Code = %d, expected %d", tt.err.Code, tt.expectedCode)
|
||
}
|
||
|
||
// 测试 HTTPStatus 字段
|
||
if tt.err.HTTPStatus != tt.expectedHTTPStatus {
|
||
t.Errorf("HTTPStatus = %d, expected %d", tt.err.HTTPStatus, tt.expectedHTTPStatus)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestAppErrorUnwrap 测试错误链支持
|
||
func TestAppErrorUnwrap(t *testing.T) {
|
||
originalErr := errors.New("database connection failed")
|
||
appErr := Wrap(CodeDatabaseError, "", originalErr)
|
||
|
||
// 测试 Unwrap
|
||
unwrapped := appErr.Unwrap()
|
||
if unwrapped != originalErr {
|
||
t.Errorf("Unwrap() = %v, expected %v", unwrapped, originalErr)
|
||
}
|
||
|
||
// 测试 errors.Is
|
||
if !errors.Is(appErr, originalErr) {
|
||
t.Error("errors.Is failed to identify wrapped error")
|
||
}
|
||
}
|
||
|
||
// BenchmarkSafeErrorHandler 基准测试错误处理性能
|
||
func BenchmarkSafeErrorHandler(b *testing.B) {
|
||
logger, _ := zap.NewProduction()
|
||
defer func() { _ = logger.Sync() }()
|
||
_ = SafeErrorHandler(logger) // 避免未使用变量警告
|
||
|
||
testErrors := []error{
|
||
New(CodeInvalidParam, "参数错误"),
|
||
New(CodeDatabaseError, "数据库错误"),
|
||
fiber.NewError(404, "Not Found"),
|
||
errors.New("standard error"),
|
||
}
|
||
|
||
b.ResetTimer()
|
||
for i := 0; i < b.N; i++ {
|
||
err := testErrors[i%len(testErrors)]
|
||
_ = err // 避免未使用变量警告
|
||
// 注意:这里无法直接调用 handler,因为它需要 Fiber Context
|
||
// 实际性能测试应该在集成测试中进行
|
||
}
|
||
}
|
||
|
||
// TestNewWithValidation 测试创建 AppError 时的参数验证
|
||
func TestNewWithValidation(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
code int
|
||
message string
|
||
expectPanic bool
|
||
}{
|
||
{
|
||
name: "有效的错误码和消息",
|
||
code: CodeInvalidParam,
|
||
message: "自定义消息",
|
||
expectPanic: false,
|
||
},
|
||
{
|
||
name: "有效的错误码,空消息",
|
||
code: CodeDatabaseError,
|
||
message: "",
|
||
expectPanic: false,
|
||
},
|
||
{
|
||
name: "未知错误码",
|
||
code: 9999,
|
||
message: "未知错误",
|
||
expectPanic: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
defer func() {
|
||
r := recover()
|
||
if (r != nil) != tt.expectPanic {
|
||
t.Errorf("New() panic = %v, expectPanic = %v", r != nil, tt.expectPanic)
|
||
}
|
||
}()
|
||
|
||
err := New(tt.code, tt.message)
|
||
if err == nil {
|
||
t.Error("New() returned nil")
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestWrapError 测试包装错误功能
|
||
func TestWrapError(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
originalErr error
|
||
code int
|
||
message string
|
||
expectedMessage string
|
||
}{
|
||
{
|
||
name: "包装标准错误",
|
||
originalErr: errors.New("connection timeout"),
|
||
code: CodeTimeout,
|
||
message: "",
|
||
expectedMessage: "请求超时: connection timeout",
|
||
},
|
||
{
|
||
name: "包装带自定义消息",
|
||
originalErr: errors.New("SQL error"),
|
||
code: CodeDatabaseError,
|
||
message: "用户表查询失败",
|
||
expectedMessage: "用户表查询失败: SQL error",
|
||
},
|
||
{
|
||
name: "包装 nil 错误",
|
||
originalErr: nil,
|
||
code: CodeInternalError,
|
||
message: "意外的 nil 错误",
|
||
expectedMessage: "意外的 nil 错误",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := Wrap(tt.code, tt.message, tt.originalErr)
|
||
|
||
if err.Error() != tt.expectedMessage {
|
||
t.Errorf("Wrap().Error() = %q, expected %q", err.Error(), tt.expectedMessage)
|
||
}
|
||
|
||
if err.Code != tt.code {
|
||
t.Errorf("Wrap().Code = %d, expected %d", err.Code, tt.code)
|
||
}
|
||
|
||
if tt.originalErr != nil {
|
||
unwrapped := err.Unwrap()
|
||
if unwrapped != tt.originalErr {
|
||
t.Errorf("Wrap().Unwrap() = %v, expected %v", unwrapped, tt.originalErr)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestErrorMessageSanitization 测试错误消息脱敏
|
||
func TestErrorMessageSanitization(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
code int
|
||
message string
|
||
shouldBeSanitized bool
|
||
expectedForClient string
|
||
}{
|
||
{
|
||
name: "客户端错误保留消息",
|
||
code: CodeInvalidParam,
|
||
message: "用户名长度必须在 3-20 之间",
|
||
shouldBeSanitized: false,
|
||
expectedForClient: "用户名长度必须在 3-20 之间",
|
||
},
|
||
{
|
||
name: "服务端错误脱敏",
|
||
code: CodeDatabaseError,
|
||
message: "pq: relation 'users' does not exist",
|
||
shouldBeSanitized: true,
|
||
expectedForClient: "数据库错误", // 应该返回通用消息
|
||
},
|
||
{
|
||
name: "内部错误脱敏",
|
||
code: CodeInternalError,
|
||
message: "panic: runtime error: invalid memory address",
|
||
shouldBeSanitized: true,
|
||
expectedForClient: "内部服务器错误",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
// 这个测试逻辑应该在 handler.go 的 handleError 中实现
|
||
// 这里仅验证逻辑概念
|
||
|
||
var clientMessage string
|
||
if tt.shouldBeSanitized {
|
||
// 服务端错误使用默认消息
|
||
clientMessage = GetMessage(tt.code, "zh-CN")
|
||
} else {
|
||
// 客户端错误保留原始消息
|
||
clientMessage = tt.message
|
||
}
|
||
|
||
if clientMessage != tt.expectedForClient {
|
||
t.Errorf("Client message = %q, expected %q", clientMessage, tt.expectedForClient)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestConcurrentErrorHandling 测试并发场景下的错误处理
|
||
func TestConcurrentErrorHandling(t *testing.T) {
|
||
logger, _ := zap.NewProduction()
|
||
defer func() { _ = logger.Sync() }()
|
||
handler := SafeErrorHandler(logger)
|
||
if handler == nil {
|
||
t.Fatal("SafeErrorHandler returned nil")
|
||
}
|
||
|
||
// 并发创建错误
|
||
errChan := make(chan error, 100)
|
||
for i := 0; i < 100; i++ {
|
||
go func(idx int) {
|
||
code := CodeInvalidParam
|
||
if idx%2 == 0 {
|
||
code = CodeDatabaseError
|
||
}
|
||
errChan <- New(code, fmt.Sprintf("错误 #%d", idx))
|
||
}(i)
|
||
}
|
||
|
||
// 验证所有错误都能正确创建
|
||
for i := 0; i < 100; i++ {
|
||
err := <-errChan
|
||
if err == nil {
|
||
t.Errorf("Goroutine %d returned nil error", i)
|
||
}
|
||
}
|
||
}
|