Files
Primora/apps/backend/internal/middleware/ratelimit.go
T
2026-04-10 12:03:31 +02:00

152 lines
3.1 KiB
Go

package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// RateLimiter implements a simple token bucket rate limiter
type RateLimiter struct {
visitors map[string]*visitor
mu sync.RWMutex
rate int // requests per window
window time.Duration // time window
}
type visitor struct {
tokens int
lastSeen time.Time
mu sync.Mutex
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(rate int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
visitors: make(map[string]*visitor),
rate: rate,
window: window,
}
// Cleanup old visitors every minute
go rl.cleanupVisitors()
return rl
}
// Middleware returns a Gin middleware function
func (rl *RateLimiter) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
if !rl.allow(ip) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Rate limit exceeded"})
return
}
c.Next()
}
}
// allow checks if a request from the given IP is allowed
func (rl *RateLimiter) allow(ip string) bool {
rl.mu.Lock()
v, exists := rl.visitors[ip]
if !exists {
v = &visitor{
tokens: rl.rate,
lastSeen: time.Now(),
}
rl.visitors[ip] = v
}
rl.mu.Unlock()
v.mu.Lock()
defer v.mu.Unlock()
// Refill tokens based on time passed
now := time.Now()
elapsed := now.Sub(v.lastSeen)
v.lastSeen = now
// Add tokens based on elapsed time
tokensToAdd := int(elapsed / rl.window * time.Duration(rl.rate))
v.tokens += tokensToAdd
if v.tokens > rl.rate {
v.tokens = rl.rate
}
// Check if we have tokens available
if v.tokens > 0 {
v.tokens--
return true
}
return false
}
// cleanupVisitors removes old visitors periodically
func (rl *RateLimiter) cleanupVisitors() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
for ip, v := range rl.visitors {
v.mu.Lock()
if time.Since(v.lastSeen) > rl.window*2 {
delete(rl.visitors, ip)
}
v.mu.Unlock()
}
rl.mu.Unlock()
}
}
// RateLimitByKey implements rate limiting by custom key (e.g., user ID, API key)
type KeyRateLimiter struct {
limiters map[string]*RateLimiter
mu sync.RWMutex
rate int
window time.Duration
}
// NewKeyRateLimiter creates a new key-based rate limiter
func NewKeyRateLimiter(rate int, window time.Duration) *KeyRateLimiter {
return &KeyRateLimiter{
limiters: make(map[string]*RateLimiter),
rate: rate,
window: window,
}
}
// Middleware returns a Gin middleware function that rate limits by a custom key
func (krl *KeyRateLimiter) Middleware(keyFunc func(*gin.Context) string) gin.HandlerFunc {
return func(c *gin.Context) {
key := keyFunc(c)
if key == "" {
c.Next()
return
}
krl.mu.RLock()
limiter, exists := krl.limiters[key]
krl.mu.RUnlock()
if !exists {
krl.mu.Lock()
limiter = NewRateLimiter(krl.rate, krl.window)
krl.limiters[key] = limiter
krl.mu.Unlock()
}
if !limiter.allow(key) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "Rate limit exceeded for this resource"})
return
}
c.Next()
}
}