mirror of
https://github.com/Dvorinka/SpotifyRecAlg.git
synced 2026-06-05 04:53:02 +00:00
first commit
This commit is contained in:
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
Rate Limiting using DragonflyDB.
|
||||
|
||||
Provides distributed rate limiting using DragonflyDB's atomic INCR command
|
||||
with automatic key expiration. This is more efficient than in-memory rate
|
||||
limiting for distributed deployments and provides persistence across restarts.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from swingmusic.db.dragonfly_client import get_dragonfly_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Token bucket / sliding window rate limiter using DragonflyDB.
|
||||
|
||||
Uses atomic Redis operations (INCR, EXPIRE) to implement rate limiting
|
||||
that works across multiple server instances.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
if self._client is None:
|
||||
self._client = get_dragonfly_client()
|
||||
return self._client
|
||||
|
||||
def _get_key(self, identifier: str, action: str) -> str:
|
||||
"""Get the Redis key for a rate limit counter."""
|
||||
return f"ratelimit:{action}:{identifier}"
|
||||
|
||||
def _get_window_key(self, identifier: str, action: str, window: int) -> str:
|
||||
"""Get the Redis key for a sliding window rate limit."""
|
||||
current_window = int(time.time() // window)
|
||||
return f"ratelimit:{action}:{identifier}:{current_window}"
|
||||
|
||||
def is_allowed(
|
||||
self, identifier: str, action: str, max_requests: int, window_seconds: int = 60
|
||||
) -> tuple[bool, int, int]:
|
||||
"""
|
||||
Check if a request is allowed under the rate limit.
|
||||
|
||||
Uses a sliding window algorithm with DragonflyDB.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier (e.g., user ID, IP address)
|
||||
action: The action being rate limited (e.g., "login", "download")
|
||||
max_requests: Maximum number of requests allowed in the window
|
||||
window_seconds: Time window in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, current_count, retry_after_seconds)
|
||||
"""
|
||||
if not self.client.is_available():
|
||||
# If DragonflyDB is not available, allow the request
|
||||
return True, 0, 0
|
||||
|
||||
try:
|
||||
key = self._get_window_key(identifier, action, window_seconds)
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
pipe = self.client.client.pipeline()
|
||||
|
||||
# Increment counter
|
||||
pipe.incr(key)
|
||||
|
||||
# Set expiry on first request (only if key is new)
|
||||
pipe.expire(key, window_seconds, nx=True)
|
||||
|
||||
results = pipe.execute()
|
||||
current_count = results[0]
|
||||
|
||||
if current_count <= max_requests:
|
||||
return True, current_count, 0
|
||||
else:
|
||||
# Calculate retry after
|
||||
ttl = self.client.client.ttl(key)
|
||||
retry_after = max(1, ttl) if ttl > 0 else window_seconds
|
||||
return False, current_count, retry_after
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limit check failed: {e}")
|
||||
# On error, allow the request
|
||||
return True, 0, 0
|
||||
|
||||
def increment(self, identifier: str, action: str, window_seconds: int = 60) -> int:
|
||||
"""
|
||||
Increment the counter for an action without checking the limit.
|
||||
|
||||
Useful for tracking usage without enforcing limits.
|
||||
|
||||
Returns:
|
||||
The new counter value
|
||||
"""
|
||||
if not self.client.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
key = self._get_window_key(identifier, action, window_seconds)
|
||||
|
||||
pipe = self.client.client.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, window_seconds, nx=True)
|
||||
|
||||
results = pipe.execute()
|
||||
return results[0]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limit increment failed: {e}")
|
||||
return 0
|
||||
|
||||
def get_count(self, identifier: str, action: str, window_seconds: int = 60) -> int:
|
||||
"""Get the current count for an action in the current window."""
|
||||
if not self.client.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
key = self._get_window_key(identifier, action, window_seconds)
|
||||
value = self.client.get(key)
|
||||
return int(value) if value else 0
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def reset(self, identifier: str, action: str) -> bool:
|
||||
"""Reset the rate limit counter for an identifier and action."""
|
||||
if not self.client.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
# Delete all windows for this identifier/action
|
||||
pattern = f"ratelimit:{action}:{identifier}:*"
|
||||
keys = self.client.client.keys(pattern)
|
||||
|
||||
if keys:
|
||||
self.client.client.delete(*keys)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limit reset failed: {e}")
|
||||
return False
|
||||
|
||||
def get_remaining(
|
||||
self, identifier: str, action: str, max_requests: int, window_seconds: int = 60
|
||||
) -> int:
|
||||
"""Get the number of remaining requests allowed."""
|
||||
current = self.get_count(identifier, action, window_seconds)
|
||||
return max(0, max_requests - current)
|
||||
|
||||
|
||||
class LoginRateLimiter(RateLimiter):
|
||||
"""Rate limiter specifically for login attempts."""
|
||||
|
||||
# Default: 10 login attempts per minute
|
||||
MAX_ATTEMPTS = 10
|
||||
WINDOW_SECONDS = 60
|
||||
|
||||
def check_login(self, identifier: str) -> tuple[bool, int, int]:
|
||||
"""Check if login is allowed for the given identifier."""
|
||||
return self.is_allowed(
|
||||
identifier, "login", self.MAX_ATTEMPTS, self.WINDOW_SECONDS
|
||||
)
|
||||
|
||||
def record_failed_login(self, identifier: str) -> int:
|
||||
"""Record a failed login attempt."""
|
||||
return self.increment(identifier, "login", self.WINDOW_SECONDS)
|
||||
|
||||
def clear_failed_logins(self, identifier: str) -> bool:
|
||||
"""Clear failed login attempts after successful login."""
|
||||
return self.reset(identifier, "login")
|
||||
|
||||
|
||||
class DownloadRateLimiter(RateLimiter):
|
||||
"""Rate limiter specifically for downloads."""
|
||||
|
||||
# Default: 100 downloads per hour
|
||||
MAX_DOWNLOADS = 100
|
||||
WINDOW_SECONDS = 3600
|
||||
|
||||
def check_download(self, user_id: int) -> tuple[bool, int, int]:
|
||||
"""Check if download is allowed for the given user."""
|
||||
return self.is_allowed(
|
||||
str(user_id), "download", self.MAX_DOWNLOADS, self.WINDOW_SECONDS
|
||||
)
|
||||
|
||||
def record_download(self, user_id: int) -> int:
|
||||
"""Record a download."""
|
||||
return self.increment(str(user_id), "download", self.WINDOW_SECONDS)
|
||||
|
||||
|
||||
class APIRateLimiter(RateLimiter):
|
||||
"""Rate limiter for general API endpoints."""
|
||||
|
||||
# Default: 100 requests per minute per user
|
||||
MAX_REQUESTS = 100
|
||||
WINDOW_SECONDS = 60
|
||||
|
||||
def check_api_request(self, identifier: str) -> tuple[bool, int, int]:
|
||||
"""Check if API request is allowed."""
|
||||
return self.is_allowed(
|
||||
identifier, "api", self.MAX_REQUESTS, self.WINDOW_SECONDS
|
||||
)
|
||||
|
||||
|
||||
# Global instances
|
||||
rate_limiter = RateLimiter()
|
||||
login_rate_limiter = LoginRateLimiter()
|
||||
download_rate_limiter = DownloadRateLimiter()
|
||||
api_rate_limiter = APIRateLimiter()
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""Get the global rate limiter instance."""
|
||||
return rate_limiter
|
||||
|
||||
|
||||
def get_login_rate_limiter() -> LoginRateLimiter:
|
||||
"""Get the global login rate limiter instance."""
|
||||
return login_rate_limiter
|
||||
|
||||
|
||||
def get_download_rate_limiter() -> DownloadRateLimiter:
|
||||
"""Get the global download rate limiter instance."""
|
||||
return download_rate_limiter
|
||||
|
||||
|
||||
def get_api_rate_limiter() -> APIRateLimiter:
|
||||
"""Get the global API rate limiter instance."""
|
||||
return api_rate_limiter
|
||||
Reference in New Issue
Block a user