主要功能: - 实现完整的 RBAC 权限系统(账号、角色、权限的多对多关联) - 基于 owner_id + shop_id 的自动数据权限过滤 - 使用 PostgreSQL WITH RECURSIVE 查询下级账号 - Redis 缓存优化下级账号查询性能(30分钟过期) - 支持多租户数据隔离和层级权限管理 技术实现: - 新增 Account、Role、Permission 模型及关联关系表 - 实现 GORM Scopes 自动应用数据权限过滤 - 添加数据库迁移脚本(000002_rbac_data_permission、000003_add_owner_id_shop_id) - 完善错误码定义(1010-1027 为 RBAC 相关错误) - 重构 main.go 采用函数拆分提高可读性 测试覆盖: - 添加 Account、Role、Permission 的集成测试 - 添加数据权限过滤的单元测试和集成测试 - 添加下级账号查询和缓存的单元测试 - 添加 API 回归测试确保向后兼容 文档更新: - 更新 README.md 添加 RBAC 功能说明 - 更新 CLAUDE.md 添加技术栈和开发原则 - 添加 docs/004-rbac-data-permission/ 功能总结和使用指南 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
478 lines
12 KiB
Go
478 lines
12 KiB
Go
package response
|
||
|
||
import (
|
||
"io"
|
||
"net/http/httptest"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/break/junhong_cmp_fiber/pkg/errors"
|
||
"github.com/bytedance/sonic"
|
||
"github.com/gofiber/fiber/v2"
|
||
)
|
||
|
||
// TestSuccess 测试成功响应(T034)
|
||
func TestSuccess(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
data any
|
||
}{
|
||
{
|
||
name: "success with string data",
|
||
data: "test data",
|
||
},
|
||
{
|
||
name: "success with map data",
|
||
data: map[string]any{
|
||
"id": 123,
|
||
"name": "test",
|
||
},
|
||
},
|
||
{
|
||
name: "success with slice data",
|
||
data: []string{"item1", "item2", "item3"},
|
||
},
|
||
{
|
||
name: "success with struct data",
|
||
data: struct {
|
||
ID int `json:"id"`
|
||
Name string `json:"name"`
|
||
}{
|
||
ID: 456,
|
||
Name: "test struct",
|
||
},
|
||
},
|
||
{
|
||
name: "success with nil data",
|
||
data: nil,
|
||
},
|
||
{
|
||
name: "success with empty map",
|
||
data: map[string]any{},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
app := fiber.New()
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
return Success(c, tt.data)
|
||
})
|
||
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
resp, err := app.Test(req)
|
||
if err != nil {
|
||
t.Fatalf("Failed to execute request: %v", err)
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
// 验证 HTTP 状态码
|
||
if resp.StatusCode != 200 {
|
||
t.Errorf("Expected status code 200, got %d", resp.StatusCode)
|
||
}
|
||
|
||
// 验证响应头
|
||
if resp.Header.Get("Content-Type") != "application/json" {
|
||
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
|
||
}
|
||
|
||
// 解析响应体
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
t.Fatalf("Failed to read response body: %v", err)
|
||
}
|
||
|
||
var response Response
|
||
if err := sonic.Unmarshal(body, &response); err != nil {
|
||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||
}
|
||
|
||
// 验证响应结构
|
||
if response.Code != errors.CodeSuccess {
|
||
t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code)
|
||
}
|
||
|
||
if response.Message != "success" {
|
||
t.Errorf("Expected message 'success', got '%s'", response.Message)
|
||
}
|
||
|
||
// 验证时间戳格式 RFC3339
|
||
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
|
||
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
|
||
}
|
||
|
||
// 验证数据字段(如果不是 nil)
|
||
if tt.data != nil {
|
||
if response.Data == nil {
|
||
t.Error("Expected data field to be non-nil")
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestError 测试错误响应(T035)
|
||
func TestError(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
httpStatus int
|
||
code int
|
||
message string
|
||
}{
|
||
{
|
||
name: "internal server error",
|
||
httpStatus: 500,
|
||
code: errors.CodeInternalError,
|
||
message: "Internal server error occurred",
|
||
},
|
||
{
|
||
name: "missing token error",
|
||
httpStatus: 401,
|
||
code: errors.CodeMissingToken,
|
||
message: "Authentication token is missing",
|
||
},
|
||
{
|
||
name: "invalid token error",
|
||
httpStatus: 401,
|
||
code: errors.CodeInvalidToken,
|
||
message: "Token is invalid or expired",
|
||
},
|
||
{
|
||
name: "rate limit error",
|
||
httpStatus: 429,
|
||
code: errors.CodeTooManyRequests,
|
||
message: "Too many requests, please try again later",
|
||
},
|
||
{
|
||
name: "service unavailable error",
|
||
httpStatus: 503,
|
||
code: errors.CodeAuthServiceUnavailable,
|
||
message: "Authentication service is currently unavailable",
|
||
},
|
||
{
|
||
name: "bad request error",
|
||
httpStatus: 400,
|
||
code: 2000,
|
||
message: "Invalid request parameters",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
app := fiber.New()
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
return Error(c, tt.httpStatus, tt.code, tt.message)
|
||
})
|
||
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
resp, err := app.Test(req)
|
||
if err != nil {
|
||
t.Fatalf("Failed to execute request: %v", err)
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
// 验证 HTTP 状态码
|
||
if resp.StatusCode != tt.httpStatus {
|
||
t.Errorf("Expected status code %d, got %d", tt.httpStatus, resp.StatusCode)
|
||
}
|
||
|
||
// 验证响应头
|
||
if resp.Header.Get("Content-Type") != "application/json" {
|
||
t.Errorf("Expected Content-Type application/json, got %s", resp.Header.Get("Content-Type"))
|
||
}
|
||
|
||
// 解析响应体
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
t.Fatalf("Failed to read response body: %v", err)
|
||
}
|
||
|
||
var response Response
|
||
if err := sonic.Unmarshal(body, &response); err != nil {
|
||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||
}
|
||
|
||
// 验证响应结构
|
||
if response.Code != tt.code {
|
||
t.Errorf("Expected code %d, got %d", tt.code, response.Code)
|
||
}
|
||
|
||
if response.Message != tt.message {
|
||
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message)
|
||
}
|
||
|
||
if response.Data != nil {
|
||
t.Errorf("Expected data to be nil in error response, got %v", response.Data)
|
||
}
|
||
|
||
// 验证时间戳格式 RFC3339
|
||
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
|
||
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestSuccessWithMessage 测试带自定义消息的成功响应(T034)
|
||
func TestSuccessWithMessage(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
data any
|
||
message string
|
||
}{
|
||
{
|
||
name: "custom success message",
|
||
data: map[string]any{
|
||
"user_id": 123,
|
||
},
|
||
message: "User created successfully",
|
||
},
|
||
{
|
||
name: "empty custom message",
|
||
data: "test data",
|
||
message: "",
|
||
},
|
||
{
|
||
name: "chinese message",
|
||
data: map[string]string{
|
||
"status": "ok",
|
||
},
|
||
message: "操作成功",
|
||
},
|
||
{
|
||
name: "long message",
|
||
data: nil,
|
||
message: "This is a very long success message that describes in detail what happened during the operation",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
app := fiber.New()
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
return SuccessWithMessage(c, tt.data, tt.message)
|
||
})
|
||
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
resp, err := app.Test(req)
|
||
if err != nil {
|
||
t.Fatalf("Failed to execute request: %v", err)
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
// 验证 HTTP 状态码(默认 200)
|
||
if resp.StatusCode != 200 {
|
||
t.Errorf("Expected status code 200, got %d", resp.StatusCode)
|
||
}
|
||
|
||
// 解析响应体
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
t.Fatalf("Failed to read response body: %v", err)
|
||
}
|
||
|
||
var response Response
|
||
if err := sonic.Unmarshal(body, &response); err != nil {
|
||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||
}
|
||
|
||
// 验证响应结构
|
||
if response.Code != errors.CodeSuccess {
|
||
t.Errorf("Expected code %d, got %d", errors.CodeSuccess, response.Code)
|
||
}
|
||
|
||
if response.Message != tt.message {
|
||
t.Errorf("Expected message '%s', got '%s'", tt.message, response.Message)
|
||
}
|
||
|
||
// 验证时间戳格式 RFC3339
|
||
if _, err := time.Parse(time.RFC3339, response.Timestamp); err != nil {
|
||
t.Errorf("Timestamp is not in RFC3339 format: %s", response.Timestamp)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestResponseSerialization 测试响应序列化(T036)
|
||
func TestResponseSerialization(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
response Response
|
||
}{
|
||
{
|
||
name: "complete response",
|
||
response: Response{
|
||
Code: 0,
|
||
Data: map[string]any{"key": "value"},
|
||
Message: "success",
|
||
Timestamp: time.Now().Format(time.RFC3339),
|
||
},
|
||
},
|
||
{
|
||
name: "response with nil data",
|
||
response: Response{
|
||
Code: 1000,
|
||
Data: nil,
|
||
Message: "error",
|
||
Timestamp: time.Now().Format(time.RFC3339),
|
||
},
|
||
},
|
||
{
|
||
name: "response with nested data",
|
||
response: Response{
|
||
Code: 0,
|
||
Data: map[string]any{
|
||
"user": map[string]any{
|
||
"id": 123,
|
||
"name": "test",
|
||
"tags": []string{"tag1", "tag2"},
|
||
},
|
||
},
|
||
Message: "success",
|
||
Timestamp: time.Now().Format(time.RFC3339),
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
// 序列化
|
||
data, err := sonic.Marshal(tt.response)
|
||
if err != nil {
|
||
t.Fatalf("Failed to marshal response: %v", err)
|
||
}
|
||
|
||
// 反序列化
|
||
var deserialized Response
|
||
if err := sonic.Unmarshal(data, &deserialized); err != nil {
|
||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||
}
|
||
|
||
// 验证字段
|
||
if deserialized.Code != tt.response.Code {
|
||
t.Errorf("Code mismatch: expected %d, got %d", tt.response.Code, deserialized.Code)
|
||
}
|
||
|
||
if deserialized.Message != tt.response.Message {
|
||
t.Errorf("Message mismatch: expected '%s', got '%s'", tt.response.Message, deserialized.Message)
|
||
}
|
||
|
||
if deserialized.Timestamp != tt.response.Timestamp {
|
||
t.Errorf("Timestamp mismatch: expected '%s', got '%s'", tt.response.Timestamp, deserialized.Timestamp)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestResponseStructFields 测试响应结构字段(T036)
|
||
func TestResponseStructFields(t *testing.T) {
|
||
response := Response{
|
||
Code: 0,
|
||
Data: "test",
|
||
Message: "success",
|
||
Timestamp: time.Now().Format(time.RFC3339),
|
||
}
|
||
|
||
data, err := sonic.Marshal(response)
|
||
if err != nil {
|
||
t.Fatalf("Failed to marshal response: %v", err)
|
||
}
|
||
|
||
// 解析为 map 以检查 JSON 键
|
||
var jsonMap map[string]any
|
||
if err := sonic.Unmarshal(data, &jsonMap); err != nil {
|
||
t.Fatalf("Failed to unmarshal to map: %v", err)
|
||
}
|
||
|
||
// 验证所有必需字段都存在
|
||
requiredFields := []string{"code", "data", "msg", "timestamp"}
|
||
for _, field := range requiredFields {
|
||
if _, exists := jsonMap[field]; !exists {
|
||
t.Errorf("Required field '%s' is missing in JSON response", field)
|
||
}
|
||
}
|
||
|
||
// 验证字段类型
|
||
if _, ok := jsonMap["code"].(float64); !ok {
|
||
t.Error("Field 'code' should be a number")
|
||
}
|
||
|
||
if _, ok := jsonMap["msg"].(string); !ok {
|
||
t.Error("Field 'msg' should be a string")
|
||
}
|
||
|
||
if _, ok := jsonMap["timestamp"].(string); !ok {
|
||
t.Error("Field 'timestamp' should be a string")
|
||
}
|
||
}
|
||
|
||
// TestMultipleResponses 测试多个连续响应(T036)
|
||
func TestMultipleResponses(t *testing.T) {
|
||
app := fiber.New()
|
||
|
||
callCount := 0
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
callCount++
|
||
if callCount%2 == 0 {
|
||
return Success(c, map[string]int{"count": callCount})
|
||
}
|
||
return Error(c, 500, errors.CodeInternalError, "error occurred")
|
||
})
|
||
|
||
// 发送多个请求
|
||
for i := 1; i <= 5; i++ {
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
resp, err := app.Test(req)
|
||
if err != nil {
|
||
t.Fatalf("Request %d failed: %v", i, err)
|
||
}
|
||
|
||
body, _ := io.ReadAll(resp.Body)
|
||
_ = resp.Body.Close()
|
||
|
||
var response Response
|
||
if err := sonic.Unmarshal(body, &response); err != nil {
|
||
t.Fatalf("Request %d: failed to unmarshal response: %v", i, err)
|
||
}
|
||
|
||
// 验证每个响应都有时间戳
|
||
if response.Timestamp == "" {
|
||
t.Errorf("Request %d: timestamp should not be empty", i)
|
||
}
|
||
}
|
||
}
|
||
|
||
// TestTimestampFormat 测试时间戳格式(T036)
|
||
func TestTimestampFormat(t *testing.T) {
|
||
app := fiber.New()
|
||
app.Get("/test", func(c *fiber.Ctx) error {
|
||
return Success(c, nil)
|
||
})
|
||
|
||
req := httptest.NewRequest("GET", "/test", nil)
|
||
resp, err := app.Test(req)
|
||
if err != nil {
|
||
t.Fatalf("Failed to execute request: %v", err)
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
body, _ := io.ReadAll(resp.Body)
|
||
var response Response
|
||
if err := sonic.Unmarshal(body, &response); err != nil {
|
||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||
}
|
||
|
||
// 验证是 RFC3339 格式
|
||
parsedTime, err := time.Parse(time.RFC3339, response.Timestamp)
|
||
if err != nil {
|
||
t.Fatalf("Timestamp is not in RFC3339 format: %s, error: %v", response.Timestamp, err)
|
||
}
|
||
|
||
// 验证时间戳是最近的(应该在最近 1 秒内)
|
||
now := time.Now()
|
||
diff := now.Sub(parsedTime)
|
||
if diff < 0 || diff > time.Second {
|
||
t.Errorf("Timestamp seems incorrect: %s (diff from now: %v)", response.Timestamp, diff)
|
||
}
|
||
}
|