mirror of
https://github.com/Dvorinka/SpotifyRecAlg.git
synced 2026-06-03 20:13:03 +00:00
235 lines
7.4 KiB
Python
235 lines
7.4 KiB
Python
"""
|
|
Rate Limiting using DragonflyDB.
|
|
|
|
Provides distributed rate limiting using DragonflyDB's atomic INCR command
|
|
with automatic key expiration. This is more efficient than in-memory rate
|
|
limiting for distributed deployments and provides persistence across restarts.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
|
|
from swingmusic.db.dragonfly_client import get_dragonfly_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimiter:
|
|
"""
|
|
Token bucket / sliding window rate limiter using DragonflyDB.
|
|
|
|
Uses atomic Redis operations (INCR, EXPIRE) to implement rate limiting
|
|
that works across multiple server instances.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._client = None
|
|
|
|
@property
|
|
def client(self):
|
|
if self._client is None:
|
|
self._client = get_dragonfly_client()
|
|
return self._client
|
|
|
|
def _get_key(self, identifier: str, action: str) -> str:
|
|
"""Get the Redis key for a rate limit counter."""
|
|
return f"ratelimit:{action}:{identifier}"
|
|
|
|
def _get_window_key(self, identifier: str, action: str, window: int) -> str:
|
|
"""Get the Redis key for a sliding window rate limit."""
|
|
current_window = int(time.time() // window)
|
|
return f"ratelimit:{action}:{identifier}:{current_window}"
|
|
|
|
def is_allowed(
|
|
self, identifier: str, action: str, max_requests: int, window_seconds: int = 60
|
|
) -> tuple[bool, int, int]:
|
|
"""
|
|
Check if a request is allowed under the rate limit.
|
|
|
|
Uses a sliding window algorithm with DragonflyDB.
|
|
|
|
Args:
|
|
identifier: Unique identifier (e.g., user ID, IP address)
|
|
action: The action being rate limited (e.g., "login", "download")
|
|
max_requests: Maximum number of requests allowed in the window
|
|
window_seconds: Time window in seconds
|
|
|
|
Returns:
|
|
Tuple of (is_allowed, current_count, retry_after_seconds)
|
|
"""
|
|
if not self.client.is_available():
|
|
# If DragonflyDB is not available, allow the request
|
|
return True, 0, 0
|
|
|
|
try:
|
|
key = self._get_window_key(identifier, action, window_seconds)
|
|
|
|
# Use pipeline for atomic operations
|
|
pipe = self.client.client.pipeline()
|
|
|
|
# Increment counter
|
|
pipe.incr(key)
|
|
|
|
# Set expiry on first request (only if key is new)
|
|
pipe.expire(key, window_seconds, nx=True)
|
|
|
|
results = pipe.execute()
|
|
current_count = results[0]
|
|
|
|
if current_count <= max_requests:
|
|
return True, current_count, 0
|
|
else:
|
|
# Calculate retry after
|
|
ttl = self.client.client.ttl(key)
|
|
retry_after = max(1, ttl) if ttl > 0 else window_seconds
|
|
return False, current_count, retry_after
|
|
|
|
except Exception as e:
|
|
logger.error(f"Rate limit check failed: {e}")
|
|
# On error, allow the request
|
|
return True, 0, 0
|
|
|
|
def increment(self, identifier: str, action: str, window_seconds: int = 60) -> int:
|
|
"""
|
|
Increment the counter for an action without checking the limit.
|
|
|
|
Useful for tracking usage without enforcing limits.
|
|
|
|
Returns:
|
|
The new counter value
|
|
"""
|
|
if not self.client.is_available():
|
|
return 0
|
|
|
|
try:
|
|
key = self._get_window_key(identifier, action, window_seconds)
|
|
|
|
pipe = self.client.client.pipeline()
|
|
pipe.incr(key)
|
|
pipe.expire(key, window_seconds, nx=True)
|
|
|
|
results = pipe.execute()
|
|
return results[0]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Rate limit increment failed: {e}")
|
|
return 0
|
|
|
|
def get_count(self, identifier: str, action: str, window_seconds: int = 60) -> int:
|
|
"""Get the current count for an action in the current window."""
|
|
if not self.client.is_available():
|
|
return 0
|
|
|
|
try:
|
|
key = self._get_window_key(identifier, action, window_seconds)
|
|
value = self.client.get(key)
|
|
return int(value) if value else 0
|
|
except Exception:
|
|
return 0
|
|
|
|
def reset(self, identifier: str, action: str) -> bool:
|
|
"""Reset the rate limit counter for an identifier and action."""
|
|
if not self.client.is_available():
|
|
return False
|
|
|
|
try:
|
|
# Delete all windows for this identifier/action
|
|
pattern = f"ratelimit:{action}:{identifier}:*"
|
|
keys = self.client.client.keys(pattern)
|
|
|
|
if keys:
|
|
self.client.client.delete(*keys)
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Rate limit reset failed: {e}")
|
|
return False
|
|
|
|
def get_remaining(
|
|
self, identifier: str, action: str, max_requests: int, window_seconds: int = 60
|
|
) -> int:
|
|
"""Get the number of remaining requests allowed."""
|
|
current = self.get_count(identifier, action, window_seconds)
|
|
return max(0, max_requests - current)
|
|
|
|
|
|
class LoginRateLimiter(RateLimiter):
|
|
"""Rate limiter specifically for login attempts."""
|
|
|
|
# Default: 10 login attempts per minute
|
|
MAX_ATTEMPTS = 10
|
|
WINDOW_SECONDS = 60
|
|
|
|
def check_login(self, identifier: str) -> tuple[bool, int, int]:
|
|
"""Check if login is allowed for the given identifier."""
|
|
return self.is_allowed(
|
|
identifier, "login", self.MAX_ATTEMPTS, self.WINDOW_SECONDS
|
|
)
|
|
|
|
def record_failed_login(self, identifier: str) -> int:
|
|
"""Record a failed login attempt."""
|
|
return self.increment(identifier, "login", self.WINDOW_SECONDS)
|
|
|
|
def clear_failed_logins(self, identifier: str) -> bool:
|
|
"""Clear failed login attempts after successful login."""
|
|
return self.reset(identifier, "login")
|
|
|
|
|
|
class DownloadRateLimiter(RateLimiter):
|
|
"""Rate limiter specifically for downloads."""
|
|
|
|
# Default: 100 downloads per hour
|
|
MAX_DOWNLOADS = 100
|
|
WINDOW_SECONDS = 3600
|
|
|
|
def check_download(self, user_id: int) -> tuple[bool, int, int]:
|
|
"""Check if download is allowed for the given user."""
|
|
return self.is_allowed(
|
|
str(user_id), "download", self.MAX_DOWNLOADS, self.WINDOW_SECONDS
|
|
)
|
|
|
|
def record_download(self, user_id: int) -> int:
|
|
"""Record a download."""
|
|
return self.increment(str(user_id), "download", self.WINDOW_SECONDS)
|
|
|
|
|
|
class APIRateLimiter(RateLimiter):
|
|
"""Rate limiter for general API endpoints."""
|
|
|
|
# Default: 100 requests per minute per user
|
|
MAX_REQUESTS = 100
|
|
WINDOW_SECONDS = 60
|
|
|
|
def check_api_request(self, identifier: str) -> tuple[bool, int, int]:
|
|
"""Check if API request is allowed."""
|
|
return self.is_allowed(
|
|
identifier, "api", self.MAX_REQUESTS, self.WINDOW_SECONDS
|
|
)
|
|
|
|
|
|
# Global instances
|
|
rate_limiter = RateLimiter()
|
|
login_rate_limiter = LoginRateLimiter()
|
|
download_rate_limiter = DownloadRateLimiter()
|
|
api_rate_limiter = APIRateLimiter()
|
|
|
|
|
|
def get_rate_limiter() -> RateLimiter:
|
|
"""Get the global rate limiter instance."""
|
|
return rate_limiter
|
|
|
|
|
|
def get_login_rate_limiter() -> LoginRateLimiter:
|
|
"""Get the global login rate limiter instance."""
|
|
return login_rate_limiter
|
|
|
|
|
|
def get_download_rate_limiter() -> DownloadRateLimiter:
|
|
"""Get the global download rate limiter instance."""
|
|
return download_rate_limiter
|
|
|
|
|
|
def get_api_rate_limiter() -> APIRateLimiter:
|
|
"""Get the global API rate limiter instance."""
|
|
return api_rate_limiter
|