package middleware import ( "net/http" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) func TestRateLimiter(t *testing.T) { limiter := NewRateLimiter(5, time.Second) // Should allow first 5 requests for i := 0; i < 5; i++ { assert.True(t, limiter.Allow("test-key"), "Request %d should be allowed", i+1) } // Should block 6th request assert.False(t, limiter.Allow("test-key"), "6th request should be blocked") // Wait for window to reset time.Sleep(time.Second + 100*time.Millisecond) // Should allow requests again assert.True(t, limiter.Allow("test-key"), "Request after reset should be allowed") } func TestRateLimitMiddleware(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(RateLimit(2, time.Second)) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) // First request should succeed req1 := httptest.NewRequest(http.MethodGet, "/test", nil) w1 := httptest.NewRecorder() router.ServeHTTP(w1, req1) assert.Equal(t, http.StatusOK, w1.Code) // Second request should succeed req2 := httptest.NewRequest(http.MethodGet, "/test", nil) w2 := httptest.NewRecorder() router.ServeHTTP(w2, req2) assert.Equal(t, http.StatusOK, w2.Code) // Third request should be rate limited req3 := httptest.NewRequest(http.MethodGet, "/test", nil) w3 := httptest.NewRecorder() router.ServeHTTP(w3, req3) assert.Equal(t, http.StatusTooManyRequests, w3.Code) } func TestRateLimitByUser(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(func(c *gin.Context) { c.Set("user_id", "user123") c.Next() }) router.Use(RateLimitByUser(2, time.Second)) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) // First two requests should succeed for i := 0; i < 2; i++ { req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) } // Third request should be rate limited req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusTooManyRequests, w.Code) }