Files
junhong_cmp_fiber/tests/integration/auth_test.go
huang d66323487b refactor: align framework cleanup with new bootstrap flow
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-11-19 12:47:25 +08:00

444 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package integration
import (
"context"
"io"
"net/http/httptest"
"testing"
"time"
"github.com/break/junhong_cmp_fiber/pkg/constants"
"github.com/break/junhong_cmp_fiber/pkg/errors"
"github.com/break/junhong_cmp_fiber/pkg/logger"
"github.com/break/junhong_cmp_fiber/pkg/middleware"
"github.com/break/junhong_cmp_fiber/pkg/response"
"github.com/break/junhong_cmp_fiber/pkg/validator"
"github.com/gofiber/fiber/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// setupAuthTestApp creates a Fiber app with authentication middleware for testing
func setupAuthTestApp(t *testing.T, rdb *redis.Client) *fiber.App {
t.Helper()
// Initialize logger
appLogConfig := logger.LogRotationConfig{
Filename: "logs/app_test.log",
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
}
accessLogConfig := logger.LogRotationConfig{
Filename: "logs/access_test.log",
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
}
if err := logger.InitLoggers("info", false, appLogConfig, accessLogConfig); err != nil {
t.Fatalf("failed to initialize logger: %v", err)
}
app := fiber.New()
// Add request ID middleware
app.Use(func(c *fiber.Ctx) error {
c.Locals(constants.ContextKeyRequestID, "test-request-id-123")
return c.Next()
})
// Add authentication middleware
tokenValidator := validator.NewTokenValidator(rdb, logger.GetAppLogger())
app.Use(middleware.Auth(middleware.AuthConfig{
TokenValidator: func(token string) (uint, int, uint, error) {
_, err := tokenValidator.Validate(token)
if err != nil {
return 0, 0, 0, err
}
// 测试中简化处理userID 设为 1userType 设为普通用户
return 1, 0, 0, nil
},
}))
// Add protected test routes
app.Get("/api/v1/test", func(c *fiber.Ctx) error {
userID := c.Locals(constants.ContextKeyUserID)
return response.Success(c, fiber.Map{
"message": "protected resource",
"user_id": userID,
})
})
// 注释:用户路由已移至实例方法,集成测试中使用测试路由即可
// 实际的用户路由测试应在 cmd/api/main.go 中完整初始化
return app
}
// TestKeyAuthMiddleware_ValidToken tests authentication with a valid token
func TestKeyAuthMiddleware_ValidToken(t *testing.T) {
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
DB: 1, // Use test database
})
defer func() { _ = rdb.Close() }()
// Check Redis availability
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
t.Skip("Redis not available, skipping integration test")
}
// Clean up test data
defer rdb.FlushDB(ctx)
// Setup test token
testToken := "test-valid-token-12345"
testUserID := "user-789"
err := rdb.Set(ctx, constants.RedisAuthTokenKey(testToken), testUserID, 1*time.Hour).Err()
require.NoError(t, err, "Failed to set test token in Redis")
// Create test app
app := setupAuthTestApp(t, rdb)
// Create request with valid token
req := httptest.NewRequest("GET", "/api/v1/test", nil)
req.Header.Set("token", testToken)
// Execute request
resp, err := app.Test(req, -1)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
// Assertions
assert.Equal(t, 200, resp.StatusCode, "Expected status 200 for valid token")
// Parse response body
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
t.Logf("Response body: %s", string(body))
// Should contain user_id in response
assert.Contains(t, string(body), testUserID, "Response should contain user ID")
assert.Contains(t, string(body), `"code":0`, "Response should have success code")
}
// TestKeyAuthMiddleware_MissingToken tests authentication with missing token
func TestKeyAuthMiddleware_MissingToken(t *testing.T) {
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
DB: 1,
})
defer func() { _ = rdb.Close() }()
// Check Redis availability
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
t.Skip("Redis not available, skipping integration test")
}
// Create test app
app := setupAuthTestApp(t, rdb)
// Create request without token
req := httptest.NewRequest("GET", "/api/v1/test", nil)
// Execute request
resp, err := app.Test(req, -1)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
// Assertions
assert.Equal(t, 401, resp.StatusCode, "Expected status 401 for missing token")
// Parse response body
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
t.Logf("Response body: %s", string(body))
// Should contain error code 1001
assert.Contains(t, string(body), `"code":1001`, "Response should have missing token error code")
// Message is in Chinese: "缺失认证令牌"
assert.Contains(t, string(body), "缺失认证令牌", "Response should have missing token message")
}
// TestKeyAuthMiddleware_InvalidToken tests authentication with invalid token
func TestKeyAuthMiddleware_InvalidToken(t *testing.T) {
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
DB: 1,
})
defer func() { _ = rdb.Close() }()
// Check Redis availability
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
t.Skip("Redis not available, skipping integration test")
}
// Clean up test data
defer rdb.FlushDB(ctx)
// Create test app
app := setupAuthTestApp(t, rdb)
// Create request with invalid token (not in Redis)
req := httptest.NewRequest("GET", "/api/v1/test", nil)
req.Header.Set("token", "invalid-token-xyz")
// Execute request
resp, err := app.Test(req, -1)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
// Assertions
assert.Equal(t, 401, resp.StatusCode, "Expected status 401 for invalid token")
// Parse response body
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
t.Logf("Response body: %s", string(body))
// Should contain error code 1002
assert.Contains(t, string(body), `"code":1002`, "Response should have invalid token error code")
// Message is in Chinese: "令牌无效或已过期"
assert.Contains(t, string(body), "令牌无效或已过期", "Response should have invalid token message")
}
// TestKeyAuthMiddleware_ExpiredToken tests authentication with expired token
func TestKeyAuthMiddleware_ExpiredToken(t *testing.T) {
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
DB: 1,
})
defer func() { _ = rdb.Close() }()
// Check Redis availability
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
t.Skip("Redis not available, skipping integration test")
}
// Clean up test data
defer rdb.FlushDB(ctx)
// Setup test token with short TTL
testToken := "test-expired-token-999"
testUserID := "user-999"
err := rdb.Set(ctx, constants.RedisAuthTokenKey(testToken), testUserID, 1*time.Second).Err()
require.NoError(t, err, "Failed to set test token in Redis")
// Wait for token to expire
time.Sleep(2 * time.Second)
// Create test app
app := setupAuthTestApp(t, rdb)
// Create request with expired token
req := httptest.NewRequest("GET", "/api/v1/test", nil)
req.Header.Set("token", testToken)
// Execute request
resp, err := app.Test(req, -1)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
// Assertions
assert.Equal(t, 401, resp.StatusCode, "Expected status 401 for expired token")
// Parse response body
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
t.Logf("Response body: %s", string(body))
// Should contain error code 1002 (expired token treated as invalid)
assert.Contains(t, string(body), `"code":1002`, "Response should have invalid token error code")
}
// TestKeyAuthMiddleware_RedisDown tests fail-closed behavior when Redis is unavailable
func TestKeyAuthMiddleware_RedisDown(t *testing.T) {
// Setup Redis client with invalid address (simulating Redis down)
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:9999", // Invalid port
DialTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
})
defer func() { _ = rdb.Close() }()
// Create test app with unavailable Redis
app := setupAuthTestApp(t, rdb)
// Create request with any token
req := httptest.NewRequest("GET", "/api/v1/test", nil)
req.Header.Set("token", "any-token")
// Execute request
resp, err := app.Test(req, -1)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
// Assertions - should fail closed with 503
assert.Equal(t, 503, resp.StatusCode, "Expected status 503 when Redis is unavailable")
// Parse response body
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
t.Logf("Response body: %s", string(body))
// Should contain error code 1004
assert.Contains(t, string(body), `"code":1004`, "Response should have service unavailable error code")
// Message is in Chinese: "认证服务不可用"
assert.Contains(t, string(body), "认证服务不可用", "Response should have service unavailable message")
}
// TestKeyAuthMiddleware_UserIDPropagation tests that user ID is properly stored in context
func TestKeyAuthMiddleware_UserIDPropagation(t *testing.T) {
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
DB: 1,
})
defer func() { _ = rdb.Close() }()
// Check Redis availability
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
t.Skip("Redis not available, skipping integration test")
}
// Clean up test data
defer rdb.FlushDB(ctx)
// Setup test token
testToken := "test-propagation-token"
testUserID := "user-propagation-123"
err := rdb.Set(ctx, constants.RedisAuthTokenKey(testToken), testUserID, 1*time.Hour).Err()
require.NoError(t, err)
// Initialize logger
appLogConfig := logger.LogRotationConfig{
Filename: "logs/app_test.log",
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
}
accessLogConfig := logger.LogRotationConfig{
Filename: "logs/access_test.log",
MaxSize: 10,
MaxBackups: 3,
MaxAge: 7,
Compress: false,
}
if err := logger.InitLoggers("info", false, appLogConfig, accessLogConfig); err != nil {
t.Fatalf("failed to initialize logger: %v", err)
}
app := fiber.New()
// Add request ID middleware
app.Use(func(c *fiber.Ctx) error {
c.Locals(constants.ContextKeyRequestID, "test-request-id")
return c.Next()
})
// Add authentication middleware
tokenValidator := validator.NewTokenValidator(rdb, logger.GetAppLogger())
app.Use(middleware.Auth(middleware.AuthConfig{
TokenValidator: func(token string) (uint, int, uint, error) {
_, err := tokenValidator.Validate(token)
if err != nil {
return 0, 0, 0, err
}
// 测试中简化处理userID 设为 1userType 设为普通用户
return 1, 0, 0, nil
},
}))
// Add test route that checks user ID
var capturedUserID uint
app.Get("/api/v1/check-user", func(c *fiber.Ctx) error {
userID, ok := c.Locals(constants.ContextKeyUserID).(uint)
if !ok {
return errors.New(errors.CodeInternalError, "User ID not found in context")
}
capturedUserID = userID
return response.Success(c, fiber.Map{
"user_id": userID,
})
})
// Create request
req := httptest.NewRequest("GET", "/api/v1/check-user", nil)
req.Header.Set("token", testToken)
// Execute request
resp, err := app.Test(req, -1)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
// Assertions
assert.Equal(t, 200, resp.StatusCode)
assert.Equal(t, testUserID, capturedUserID, "User ID should be propagated to handler")
}
// TestKeyAuthMiddleware_MultipleRequests tests multiple requests with different tokens
func TestKeyAuthMiddleware_MultipleRequests(t *testing.T) {
// Setup Redis client
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
DB: 1,
})
defer func() { _ = rdb.Close() }()
// Check Redis availability
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
t.Skip("Redis not available, skipping integration test")
}
// Clean up test data
defer rdb.FlushDB(ctx)
// Setup multiple test tokens
tokens := map[string]string{
"token-user-1": "user-001",
"token-user-2": "user-002",
"token-user-3": "user-003",
}
for token, userID := range tokens {
err := rdb.Set(ctx, constants.RedisAuthTokenKey(token), userID, 1*time.Hour).Err()
require.NoError(t, err)
}
// Create test app
app := setupAuthTestApp(t, rdb)
// Test each token
for token, expectedUserID := range tokens {
t.Run("token_"+expectedUserID, func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/v1/test", nil)
req.Header.Set("token", token)
resp, err := app.Test(req, -1)
require.NoError(t, err)
defer func() { _ = resp.Body.Close() }()
assert.Equal(t, 200, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Contains(t, string(body), expectedUserID)
})
}
}