package integration import ( "fmt" "io" "net/http/httptest" "testing" "time" "github.com/break/junhong_cmp_fiber/internal/middleware" "github.com/break/junhong_cmp_fiber/pkg/logger" "github.com/break/junhong_cmp_fiber/pkg/response" "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // setupRateLimiterTestApp creates a Fiber app with rate limiter for testing func setupRateLimiterTestApp(t *testing.T, max int, expiration time.Duration) *fiber.App { t.Helper() // Initialize logger appLogConfig := logger.LogRotationConfig{ Filename: "logs/app_test.log", MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, } accessLogConfig := logger.LogRotationConfig{ Filename: "logs/access_test.log", MaxSize: 10, MaxBackups: 3, MaxAge: 7, Compress: false, } if err := logger.InitLoggers("info", false, appLogConfig, accessLogConfig); err != nil { t.Fatalf("failed to initialize logger: %v", err) } app := fiber.New() // Add rate limiter middleware (nil storage = in-memory) app.Use(middleware.RateLimiter(max, expiration, nil)) // Add test route app.Get("/api/v1/test", func(c *fiber.Ctx) error { return response.Success(c, fiber.Map{ "message": "success", }) }) return app } // TestRateLimiter_LimitExceeded tests that rate limiter returns 429 when limit is exceeded func TestRateLimiter_LimitExceeded(t *testing.T) { // Create app with low limit for easy testing max := 5 expiration := 1 * time.Minute app := setupRateLimiterTestApp(t, max, expiration) // Make requests up to the limit for i := 1; i <= max; i++ { req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", "192.168.1.100") // Simulate same IP resp, err := app.Test(req, -1) require.NoError(t, err) resp.Body.Close() assert.Equal(t, 200, resp.StatusCode, "Request %d should succeed", i) } // The next request should be rate limited req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", "192.168.1.100") resp, err := app.Test(req, -1) require.NoError(t, err) defer resp.Body.Close() // Should get 429 Too Many Requests assert.Equal(t, 429, resp.StatusCode, "Request should be rate limited") // Check response body body, err := io.ReadAll(resp.Body) require.NoError(t, err) t.Logf("Rate limit response: %s", string(body)) // Should contain error code 1003 assert.Contains(t, string(body), `"code":1003`, "Response should have too many requests error code") // Message is in Chinese: "请求过于频繁" assert.Contains(t, string(body), "请求过于频繁", "Response should have rate limit message") } // TestRateLimiter_ResetAfterExpiration tests that rate limit resets after window expiration func TestRateLimiter_ResetAfterExpiration(t *testing.T) { // Create app with short expiration for testing max := 3 expiration := 2 * time.Second app := setupRateLimiterTestApp(t, max, expiration) // Make requests up to the limit for i := 1; i <= max; i++ { req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", "192.168.1.101") resp, err := app.Test(req, -1) require.NoError(t, err) resp.Body.Close() assert.Equal(t, 200, resp.StatusCode, "Request %d should succeed", i) } // Next request should be rate limited req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", "192.168.1.101") resp, err := app.Test(req, -1) require.NoError(t, err) resp.Body.Close() assert.Equal(t, 429, resp.StatusCode, "Request should be rate limited") // Wait for rate limit window to expire t.Log("Waiting for rate limit window to reset...") time.Sleep(expiration + 500*time.Millisecond) // Request should succeed after reset req = httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", "192.168.1.101") resp, err = app.Test(req, -1) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, 200, resp.StatusCode, "Request should succeed after rate limit reset") body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Contains(t, string(body), `"code":0`, "Response should be successful after reset") } // TestRateLimiter_PerIPRateLimiting tests that different IPs have separate rate limits func TestRateLimiter_PerIPRateLimiting(t *testing.T) { max := 5 expiration := 1 * time.Minute // Test with multiple different IPs ips := []string{ "192.168.1.10", "192.168.1.20", "192.168.1.30", } for _, ip := range ips { ip := ip // Capture for closure t.Run(fmt.Sprintf("IP_%s", ip), func(t *testing.T) { // Create fresh app for each IP test to avoid shared limiter state freshApp := setupRateLimiterTestApp(t, max, expiration) // Each IP should be able to make 'max' successful requests for i := 1; i <= max; i++ { req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", ip) resp, err := freshApp.Test(req, -1) require.NoError(t, err) resp.Body.Close() assert.Equal(t, 200, resp.StatusCode, "IP %s request %d should succeed", ip, i) } // The next request for this IP should be rate limited req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", ip) resp, err := freshApp.Test(req, -1) require.NoError(t, err) resp.Body.Close() assert.Equal(t, 429, resp.StatusCode, "IP %s should be rate limited", ip) }) } } // TestRateLimiter_ConcurrentRequests tests rate limiter with concurrent requests from same IP func TestRateLimiter_ConcurrentRequests(t *testing.T) { // Create app with limit max := 10 expiration := 1 * time.Minute app := setupRateLimiterTestApp(t, max, expiration) // Make concurrent requests concurrentRequests := 15 results := make(chan int, concurrentRequests) for i := 0; i < concurrentRequests; i++ { go func() { req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", "192.168.1.200") resp, err := app.Test(req, -1) if err != nil { results <- 0 return } defer resp.Body.Close() results <- resp.StatusCode }() } // Collect results var successCount, rateLimitedCount int for i := 0; i < concurrentRequests; i++ { status := <-results if status == 200 { successCount++ } else if status == 429 { rateLimitedCount++ } } t.Logf("Concurrent requests: %d success, %d rate limited", successCount, rateLimitedCount) // Should have exactly 'max' successful requests assert.Equal(t, max, successCount, "Should have exactly max successful requests") // Remaining requests should be rate limited assert.Equal(t, concurrentRequests-max, rateLimitedCount, "Remaining requests should be rate limited") } // TestRateLimiter_DifferentLimits tests rate limiter configuration with different limits func TestRateLimiter_DifferentLimits(t *testing.T) { tests := []struct { name string max int expiration time.Duration }{ { name: "low_limit", max: 2, expiration: 1 * time.Minute, }, { name: "medium_limit", max: 10, expiration: 1 * time.Minute, }, { name: "high_limit", max: 100, expiration: 1 * time.Minute, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { app := setupRateLimiterTestApp(t, tt.max, tt.expiration) // Make requests up to limit for i := 1; i <= tt.max; i++ { req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", fmt.Sprintf("192.168.1.%d", 50+i)) resp, err := app.Test(req, -1) require.NoError(t, err) resp.Body.Close() assert.Equal(t, 200, resp.StatusCode) } // Next request should be rate limited req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", fmt.Sprintf("192.168.1.%d", 50)) resp, err := app.Test(req, -1) require.NoError(t, err) resp.Body.Close() assert.Equal(t, 429, resp.StatusCode, "Should be rate limited after %d requests", tt.max) }) } } // TestRateLimiter_ShortWindow tests rate limiter with very short time window func TestRateLimiter_ShortWindow(t *testing.T) { // Create app with short window max := 3 expiration := 1 * time.Second app := setupRateLimiterTestApp(t, max, expiration) // Make first batch of requests for i := 1; i <= max; i++ { req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", "192.168.1.250") resp, err := app.Test(req, -1) require.NoError(t, err) resp.Body.Close() assert.Equal(t, 200, resp.StatusCode) } // Should be rate limited now req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", "192.168.1.250") resp, err := app.Test(req, -1) require.NoError(t, err) resp.Body.Close() assert.Equal(t, 429, resp.StatusCode) // Wait for window to expire time.Sleep(expiration + 200*time.Millisecond) // Should be able to make requests again for i := 1; i <= max; i++ { req := httptest.NewRequest("GET", "/api/v1/test", nil) req.Header.Set("X-Forwarded-For", "192.168.1.250") resp, err := app.Test(req, -1) require.NoError(t, err) resp.Body.Close() assert.Equal(t, 200, resp.StatusCode, "Request %d should succeed after window reset", i) } }