Files
Containr/app/backend/internal/middleware/ratelimit.go
T
2026-04-10 12:02:36 +02:00

168 lines
3.2 KiB
Go

package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// RateLimiter implements a token bucket rate limiter
type RateLimiter struct {
mu sync.RWMutex
buckets map[string]*bucket
rate int // requests per window
window time.Duration // time window
cleanupInterval time.Duration
}
type bucket struct {
tokens int
lastReset time.Time
mu sync.Mutex
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(rate int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
buckets: make(map[string]*bucket),
rate: rate,
window: window,
cleanupInterval: window * 2,
}
// Start cleanup goroutine
go rl.cleanup()
return rl
}
// Allow checks if a request should be allowed
func (rl *RateLimiter) Allow(key string) bool {
rl.mu.RLock()
b, exists := rl.buckets[key]
rl.mu.RUnlock()
if !exists {
rl.mu.Lock()
b = &bucket{
tokens: rl.rate,
lastReset: time.Now(),
}
rl.buckets[key] = b
rl.mu.Unlock()
}
b.mu.Lock()
defer b.mu.Unlock()
// Reset bucket if window has passed
if time.Since(b.lastReset) > rl.window {
b.tokens = rl.rate
b.lastReset = time.Now()
}
if b.tokens > 0 {
b.tokens--
return true
}
return false
}
// cleanup removes old buckets
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(rl.cleanupInterval)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
now := time.Now()
for key, b := range rl.buckets {
b.mu.Lock()
if now.Sub(b.lastReset) > rl.cleanupInterval {
delete(rl.buckets, key)
}
b.mu.Unlock()
}
rl.mu.Unlock()
}
}
// RateLimit returns a middleware that limits requests
func RateLimit(rate int, window time.Duration) gin.HandlerFunc {
limiter := NewRateLimiter(rate, window)
return func(c *gin.Context) {
// Use IP address as key
key := c.ClientIP()
// For authenticated requests, use user ID if available
if userID, exists := c.Get("user_id"); exists {
if uid, ok := userID.(string); ok && uid != "" {
key = "user:" + uid
}
}
if !limiter.Allow(key) {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "rate_limit_exceeded",
"message": "Too many requests. Please try again later.",
})
c.Abort()
return
}
c.Next()
}
}
// RateLimitByIP limits requests by IP address
func RateLimitByIP(rate int, window time.Duration) gin.HandlerFunc {
limiter := NewRateLimiter(rate, window)
return func(c *gin.Context) {
if !limiter.Allow(c.ClientIP()) {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "rate_limit_exceeded",
"message": "Too many requests from your IP. Please try again later.",
})
c.Abort()
return
}
c.Next()
}
}
// RateLimitByUser limits requests by authenticated user
func RateLimitByUser(rate int, window time.Duration) gin.HandlerFunc {
limiter := NewRateLimiter(rate, window)
return func(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.Next()
return
}
uid, ok := userID.(string)
if !ok || uid == "" {
c.Next()
return
}
if !limiter.Allow("user:" + uid) {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "rate_limit_exceeded",
"message": "Too many requests. Please slow down.",
})
c.Abort()
return
}
c.Next()
}
}