mirror of
https://github.com/Dvorinka/SpotifyRecAlg.git
synced 2026-06-04 12:33:03 +00:00
first commit
This commit is contained in:
@@ -0,0 +1,66 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import (
|
||||
delete,
|
||||
func,
|
||||
insert,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass
|
||||
|
||||
from swingmusic.db.engine import DbEngine
|
||||
|
||||
|
||||
class Base(MappedAsDataclass, DeclarativeBase):
|
||||
"""
|
||||
Base class for all database models.
|
||||
|
||||
It has methods common to all tables. eg. `insert_one`, `insert_many`, `remove_all`, `remove_one`, `all`, `count`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def execute(cls, stmt: Any, commit: bool = False):
|
||||
with DbEngine.manager(commit=commit) as session:
|
||||
result = session.execute(stmt.execution_options(yield_per=100))
|
||||
|
||||
if commit:
|
||||
session.commit()
|
||||
|
||||
yield result
|
||||
|
||||
@classmethod
|
||||
def insert_many(cls, items: list[dict[str, Any]]):
|
||||
"""
|
||||
Inserts multiple items into the database.
|
||||
"""
|
||||
return next(cls.execute(insert(cls).values(items), commit=True))
|
||||
|
||||
@classmethod
|
||||
def insert_one(cls, item: dict[str, Any]):
|
||||
"""
|
||||
Inserts a single item into the database.
|
||||
"""
|
||||
return cls.insert_many([item])
|
||||
|
||||
@classmethod
|
||||
def remove_all(cls):
|
||||
return next(cls.execute(delete(cls), commit=True))
|
||||
|
||||
@classmethod
|
||||
def remove_one(cls, id: int):
|
||||
return next(cls.execute(delete(cls).where(cls.id == id), commit=True))
|
||||
|
||||
@classmethod
|
||||
def all(cls):
|
||||
return next(cls.execute(select(cls).execution_options(yield_per=100)))
|
||||
|
||||
@classmethod
|
||||
def count(cls):
|
||||
return next(cls.execute(select(func.count()).select_from(cls))).scalar()
|
||||
|
||||
|
||||
def create_all_tables():
|
||||
"""
|
||||
Creates all the tables that build on the Base class.
|
||||
"""
|
||||
Base().metadata.create_all(DbEngine.engine)
|
||||
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
Native DragonflyDB Client for SwingMusic
|
||||
|
||||
Integrated as a native database service like SQLite, providing:
|
||||
- Ultra-fast caching for all services
|
||||
- Session management
|
||||
- User preferences
|
||||
- Temporary data storage
|
||||
- Real-time features
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DragonflyDBClient:
|
||||
"""
|
||||
Native DragonflyDB client integrated into SwingMusic
|
||||
Provides Redis-compatible operations with automatic fallback
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str | None = None,
|
||||
port: int | None = None,
|
||||
db: int | None = None,
|
||||
):
|
||||
self.host = host or os.environ.get("DRAGONFLYDB_HOST", "localhost")
|
||||
self.port = port or int(os.environ.get("DRAGONFLYDB_PORT", "6379"))
|
||||
self.db = db if db is not None else int(os.environ.get("DRAGONFLYDB_DB", "0"))
|
||||
self.client = None
|
||||
self.available = False
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
"""Connect to DragonflyDB with fallback handling"""
|
||||
try:
|
||||
import redis
|
||||
|
||||
self.client = redis.Redis(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
db=self.db,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=2,
|
||||
socket_timeout=2,
|
||||
retry_on_timeout=True,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
# Test connection
|
||||
self.client.ping()
|
||||
self.available = True
|
||||
logger.info(f"✅ DragonflyDB connected at {self.host}:{self.port}")
|
||||
|
||||
except ImportError:
|
||||
logger.warning("❌ Redis library not installed, DragonflyDB unavailable")
|
||||
self.available = False
|
||||
except Exception as e:
|
||||
logger.warning(f"❌ DragonflyDB connection failed: {e}")
|
||||
self.available = False
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if DragonflyDB is available"""
|
||||
if not self.available or not self.client:
|
||||
self._connect()
|
||||
if not self.available or not self.client:
|
||||
return False
|
||||
|
||||
try:
|
||||
self.client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
self.available = False
|
||||
return False
|
||||
|
||||
def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
|
||||
"""Set a key-value pair with optional TTL"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
serialized_value = (
|
||||
json.dumps(value) if not isinstance(value, str) else value
|
||||
)
|
||||
|
||||
if ttl:
|
||||
return self.client.setex(key, ttl, serialized_value)
|
||||
else:
|
||||
return self.client.set(key, serialized_value)
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB set failed: {e}")
|
||||
return False
|
||||
|
||||
def get(self, key: str) -> Any | None:
|
||||
"""Get a value by key"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
value = self.client.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Try to deserialize as JSON
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB get failed: {e}")
|
||||
return None
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete a key"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
return bool(self.client.delete(key))
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB delete failed: {e}")
|
||||
return False
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""Check if key exists"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
return bool(self.client.exists(key))
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB exists failed: {e}")
|
||||
return False
|
||||
|
||||
def expire(self, key: str, ttl: int) -> bool:
|
||||
"""Set TTL for existing key"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
return bool(self.client.expire(key, ttl))
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB expire failed: {e}")
|
||||
return False
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
"""Get TTL for key"""
|
||||
if not self.is_available():
|
||||
return -1
|
||||
|
||||
try:
|
||||
return self.client.ttl(key)
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB ttl failed: {e}")
|
||||
return -1
|
||||
|
||||
def keys(self, pattern: str = "*") -> list[str]:
|
||||
"""Get keys matching pattern"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
try:
|
||||
return self.client.keys(pattern)
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB keys failed: {e}")
|
||||
return []
|
||||
|
||||
def incr(self, key: str, amount: int = 1) -> int:
|
||||
"""Increment value by amount"""
|
||||
if not self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
return self.client.incr(key, amount)
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB incr failed: {e}")
|
||||
return 0
|
||||
|
||||
def lpush(self, key: str, *values) -> int:
|
||||
"""Push values to left of list"""
|
||||
if not self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
return self.client.lpush(key, *values)
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB lpush failed: {e}")
|
||||
return 0
|
||||
|
||||
def rpop(self, key: str) -> str | None:
|
||||
"""Pop value from right of list"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
return self.client.rpop(key)
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB rpop failed: {e}")
|
||||
return None
|
||||
|
||||
def lrange(self, key: str, start: int, end: int) -> list[str]:
|
||||
"""Get range of list elements"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
try:
|
||||
return self.client.lrange(key, start, end)
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB lrange failed: {e}")
|
||||
return []
|
||||
|
||||
def llen(self, key: str) -> int:
|
||||
"""Get length of list"""
|
||||
if not self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
return self.client.llen(key)
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB llen failed: {e}")
|
||||
return 0
|
||||
|
||||
def lrem(self, key: str, count: int, value: str) -> int:
|
||||
"""Remove elements from list"""
|
||||
if not self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
return self.client.lrem(key, count, value)
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB lrem failed: {e}")
|
||||
return 0
|
||||
|
||||
def ltrim(self, key: str, start: int, end: int) -> bool:
|
||||
"""Trim list to range"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
return self.client.ltrim(key, start, end)
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB ltrim failed: {e}")
|
||||
return False
|
||||
|
||||
def flushdb(self) -> bool:
|
||||
"""Clear all keys in current database"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
return self.client.flushdb()
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB flushdb failed: {e}")
|
||||
return False
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Ping DragonflyDB."""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
return bool(self.client.ping())
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB ping failed: {e}")
|
||||
self.available = False
|
||||
return False
|
||||
|
||||
def info(self) -> dict[str, Any]:
|
||||
"""Get DragonflyDB server info"""
|
||||
if not self.is_available():
|
||||
return {}
|
||||
|
||||
try:
|
||||
info = self.client.info()
|
||||
return {
|
||||
"version": info.get("redis_version", "unknown"),
|
||||
"used_memory": info.get("used_memory", 0),
|
||||
"used_memory_human": info.get("used_memory_human", "0B"),
|
||||
"connected_clients": info.get("connected_clients", 0),
|
||||
"total_commands_processed": info.get("total_commands_processed", 0),
|
||||
"keyspace_hits": info.get("keyspace_hits", 0),
|
||||
"keyspace_misses": info.get("keyspace_misses", 0),
|
||||
"uptime_in_seconds": info.get("uptime_in_seconds", 0),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"DragonflyDB info failed: {e}")
|
||||
return {}
|
||||
|
||||
def close(self):
|
||||
"""Close DragonflyDB connection"""
|
||||
if self.client:
|
||||
try:
|
||||
self.client.close()
|
||||
logger.info("DragonflyDB connection closed")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Global DragonflyDB instance (like SQLite)
|
||||
_dragonfly_client: DragonflyDBClient | None = None
|
||||
|
||||
|
||||
def get_dragonfly_client() -> DragonflyDBClient:
|
||||
"""Get the global DragonflyDB client instance"""
|
||||
global _dragonfly_client
|
||||
if _dragonfly_client is None:
|
||||
_dragonfly_client = DragonflyDBClient()
|
||||
return _dragonfly_client
|
||||
|
||||
|
||||
def init_dragonfly_if_available() -> bool:
|
||||
"""Initialize DragonflyDB if available"""
|
||||
client = get_dragonfly_client()
|
||||
return client.is_available()
|
||||
|
||||
|
||||
class DragonflyCache:
|
||||
"""High-level cache interface using DragonflyDB"""
|
||||
|
||||
def __init__(self, prefix: str = "swingmusic"):
|
||||
self.client = get_dragonfly_client()
|
||||
self.prefix = prefix
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""Create namespaced key"""
|
||||
return f"{self.prefix}:{key}"
|
||||
|
||||
def set(self, key: str, value: Any, ttl_hours: int = 12) -> bool:
|
||||
"""Set cache value with TTL in hours"""
|
||||
ttl_seconds = ttl_hours * 3600
|
||||
return self.client.set(self._make_key(key), value, ttl_seconds)
|
||||
|
||||
def get(self, key: str) -> Any | None:
|
||||
"""Get cache value"""
|
||||
return self.client.get(self._make_key(key))
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete cache value"""
|
||||
return self.client.delete(self._make_key(key))
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""Check if cache value exists"""
|
||||
return self.client.exists(self._make_key(key))
|
||||
|
||||
def clear_all(self) -> bool:
|
||||
"""Clear all SwingMusic cache entries"""
|
||||
if not self.client.is_available():
|
||||
return False
|
||||
|
||||
keys = self.client.keys(f"{self.prefix}:*")
|
||||
if keys:
|
||||
return self.client.client.delete(*keys) > 0
|
||||
return True
|
||||
|
||||
|
||||
# Native cache instances for different purposes
|
||||
spotify_cache = DragonflyCache("spotify")
|
||||
session_cache = DragonflyCache("session")
|
||||
user_cache = DragonflyCache("user")
|
||||
temp_cache = DragonflyCache("temp")
|
||||
|
||||
|
||||
def get_spotify_cache() -> DragonflyCache:
|
||||
"""Get Spotify metadata cache"""
|
||||
return spotify_cache
|
||||
|
||||
|
||||
def get_session_cache() -> DragonflyCache:
|
||||
"""Get user session cache"""
|
||||
return session_cache
|
||||
|
||||
|
||||
def get_user_cache() -> DragonflyCache:
|
||||
"""Get user preferences cache"""
|
||||
return user_cache
|
||||
|
||||
|
||||
def get_temp_cache() -> DragonflyCache:
|
||||
"""Get temporary data cache"""
|
||||
return temp_cache
|
||||
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
Extended DragonflyDB Client for SwingMusic
|
||||
|
||||
Comprehensive caching system with 15+ cache services for:
|
||||
- Track metadata and persistence
|
||||
- User sessions and preferences
|
||||
- Mobile offline synchronization
|
||||
- Real-time features and analytics
|
||||
- Background job processing
|
||||
- Search and recommendations
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from swingmusic.db.dragonfly_client import DragonflyCache, get_dragonfly_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtendedDragonflyServices:
|
||||
"""
|
||||
Extended DragonflyDB services for complete SwingMusic integration
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = get_dragonfly_client()
|
||||
|
||||
# Core performance caches
|
||||
self.track_cache = DragonflyCache("tracks")
|
||||
self.artist_cache = DragonflyCache("artists")
|
||||
self.album_cache = DragonflyCache("albums")
|
||||
|
||||
# User experience caches
|
||||
self.session_cache = DragonflyCache("sessions")
|
||||
self.user_cache = DragonflyCache("users")
|
||||
self.search_cache = DragonflyCache("search")
|
||||
self.homepage_cache = DragonflyCache("homepage")
|
||||
|
||||
# Mobile and offline caches
|
||||
self.mobile_cache = DragonflyCache("mobile")
|
||||
self.sync_cache = DragonflyCache("sync")
|
||||
self.progress_cache = DragonflyCache("progress")
|
||||
self.playlist_cache = DragonflyCache("playlists")
|
||||
|
||||
# Real-time feature caches
|
||||
self.playcount_cache = DragonflyCache("playcounts")
|
||||
self.recent_cache = DragonflyCache("recent")
|
||||
self.favorite_cache = DragonflyCache("favorites")
|
||||
self.recommendation_cache = DragonflyCache("recommendations")
|
||||
|
||||
# Background processing caches
|
||||
self.job_cache = DragonflyCache("jobs")
|
||||
self.lyrics_cache = DragonflyCache("lyrics")
|
||||
self.index_cache = DragonflyCache("index")
|
||||
self.temp_cache = DragonflyCache("temp")
|
||||
|
||||
logger.info("Extended DragonflyDB services initialized")
|
||||
|
||||
|
||||
class TrackCacheService:
|
||||
"""High-performance track caching with persistence"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = DragonflyCache("tracks")
|
||||
|
||||
def get_track(self, trackhash: str) -> dict[str, Any] | None:
|
||||
"""Get track data from cache"""
|
||||
return self.cache.get(f"track:{trackhash}")
|
||||
|
||||
def set_track(
|
||||
self, trackhash: str, track_data: dict[str, Any], ttl_hours: int = 24
|
||||
):
|
||||
"""Cache track data"""
|
||||
return self.cache.set(f"track:{trackhash}", track_data, ttl_hours)
|
||||
|
||||
def get_track_batch(self, trackhashes: list[str]) -> dict[str, Any]:
|
||||
"""Get multiple tracks from cache"""
|
||||
results = {}
|
||||
for trackhash in trackhashes:
|
||||
track = self.get_track(trackhash)
|
||||
if track:
|
||||
results[trackhash] = track
|
||||
return results
|
||||
|
||||
def set_track_batch(self, tracks: dict[str, dict[str, Any]], ttl_hours: int = 24):
|
||||
"""Cache multiple tracks"""
|
||||
success_count = 0
|
||||
for trackhash, track_data in tracks.items():
|
||||
if self.set_track(trackhash, track_data, ttl_hours):
|
||||
success_count += 1
|
||||
return success_count
|
||||
|
||||
def invalidate_track(self, trackhash: str):
|
||||
"""Remove track from cache"""
|
||||
return self.cache.delete(f"track:{trackhash}")
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get track cache statistics"""
|
||||
keys = self.cache.client.keys("tracks:track:*")
|
||||
return {
|
||||
"total_tracks": len(keys),
|
||||
"memory_usage": self.cache.client.info().get(
|
||||
"used_memory_human", "Unknown"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class UserSessionService:
|
||||
"""Ultra-fast user session management"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = DragonflyCache("sessions")
|
||||
# Backward compatibility for older call sites.
|
||||
self.session_cache = self.cache
|
||||
|
||||
def create_session(
|
||||
self, session_token: str, user_data: dict[str, Any], ttl_hours: int = 24
|
||||
):
|
||||
"""Create user session"""
|
||||
return self.cache.set(f"session:{session_token}", user_data, ttl_hours)
|
||||
|
||||
def set_user_session(
|
||||
self, userid: int, user_data: dict[str, Any], ttl_seconds: int = 24 * 3600
|
||||
):
|
||||
"""Store latest session payload by user id for quick lookups."""
|
||||
ttl_hours = max(1, int(ttl_seconds // 3600))
|
||||
return self.cache.set(f"user_session:{userid}", user_data, ttl_hours)
|
||||
|
||||
def get_user_session(self, userid: int) -> dict[str, Any] | None:
|
||||
"""Get latest session payload for a user id."""
|
||||
return self.cache.get(f"user_session:{userid}")
|
||||
|
||||
def get_session(self, session_token: str) -> dict[str, Any] | None:
|
||||
"""Get user session"""
|
||||
return self.cache.get(f"session:{session_token}")
|
||||
|
||||
def refresh_session(self, session_token: str, ttl_hours: int = 24):
|
||||
"""Refresh session TTL"""
|
||||
return self.cache.expire(f"session:{session_token}", ttl_hours * 3600)
|
||||
|
||||
def invalidate_session(self, session_token: str):
|
||||
"""Invalidate user session"""
|
||||
return self.cache.delete(f"session:{session_token}")
|
||||
|
||||
def invalidate_user_session(self, userid: int):
|
||||
"""Invalidate latest session payload for a user id."""
|
||||
return self.cache.delete(f"user_session:{userid}")
|
||||
|
||||
def get_user_sessions(self, userid: int) -> list[str]:
|
||||
"""Get all active sessions for user"""
|
||||
pattern = "session:*"
|
||||
keys = self.cache.client.keys(pattern)
|
||||
user_sessions = []
|
||||
|
||||
for key in keys:
|
||||
session_data = self.cache.get(key.replace("session:", ""))
|
||||
if session_data and session_data.get("userid") == userid:
|
||||
user_sessions.append(key)
|
||||
|
||||
return user_sessions
|
||||
|
||||
|
||||
class MobileSyncService:
|
||||
"""Reliable mobile offline synchronization"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = DragonflyCache("mobile")
|
||||
|
||||
def queue_sync_action(self, userid: int, action: dict[str, Any]):
|
||||
"""Queue a sync action for mobile device"""
|
||||
queue_key = f"sync_queue:user:{userid}"
|
||||
return self.cache.client.lpush(queue_key, json.dumps(action))
|
||||
|
||||
def get_sync_actions(self, userid: int, count: int = 10) -> list[dict[str, Any]]:
|
||||
"""Get pending sync actions for user"""
|
||||
queue_key = f"sync_queue:user:{userid}"
|
||||
actions_data = self.cache.client.lrange(queue_key, 0, count - 1)
|
||||
|
||||
actions = []
|
||||
for action_data in actions_data:
|
||||
try:
|
||||
actions.append(json.loads(action_data))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return actions
|
||||
|
||||
def mark_sync_completed(self, userid: int, action_id: str):
|
||||
"""Mark sync action as completed"""
|
||||
# Remove from queue
|
||||
queue_key = f"sync_queue:user:{userid}"
|
||||
return self.cache.client.lrem(queue_key, 1, action_id)
|
||||
|
||||
def set_sync_state(self, userid: int, device_id: str, state: dict[str, Any]):
|
||||
"""Set device sync state"""
|
||||
state_key = f"sync_state:user:{userid}:device:{device_id}"
|
||||
return self.cache.set(state_key, state, ttl_hours=24)
|
||||
|
||||
def get_sync_state(self, userid: int, device_id: str) -> dict[str, Any] | None:
|
||||
"""Get device sync state"""
|
||||
state_key = f"sync_state:user:{userid}:device:{device_id}"
|
||||
return self.cache.get(state_key)
|
||||
|
||||
|
||||
class RealTimeFeaturesService:
|
||||
"""Real-time features like play counts and favorites"""
|
||||
|
||||
def __init__(self):
|
||||
self.playcount_cache = DragonflyCache("playcounts")
|
||||
self.recent_cache = DragonflyCache("recent")
|
||||
self.favorite_cache = DragonflyCache("favorites")
|
||||
|
||||
def increment_playcount(self, trackhash: str, userid: int | None = None):
|
||||
"""Increment track play count"""
|
||||
key = f"plays:{trackhash}"
|
||||
if userid:
|
||||
key = f"plays:user:{userid}:track:{trackhash}"
|
||||
|
||||
return self.playcount_cache.client.incr(key)
|
||||
|
||||
def get_playcount(self, trackhash: str, userid: int | None = None) -> int:
|
||||
"""Get track play count"""
|
||||
key = f"plays:{trackhash}"
|
||||
if userid:
|
||||
key = f"plays:user:{userid}:track:{trackhash}"
|
||||
|
||||
count = self.playcount_cache.client.get(key)
|
||||
return int(count) if count else 0
|
||||
|
||||
def add_to_recently_played(self, userid: int, trackhash: str, limit: int = 50):
|
||||
"""Add track to recently played list"""
|
||||
key = f"recent:user:{userid}"
|
||||
|
||||
# Add to beginning of list
|
||||
self.recent_cache.client.lpush(key, trackhash)
|
||||
|
||||
# Remove duplicates
|
||||
self.recent_cache.client.lrem(key, 1, trackhash)
|
||||
|
||||
# Add back to beginning
|
||||
self.recent_cache.client.lpush(key, trackhash)
|
||||
|
||||
# Limit list size
|
||||
self.recent_cache.client.ltrim(key, 0, limit - 1)
|
||||
|
||||
# Set TTL
|
||||
self.recent_cache.client.expire(key, 7 * 24 * 3600) # 7 days
|
||||
|
||||
def get_recently_played(self, userid: int, limit: int = 50) -> list[str]:
|
||||
"""Get recently played tracks for user"""
|
||||
key = f"recent:user:{userid}"
|
||||
return self.recent_cache.client.lrange(key, 0, limit - 1)
|
||||
|
||||
def toggle_favorite(self, userid: int, trackhash: str) -> bool:
|
||||
"""Toggle favorite status for track"""
|
||||
key = f"fav:user:{userid}:track:{trackhash}"
|
||||
|
||||
current = self.favorite_cache.client.get(key)
|
||||
if current:
|
||||
# Remove favorite
|
||||
self.favorite_cache.client.delete(key)
|
||||
return False
|
||||
else:
|
||||
# Add favorite
|
||||
self.favorite_cache.client.set(key, True, ttl_hours=24 * 30) # 30 days
|
||||
return True
|
||||
|
||||
def is_favorite(self, userid: int, trackhash: str) -> bool:
|
||||
"""Check if track is favorited by user"""
|
||||
key = f"fav:user:{userid}:track:{trackhash}"
|
||||
return bool(self.favorite_cache.client.get(key))
|
||||
|
||||
def get_user_favorites(self, userid: int) -> list[str]:
|
||||
"""Get all favorite tracks for user"""
|
||||
pattern = f"fav:user:{userid}:track:*"
|
||||
keys = self.favorite_cache.client.keys(pattern)
|
||||
|
||||
favorites = []
|
||||
for key in keys:
|
||||
trackhash = key.split(":")[-1]
|
||||
favorites.append(trackhash)
|
||||
|
||||
return favorites
|
||||
|
||||
|
||||
class SearchCacheService:
|
||||
"""High-performance search results caching"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = DragonflyCache("search")
|
||||
|
||||
def cache_search_results(
|
||||
self, query: str, results: dict[str, Any], ttl_hours: int = 6
|
||||
):
|
||||
"""Cache search results"""
|
||||
query_hash = hash(query) # Simple hash for key
|
||||
return self.cache.set(f"results:{query_hash}", results, ttl_hours)
|
||||
|
||||
def get_search_results(self, query: str) -> dict[str, Any] | None:
|
||||
"""Get cached search results"""
|
||||
query_hash = hash(query)
|
||||
return self.cache.get(f"results:{query_hash}")
|
||||
|
||||
def cache_suggestions(
|
||||
self, query_type: str, suggestions: list[str], ttl_hours: int = 12
|
||||
):
|
||||
"""Cache search suggestions"""
|
||||
return self.cache.set(f"suggestions:{query_type}", suggestions, ttl_hours)
|
||||
|
||||
def get_suggestions(self, query_type: str) -> list[str]:
|
||||
"""Get cached search suggestions"""
|
||||
suggestions = self.cache.get(f"suggestions:{query_type}")
|
||||
return suggestions if suggestions else []
|
||||
|
||||
def invalidate_search_cache(self, pattern: str = "*"):
|
||||
"""Invalidate search cache"""
|
||||
keys = self.cache.client.keys(f"search:{pattern}")
|
||||
if keys:
|
||||
return self.cache.client.delete(*keys)
|
||||
return True
|
||||
|
||||
|
||||
class JobQueueService:
|
||||
"""High-performance background job processing"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = DragonflyCache("jobs")
|
||||
|
||||
def enqueue_job(self, queue: str, job_data: dict[str, Any]):
|
||||
"""Add job to queue"""
|
||||
job_json = json.dumps(job_data)
|
||||
return self.cache.client.lpush(f"queue:{queue}", job_json)
|
||||
|
||||
def dequeue_job(self, queue: str) -> dict[str, Any] | None:
|
||||
"""Get next job from queue"""
|
||||
job_json = self.cache.client.rpop(f"queue:{queue}")
|
||||
if job_json:
|
||||
try:
|
||||
return json.loads(job_json)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_queue_size(self, queue: str) -> int:
|
||||
"""Get number of jobs in queue"""
|
||||
return self.cache.client.llen(f"queue:{queue}")
|
||||
|
||||
def peek_jobs(self, queue: str, count: int = 10) -> list[dict[str, Any]]:
|
||||
"""Peek at jobs in queue without removing them"""
|
||||
jobs_data = self.cache.client.lrange(f"queue:{queue}", 0, count - 1)
|
||||
|
||||
jobs = []
|
||||
for job_data in jobs_data:
|
||||
try:
|
||||
jobs.append(json.loads(job_data))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return jobs
|
||||
|
||||
def clear_queue(self, queue: str):
|
||||
"""Clear all jobs from queue"""
|
||||
return self.cache.client.delete(f"queue:{queue}")
|
||||
|
||||
|
||||
# Global service instances
|
||||
_track_cache_service: TrackCacheService | None = None
|
||||
_user_session_service: UserSessionService | None = None
|
||||
_mobile_sync_service: MobileSyncService | None = None
|
||||
_realtime_service: RealTimeFeaturesService | None = None
|
||||
_search_cache_service: SearchCacheService | None = None
|
||||
_job_queue_service: JobQueueService | None = None
|
||||
|
||||
|
||||
def get_track_cache_service() -> TrackCacheService:
|
||||
"""Get track cache service instance"""
|
||||
global _track_cache_service
|
||||
if _track_cache_service is None:
|
||||
_track_cache_service = TrackCacheService()
|
||||
return _track_cache_service
|
||||
|
||||
|
||||
def get_user_session_service() -> UserSessionService:
|
||||
"""Get user session service instance"""
|
||||
global _user_session_service
|
||||
if _user_session_service is None:
|
||||
_user_session_service = UserSessionService()
|
||||
return _user_session_service
|
||||
|
||||
|
||||
def get_mobile_sync_service() -> MobileSyncService:
|
||||
"""Get mobile sync service instance"""
|
||||
global _mobile_sync_service
|
||||
if _mobile_sync_service is None:
|
||||
_mobile_sync_service = MobileSyncService()
|
||||
return _mobile_sync_service
|
||||
|
||||
|
||||
def get_realtime_service() -> RealTimeFeaturesService:
|
||||
"""Get real-time features service instance"""
|
||||
global _realtime_service
|
||||
if _realtime_service is None:
|
||||
_realtime_service = RealTimeFeaturesService()
|
||||
return _realtime_service
|
||||
|
||||
|
||||
def get_search_cache_service() -> SearchCacheService:
|
||||
"""Get search cache service instance"""
|
||||
global _search_cache_service
|
||||
if _search_cache_service is None:
|
||||
_search_cache_service = SearchCacheService()
|
||||
return _search_cache_service
|
||||
|
||||
|
||||
def get_job_queue_service() -> JobQueueService:
|
||||
"""Get job queue service instance"""
|
||||
global _job_queue_service
|
||||
if _job_queue_service is None:
|
||||
_job_queue_service = JobQueueService()
|
||||
return _job_queue_service
|
||||
|
||||
|
||||
def get_all_dragonfly_services() -> dict[str, Any]:
|
||||
"""Get all DragonflyDB services for monitoring"""
|
||||
return {
|
||||
"track_cache": get_track_cache_service(),
|
||||
"user_sessions": get_user_session_service(),
|
||||
"mobile_sync": get_mobile_sync_service(),
|
||||
"realtime": get_realtime_service(),
|
||||
"search_cache": get_search_cache_service(),
|
||||
"job_queue": get_job_queue_service(),
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy import Engine, create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from swingmusic.settings import Paths
|
||||
|
||||
|
||||
@event.listens_for(Engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA cache_size=10000")
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.execute("PRAGMA temp_store=FILE")
|
||||
cursor.execute("PRAGMA mmap_size=0")
|
||||
cursor.close()
|
||||
|
||||
|
||||
class classproperty(property):
|
||||
"""
|
||||
A class property decorator.
|
||||
"""
|
||||
|
||||
def __get__(self, owner_self, owner_cls):
|
||||
if self.fget:
|
||||
return self.fget(owner_cls)
|
||||
|
||||
|
||||
class DbEngine:
|
||||
"""
|
||||
The database engine instance.
|
||||
"""
|
||||
|
||||
_engine: Engine | None = None
|
||||
|
||||
@classproperty
|
||||
def engine(cls) -> Engine:
|
||||
if not cls._engine:
|
||||
cls._engine = create_engine(
|
||||
f"sqlite+pysqlite:///{Paths().app_db_path}",
|
||||
echo=False,
|
||||
max_overflow=20,
|
||||
pool_size=10,
|
||||
)
|
||||
|
||||
return cls._engine
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def manager(cls, commit: bool = False):
|
||||
"""
|
||||
This context manager manages access to the database.
|
||||
|
||||
When the context manager is entered, it returns a session object that can be used to execute SQL statements.
|
||||
|
||||
If the `commit` parameter is set to `True`, the context manager will commit the transaction when it exits.
|
||||
"""
|
||||
Session = sessionmaker(cls.engine)
|
||||
|
||||
try:
|
||||
with Session() as session:
|
||||
yield session
|
||||
|
||||
if commit:
|
||||
session.commit()
|
||||
# yield session.execution_options(preserve_rowcount=True, yield_per=100)
|
||||
# yield conn.execution_options(preserve_rowcount=True, yield_per=100)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise e
|
||||
finally:
|
||||
if commit:
|
||||
session.commit()
|
||||
|
||||
session.close()
|
||||
# del conn
|
||||
# cls.engine.clear_compiled_cache()
|
||||
@@ -0,0 +1,79 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Integer, String, delete, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from swingmusic.config import UserConfig
|
||||
from swingmusic.db import Base
|
||||
from swingmusic.db.engine import DbEngine
|
||||
from swingmusic.db.utils import track_to_dataclass, tracks_to_dataclasses
|
||||
|
||||
|
||||
class TrackTable(Base):
|
||||
__tablename__ = "track"
|
||||
|
||||
id: Mapped[int] = mapped_column(init=False, primary_key=True)
|
||||
album: Mapped[str] = mapped_column(String())
|
||||
albumartists: Mapped[str] = mapped_column(String())
|
||||
albumhash: Mapped[str] = mapped_column(String(), index=True)
|
||||
artists: Mapped[str] = mapped_column(String())
|
||||
bitrate: Mapped[int] = mapped_column(Integer())
|
||||
copyright: Mapped[str | None] = mapped_column(String())
|
||||
date: Mapped[int] = mapped_column(Integer(), nullable=True)
|
||||
disc: Mapped[int] = mapped_column(Integer())
|
||||
duration: Mapped[int] = mapped_column(Integer())
|
||||
filepath: Mapped[str] = mapped_column(String(), index=True, unique=True)
|
||||
folder: Mapped[str] = mapped_column(String(), index=True)
|
||||
genres: Mapped[str | None] = mapped_column(String())
|
||||
last_mod: Mapped[float] = mapped_column(Integer())
|
||||
title: Mapped[str] = mapped_column(String())
|
||||
track: Mapped[int] = mapped_column(Integer())
|
||||
trackhash: Mapped[str] = mapped_column(String(), index=True)
|
||||
lastplayed: Mapped[int] = mapped_column(Integer(), default=0)
|
||||
playcount: Mapped[int] = mapped_column(Integer(), default=0)
|
||||
playduration: Mapped[int] = mapped_column(Integer(), default=0)
|
||||
extra: Mapped[dict[str, Any] | None] = mapped_column(JSON(), default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def get_all(cls):
|
||||
with DbEngine.manager() as conn:
|
||||
config = UserConfig()
|
||||
result = conn.execute(select(cls).execution_options(yield_per=100))
|
||||
|
||||
for i in result.scalars():
|
||||
d = i.__dict__
|
||||
del d["_sa_instance_state"]
|
||||
|
||||
yield track_to_dataclass(d, config)
|
||||
|
||||
@classmethod
|
||||
def get_tracks_by_filepaths(cls, filepaths: list[str]):
|
||||
with DbEngine.manager() as conn:
|
||||
result = conn.execute(
|
||||
select(TrackTable)
|
||||
.where(TrackTable.filepath.in_(filepaths))
|
||||
.order_by(TrackTable.last_mod)
|
||||
)
|
||||
return tracks_to_dataclasses(result.fetchall())
|
||||
|
||||
@classmethod
|
||||
def get_tracks_in_path(cls, path: str):
|
||||
with DbEngine.manager() as conn:
|
||||
result = conn.execute(
|
||||
select(TrackTable)
|
||||
.where(TrackTable.filepath.contains(path))
|
||||
.order_by(TrackTable.last_mod)
|
||||
)
|
||||
|
||||
clean = []
|
||||
for row in result.fetchall():
|
||||
d = row[0].__dict__
|
||||
del d["_sa_instance_state"]
|
||||
clean.append(d)
|
||||
|
||||
return tracks_to_dataclasses(clean)
|
||||
|
||||
@classmethod
|
||||
def remove_tracks_by_filepaths(cls, filepaths: set[str]):
|
||||
with DbEngine.manager(commit=True) as conn:
|
||||
conn.execute(delete(TrackTable).where(TrackTable.filepath.in_(filepaths)))
|
||||
@@ -0,0 +1,33 @@
|
||||
from sqlalchemy import Integer, insert, select, update
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from swingmusic.db import Base
|
||||
from swingmusic.db.engine import DbEngine
|
||||
|
||||
|
||||
class MigrationTable(Base):
|
||||
__tablename__ = "dbmigration"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
version: Mapped[int] = mapped_column(Integer())
|
||||
|
||||
@classmethod
|
||||
def set_version(cls, version: int):
|
||||
with DbEngine.manager(commit=True) as conn:
|
||||
result = conn.execute(
|
||||
update(cls).where(cls.id == 1).values(version=version)
|
||||
)
|
||||
|
||||
if result.rowcount == 0:
|
||||
conn.execute(insert(cls).values(id=1, version=version))
|
||||
|
||||
@classmethod
|
||||
def get_version(cls):
|
||||
with DbEngine.manager() as conn:
|
||||
result = conn.execute(select(cls.version).where(cls.id == 1))
|
||||
result = result.fetchone()
|
||||
|
||||
if result:
|
||||
return result[0]
|
||||
|
||||
return -1
|
||||
@@ -0,0 +1,745 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
delete,
|
||||
select,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from swingmusic.db import Base
|
||||
|
||||
|
||||
class LibraryFileTable(Base):
|
||||
__tablename__ = "library_file"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
trackhash: Mapped[str] = mapped_column(String(), unique=True, index=True)
|
||||
filepath: Mapped[str] = mapped_column(String(), unique=True, index=True)
|
||||
codec: Mapped[str] = mapped_column(String(), default="unknown")
|
||||
quality: Mapped[str] = mapped_column(String(), default="unknown")
|
||||
bitrate: Mapped[int] = mapped_column(Integer(), default=0)
|
||||
source: Mapped[str] = mapped_column(String(), default="local")
|
||||
checksum: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
|
||||
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def get_by_trackhash(cls, trackhash: str):
|
||||
result = cls.execute(select(cls).where(cls.trackhash == trackhash))
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def upsert_from_local_track(
|
||||
cls,
|
||||
*,
|
||||
trackhash: str,
|
||||
filepath: str,
|
||||
bitrate: int,
|
||||
codec: str,
|
||||
quality: str,
|
||||
source: str = "local",
|
||||
):
|
||||
now = int(time.time())
|
||||
row = cls.get_by_trackhash(trackhash)
|
||||
|
||||
if row:
|
||||
next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where(cls.id == row.id)
|
||||
.values(
|
||||
filepath=filepath,
|
||||
bitrate=bitrate,
|
||||
codec=codec,
|
||||
quality=quality,
|
||||
source=source,
|
||||
updated_at=now,
|
||||
),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
return cls.get_by_trackhash(trackhash)
|
||||
|
||||
cls.insert_one(
|
||||
{
|
||||
"trackhash": trackhash,
|
||||
"filepath": filepath,
|
||||
"bitrate": bitrate,
|
||||
"codec": codec,
|
||||
"quality": quality,
|
||||
"source": source,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"extra": {},
|
||||
}
|
||||
)
|
||||
return cls.get_by_trackhash(trackhash)
|
||||
|
||||
|
||||
class DownloadJobTable(Base):
|
||||
__tablename__ = "download_job"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
userid: Mapped[int] = mapped_column(
|
||||
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
|
||||
)
|
||||
trackhash: Mapped[str | None] = mapped_column(
|
||||
String(), nullable=True, index=True, default=None
|
||||
)
|
||||
title: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
|
||||
artist: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
|
||||
album: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
|
||||
item_type: Mapped[str] = mapped_column(String(), default="track")
|
||||
source_url: Mapped[str | None] = mapped_column(
|
||||
String(), nullable=True, index=True, default=None
|
||||
)
|
||||
source: Mapped[str] = mapped_column(String(), default="spotify", index=True)
|
||||
provider: Mapped[str] = mapped_column(String(), default="spotify")
|
||||
codec: Mapped[str] = mapped_column(String(), default="mp3")
|
||||
quality: Mapped[str] = mapped_column(String(), default="high")
|
||||
target_path: Mapped[str | None] = mapped_column(
|
||||
String(), nullable=True, default=None
|
||||
)
|
||||
state: Mapped[str] = mapped_column(String(), default="queued", index=True)
|
||||
progress: Mapped[float] = mapped_column(Float(), default=0.0)
|
||||
error: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
|
||||
retry_count: Mapped[int] = mapped_column(Integer(), default=0)
|
||||
payload: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
|
||||
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
started_at: Mapped[int | None] = mapped_column(
|
||||
Integer(), nullable=True, default=None
|
||||
)
|
||||
finished_at: Mapped[int | None] = mapped_column(
|
||||
Integer(), nullable=True, default=None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def enqueue(cls, payload: dict[str, Any]):
|
||||
now = int(time.time())
|
||||
values = {
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"state": "queued",
|
||||
"progress": 0.0,
|
||||
**payload,
|
||||
}
|
||||
result = cls.insert_one(values)
|
||||
return result.lastrowid
|
||||
|
||||
@classmethod
|
||||
def get_by_id(cls, job_id: int):
|
||||
result = cls.execute(select(cls).where(cls.id == job_id))
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def get_queued_job(cls):
|
||||
result = cls.execute(
|
||||
select(cls)
|
||||
.where(cls.state == "queued")
|
||||
.order_by(cls.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def update_job(cls, job_id: int, values: dict[str, Any]):
|
||||
values = {**values, "updated_at": int(time.time())}
|
||||
return next(
|
||||
cls.execute(update(cls).where(cls.id == job_id).values(values), commit=True)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_for_user(cls, userid: int, states: list[str] | set[str] | None = None):
|
||||
query = select(cls).where(cls.userid == userid).order_by(cls.created_at.desc())
|
||||
if states:
|
||||
query = query.where(cls.state.in_(list(states)))
|
||||
|
||||
result = cls.execute(query)
|
||||
return list(next(result).scalars())
|
||||
|
||||
@classmethod
|
||||
def delete_for_user(
|
||||
cls, userid: int, states: list[str] | set[str] | None = None
|
||||
) -> int:
|
||||
statement = delete(cls).where(cls.userid == userid)
|
||||
if states:
|
||||
statement = statement.where(cls.state.in_(list(states)))
|
||||
|
||||
result = next(cls.execute(statement, commit=True))
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
|
||||
class TrackedPlaylistTable(Base):
|
||||
__tablename__ = "tracked_playlist"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"userid",
|
||||
"service",
|
||||
"playlist_id",
|
||||
name="uq_tracked_playlist_user_service",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
userid: Mapped[int] = mapped_column(
|
||||
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
|
||||
)
|
||||
source_url: Mapped[str] = mapped_column(String(), index=True)
|
||||
playlist_id: Mapped[str] = mapped_column(String(), index=True)
|
||||
service: Mapped[str] = mapped_column(String(), default="spotify", index=True)
|
||||
title: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
|
||||
owner_name: Mapped[str | None] = mapped_column(
|
||||
String(), nullable=True, default=None
|
||||
)
|
||||
quality: Mapped[str] = mapped_column(String(), default="lossless")
|
||||
codec: Mapped[str] = mapped_column(String(), default="flac")
|
||||
auto_sync: Mapped[bool] = mapped_column(Boolean(), default=True, index=True)
|
||||
sync_interval_seconds: Mapped[int] = mapped_column(Integer(), default=900)
|
||||
next_sync_at: Mapped[int] = mapped_column(
|
||||
Integer(), default=lambda: int(time.time())
|
||||
)
|
||||
last_sync_at: Mapped[int | None] = mapped_column(
|
||||
Integer(), nullable=True, default=None
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(), default="active", index=True)
|
||||
snapshot_track_ids: Mapped[list[str]] = mapped_column(JSON(), default_factory=list)
|
||||
snapshot_hash: Mapped[str | None] = mapped_column(
|
||||
String(), nullable=True, default=None
|
||||
)
|
||||
last_result: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
|
||||
last_error: Mapped[str | None] = mapped_column(
|
||||
String(), nullable=True, default=None
|
||||
)
|
||||
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def get_by_id(cls, tracked_id: int, userid: int | None = None):
|
||||
statement = select(cls).where(cls.id == tracked_id)
|
||||
if userid is not None:
|
||||
statement = statement.where(cls.userid == userid)
|
||||
result = cls.execute(statement)
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def get_by_source(cls, *, userid: int, service: str, playlist_id: str):
|
||||
result = cls.execute(
|
||||
select(cls).where(
|
||||
and_(
|
||||
cls.userid == userid,
|
||||
cls.service == service,
|
||||
cls.playlist_id == playlist_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def list_for_user(cls, userid: int, include_deleted: bool = False):
|
||||
statement = (
|
||||
select(cls).where(cls.userid == userid).order_by(cls.created_at.desc())
|
||||
)
|
||||
if not include_deleted:
|
||||
statement = statement.where(cls.status != "deleted")
|
||||
result = cls.execute(statement)
|
||||
return list(next(result).scalars())
|
||||
|
||||
@classmethod
|
||||
def upsert(
|
||||
cls,
|
||||
*,
|
||||
userid: int,
|
||||
service: str,
|
||||
playlist_id: str,
|
||||
source_url: str,
|
||||
values: dict[str, Any] | None = None,
|
||||
):
|
||||
now = int(time.time())
|
||||
row = cls.get_by_source(userid=userid, service=service, playlist_id=playlist_id)
|
||||
payload: dict[str, Any] = {
|
||||
"userid": userid,
|
||||
"service": service,
|
||||
"playlist_id": playlist_id,
|
||||
"source_url": source_url,
|
||||
}
|
||||
if values:
|
||||
payload.update(values)
|
||||
|
||||
if row:
|
||||
next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where(cls.id == row.id)
|
||||
.values(
|
||||
{
|
||||
**payload,
|
||||
"updated_at": now,
|
||||
}
|
||||
),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
return cls.get_by_id(row.id)
|
||||
|
||||
cls.insert_one(
|
||||
{
|
||||
**payload,
|
||||
"status": payload.get("status", "active"),
|
||||
"auto_sync": bool(payload.get("auto_sync", True)),
|
||||
"sync_interval_seconds": int(payload.get("sync_interval_seconds", 900)),
|
||||
"next_sync_at": int(payload.get("next_sync_at", now)),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"snapshot_track_ids": payload.get("snapshot_track_ids", []),
|
||||
"last_result": payload.get("last_result", {}),
|
||||
"extra": payload.get("extra", {}),
|
||||
}
|
||||
)
|
||||
return cls.get_by_source(
|
||||
userid=userid, service=service, playlist_id=playlist_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def update_row(cls, tracked_id: int, values: dict[str, Any]):
|
||||
next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where(cls.id == tracked_id)
|
||||
.values({**values, "updated_at": int(time.time())}),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
return cls.get_by_id(tracked_id)
|
||||
|
||||
@classmethod
|
||||
def due_for_sync(cls, *, now_ts: int | None = None, limit: int = 50):
|
||||
now_ts = int(now_ts or time.time())
|
||||
result = cls.execute(
|
||||
select(cls)
|
||||
.where(cls.auto_sync.is_(True))
|
||||
.where(cls.status.in_(["active", "failed", "syncing"]))
|
||||
.where(cls.next_sync_at <= now_ts)
|
||||
.order_by(cls.next_sync_at.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
return list(next(result).scalars())
|
||||
|
||||
|
||||
class UserLibraryTrackTable(Base):
|
||||
__tablename__ = "user_library_track"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("userid", "trackhash", name="uq_user_track_projection"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
userid: Mapped[int] = mapped_column(
|
||||
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
|
||||
)
|
||||
trackhash: Mapped[str] = mapped_column(String(), index=True)
|
||||
file_id: Mapped[int] = mapped_column(
|
||||
Integer(),
|
||||
ForeignKey("library_file.id", ondelete="set null"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(), default="missing", index=True)
|
||||
source_url: Mapped[str | None] = mapped_column(
|
||||
String(), nullable=True, default=None
|
||||
)
|
||||
download_job_id: Mapped[int | None] = mapped_column(
|
||||
Integer(),
|
||||
ForeignKey("download_job.id", ondelete="set null"),
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
error: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
|
||||
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def get_user_track(cls, userid: int, trackhash: str):
|
||||
result = cls.execute(
|
||||
select(cls).where(and_(cls.userid == userid, cls.trackhash == trackhash))
|
||||
)
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def upsert_status(
|
||||
cls,
|
||||
*,
|
||||
userid: int,
|
||||
trackhash: str,
|
||||
status: str,
|
||||
file_id: int | None = None,
|
||||
download_job_id: int | None = None,
|
||||
source_url: str | None = None,
|
||||
error: str | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
now = int(time.time())
|
||||
row = cls.get_user_track(userid, trackhash)
|
||||
|
||||
values: dict[str, Any] = {
|
||||
"status": status,
|
||||
"updated_at": now,
|
||||
"file_id": file_id,
|
||||
"download_job_id": download_job_id,
|
||||
"source_url": source_url,
|
||||
"error": error,
|
||||
}
|
||||
|
||||
if extra is not None:
|
||||
values["extra"] = extra
|
||||
|
||||
if row:
|
||||
next(
|
||||
cls.execute(
|
||||
update(cls).where(cls.id == row.id).values(values), commit=True
|
||||
)
|
||||
)
|
||||
return cls.get_user_track(userid, trackhash)
|
||||
|
||||
cls.insert_one(
|
||||
{
|
||||
"userid": userid,
|
||||
"trackhash": trackhash,
|
||||
"status": status,
|
||||
"file_id": file_id,
|
||||
"download_job_id": download_job_id,
|
||||
"source_url": source_url,
|
||||
"error": error,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"extra": extra or {},
|
||||
}
|
||||
)
|
||||
return cls.get_user_track(userid, trackhash)
|
||||
|
||||
@classmethod
|
||||
def get_status_map(cls, userid: int, trackhashes: set[str] | list[str]):
|
||||
if not trackhashes:
|
||||
return {}
|
||||
|
||||
result = cls.execute(
|
||||
select(cls).where(
|
||||
and_(cls.userid == userid, cls.trackhash.in_(set(trackhashes)))
|
||||
)
|
||||
)
|
||||
rows = list(next(result).scalars())
|
||||
return {row.trackhash: row for row in rows}
|
||||
|
||||
|
||||
class UserRootDirOwnershipTable(Base):
|
||||
__tablename__ = "user_root_dir_ownership"
|
||||
__table_args__ = (UniqueConstraint("userid", "path", name="uq_user_root_path"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
userid: Mapped[int] = mapped_column(
|
||||
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
|
||||
)
|
||||
path: Mapped[str] = mapped_column(String(), index=True)
|
||||
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
|
||||
@classmethod
|
||||
def assign_paths(cls, userid: int, paths: list[str]):
|
||||
existing_result = cls.execute(select(cls.path).where(cls.userid == userid))
|
||||
existing = {row[0] for row in next(existing_result).all()}
|
||||
|
||||
for path in paths:
|
||||
if path in existing:
|
||||
continue
|
||||
cls.insert_one(
|
||||
{"userid": userid, "path": path, "created_at": int(time.time())}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_paths(cls, userid: int) -> list[str]:
|
||||
result = cls.execute(select(cls.path).where(cls.userid == userid))
|
||||
paths = [row for row in next(result).scalars().all() if row]
|
||||
return list(dict.fromkeys(paths))
|
||||
|
||||
@classmethod
|
||||
def replace_paths(cls, userid: int, paths: list[str]):
|
||||
cleaned = [path.strip() for path in paths if path and path.strip()]
|
||||
cleaned = list(dict.fromkeys(cleaned))
|
||||
|
||||
next(cls.execute(delete(cls).where(cls.userid == userid), commit=True))
|
||||
if not cleaned:
|
||||
return
|
||||
|
||||
now = int(time.time())
|
||||
cls.insert_many(
|
||||
[{"userid": userid, "path": path, "created_at": now} for path in cleaned]
|
||||
)
|
||||
|
||||
|
||||
class SetupStateTable(Base):
|
||||
__tablename__ = "setup_state"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
owner_userid: Mapped[int | None] = mapped_column(
|
||||
Integer(),
|
||||
ForeignKey("user.id", ondelete="set null"),
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
primary_music_dir: Mapped[str | None] = mapped_column(
|
||||
String(), nullable=True, default=None
|
||||
)
|
||||
setup_completed: Mapped[bool] = mapped_column(Boolean(), default=False)
|
||||
index_state: Mapped[str] = mapped_column(String(), default="idle")
|
||||
index_progress: Mapped[float] = mapped_column(Float(), default=0.0)
|
||||
index_message: Mapped[str | None] = mapped_column(
|
||||
String(), nullable=True, default=None
|
||||
)
|
||||
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def get_singleton(cls):
|
||||
result = cls.execute(select(cls).where(cls.id == 1))
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def ensure_singleton(cls):
|
||||
row = cls.get_singleton()
|
||||
if row:
|
||||
return row
|
||||
|
||||
cls.insert_one(
|
||||
{
|
||||
"id": 1,
|
||||
"setup_completed": False,
|
||||
"index_state": "idle",
|
||||
"index_progress": 0.0,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
"extra": {},
|
||||
}
|
||||
)
|
||||
return cls.get_singleton()
|
||||
|
||||
@classmethod
|
||||
def update_state(cls, values: dict[str, Any]):
|
||||
cls.ensure_singleton()
|
||||
next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where(cls.id == 1)
|
||||
.values(
|
||||
{
|
||||
**values,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
return cls.get_singleton()
|
||||
|
||||
@classmethod
|
||||
def mark_index_progress(
|
||||
cls,
|
||||
*,
|
||||
state: str,
|
||||
progress: float,
|
||||
message: str | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
values: dict[str, Any] = {
|
||||
"index_state": state,
|
||||
"index_progress": max(0.0, min(float(progress), 100.0)),
|
||||
"index_message": message,
|
||||
}
|
||||
if extra is not None:
|
||||
values["extra"] = extra
|
||||
return cls.update_state(values)
|
||||
|
||||
|
||||
class LyricsStatusTable(Base):
|
||||
__tablename__ = "lyrics_status"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
trackhash: Mapped[str] = mapped_column(String(), unique=True, index=True)
|
||||
filepath: Mapped[str | None] = mapped_column(
|
||||
String(), nullable=True, index=True, default=None
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(), default="pending", index=True)
|
||||
source: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
|
||||
has_embedded: Mapped[bool] = mapped_column(Boolean(), default=False)
|
||||
has_lrc: Mapped[bool] = mapped_column(Boolean(), default=False)
|
||||
last_error: Mapped[str | None] = mapped_column(
|
||||
String(), nullable=True, default=None
|
||||
)
|
||||
attempts: Mapped[int] = mapped_column(Integer(), default=0)
|
||||
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def get_by_trackhash(cls, trackhash: str):
|
||||
result = cls.execute(select(cls).where(cls.trackhash == trackhash))
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def upsert(
|
||||
cls,
|
||||
*,
|
||||
trackhash: str,
|
||||
filepath: str | None = None,
|
||||
status: str,
|
||||
source: str | None = None,
|
||||
has_embedded: bool | None = None,
|
||||
has_lrc: bool | None = None,
|
||||
last_error: str | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
increment_attempt: bool = False,
|
||||
):
|
||||
now = int(time.time())
|
||||
row = cls.get_by_trackhash(trackhash)
|
||||
values: dict[str, Any] = {
|
||||
"status": status,
|
||||
"source": source,
|
||||
"last_error": last_error,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
if filepath is not None:
|
||||
values["filepath"] = filepath
|
||||
if has_embedded is not None:
|
||||
values["has_embedded"] = bool(has_embedded)
|
||||
if has_lrc is not None:
|
||||
values["has_lrc"] = bool(has_lrc)
|
||||
if extra is not None:
|
||||
values["extra"] = extra
|
||||
|
||||
if row:
|
||||
if increment_attempt:
|
||||
values["attempts"] = int(row.attempts or 0) + 1
|
||||
next(
|
||||
cls.execute(
|
||||
update(cls).where(cls.id == row.id).values(values), commit=True
|
||||
)
|
||||
)
|
||||
return cls.get_by_trackhash(trackhash)
|
||||
|
||||
cls.insert_one(
|
||||
{
|
||||
"trackhash": trackhash,
|
||||
"filepath": filepath,
|
||||
"status": status,
|
||||
"source": source,
|
||||
"has_embedded": bool(has_embedded)
|
||||
if has_embedded is not None
|
||||
else False,
|
||||
"has_lrc": bool(has_lrc) if has_lrc is not None else False,
|
||||
"last_error": last_error,
|
||||
"attempts": 1 if increment_attempt else 0,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"extra": extra or {},
|
||||
}
|
||||
)
|
||||
return cls.get_by_trackhash(trackhash)
|
||||
|
||||
|
||||
class InviteTokenTable(Base):
|
||||
__tablename__ = "invite_token"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
token: Mapped[str] = mapped_column(String(), unique=True, index=True)
|
||||
created_by: Mapped[int | None] = mapped_column(
|
||||
Integer(),
|
||||
ForeignKey("user.id", ondelete="set null"),
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
used_by: Mapped[int | None] = mapped_column(
|
||||
Integer(),
|
||||
ForeignKey("user.id", ondelete="set null"),
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
roles: Mapped[list[str]] = mapped_column(JSON(), default_factory=lambda: ["user"])
|
||||
active: Mapped[bool] = mapped_column(Boolean(), default=True)
|
||||
expires_at: Mapped[int | None] = mapped_column(
|
||||
Integer(), nullable=True, default=None
|
||||
)
|
||||
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
|
||||
used_at: Mapped[int | None] = mapped_column(Integer(), nullable=True, default=None)
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def create_token(
|
||||
cls,
|
||||
*,
|
||||
created_by: int | None,
|
||||
roles: list[str] | None = None,
|
||||
expires_in_seconds: int = 7 * 24 * 3600,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
token = secrets.token_urlsafe(24)
|
||||
now = int(time.time())
|
||||
expires_at = now + expires_in_seconds if expires_in_seconds > 0 else None
|
||||
|
||||
cls.insert_one(
|
||||
{
|
||||
"token": token,
|
||||
"created_by": created_by,
|
||||
"roles": roles or ["user"],
|
||||
"active": True,
|
||||
"expires_at": expires_at,
|
||||
"created_at": now,
|
||||
"extra": extra or {},
|
||||
}
|
||||
)
|
||||
|
||||
result = cls.execute(select(cls).where(cls.token == token))
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def get_valid_token(cls, token: str):
|
||||
now = int(time.time())
|
||||
result = cls.execute(select(cls).where(cls.token == token))
|
||||
row = next(result).scalar()
|
||||
|
||||
if not row or not row.active:
|
||||
return None
|
||||
|
||||
if row.expires_at is not None and row.expires_at < now:
|
||||
cls.consume_token(token, used_by=None, deactivate_only=True)
|
||||
return None
|
||||
|
||||
return row
|
||||
|
||||
@classmethod
|
||||
def consume_token(
|
||||
cls, token: str, used_by: int | None, deactivate_only: bool = False
|
||||
):
|
||||
values: dict[str, Any] = {"active": False, "used_at": int(time.time())}
|
||||
if not deactivate_only:
|
||||
values["used_by"] = used_by
|
||||
|
||||
next(
|
||||
cls.execute(
|
||||
update(cls).where(cls.token == token).values(values), commit=True
|
||||
)
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
This module contains the functions to interact with the SQLite database.
|
||||
"""
|
||||
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Reads and saves the latest database migrations version.
|
||||
"""
|
||||
|
||||
from swingmusic.db.sqlite.utils import SQLiteManager
|
||||
|
||||
|
||||
class MigrationManager:
|
||||
@staticmethod
|
||||
def get_index() -> int:
|
||||
"""
|
||||
Returns the latest databases migrations index.
|
||||
"""
|
||||
sql = "SELECT * FROM dbmigrations"
|
||||
with SQLiteManager() as cur:
|
||||
cur.execute(sql)
|
||||
ver = int(cur.fetchone()[1])
|
||||
cur.close()
|
||||
|
||||
return ver
|
||||
|
||||
# 👇 Setters 👇
|
||||
@staticmethod
|
||||
def set_index(version: int):
|
||||
"""
|
||||
Updates the databases migrations index.
|
||||
"""
|
||||
sql = "UPDATE dbmigrations SET version = ? WHERE id = 1"
|
||||
with SQLiteManager() as cur:
|
||||
cur.execute(sql, (version,))
|
||||
cur.close()
|
||||
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Helper functions for use with the SQLite database.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import time
|
||||
from sqlite3 import Connection, Cursor
|
||||
|
||||
from swingmusic import settings
|
||||
from swingmusic.models import Album, Playlist, Track
|
||||
|
||||
|
||||
def tuple_to_track(track: tuple):
|
||||
"""
|
||||
Takes a tuple and returns a Track object
|
||||
"""
|
||||
return Track(*track[1:]) # rowid is removed from the tuple
|
||||
|
||||
|
||||
def tuples_to_tracks(tracks: list[tuple]):
|
||||
"""
|
||||
Takes a list of tuples and returns a generator that yields a Track object for each tuple
|
||||
"""
|
||||
for track in tracks:
|
||||
yield tuple_to_track(track)
|
||||
|
||||
|
||||
def tuple_to_album(album: tuple):
|
||||
"""
|
||||
Takes a tuple and returns an Album object
|
||||
"""
|
||||
return Album(*album[1:]) # rowid is removed from the tuple
|
||||
|
||||
|
||||
def tuples_to_albums(albums: list[tuple]):
|
||||
"""
|
||||
Takes a list of tuples and returns a generator that yields an album object for each tuple
|
||||
"""
|
||||
for album in albums:
|
||||
yield tuple_to_album(album)
|
||||
|
||||
|
||||
def tuple_to_playlist(playlist: tuple):
|
||||
"""
|
||||
Takes a tuple and returns a Playlist object
|
||||
"""
|
||||
return Playlist(*playlist)
|
||||
|
||||
|
||||
def tuples_to_playlists(playlists: list[tuple]):
|
||||
"""
|
||||
Takes a list of tuples and returns a list of Playlist objects
|
||||
"""
|
||||
for playlist in playlists:
|
||||
yield tuple_to_playlist(playlist)
|
||||
|
||||
|
||||
class SQLiteManager:
|
||||
"""
|
||||
This is a context manager that handles the connection and cursor
|
||||
for you. It also commits and closes the connection when you're done.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: Connection | None = None,
|
||||
userdata_db=False,
|
||||
test_db_path: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
When a connection is passed in, don't close the connection, because it's
|
||||
a connection to the search database [in memory db].
|
||||
"""
|
||||
self.conn = conn
|
||||
self.CLOSE_CONN = True
|
||||
self.userdata_db = userdata_db
|
||||
self.test_db_path = test_db_path
|
||||
|
||||
if conn:
|
||||
self.conn = conn
|
||||
self.CLOSE_CONN = False
|
||||
|
||||
def __enter__(self) -> Cursor:
|
||||
if self.conn is not None:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("PRAGMA foreign_keys = ON")
|
||||
return cur
|
||||
|
||||
db_path = self.test_db_path or settings.Paths().app_db_path
|
||||
|
||||
if self.userdata_db:
|
||||
db_path = settings.Paths().userdata_db_path
|
||||
|
||||
self.conn = sqlite3.connect(
|
||||
db_path,
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("PRAGMA foreign_keys = ON")
|
||||
return cur
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
trial_count = 0
|
||||
|
||||
while trial_count < 10:
|
||||
try:
|
||||
self.conn.commit()
|
||||
|
||||
if self.CLOSE_CONN:
|
||||
self.conn.close()
|
||||
|
||||
return
|
||||
except sqlite3.OperationalError:
|
||||
trial_count += 1
|
||||
time.sleep(3)
|
||||
|
||||
self.conn.close()
|
||||
@@ -0,0 +1,835 @@
|
||||
import datetime
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
and_,
|
||||
delete,
|
||||
func,
|
||||
insert,
|
||||
select,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from swingmusic.db import Base
|
||||
from swingmusic.db.engine import DbEngine
|
||||
from swingmusic.db.utils import (
|
||||
favorite_to_dataclass,
|
||||
favorites_to_dataclass,
|
||||
playlist_to_dataclass,
|
||||
plugin_to_dataclass,
|
||||
similar_artist_to_dataclass,
|
||||
tracklog_to_dataclass,
|
||||
user_to_dataclass,
|
||||
)
|
||||
from swingmusic.models.mix import Mix
|
||||
from swingmusic.utils.auth import get_current_userid, hash_password
|
||||
|
||||
|
||||
class UserTable(Base):
|
||||
__tablename__ = "user"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
image: Mapped[str] = mapped_column(String(), nullable=True)
|
||||
password: Mapped[str] = mapped_column(String())
|
||||
username: Mapped[str] = mapped_column(String(), index=True)
|
||||
roles: Mapped[list[str]] = mapped_column(JSON(), default_factory=lambda: [])
|
||||
password_change_required: Mapped[bool] = mapped_column(Boolean(), default=False)
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSON(), nullable=True, default_factory=dict
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all(cls):
|
||||
result = cls.execute(select(cls))
|
||||
|
||||
for i in next(result).scalars():
|
||||
yield user_to_dataclass(i)
|
||||
|
||||
@classmethod
|
||||
def insert_default_user(cls):
|
||||
user = {
|
||||
"username": "admin",
|
||||
"password": hash_password("admin"),
|
||||
"roles": ["admin"],
|
||||
"password_change_required": True,
|
||||
}
|
||||
|
||||
return cls.insert_one(user)
|
||||
|
||||
@classmethod
|
||||
def insert_guest_user(cls):
|
||||
user = {
|
||||
"username": "guest",
|
||||
"password": hash_password("guest"),
|
||||
"roles": ["guest"],
|
||||
"password_change_required": True,
|
||||
}
|
||||
|
||||
return cls.insert_one(user)
|
||||
|
||||
@classmethod
|
||||
def get_by_id(cls, id: int):
|
||||
result = cls.execute(select(cls).where(cls.id == id))
|
||||
res = next(result).scalar()
|
||||
|
||||
if res:
|
||||
return user_to_dataclass(res)
|
||||
|
||||
@classmethod
|
||||
def get_by_username(cls, username: str):
|
||||
res = cls.execute(select(cls).where(cls.username == username))
|
||||
res = next(res).scalar()
|
||||
|
||||
if res:
|
||||
return user_to_dataclass(res)
|
||||
|
||||
@classmethod
|
||||
def update_one(cls, user: dict[str, Any]):
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls).where(cls.id == user["id"]).values(user), commit=True
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def remove_by_username(cls, username: str):
|
||||
return next(
|
||||
cls.execute(delete(cls).where(cls.username == username), commit=True)
|
||||
)
|
||||
|
||||
|
||||
class PluginTable(Base):
|
||||
__tablename__ = "plugin"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(), unique=True)
|
||||
active: Mapped[bool] = mapped_column(Boolean())
|
||||
settings: Mapped[dict[str, Any]] = mapped_column(JSON())
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), nullable=True)
|
||||
|
||||
@classmethod
|
||||
def get_all(cls):
|
||||
result = cls.execute(select(cls))
|
||||
|
||||
for i in next(result).scalars():
|
||||
yield plugin_to_dataclass(i)
|
||||
|
||||
@classmethod
|
||||
def activate(cls, name: str, value: bool):
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls).where(cls.name == name).values(active=value), commit=True
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_by_name(cls, name: str):
|
||||
result = cls.execute(select(cls).where(cls.name == name))
|
||||
res = next(result).scalar()
|
||||
|
||||
if res:
|
||||
return plugin_to_dataclass(res)
|
||||
|
||||
@classmethod
|
||||
def update_settings(cls, name: str, settings: dict[str, Any]):
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls).where(cls.name == name).values(settings=settings),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SimilarArtistTable(Base):
|
||||
__tablename__ = "notlastfm_similar_artists"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer(), primary_key=True)
|
||||
artisthash: Mapped[str] = mapped_column(String(), index=True)
|
||||
similar_artists: Mapped[dict[str, str]] = mapped_column(JSON())
|
||||
|
||||
@classmethod
|
||||
def get_all(cls):
|
||||
result = cls.execute(select(cls).execution_options(yield_per=100))
|
||||
|
||||
for i in next(result).scalars():
|
||||
yield similar_artist_to_dataclass(i)
|
||||
|
||||
@classmethod
|
||||
def exists(cls, artisthash: str):
|
||||
"""
|
||||
Check whether an artisthash exists in the database.
|
||||
"""
|
||||
|
||||
with DbEngine.manager() as conn:
|
||||
result = conn.execute(
|
||||
select(cls.artisthash)
|
||||
.where(cls.artisthash == artisthash)
|
||||
.execution_options(yield_per=100)
|
||||
)
|
||||
|
||||
return len(result.scalars().all()) > 0
|
||||
|
||||
@classmethod
|
||||
def get_by_hash(cls, artisthash: str):
|
||||
"""
|
||||
Get a single artist by hash.
|
||||
"""
|
||||
result = cls.execute(select(cls).where(cls.artisthash == artisthash))
|
||||
res = next(result).scalar()
|
||||
|
||||
if res:
|
||||
return similar_artist_to_dataclass(res)
|
||||
|
||||
|
||||
class FavoritesTable(Base):
|
||||
__tablename__ = "favorite"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
hash: Mapped[str] = mapped_column(String(), unique=True)
|
||||
type: Mapped[str] = mapped_column(String(), index=True)
|
||||
timestamp: Mapped[int] = mapped_column(Integer(), index=True)
|
||||
userid: Mapped[int] = mapped_column(
|
||||
Integer(), ForeignKey("user.id", ondelete="cascade"), default=1, index=True
|
||||
)
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSON(), nullable=True, default_factory=dict
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _normalize_item_hash(cls, raw_hash: str, item_type: str) -> str:
|
||||
"""
|
||||
Normalize legacy and scoped favorite hash formats to plain item hash.
|
||||
Accepted formats:
|
||||
- <hash>
|
||||
- <type>_<hash>
|
||||
- u<userid>:<type>_<hash>
|
||||
"""
|
||||
normalized = str(raw_hash or "").strip()
|
||||
item_type = str(item_type or "").strip()
|
||||
|
||||
if normalized.startswith("u") and ":" in normalized:
|
||||
user_prefix, remainder = normalized.split(":", 1)
|
||||
if user_prefix[1:].isdigit():
|
||||
normalized = remainder
|
||||
|
||||
type_prefix = f"{item_type}_"
|
||||
if item_type and normalized.startswith(type_prefix):
|
||||
normalized = normalized[len(type_prefix) :]
|
||||
|
||||
return normalized
|
||||
|
||||
@classmethod
|
||||
def _hash_candidates(
|
||||
cls,
|
||||
*,
|
||||
hash_value: str,
|
||||
item_type: str,
|
||||
userid: int | None = None,
|
||||
) -> set[str]:
|
||||
canonical = cls._normalize_item_hash(hash_value, item_type)
|
||||
candidates = {canonical}
|
||||
|
||||
if item_type:
|
||||
candidates.add(f"{item_type}_{canonical}")
|
||||
if userid is not None:
|
||||
candidates.add(f"u{int(userid)}:{item_type}_{canonical}")
|
||||
|
||||
return {candidate for candidate in candidates if candidate}
|
||||
|
||||
@classmethod
|
||||
def get_all(cls, with_user: bool = True):
|
||||
with DbEngine.manager() as conn:
|
||||
if with_user:
|
||||
result = conn.execute(
|
||||
select(cls).where(cls.userid == get_current_userid())
|
||||
)
|
||||
else:
|
||||
result = conn.execute(select(cls))
|
||||
|
||||
for i in result.scalars():
|
||||
yield favorite_to_dataclass(i)
|
||||
|
||||
@classmethod
|
||||
def insert_item(cls, item: dict[str, Any]):
|
||||
item_type = str(item.get("type") or "").strip()
|
||||
canonical_hash = cls._normalize_item_hash(item.get("hash", ""), item_type)
|
||||
userid = int(item.get("userid") or get_current_userid())
|
||||
|
||||
if cls.check_exists(canonical_hash, item_type, userid=userid):
|
||||
return None
|
||||
|
||||
# Scope favorites per user while keeping backward compatibility
|
||||
# with legacy `type_hash` entries.
|
||||
item["hash"] = f"u{userid}:{item_type}_{canonical_hash}"
|
||||
|
||||
if item.get("timestamp") is None:
|
||||
item["timestamp"] = int(datetime.datetime.now().timestamp())
|
||||
|
||||
item["userid"] = userid
|
||||
|
||||
return next(cls.execute(insert(cls).values(item), commit=True))
|
||||
|
||||
@classmethod
|
||||
def remove_item(cls, item: dict[str, Any]):
|
||||
userid = int(item.get("userid") or get_current_userid())
|
||||
candidates = cls._hash_candidates(
|
||||
hash_value=str(item.get("hash") or ""),
|
||||
item_type=str(item.get("type") or ""),
|
||||
userid=userid,
|
||||
)
|
||||
return next(
|
||||
cls.execute(
|
||||
delete(cls).where(and_(cls.userid == userid, cls.hash.in_(candidates))),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_exists(cls, hash: str, type: str, userid: int | None = None):
|
||||
userid = int(userid or get_current_userid())
|
||||
candidates = cls._hash_candidates(
|
||||
hash_value=hash,
|
||||
item_type=type,
|
||||
userid=userid,
|
||||
)
|
||||
result = cls.execute(
|
||||
select(cls).where(and_(cls.userid == userid, cls.hash.in_(candidates)))
|
||||
)
|
||||
|
||||
return next(result).scalar() is not None
|
||||
|
||||
@classmethod
|
||||
def get_by_hash(cls, hash: str, type: str, userid: int | None = None):
|
||||
userid = int(userid or get_current_userid())
|
||||
candidates = cls._hash_candidates(
|
||||
hash_value=hash,
|
||||
item_type=type,
|
||||
userid=userid,
|
||||
)
|
||||
result = cls.execute(
|
||||
select(cls).where(and_(cls.userid == userid, cls.hash.in_(candidates)))
|
||||
)
|
||||
|
||||
return next(result).scalars().all()
|
||||
|
||||
@classmethod
|
||||
def get_all_of_type(cls, type: str, start: int, limit: int):
|
||||
result = cls.execute(
|
||||
select(cls)
|
||||
# .select_from(join(table, cls, field == cls.hash))
|
||||
.where(and_(cls.type == type, cls.userid == get_current_userid()))
|
||||
.order_by(cls.timestamp.desc())
|
||||
.offset(start)
|
||||
# INFO: If start is 0, fetch all so we can get the total count
|
||||
.limit(limit if start != 0 else None)
|
||||
)
|
||||
|
||||
res = next(result).scalars().all()
|
||||
|
||||
if start == 0:
|
||||
# if limit == -1, return all
|
||||
if limit == -1:
|
||||
limit = len(res)
|
||||
|
||||
return res[:limit], len(res)
|
||||
|
||||
return res, -1
|
||||
|
||||
@classmethod
|
||||
def get_fav_tracks(cls, start: int, limit: int):
|
||||
result, total = cls.get_all_of_type("track", start, limit)
|
||||
return favorites_to_dataclass(result), total
|
||||
|
||||
@classmethod
|
||||
def get_fav_albums(cls, start: int, limit: int):
|
||||
result, total = cls.get_all_of_type("album", start, limit)
|
||||
return favorites_to_dataclass(result), total
|
||||
|
||||
@classmethod
|
||||
def get_fav_artists(cls, start: int, limit: int):
|
||||
result, total = cls.get_all_of_type("artist", start, limit)
|
||||
return favorites_to_dataclass(result), total
|
||||
|
||||
@classmethod
|
||||
def count_favs_in_period(cls, start_time: int, end_time: int):
|
||||
result = cls.execute(
|
||||
select(func.count(cls.id))
|
||||
.where(cls.userid == get_current_userid())
|
||||
.where(and_(cls.timestamp >= start_time, cls.timestamp <= end_time))
|
||||
)
|
||||
|
||||
res = next(result).scalar()
|
||||
|
||||
if res:
|
||||
return res
|
||||
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def count_tracks(cls):
|
||||
result = cls.execute(
|
||||
select(func.count(cls.id)).where(
|
||||
and_(cls.type == "track", cls.userid == get_current_userid())
|
||||
)
|
||||
)
|
||||
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def get_last_trackhash(cls):
|
||||
result = cls.execute(
|
||||
select(cls.hash)
|
||||
.where(and_(cls.type == "track", cls.userid == get_current_userid()))
|
||||
.order_by(cls.timestamp.desc())
|
||||
)
|
||||
|
||||
db_hash = next(result).scalar()
|
||||
if not db_hash:
|
||||
return None
|
||||
|
||||
return cls._normalize_item_hash(db_hash, "track")
|
||||
|
||||
|
||||
class ScrobbleTable(Base):
|
||||
__tablename__ = "scrobble"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
trackhash: Mapped[str] = mapped_column(String(), index=True)
|
||||
duration: Mapped[int] = mapped_column(Integer())
|
||||
timestamp: Mapped[int] = mapped_column(Integer())
|
||||
source: Mapped[str] = mapped_column(String())
|
||||
userid: Mapped[int] = mapped_column(
|
||||
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
|
||||
)
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSON(), nullable=True, default_factory=dict
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add(cls, item: dict[str, Any]):
|
||||
if item.get("userid") is None:
|
||||
item["userid"] = get_current_userid()
|
||||
|
||||
return cls.insert_one(item)
|
||||
|
||||
@classmethod
|
||||
def get_all(cls, start: int, limit: int | None = None, userid: int | None = None):
|
||||
result = cls.execute(
|
||||
select(cls)
|
||||
.where(cls.userid == (userid if userid else get_current_userid()))
|
||||
.order_by(cls.timestamp.desc())
|
||||
.offset(start)
|
||||
.limit(limit)
|
||||
.execution_options(yield_per=100)
|
||||
)
|
||||
|
||||
for i in next(result).scalars():
|
||||
yield tracklog_to_dataclass(i)
|
||||
|
||||
@classmethod
|
||||
def get_all_in_period(cls, start_time: int, end_time: int, userid: int | None):
|
||||
# UserId will be None if function is called from the API
|
||||
# In that case, we use the request userid
|
||||
if userid is None:
|
||||
userid = get_current_userid()
|
||||
|
||||
result = cls.execute(
|
||||
select(cls)
|
||||
.where(cls.userid == userid)
|
||||
.where(and_(cls.timestamp >= start_time, cls.timestamp <= end_time))
|
||||
.order_by(cls.timestamp.desc())
|
||||
.execution_options(yield_per=100)
|
||||
)
|
||||
|
||||
for i in next(result).scalars():
|
||||
yield tracklog_to_dataclass(i)
|
||||
|
||||
@classmethod
|
||||
def get_last_entry(cls, userid: int):
|
||||
result = cls.execute(
|
||||
select(cls).where(cls.userid == userid).order_by(cls.timestamp.desc())
|
||||
)
|
||||
res = next(result).scalar()
|
||||
|
||||
if res:
|
||||
return tracklog_to_dataclass(res)
|
||||
|
||||
|
||||
class PlaylistTable(Base):
|
||||
__tablename__ = "playlist"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(), index=True)
|
||||
last_updated: Mapped[int] = mapped_column(Integer())
|
||||
image: Mapped[str] = mapped_column(String(), nullable=True)
|
||||
userid: Mapped[int] = mapped_column(
|
||||
Integer(), ForeignKey("user.id", ondelete="cascade")
|
||||
)
|
||||
settings: Mapped[dict[str, Any]] = mapped_column(JSON())
|
||||
trackhashes: Mapped[list[str]] = mapped_column(JSON(), default_factory=list)
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSON(), nullable=True, default_factory=dict
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all(cls, current_user: bool = True):
|
||||
if current_user:
|
||||
result = cls.execute(
|
||||
select(cls)
|
||||
.where(cls.userid == get_current_userid())
|
||||
.execution_options(yield_per=100)
|
||||
)
|
||||
else:
|
||||
result = cls.execute(select(cls).execution_options(yield_per=100))
|
||||
|
||||
for i in next(result).scalars():
|
||||
yield playlist_to_dataclass(i)
|
||||
|
||||
@classmethod
|
||||
def add_one(cls, playlist: dict[str, Any]):
|
||||
playlist["userid"] = get_current_userid()
|
||||
result = cls.insert_one(playlist)
|
||||
|
||||
return result.lastrowid
|
||||
|
||||
@classmethod
|
||||
def check_exists_by_name(cls, name: str):
|
||||
result = cls.execute(
|
||||
select(cls).where((cls.name == name) & (cls.userid == get_current_userid()))
|
||||
)
|
||||
return next(result).scalar() is not None
|
||||
|
||||
@classmethod
|
||||
def append_to_playlist(cls, id: int, trackhashes: list[str]):
|
||||
dbtrackhashes = cls.get_trackhashes(id) or []
|
||||
trackhashes = list(set(dbtrackhashes).union(set(trackhashes)))
|
||||
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where((cls.id == id) & (cls.userid == get_current_userid()))
|
||||
.values(trackhashes=trackhashes),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_trackhashes(cls, id: int):
|
||||
result = cls.execute(
|
||||
select(cls.trackhashes).where(
|
||||
(cls.id == id) & (cls.userid == get_current_userid())
|
||||
)
|
||||
)
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def remove_from_playlist(cls, id: int, trackhashes: list[dict[str, Any]]):
|
||||
# INFO: Get db trackhashes
|
||||
dbtrackhashes = cls.get_trackhashes(id)
|
||||
if dbtrackhashes:
|
||||
for item in trackhashes:
|
||||
if dbtrackhashes.index(item["trackhash"]) == item["index"]:
|
||||
dbtrackhashes.remove(item["trackhash"])
|
||||
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where((cls.id == id) & (cls.userid == get_current_userid()))
|
||||
.values(trackhashes=dbtrackhashes),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_by_id(cls, id: int):
|
||||
result = cls.execute(
|
||||
select(cls).where((cls.id == id) & (cls.userid == get_current_userid()))
|
||||
)
|
||||
result = next(result).scalar()
|
||||
|
||||
if result:
|
||||
return playlist_to_dataclass(result)
|
||||
|
||||
@classmethod
|
||||
def update_one(cls, id: int, playlist: dict[str, Any]):
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where((cls.id == id) & (cls.userid == get_current_userid()))
|
||||
.values(playlist),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def update_settings(cls, id: int, settings: dict[str, Any]):
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where((cls.id == id) & (cls.userid == get_current_userid()))
|
||||
.values(settings=settings),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def remove_image(cls, id: int):
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where((cls.id == id) & (cls.userid == get_current_userid()))
|
||||
.values(image=None),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class LibDataTable(Base):
|
||||
__tablename__ = "artistdata"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
itemhash: Mapped[str] = mapped_column(String(), unique=True, index=True)
|
||||
itemtype: Mapped[str] = mapped_column(String())
|
||||
color: Mapped[str] = mapped_column(String(), nullable=True)
|
||||
bio: Mapped[str] = mapped_column(String(), nullable=True)
|
||||
info: Mapped[dict[str, Any]] = mapped_column(JSON(), nullable=True)
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSON(), nullable=True, default_factory=dict
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def update_one(cls, hash: str, data: dict[str, Any]):
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls).where(cls.itemhash == hash).values(data), commit=True
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def find_one(cls, hash: str, type: Literal["album", "artist"]):
|
||||
result = cls.execute(
|
||||
select(cls).where((cls.itemhash == type + hash) & (cls.itemtype == type))
|
||||
)
|
||||
return next(result).scalar()
|
||||
|
||||
@classmethod
|
||||
def get_all_colors(cls, type: str) -> Iterable[dict[str, str]]:
|
||||
result = cls.execute(select(cls).where(cls.itemtype == type))
|
||||
|
||||
for i in next(result).scalars():
|
||||
yield {"itemhash": i.itemhash.replace(type, ""), "color": i.color}
|
||||
|
||||
|
||||
class MixTable(Base):
|
||||
__tablename__ = "mix"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
mixid: Mapped[str] = mapped_column(String(), index=True)
|
||||
title: Mapped[str] = mapped_column(String())
|
||||
description: Mapped[str] = mapped_column(String())
|
||||
timestamp: Mapped[int] = mapped_column(Integer())
|
||||
sourcehash: Mapped[str] = mapped_column(String(), unique=True, index=True)
|
||||
userid: Mapped[int] = mapped_column(
|
||||
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
|
||||
)
|
||||
saved: Mapped[bool] = mapped_column(Boolean(), default=False)
|
||||
tracks: Mapped[list[str]] = mapped_column(JSON(), default_factory=list)
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSON(), nullable=True, default_factory=dict
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all(cls, with_userid: bool = False):
|
||||
if with_userid:
|
||||
result = cls.execute(
|
||||
select(cls)
|
||||
.where(cls.userid == get_current_userid())
|
||||
.order_by(cls.timestamp.desc())
|
||||
)
|
||||
else:
|
||||
result = cls.execute(select(cls).order_by(cls.timestamp.desc()))
|
||||
|
||||
for i in next(result).scalars():
|
||||
yield Mix.mix_to_dataclass(i)
|
||||
|
||||
@classmethod
|
||||
def get_by_sourcehash(cls, sourcehash: str):
|
||||
result = cls.execute(
|
||||
select(cls).where(
|
||||
and_(cls.sourcehash == sourcehash, cls.userid == get_current_userid())
|
||||
)
|
||||
)
|
||||
|
||||
res = next(result).scalar()
|
||||
|
||||
if res:
|
||||
return Mix.mix_to_dataclass(res)
|
||||
|
||||
@classmethod
|
||||
def get_by_mixid(cls, mixid: str):
|
||||
result = cls.execute(
|
||||
select(cls).where(
|
||||
and_(cls.mixid == mixid, cls.userid == get_current_userid())
|
||||
)
|
||||
)
|
||||
res = next(result).scalar()
|
||||
|
||||
if res:
|
||||
return Mix.mix_to_dataclass(res)
|
||||
|
||||
@classmethod
|
||||
def insert_one(cls, mix: Mix):
|
||||
mixdict = asdict(mix)
|
||||
mixdict["mixid"] = mix.id
|
||||
del mixdict["id"]
|
||||
|
||||
return next(cls.execute(insert(cls).values(mixdict), commit=True))
|
||||
|
||||
@classmethod
|
||||
def update_one(cls, mixid: str, mix: Mix):
|
||||
mixdict = asdict(mix)
|
||||
mixdict["mixid"] = mix.id
|
||||
del mixdict["id"]
|
||||
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where(
|
||||
and_(
|
||||
cls.mixid == mixid,
|
||||
cls.sourcehash == mix.sourcehash,
|
||||
cls.userid == get_current_userid(),
|
||||
)
|
||||
)
|
||||
.values(mixdict),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_artist_mix(cls, sourcehash: str):
|
||||
"""
|
||||
Toggles the saved status of an artist mix.
|
||||
"""
|
||||
|
||||
mix = cls.get_by_sourcehash(sourcehash)
|
||||
|
||||
if not mix:
|
||||
return False
|
||||
|
||||
mix.saved = not mix.saved
|
||||
cls.update_one(mix.id, mix)
|
||||
|
||||
return mix.saved
|
||||
|
||||
@classmethod
|
||||
def get_saved_track_mixes(cls):
|
||||
"""
|
||||
Return all mixes that have the extra.trackmix_saved set to True.
|
||||
"""
|
||||
|
||||
result = cls.execute(
|
||||
select(cls).where(
|
||||
and_(cls.extra.c.trackmix_saved, cls.userid == get_current_userid())
|
||||
)
|
||||
)
|
||||
# return Mix.mixes_to_dataclasses(result.fetchall())
|
||||
|
||||
for i in next(result).scalars():
|
||||
yield Mix.mix_to_dataclass(i)
|
||||
|
||||
@classmethod
|
||||
def save_track_mix(cls, sourcehash: str):
|
||||
"""
|
||||
Toggles the property extra.trackmix_saved to True.
|
||||
"""
|
||||
|
||||
mix = cls.get_by_sourcehash(sourcehash)
|
||||
if not mix:
|
||||
return False
|
||||
|
||||
mix.extra["trackmix_saved"] = not mix.extra.get("trackmix_saved", False)
|
||||
cls.update_one(mix.id, mix)
|
||||
|
||||
return mix.extra["trackmix_saved"]
|
||||
|
||||
|
||||
class CollectionTable(Base):
|
||||
# INFO: table name was kept as page to avoid breaking existing data
|
||||
__tablename__ = "page"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(), index=True)
|
||||
userid: Mapped[int] = mapped_column(
|
||||
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
|
||||
)
|
||||
items: Mapped[list[dict[str, Any]]] = mapped_column(JSON(), default_factory=list)
|
||||
extra: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSON(), nullable=True, default_factory=dict
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_dict(cls, entry: Any) -> dict[str, Any]:
|
||||
d = entry.__dict__
|
||||
del d["_sa_instance_state"]
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def get_all(cls):
|
||||
result = cls.execute(select(cls).where(cls.userid == get_current_userid()))
|
||||
|
||||
for i in next(result).scalars():
|
||||
yield cls.to_dict(i)
|
||||
|
||||
@classmethod
|
||||
def get_by_id(cls, id: int):
|
||||
result = cls.execute(
|
||||
select(cls).where(and_(cls.id == id, cls.userid == get_current_userid()))
|
||||
)
|
||||
res = next(result).scalar()
|
||||
|
||||
if res:
|
||||
return cls.to_dict(res)
|
||||
|
||||
@classmethod
|
||||
def delete_by_id(cls, id: int):
|
||||
return next(
|
||||
cls.execute(
|
||||
delete(cls).where(
|
||||
and_(cls.id == id, cls.userid == get_current_userid())
|
||||
),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def update_items(cls, id: int, items: list[dict[str, Any]]):
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where(and_(cls.id == id, cls.userid == get_current_userid()))
|
||||
.values(items=items),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def update_one(cls, payload: dict[str, Any]):
|
||||
return next(
|
||||
cls.execute(
|
||||
update(cls)
|
||||
.where(
|
||||
and_(cls.id == payload["id"], cls.userid == get_current_userid())
|
||||
)
|
||||
.values(payload),
|
||||
commit=True,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,100 @@
|
||||
from typing import Any
|
||||
|
||||
from swingmusic.config import UserConfig
|
||||
from swingmusic.models import Album as AlbumModel
|
||||
from swingmusic.models import Artist as ArtistModel
|
||||
from swingmusic.models import Track as TrackModel
|
||||
from swingmusic.models.favorite import Favorite
|
||||
from swingmusic.models.lastfm import SimilarArtist
|
||||
from swingmusic.models.logger import TrackLog
|
||||
from swingmusic.models.playlist import Playlist
|
||||
from swingmusic.models.plugins import Plugin
|
||||
from swingmusic.models.user import User
|
||||
|
||||
|
||||
def row_to_dict(row: Any):
|
||||
d = row.__dict__
|
||||
del d["_sa_instance_state"]
|
||||
return d
|
||||
|
||||
|
||||
def track_to_dataclass(track: dict, config: UserConfig):
|
||||
return TrackModel(**track, config=config)
|
||||
|
||||
|
||||
def tracks_to_dataclasses(tracks: Any):
|
||||
return [track_to_dataclass(track, UserConfig()) for track in tracks]
|
||||
|
||||
|
||||
def album_to_dataclass(album: Any):
|
||||
return AlbumModel(**album._asdict())
|
||||
|
||||
|
||||
def albums_to_dataclasses(albums: Any):
|
||||
return [album_to_dataclass(album) for album in albums]
|
||||
|
||||
|
||||
def artist_to_dataclass(artist: Any):
|
||||
return ArtistModel(**artist._asdict())
|
||||
|
||||
|
||||
def artists_to_dataclasses(artists: Any):
|
||||
return [artist_to_dataclass(artist) for artist in artists]
|
||||
|
||||
|
||||
# SECTION: User data helpers
|
||||
def similar_artist_to_dataclass(entry: Any):
|
||||
entry_dict = row_to_dict(entry)
|
||||
del entry_dict["id"]
|
||||
|
||||
return SimilarArtist(**entry_dict)
|
||||
|
||||
|
||||
def similar_artists_to_dataclass(entries: Any):
|
||||
return [similar_artist_to_dataclass(entry) for entry in entries]
|
||||
|
||||
|
||||
def favorite_to_dataclass(entry: Any):
|
||||
entry_dict = row_to_dict(entry)
|
||||
del entry_dict["id"]
|
||||
|
||||
return Favorite(**entry_dict)
|
||||
|
||||
|
||||
def favorites_to_dataclass(entries: Any):
|
||||
return [favorite_to_dataclass(entry) for entry in entries]
|
||||
|
||||
|
||||
def user_to_dataclass(entry: Any):
|
||||
return User(**row_to_dict(entry))
|
||||
|
||||
|
||||
# def user_to_dataclasses(entries: Any):
|
||||
# return [user_to_dataclass(entry) for entry in entries]
|
||||
|
||||
|
||||
def plugin_to_dataclass(entry: Any):
|
||||
entry_dict = row_to_dict(entry)
|
||||
del entry_dict["id"]
|
||||
return Plugin(**entry_dict)
|
||||
|
||||
|
||||
def plugin_to_dataclasses(entries: Any):
|
||||
return [plugin_to_dataclass(entry) for entry in entries]
|
||||
|
||||
|
||||
def tracklog_to_dataclass(entry: Any):
|
||||
return TrackLog(**row_to_dict(entry))
|
||||
|
||||
|
||||
def tracklog_to_dataclasses(entries: Any):
|
||||
return [tracklog_to_dataclass(entry) for entry in entries]
|
||||
|
||||
|
||||
def playlist_to_dataclass(entry: Any):
|
||||
entry_dict = row_to_dict(entry)
|
||||
return Playlist(**entry_dict)
|
||||
|
||||
|
||||
def playlists_to_dataclasses(entries: Any):
|
||||
return [playlist_to_dataclass(entry) for entry in entries]
|
||||
Reference in New Issue
Block a user