first commit

This commit is contained in:
Tomas Dvorak
2026-04-13 17:46:58 +02:00
commit 6e8fedf534
234 changed files with 53808 additions and 0 deletions
+66
View File
@@ -0,0 +1,66 @@
from typing import Any
from sqlalchemy import (
delete,
func,
insert,
select,
)
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass
from swingmusic.db.engine import DbEngine
class Base(MappedAsDataclass, DeclarativeBase):
"""
Base class for all database models.
It has methods common to all tables. eg. `insert_one`, `insert_many`, `remove_all`, `remove_one`, `all`, `count`.
"""
@classmethod
def execute(cls, stmt: Any, commit: bool = False):
with DbEngine.manager(commit=commit) as session:
result = session.execute(stmt.execution_options(yield_per=100))
if commit:
session.commit()
yield result
@classmethod
def insert_many(cls, items: list[dict[str, Any]]):
"""
Inserts multiple items into the database.
"""
return next(cls.execute(insert(cls).values(items), commit=True))
@classmethod
def insert_one(cls, item: dict[str, Any]):
"""
Inserts a single item into the database.
"""
return cls.insert_many([item])
@classmethod
def remove_all(cls):
return next(cls.execute(delete(cls), commit=True))
@classmethod
def remove_one(cls, id: int):
return next(cls.execute(delete(cls).where(cls.id == id), commit=True))
@classmethod
def all(cls):
return next(cls.execute(select(cls).execution_options(yield_per=100)))
@classmethod
def count(cls):
return next(cls.execute(select(func.count()).select_from(cls))).scalar()
def create_all_tables():
"""
Creates all the tables that build on the Base class.
"""
Base().metadata.create_all(DbEngine.engine)
+385
View File
@@ -0,0 +1,385 @@
"""
Native DragonflyDB Client for SwingMusic
Integrated as a native database service like SQLite, providing:
- Ultra-fast caching for all services
- Session management
- User preferences
- Temporary data storage
- Real-time features
"""
import json
import logging
import os
from typing import Any
logger = logging.getLogger(__name__)
class DragonflyDBClient:
"""
Native DragonflyDB client integrated into SwingMusic
Provides Redis-compatible operations with automatic fallback
"""
def __init__(
self,
host: str | None = None,
port: int | None = None,
db: int | None = None,
):
self.host = host or os.environ.get("DRAGONFLYDB_HOST", "localhost")
self.port = port or int(os.environ.get("DRAGONFLYDB_PORT", "6379"))
self.db = db if db is not None else int(os.environ.get("DRAGONFLYDB_DB", "0"))
self.client = None
self.available = False
self._connect()
def _connect(self):
"""Connect to DragonflyDB with fallback handling"""
try:
import redis
self.client = redis.Redis(
host=self.host,
port=self.port,
db=self.db,
decode_responses=True,
socket_connect_timeout=2,
socket_timeout=2,
retry_on_timeout=True,
health_check_interval=30,
)
# Test connection
self.client.ping()
self.available = True
logger.info(f"✅ DragonflyDB connected at {self.host}:{self.port}")
except ImportError:
logger.warning("❌ Redis library not installed, DragonflyDB unavailable")
self.available = False
except Exception as e:
logger.warning(f"❌ DragonflyDB connection failed: {e}")
self.available = False
def is_available(self) -> bool:
"""Check if DragonflyDB is available"""
if not self.available or not self.client:
self._connect()
if not self.available or not self.client:
return False
try:
self.client.ping()
return True
except Exception:
self.available = False
return False
def set(self, key: str, value: Any, ttl: int | None = None) -> bool:
"""Set a key-value pair with optional TTL"""
if not self.is_available():
return False
try:
serialized_value = (
json.dumps(value) if not isinstance(value, str) else value
)
if ttl:
return self.client.setex(key, ttl, serialized_value)
else:
return self.client.set(key, serialized_value)
except Exception as e:
logger.debug(f"DragonflyDB set failed: {e}")
return False
def get(self, key: str) -> Any | None:
"""Get a value by key"""
if not self.is_available():
return None
try:
value = self.client.get(key)
if value is None:
return None
# Try to deserialize as JSON
try:
return json.loads(value)
except json.JSONDecodeError:
return value
except Exception as e:
logger.debug(f"DragonflyDB get failed: {e}")
return None
def delete(self, key: str) -> bool:
"""Delete a key"""
if not self.is_available():
return False
try:
return bool(self.client.delete(key))
except Exception as e:
logger.debug(f"DragonflyDB delete failed: {e}")
return False
def exists(self, key: str) -> bool:
"""Check if key exists"""
if not self.is_available():
return False
try:
return bool(self.client.exists(key))
except Exception as e:
logger.debug(f"DragonflyDB exists failed: {e}")
return False
def expire(self, key: str, ttl: int) -> bool:
"""Set TTL for existing key"""
if not self.is_available():
return False
try:
return bool(self.client.expire(key, ttl))
except Exception as e:
logger.debug(f"DragonflyDB expire failed: {e}")
return False
def ttl(self, key: str) -> int:
"""Get TTL for key"""
if not self.is_available():
return -1
try:
return self.client.ttl(key)
except Exception as e:
logger.debug(f"DragonflyDB ttl failed: {e}")
return -1
def keys(self, pattern: str = "*") -> list[str]:
"""Get keys matching pattern"""
if not self.is_available():
return []
try:
return self.client.keys(pattern)
except Exception as e:
logger.debug(f"DragonflyDB keys failed: {e}")
return []
def incr(self, key: str, amount: int = 1) -> int:
"""Increment value by amount"""
if not self.is_available():
return 0
try:
return self.client.incr(key, amount)
except Exception as e:
logger.debug(f"DragonflyDB incr failed: {e}")
return 0
def lpush(self, key: str, *values) -> int:
"""Push values to left of list"""
if not self.is_available():
return 0
try:
return self.client.lpush(key, *values)
except Exception as e:
logger.debug(f"DragonflyDB lpush failed: {e}")
return 0
def rpop(self, key: str) -> str | None:
"""Pop value from right of list"""
if not self.is_available():
return None
try:
return self.client.rpop(key)
except Exception as e:
logger.debug(f"DragonflyDB rpop failed: {e}")
return None
def lrange(self, key: str, start: int, end: int) -> list[str]:
"""Get range of list elements"""
if not self.is_available():
return []
try:
return self.client.lrange(key, start, end)
except Exception as e:
logger.debug(f"DragonflyDB lrange failed: {e}")
return []
def llen(self, key: str) -> int:
"""Get length of list"""
if not self.is_available():
return 0
try:
return self.client.llen(key)
except Exception as e:
logger.debug(f"DragonflyDB llen failed: {e}")
return 0
def lrem(self, key: str, count: int, value: str) -> int:
"""Remove elements from list"""
if not self.is_available():
return 0
try:
return self.client.lrem(key, count, value)
except Exception as e:
logger.debug(f"DragonflyDB lrem failed: {e}")
return 0
def ltrim(self, key: str, start: int, end: int) -> bool:
"""Trim list to range"""
if not self.is_available():
return False
try:
return self.client.ltrim(key, start, end)
except Exception as e:
logger.debug(f"DragonflyDB ltrim failed: {e}")
return False
def flushdb(self) -> bool:
"""Clear all keys in current database"""
if not self.is_available():
return False
try:
return self.client.flushdb()
except Exception as e:
logger.debug(f"DragonflyDB flushdb failed: {e}")
return False
def ping(self) -> bool:
"""Ping DragonflyDB."""
if not self.is_available():
return False
try:
return bool(self.client.ping())
except Exception as e:
logger.debug(f"DragonflyDB ping failed: {e}")
self.available = False
return False
def info(self) -> dict[str, Any]:
"""Get DragonflyDB server info"""
if not self.is_available():
return {}
try:
info = self.client.info()
return {
"version": info.get("redis_version", "unknown"),
"used_memory": info.get("used_memory", 0),
"used_memory_human": info.get("used_memory_human", "0B"),
"connected_clients": info.get("connected_clients", 0),
"total_commands_processed": info.get("total_commands_processed", 0),
"keyspace_hits": info.get("keyspace_hits", 0),
"keyspace_misses": info.get("keyspace_misses", 0),
"uptime_in_seconds": info.get("uptime_in_seconds", 0),
}
except Exception as e:
logger.debug(f"DragonflyDB info failed: {e}")
return {}
def close(self):
"""Close DragonflyDB connection"""
if self.client:
try:
self.client.close()
logger.info("DragonflyDB connection closed")
except Exception:
pass
# Global DragonflyDB instance (like SQLite)
_dragonfly_client: DragonflyDBClient | None = None
def get_dragonfly_client() -> DragonflyDBClient:
"""Get the global DragonflyDB client instance"""
global _dragonfly_client
if _dragonfly_client is None:
_dragonfly_client = DragonflyDBClient()
return _dragonfly_client
def init_dragonfly_if_available() -> bool:
"""Initialize DragonflyDB if available"""
client = get_dragonfly_client()
return client.is_available()
class DragonflyCache:
"""High-level cache interface using DragonflyDB"""
def __init__(self, prefix: str = "swingmusic"):
self.client = get_dragonfly_client()
self.prefix = prefix
def _make_key(self, key: str) -> str:
"""Create namespaced key"""
return f"{self.prefix}:{key}"
def set(self, key: str, value: Any, ttl_hours: int = 12) -> bool:
"""Set cache value with TTL in hours"""
ttl_seconds = ttl_hours * 3600
return self.client.set(self._make_key(key), value, ttl_seconds)
def get(self, key: str) -> Any | None:
"""Get cache value"""
return self.client.get(self._make_key(key))
def delete(self, key: str) -> bool:
"""Delete cache value"""
return self.client.delete(self._make_key(key))
def exists(self, key: str) -> bool:
"""Check if cache value exists"""
return self.client.exists(self._make_key(key))
def clear_all(self) -> bool:
"""Clear all SwingMusic cache entries"""
if not self.client.is_available():
return False
keys = self.client.keys(f"{self.prefix}:*")
if keys:
return self.client.client.delete(*keys) > 0
return True
# Native cache instances for different purposes
spotify_cache = DragonflyCache("spotify")
session_cache = DragonflyCache("session")
user_cache = DragonflyCache("user")
temp_cache = DragonflyCache("temp")
def get_spotify_cache() -> DragonflyCache:
"""Get Spotify metadata cache"""
return spotify_cache
def get_session_cache() -> DragonflyCache:
"""Get user session cache"""
return session_cache
def get_user_cache() -> DragonflyCache:
"""Get user preferences cache"""
return user_cache
def get_temp_cache() -> DragonflyCache:
"""Get temporary data cache"""
return temp_cache
+434
View File
@@ -0,0 +1,434 @@
"""
Extended DragonflyDB Client for SwingMusic
Comprehensive caching system with 15+ cache services for:
- Track metadata and persistence
- User sessions and preferences
- Mobile offline synchronization
- Real-time features and analytics
- Background job processing
- Search and recommendations
"""
import json
import logging
from typing import Any
from swingmusic.db.dragonfly_client import DragonflyCache, get_dragonfly_client
logger = logging.getLogger(__name__)
class ExtendedDragonflyServices:
"""
Extended DragonflyDB services for complete SwingMusic integration
"""
def __init__(self):
self.client = get_dragonfly_client()
# Core performance caches
self.track_cache = DragonflyCache("tracks")
self.artist_cache = DragonflyCache("artists")
self.album_cache = DragonflyCache("albums")
# User experience caches
self.session_cache = DragonflyCache("sessions")
self.user_cache = DragonflyCache("users")
self.search_cache = DragonflyCache("search")
self.homepage_cache = DragonflyCache("homepage")
# Mobile and offline caches
self.mobile_cache = DragonflyCache("mobile")
self.sync_cache = DragonflyCache("sync")
self.progress_cache = DragonflyCache("progress")
self.playlist_cache = DragonflyCache("playlists")
# Real-time feature caches
self.playcount_cache = DragonflyCache("playcounts")
self.recent_cache = DragonflyCache("recent")
self.favorite_cache = DragonflyCache("favorites")
self.recommendation_cache = DragonflyCache("recommendations")
# Background processing caches
self.job_cache = DragonflyCache("jobs")
self.lyrics_cache = DragonflyCache("lyrics")
self.index_cache = DragonflyCache("index")
self.temp_cache = DragonflyCache("temp")
logger.info("Extended DragonflyDB services initialized")
class TrackCacheService:
"""High-performance track caching with persistence"""
def __init__(self):
self.cache = DragonflyCache("tracks")
def get_track(self, trackhash: str) -> dict[str, Any] | None:
"""Get track data from cache"""
return self.cache.get(f"track:{trackhash}")
def set_track(
self, trackhash: str, track_data: dict[str, Any], ttl_hours: int = 24
):
"""Cache track data"""
return self.cache.set(f"track:{trackhash}", track_data, ttl_hours)
def get_track_batch(self, trackhashes: list[str]) -> dict[str, Any]:
"""Get multiple tracks from cache"""
results = {}
for trackhash in trackhashes:
track = self.get_track(trackhash)
if track:
results[trackhash] = track
return results
def set_track_batch(self, tracks: dict[str, dict[str, Any]], ttl_hours: int = 24):
"""Cache multiple tracks"""
success_count = 0
for trackhash, track_data in tracks.items():
if self.set_track(trackhash, track_data, ttl_hours):
success_count += 1
return success_count
def invalidate_track(self, trackhash: str):
"""Remove track from cache"""
return self.cache.delete(f"track:{trackhash}")
def get_stats(self) -> dict[str, Any]:
"""Get track cache statistics"""
keys = self.cache.client.keys("tracks:track:*")
return {
"total_tracks": len(keys),
"memory_usage": self.cache.client.info().get(
"used_memory_human", "Unknown"
),
}
class UserSessionService:
"""Ultra-fast user session management"""
def __init__(self):
self.cache = DragonflyCache("sessions")
# Backward compatibility for older call sites.
self.session_cache = self.cache
def create_session(
self, session_token: str, user_data: dict[str, Any], ttl_hours: int = 24
):
"""Create user session"""
return self.cache.set(f"session:{session_token}", user_data, ttl_hours)
def set_user_session(
self, userid: int, user_data: dict[str, Any], ttl_seconds: int = 24 * 3600
):
"""Store latest session payload by user id for quick lookups."""
ttl_hours = max(1, int(ttl_seconds // 3600))
return self.cache.set(f"user_session:{userid}", user_data, ttl_hours)
def get_user_session(self, userid: int) -> dict[str, Any] | None:
"""Get latest session payload for a user id."""
return self.cache.get(f"user_session:{userid}")
def get_session(self, session_token: str) -> dict[str, Any] | None:
"""Get user session"""
return self.cache.get(f"session:{session_token}")
def refresh_session(self, session_token: str, ttl_hours: int = 24):
"""Refresh session TTL"""
return self.cache.expire(f"session:{session_token}", ttl_hours * 3600)
def invalidate_session(self, session_token: str):
"""Invalidate user session"""
return self.cache.delete(f"session:{session_token}")
def invalidate_user_session(self, userid: int):
"""Invalidate latest session payload for a user id."""
return self.cache.delete(f"user_session:{userid}")
def get_user_sessions(self, userid: int) -> list[str]:
"""Get all active sessions for user"""
pattern = "session:*"
keys = self.cache.client.keys(pattern)
user_sessions = []
for key in keys:
session_data = self.cache.get(key.replace("session:", ""))
if session_data and session_data.get("userid") == userid:
user_sessions.append(key)
return user_sessions
class MobileSyncService:
"""Reliable mobile offline synchronization"""
def __init__(self):
self.cache = DragonflyCache("mobile")
def queue_sync_action(self, userid: int, action: dict[str, Any]):
"""Queue a sync action for mobile device"""
queue_key = f"sync_queue:user:{userid}"
return self.cache.client.lpush(queue_key, json.dumps(action))
def get_sync_actions(self, userid: int, count: int = 10) -> list[dict[str, Any]]:
"""Get pending sync actions for user"""
queue_key = f"sync_queue:user:{userid}"
actions_data = self.cache.client.lrange(queue_key, 0, count - 1)
actions = []
for action_data in actions_data:
try:
actions.append(json.loads(action_data))
except json.JSONDecodeError:
continue
return actions
def mark_sync_completed(self, userid: int, action_id: str):
"""Mark sync action as completed"""
# Remove from queue
queue_key = f"sync_queue:user:{userid}"
return self.cache.client.lrem(queue_key, 1, action_id)
def set_sync_state(self, userid: int, device_id: str, state: dict[str, Any]):
"""Set device sync state"""
state_key = f"sync_state:user:{userid}:device:{device_id}"
return self.cache.set(state_key, state, ttl_hours=24)
def get_sync_state(self, userid: int, device_id: str) -> dict[str, Any] | None:
"""Get device sync state"""
state_key = f"sync_state:user:{userid}:device:{device_id}"
return self.cache.get(state_key)
class RealTimeFeaturesService:
"""Real-time features like play counts and favorites"""
def __init__(self):
self.playcount_cache = DragonflyCache("playcounts")
self.recent_cache = DragonflyCache("recent")
self.favorite_cache = DragonflyCache("favorites")
def increment_playcount(self, trackhash: str, userid: int | None = None):
"""Increment track play count"""
key = f"plays:{trackhash}"
if userid:
key = f"plays:user:{userid}:track:{trackhash}"
return self.playcount_cache.client.incr(key)
def get_playcount(self, trackhash: str, userid: int | None = None) -> int:
"""Get track play count"""
key = f"plays:{trackhash}"
if userid:
key = f"plays:user:{userid}:track:{trackhash}"
count = self.playcount_cache.client.get(key)
return int(count) if count else 0
def add_to_recently_played(self, userid: int, trackhash: str, limit: int = 50):
"""Add track to recently played list"""
key = f"recent:user:{userid}"
# Add to beginning of list
self.recent_cache.client.lpush(key, trackhash)
# Remove duplicates
self.recent_cache.client.lrem(key, 1, trackhash)
# Add back to beginning
self.recent_cache.client.lpush(key, trackhash)
# Limit list size
self.recent_cache.client.ltrim(key, 0, limit - 1)
# Set TTL
self.recent_cache.client.expire(key, 7 * 24 * 3600) # 7 days
def get_recently_played(self, userid: int, limit: int = 50) -> list[str]:
"""Get recently played tracks for user"""
key = f"recent:user:{userid}"
return self.recent_cache.client.lrange(key, 0, limit - 1)
def toggle_favorite(self, userid: int, trackhash: str) -> bool:
"""Toggle favorite status for track"""
key = f"fav:user:{userid}:track:{trackhash}"
current = self.favorite_cache.client.get(key)
if current:
# Remove favorite
self.favorite_cache.client.delete(key)
return False
else:
# Add favorite
self.favorite_cache.client.set(key, True, ttl_hours=24 * 30) # 30 days
return True
def is_favorite(self, userid: int, trackhash: str) -> bool:
"""Check if track is favorited by user"""
key = f"fav:user:{userid}:track:{trackhash}"
return bool(self.favorite_cache.client.get(key))
def get_user_favorites(self, userid: int) -> list[str]:
"""Get all favorite tracks for user"""
pattern = f"fav:user:{userid}:track:*"
keys = self.favorite_cache.client.keys(pattern)
favorites = []
for key in keys:
trackhash = key.split(":")[-1]
favorites.append(trackhash)
return favorites
class SearchCacheService:
"""High-performance search results caching"""
def __init__(self):
self.cache = DragonflyCache("search")
def cache_search_results(
self, query: str, results: dict[str, Any], ttl_hours: int = 6
):
"""Cache search results"""
query_hash = hash(query) # Simple hash for key
return self.cache.set(f"results:{query_hash}", results, ttl_hours)
def get_search_results(self, query: str) -> dict[str, Any] | None:
"""Get cached search results"""
query_hash = hash(query)
return self.cache.get(f"results:{query_hash}")
def cache_suggestions(
self, query_type: str, suggestions: list[str], ttl_hours: int = 12
):
"""Cache search suggestions"""
return self.cache.set(f"suggestions:{query_type}", suggestions, ttl_hours)
def get_suggestions(self, query_type: str) -> list[str]:
"""Get cached search suggestions"""
suggestions = self.cache.get(f"suggestions:{query_type}")
return suggestions if suggestions else []
def invalidate_search_cache(self, pattern: str = "*"):
"""Invalidate search cache"""
keys = self.cache.client.keys(f"search:{pattern}")
if keys:
return self.cache.client.delete(*keys)
return True
class JobQueueService:
"""High-performance background job processing"""
def __init__(self):
self.cache = DragonflyCache("jobs")
def enqueue_job(self, queue: str, job_data: dict[str, Any]):
"""Add job to queue"""
job_json = json.dumps(job_data)
return self.cache.client.lpush(f"queue:{queue}", job_json)
def dequeue_job(self, queue: str) -> dict[str, Any] | None:
"""Get next job from queue"""
job_json = self.cache.client.rpop(f"queue:{queue}")
if job_json:
try:
return json.loads(job_json)
except json.JSONDecodeError:
return None
return None
def get_queue_size(self, queue: str) -> int:
"""Get number of jobs in queue"""
return self.cache.client.llen(f"queue:{queue}")
def peek_jobs(self, queue: str, count: int = 10) -> list[dict[str, Any]]:
"""Peek at jobs in queue without removing them"""
jobs_data = self.cache.client.lrange(f"queue:{queue}", 0, count - 1)
jobs = []
for job_data in jobs_data:
try:
jobs.append(json.loads(job_data))
except json.JSONDecodeError:
continue
return jobs
def clear_queue(self, queue: str):
"""Clear all jobs from queue"""
return self.cache.client.delete(f"queue:{queue}")
# Global service instances
_track_cache_service: TrackCacheService | None = None
_user_session_service: UserSessionService | None = None
_mobile_sync_service: MobileSyncService | None = None
_realtime_service: RealTimeFeaturesService | None = None
_search_cache_service: SearchCacheService | None = None
_job_queue_service: JobQueueService | None = None
def get_track_cache_service() -> TrackCacheService:
"""Get track cache service instance"""
global _track_cache_service
if _track_cache_service is None:
_track_cache_service = TrackCacheService()
return _track_cache_service
def get_user_session_service() -> UserSessionService:
"""Get user session service instance"""
global _user_session_service
if _user_session_service is None:
_user_session_service = UserSessionService()
return _user_session_service
def get_mobile_sync_service() -> MobileSyncService:
"""Get mobile sync service instance"""
global _mobile_sync_service
if _mobile_sync_service is None:
_mobile_sync_service = MobileSyncService()
return _mobile_sync_service
def get_realtime_service() -> RealTimeFeaturesService:
"""Get real-time features service instance"""
global _realtime_service
if _realtime_service is None:
_realtime_service = RealTimeFeaturesService()
return _realtime_service
def get_search_cache_service() -> SearchCacheService:
"""Get search cache service instance"""
global _search_cache_service
if _search_cache_service is None:
_search_cache_service = SearchCacheService()
return _search_cache_service
def get_job_queue_service() -> JobQueueService:
"""Get job queue service instance"""
global _job_queue_service
if _job_queue_service is None:
_job_queue_service = JobQueueService()
return _job_queue_service
def get_all_dragonfly_services() -> dict[str, Any]:
"""Get all DragonflyDB services for monitoring"""
return {
"track_cache": get_track_cache_service(),
"user_sessions": get_user_session_service(),
"mobile_sync": get_mobile_sync_service(),
"realtime": get_realtime_service(),
"search_cache": get_search_cache_service(),
"job_queue": get_job_queue_service(),
}
+79
View File
@@ -0,0 +1,79 @@
from contextlib import contextmanager
from sqlalchemy import Engine, create_engine, event
from sqlalchemy.orm import sessionmaker
from swingmusic.settings import Paths
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA cache_size=10000")
cursor.execute("PRAGMA foreign_keys=ON")
cursor.execute("PRAGMA temp_store=FILE")
cursor.execute("PRAGMA mmap_size=0")
cursor.close()
class classproperty(property):
"""
A class property decorator.
"""
def __get__(self, owner_self, owner_cls):
if self.fget:
return self.fget(owner_cls)
class DbEngine:
"""
The database engine instance.
"""
_engine: Engine | None = None
@classproperty
def engine(cls) -> Engine:
if not cls._engine:
cls._engine = create_engine(
f"sqlite+pysqlite:///{Paths().app_db_path}",
echo=False,
max_overflow=20,
pool_size=10,
)
return cls._engine
@classmethod
@contextmanager
def manager(cls, commit: bool = False):
"""
This context manager manages access to the database.
When the context manager is entered, it returns a session object that can be used to execute SQL statements.
If the `commit` parameter is set to `True`, the context manager will commit the transaction when it exits.
"""
Session = sessionmaker(cls.engine)
try:
with Session() as session:
yield session
if commit:
session.commit()
# yield session.execution_options(preserve_rowcount=True, yield_per=100)
# yield conn.execution_options(preserve_rowcount=True, yield_per=100)
except Exception as e:
session.rollback()
raise e
finally:
if commit:
session.commit()
session.close()
# del conn
# cls.engine.clear_compiled_cache()
+79
View File
@@ -0,0 +1,79 @@
from typing import Any
from sqlalchemy import JSON, Integer, String, delete, select
from sqlalchemy.orm import Mapped, mapped_column
from swingmusic.config import UserConfig
from swingmusic.db import Base
from swingmusic.db.engine import DbEngine
from swingmusic.db.utils import track_to_dataclass, tracks_to_dataclasses
class TrackTable(Base):
__tablename__ = "track"
id: Mapped[int] = mapped_column(init=False, primary_key=True)
album: Mapped[str] = mapped_column(String())
albumartists: Mapped[str] = mapped_column(String())
albumhash: Mapped[str] = mapped_column(String(), index=True)
artists: Mapped[str] = mapped_column(String())
bitrate: Mapped[int] = mapped_column(Integer())
copyright: Mapped[str | None] = mapped_column(String())
date: Mapped[int] = mapped_column(Integer(), nullable=True)
disc: Mapped[int] = mapped_column(Integer())
duration: Mapped[int] = mapped_column(Integer())
filepath: Mapped[str] = mapped_column(String(), index=True, unique=True)
folder: Mapped[str] = mapped_column(String(), index=True)
genres: Mapped[str | None] = mapped_column(String())
last_mod: Mapped[float] = mapped_column(Integer())
title: Mapped[str] = mapped_column(String())
track: Mapped[int] = mapped_column(Integer())
trackhash: Mapped[str] = mapped_column(String(), index=True)
lastplayed: Mapped[int] = mapped_column(Integer(), default=0)
playcount: Mapped[int] = mapped_column(Integer(), default=0)
playduration: Mapped[int] = mapped_column(Integer(), default=0)
extra: Mapped[dict[str, Any] | None] = mapped_column(JSON(), default_factory=dict)
@classmethod
def get_all(cls):
with DbEngine.manager() as conn:
config = UserConfig()
result = conn.execute(select(cls).execution_options(yield_per=100))
for i in result.scalars():
d = i.__dict__
del d["_sa_instance_state"]
yield track_to_dataclass(d, config)
@classmethod
def get_tracks_by_filepaths(cls, filepaths: list[str]):
with DbEngine.manager() as conn:
result = conn.execute(
select(TrackTable)
.where(TrackTable.filepath.in_(filepaths))
.order_by(TrackTable.last_mod)
)
return tracks_to_dataclasses(result.fetchall())
@classmethod
def get_tracks_in_path(cls, path: str):
with DbEngine.manager() as conn:
result = conn.execute(
select(TrackTable)
.where(TrackTable.filepath.contains(path))
.order_by(TrackTable.last_mod)
)
clean = []
for row in result.fetchall():
d = row[0].__dict__
del d["_sa_instance_state"]
clean.append(d)
return tracks_to_dataclasses(clean)
@classmethod
def remove_tracks_by_filepaths(cls, filepaths: set[str]):
with DbEngine.manager(commit=True) as conn:
conn.execute(delete(TrackTable).where(TrackTable.filepath.in_(filepaths)))
+33
View File
@@ -0,0 +1,33 @@
from sqlalchemy import Integer, insert, select, update
from sqlalchemy.orm import Mapped, mapped_column
from swingmusic.db import Base
from swingmusic.db.engine import DbEngine
class MigrationTable(Base):
__tablename__ = "dbmigration"
id: Mapped[int] = mapped_column(primary_key=True)
version: Mapped[int] = mapped_column(Integer())
@classmethod
def set_version(cls, version: int):
with DbEngine.manager(commit=True) as conn:
result = conn.execute(
update(cls).where(cls.id == 1).values(version=version)
)
if result.rowcount == 0:
conn.execute(insert(cls).values(id=1, version=version))
@classmethod
def get_version(cls):
with DbEngine.manager() as conn:
result = conn.execute(select(cls.version).where(cls.id == 1))
result = result.fetchone()
if result:
return result[0]
return -1
+745
View File
@@ -0,0 +1,745 @@
from __future__ import annotations
import secrets
import time
from typing import Any
from sqlalchemy import (
JSON,
Boolean,
Float,
ForeignKey,
Integer,
String,
UniqueConstraint,
and_,
delete,
select,
update,
)
from sqlalchemy.orm import Mapped, mapped_column
from swingmusic.db import Base
class LibraryFileTable(Base):
__tablename__ = "library_file"
id: Mapped[int] = mapped_column(primary_key=True)
trackhash: Mapped[str] = mapped_column(String(), unique=True, index=True)
filepath: Mapped[str] = mapped_column(String(), unique=True, index=True)
codec: Mapped[str] = mapped_column(String(), default="unknown")
quality: Mapped[str] = mapped_column(String(), default="unknown")
bitrate: Mapped[int] = mapped_column(Integer(), default=0)
source: Mapped[str] = mapped_column(String(), default="local")
checksum: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
@classmethod
def get_by_trackhash(cls, trackhash: str):
result = cls.execute(select(cls).where(cls.trackhash == trackhash))
return next(result).scalar()
@classmethod
def upsert_from_local_track(
cls,
*,
trackhash: str,
filepath: str,
bitrate: int,
codec: str,
quality: str,
source: str = "local",
):
now = int(time.time())
row = cls.get_by_trackhash(trackhash)
if row:
next(
cls.execute(
update(cls)
.where(cls.id == row.id)
.values(
filepath=filepath,
bitrate=bitrate,
codec=codec,
quality=quality,
source=source,
updated_at=now,
),
commit=True,
)
)
return cls.get_by_trackhash(trackhash)
cls.insert_one(
{
"trackhash": trackhash,
"filepath": filepath,
"bitrate": bitrate,
"codec": codec,
"quality": quality,
"source": source,
"created_at": now,
"updated_at": now,
"extra": {},
}
)
return cls.get_by_trackhash(trackhash)
class DownloadJobTable(Base):
__tablename__ = "download_job"
id: Mapped[int] = mapped_column(primary_key=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
)
trackhash: Mapped[str | None] = mapped_column(
String(), nullable=True, index=True, default=None
)
title: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
artist: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
album: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
item_type: Mapped[str] = mapped_column(String(), default="track")
source_url: Mapped[str | None] = mapped_column(
String(), nullable=True, index=True, default=None
)
source: Mapped[str] = mapped_column(String(), default="spotify", index=True)
provider: Mapped[str] = mapped_column(String(), default="spotify")
codec: Mapped[str] = mapped_column(String(), default="mp3")
quality: Mapped[str] = mapped_column(String(), default="high")
target_path: Mapped[str | None] = mapped_column(
String(), nullable=True, default=None
)
state: Mapped[str] = mapped_column(String(), default="queued", index=True)
progress: Mapped[float] = mapped_column(Float(), default=0.0)
error: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
retry_count: Mapped[int] = mapped_column(Integer(), default=0)
payload: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
started_at: Mapped[int | None] = mapped_column(
Integer(), nullable=True, default=None
)
finished_at: Mapped[int | None] = mapped_column(
Integer(), nullable=True, default=None
)
@classmethod
def enqueue(cls, payload: dict[str, Any]):
now = int(time.time())
values = {
"created_at": now,
"updated_at": now,
"state": "queued",
"progress": 0.0,
**payload,
}
result = cls.insert_one(values)
return result.lastrowid
@classmethod
def get_by_id(cls, job_id: int):
result = cls.execute(select(cls).where(cls.id == job_id))
return next(result).scalar()
@classmethod
def get_queued_job(cls):
result = cls.execute(
select(cls)
.where(cls.state == "queued")
.order_by(cls.created_at.asc())
.limit(1)
)
return next(result).scalar()
@classmethod
def update_job(cls, job_id: int, values: dict[str, Any]):
values = {**values, "updated_at": int(time.time())}
return next(
cls.execute(update(cls).where(cls.id == job_id).values(values), commit=True)
)
@classmethod
def list_for_user(cls, userid: int, states: list[str] | set[str] | None = None):
query = select(cls).where(cls.userid == userid).order_by(cls.created_at.desc())
if states:
query = query.where(cls.state.in_(list(states)))
result = cls.execute(query)
return list(next(result).scalars())
@classmethod
def delete_for_user(
cls, userid: int, states: list[str] | set[str] | None = None
) -> int:
statement = delete(cls).where(cls.userid == userid)
if states:
statement = statement.where(cls.state.in_(list(states)))
result = next(cls.execute(statement, commit=True))
return int(result.rowcount or 0)
class TrackedPlaylistTable(Base):
__tablename__ = "tracked_playlist"
__table_args__ = (
UniqueConstraint(
"userid",
"service",
"playlist_id",
name="uq_tracked_playlist_user_service",
),
)
id: Mapped[int] = mapped_column(primary_key=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
)
source_url: Mapped[str] = mapped_column(String(), index=True)
playlist_id: Mapped[str] = mapped_column(String(), index=True)
service: Mapped[str] = mapped_column(String(), default="spotify", index=True)
title: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
owner_name: Mapped[str | None] = mapped_column(
String(), nullable=True, default=None
)
quality: Mapped[str] = mapped_column(String(), default="lossless")
codec: Mapped[str] = mapped_column(String(), default="flac")
auto_sync: Mapped[bool] = mapped_column(Boolean(), default=True, index=True)
sync_interval_seconds: Mapped[int] = mapped_column(Integer(), default=900)
next_sync_at: Mapped[int] = mapped_column(
Integer(), default=lambda: int(time.time())
)
last_sync_at: Mapped[int | None] = mapped_column(
Integer(), nullable=True, default=None
)
status: Mapped[str] = mapped_column(String(), default="active", index=True)
snapshot_track_ids: Mapped[list[str]] = mapped_column(JSON(), default_factory=list)
snapshot_hash: Mapped[str | None] = mapped_column(
String(), nullable=True, default=None
)
last_result: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
last_error: Mapped[str | None] = mapped_column(
String(), nullable=True, default=None
)
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
@classmethod
def get_by_id(cls, tracked_id: int, userid: int | None = None):
statement = select(cls).where(cls.id == tracked_id)
if userid is not None:
statement = statement.where(cls.userid == userid)
result = cls.execute(statement)
return next(result).scalar()
@classmethod
def get_by_source(cls, *, userid: int, service: str, playlist_id: str):
result = cls.execute(
select(cls).where(
and_(
cls.userid == userid,
cls.service == service,
cls.playlist_id == playlist_id,
)
)
)
return next(result).scalar()
@classmethod
def list_for_user(cls, userid: int, include_deleted: bool = False):
statement = (
select(cls).where(cls.userid == userid).order_by(cls.created_at.desc())
)
if not include_deleted:
statement = statement.where(cls.status != "deleted")
result = cls.execute(statement)
return list(next(result).scalars())
@classmethod
def upsert(
cls,
*,
userid: int,
service: str,
playlist_id: str,
source_url: str,
values: dict[str, Any] | None = None,
):
now = int(time.time())
row = cls.get_by_source(userid=userid, service=service, playlist_id=playlist_id)
payload: dict[str, Any] = {
"userid": userid,
"service": service,
"playlist_id": playlist_id,
"source_url": source_url,
}
if values:
payload.update(values)
if row:
next(
cls.execute(
update(cls)
.where(cls.id == row.id)
.values(
{
**payload,
"updated_at": now,
}
),
commit=True,
)
)
return cls.get_by_id(row.id)
cls.insert_one(
{
**payload,
"status": payload.get("status", "active"),
"auto_sync": bool(payload.get("auto_sync", True)),
"sync_interval_seconds": int(payload.get("sync_interval_seconds", 900)),
"next_sync_at": int(payload.get("next_sync_at", now)),
"created_at": now,
"updated_at": now,
"snapshot_track_ids": payload.get("snapshot_track_ids", []),
"last_result": payload.get("last_result", {}),
"extra": payload.get("extra", {}),
}
)
return cls.get_by_source(
userid=userid, service=service, playlist_id=playlist_id
)
@classmethod
def update_row(cls, tracked_id: int, values: dict[str, Any]):
next(
cls.execute(
update(cls)
.where(cls.id == tracked_id)
.values({**values, "updated_at": int(time.time())}),
commit=True,
)
)
return cls.get_by_id(tracked_id)
@classmethod
def due_for_sync(cls, *, now_ts: int | None = None, limit: int = 50):
now_ts = int(now_ts or time.time())
result = cls.execute(
select(cls)
.where(cls.auto_sync.is_(True))
.where(cls.status.in_(["active", "failed", "syncing"]))
.where(cls.next_sync_at <= now_ts)
.order_by(cls.next_sync_at.asc())
.limit(limit)
)
return list(next(result).scalars())
class UserLibraryTrackTable(Base):
__tablename__ = "user_library_track"
__table_args__ = (
UniqueConstraint("userid", "trackhash", name="uq_user_track_projection"),
)
id: Mapped[int] = mapped_column(primary_key=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
)
trackhash: Mapped[str] = mapped_column(String(), index=True)
file_id: Mapped[int] = mapped_column(
Integer(),
ForeignKey("library_file.id", ondelete="set null"),
nullable=True,
index=True,
)
status: Mapped[str] = mapped_column(String(), default="missing", index=True)
source_url: Mapped[str | None] = mapped_column(
String(), nullable=True, default=None
)
download_job_id: Mapped[int | None] = mapped_column(
Integer(),
ForeignKey("download_job.id", ondelete="set null"),
nullable=True,
default=None,
)
error: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
@classmethod
def get_user_track(cls, userid: int, trackhash: str):
result = cls.execute(
select(cls).where(and_(cls.userid == userid, cls.trackhash == trackhash))
)
return next(result).scalar()
@classmethod
def upsert_status(
cls,
*,
userid: int,
trackhash: str,
status: str,
file_id: int | None = None,
download_job_id: int | None = None,
source_url: str | None = None,
error: str | None = None,
extra: dict[str, Any] | None = None,
):
now = int(time.time())
row = cls.get_user_track(userid, trackhash)
values: dict[str, Any] = {
"status": status,
"updated_at": now,
"file_id": file_id,
"download_job_id": download_job_id,
"source_url": source_url,
"error": error,
}
if extra is not None:
values["extra"] = extra
if row:
next(
cls.execute(
update(cls).where(cls.id == row.id).values(values), commit=True
)
)
return cls.get_user_track(userid, trackhash)
cls.insert_one(
{
"userid": userid,
"trackhash": trackhash,
"status": status,
"file_id": file_id,
"download_job_id": download_job_id,
"source_url": source_url,
"error": error,
"created_at": now,
"updated_at": now,
"extra": extra or {},
}
)
return cls.get_user_track(userid, trackhash)
@classmethod
def get_status_map(cls, userid: int, trackhashes: set[str] | list[str]):
if not trackhashes:
return {}
result = cls.execute(
select(cls).where(
and_(cls.userid == userid, cls.trackhash.in_(set(trackhashes)))
)
)
rows = list(next(result).scalars())
return {row.trackhash: row for row in rows}
class UserRootDirOwnershipTable(Base):
__tablename__ = "user_root_dir_ownership"
__table_args__ = (UniqueConstraint("userid", "path", name="uq_user_root_path"),)
id: Mapped[int] = mapped_column(primary_key=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
)
path: Mapped[str] = mapped_column(String(), index=True)
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
@classmethod
def assign_paths(cls, userid: int, paths: list[str]):
existing_result = cls.execute(select(cls.path).where(cls.userid == userid))
existing = {row[0] for row in next(existing_result).all()}
for path in paths:
if path in existing:
continue
cls.insert_one(
{"userid": userid, "path": path, "created_at": int(time.time())}
)
@classmethod
def get_paths(cls, userid: int) -> list[str]:
result = cls.execute(select(cls.path).where(cls.userid == userid))
paths = [row for row in next(result).scalars().all() if row]
return list(dict.fromkeys(paths))
@classmethod
def replace_paths(cls, userid: int, paths: list[str]):
cleaned = [path.strip() for path in paths if path and path.strip()]
cleaned = list(dict.fromkeys(cleaned))
next(cls.execute(delete(cls).where(cls.userid == userid), commit=True))
if not cleaned:
return
now = int(time.time())
cls.insert_many(
[{"userid": userid, "path": path, "created_at": now} for path in cleaned]
)
class SetupStateTable(Base):
__tablename__ = "setup_state"
id: Mapped[int] = mapped_column(primary_key=True)
owner_userid: Mapped[int | None] = mapped_column(
Integer(),
ForeignKey("user.id", ondelete="set null"),
nullable=True,
default=None,
)
primary_music_dir: Mapped[str | None] = mapped_column(
String(), nullable=True, default=None
)
setup_completed: Mapped[bool] = mapped_column(Boolean(), default=False)
index_state: Mapped[str] = mapped_column(String(), default="idle")
index_progress: Mapped[float] = mapped_column(Float(), default=0.0)
index_message: Mapped[str | None] = mapped_column(
String(), nullable=True, default=None
)
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
@classmethod
def get_singleton(cls):
result = cls.execute(select(cls).where(cls.id == 1))
return next(result).scalar()
@classmethod
def ensure_singleton(cls):
row = cls.get_singleton()
if row:
return row
cls.insert_one(
{
"id": 1,
"setup_completed": False,
"index_state": "idle",
"index_progress": 0.0,
"created_at": int(time.time()),
"updated_at": int(time.time()),
"extra": {},
}
)
return cls.get_singleton()
@classmethod
def update_state(cls, values: dict[str, Any]):
cls.ensure_singleton()
next(
cls.execute(
update(cls)
.where(cls.id == 1)
.values(
{
**values,
"updated_at": int(time.time()),
}
),
commit=True,
)
)
return cls.get_singleton()
@classmethod
def mark_index_progress(
cls,
*,
state: str,
progress: float,
message: str | None = None,
extra: dict[str, Any] | None = None,
):
values: dict[str, Any] = {
"index_state": state,
"index_progress": max(0.0, min(float(progress), 100.0)),
"index_message": message,
}
if extra is not None:
values["extra"] = extra
return cls.update_state(values)
class LyricsStatusTable(Base):
__tablename__ = "lyrics_status"
id: Mapped[int] = mapped_column(primary_key=True)
trackhash: Mapped[str] = mapped_column(String(), unique=True, index=True)
filepath: Mapped[str | None] = mapped_column(
String(), nullable=True, index=True, default=None
)
status: Mapped[str] = mapped_column(String(), default="pending", index=True)
source: Mapped[str | None] = mapped_column(String(), nullable=True, default=None)
has_embedded: Mapped[bool] = mapped_column(Boolean(), default=False)
has_lrc: Mapped[bool] = mapped_column(Boolean(), default=False)
last_error: Mapped[str | None] = mapped_column(
String(), nullable=True, default=None
)
attempts: Mapped[int] = mapped_column(Integer(), default=0)
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
updated_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
@classmethod
def get_by_trackhash(cls, trackhash: str):
result = cls.execute(select(cls).where(cls.trackhash == trackhash))
return next(result).scalar()
@classmethod
def upsert(
cls,
*,
trackhash: str,
filepath: str | None = None,
status: str,
source: str | None = None,
has_embedded: bool | None = None,
has_lrc: bool | None = None,
last_error: str | None = None,
extra: dict[str, Any] | None = None,
increment_attempt: bool = False,
):
now = int(time.time())
row = cls.get_by_trackhash(trackhash)
values: dict[str, Any] = {
"status": status,
"source": source,
"last_error": last_error,
"updated_at": now,
}
if filepath is not None:
values["filepath"] = filepath
if has_embedded is not None:
values["has_embedded"] = bool(has_embedded)
if has_lrc is not None:
values["has_lrc"] = bool(has_lrc)
if extra is not None:
values["extra"] = extra
if row:
if increment_attempt:
values["attempts"] = int(row.attempts or 0) + 1
next(
cls.execute(
update(cls).where(cls.id == row.id).values(values), commit=True
)
)
return cls.get_by_trackhash(trackhash)
cls.insert_one(
{
"trackhash": trackhash,
"filepath": filepath,
"status": status,
"source": source,
"has_embedded": bool(has_embedded)
if has_embedded is not None
else False,
"has_lrc": bool(has_lrc) if has_lrc is not None else False,
"last_error": last_error,
"attempts": 1 if increment_attempt else 0,
"created_at": now,
"updated_at": now,
"extra": extra or {},
}
)
return cls.get_by_trackhash(trackhash)
class InviteTokenTable(Base):
__tablename__ = "invite_token"
id: Mapped[int] = mapped_column(primary_key=True)
token: Mapped[str] = mapped_column(String(), unique=True, index=True)
created_by: Mapped[int | None] = mapped_column(
Integer(),
ForeignKey("user.id", ondelete="set null"),
nullable=True,
default=None,
)
used_by: Mapped[int | None] = mapped_column(
Integer(),
ForeignKey("user.id", ondelete="set null"),
nullable=True,
default=None,
)
roles: Mapped[list[str]] = mapped_column(JSON(), default_factory=lambda: ["user"])
active: Mapped[bool] = mapped_column(Boolean(), default=True)
expires_at: Mapped[int | None] = mapped_column(
Integer(), nullable=True, default=None
)
created_at: Mapped[int] = mapped_column(Integer(), default=lambda: int(time.time()))
used_at: Mapped[int | None] = mapped_column(Integer(), nullable=True, default=None)
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), default_factory=dict)
@classmethod
def create_token(
cls,
*,
created_by: int | None,
roles: list[str] | None = None,
expires_in_seconds: int = 7 * 24 * 3600,
extra: dict[str, Any] | None = None,
):
token = secrets.token_urlsafe(24)
now = int(time.time())
expires_at = now + expires_in_seconds if expires_in_seconds > 0 else None
cls.insert_one(
{
"token": token,
"created_by": created_by,
"roles": roles or ["user"],
"active": True,
"expires_at": expires_at,
"created_at": now,
"extra": extra or {},
}
)
result = cls.execute(select(cls).where(cls.token == token))
return next(result).scalar()
@classmethod
def get_valid_token(cls, token: str):
now = int(time.time())
result = cls.execute(select(cls).where(cls.token == token))
row = next(result).scalar()
if not row or not row.active:
return None
if row.expires_at is not None and row.expires_at < now:
cls.consume_token(token, used_by=None, deactivate_only=True)
return None
return row
@classmethod
def consume_token(
cls, token: str, used_by: int | None, deactivate_only: bool = False
):
values: dict[str, Any] = {"active": False, "used_at": int(time.time())}
if not deactivate_only:
values["used_by"] = used_by
next(
cls.execute(
update(cls).where(cls.token == token).values(values), commit=True
)
)
File diff suppressed because it is too large Load Diff
+3
View File
@@ -0,0 +1,3 @@
"""
This module contains the functions to interact with the SQLite database.
"""
+31
View File
@@ -0,0 +1,31 @@
"""
Reads and saves the latest database migrations version.
"""
from swingmusic.db.sqlite.utils import SQLiteManager
class MigrationManager:
@staticmethod
def get_index() -> int:
"""
Returns the latest databases migrations index.
"""
sql = "SELECT * FROM dbmigrations"
with SQLiteManager() as cur:
cur.execute(sql)
ver = int(cur.fetchone()[1])
cur.close()
return ver
# 👇 Setters 👇
@staticmethod
def set_index(version: int):
"""
Updates the databases migrations index.
"""
sql = "UPDATE dbmigrations SET version = ? WHERE id = 1"
with SQLiteManager() as cur:
cur.execute(sql, (version,))
cur.close()
+118
View File
@@ -0,0 +1,118 @@
"""
Helper functions for use with the SQLite database.
"""
import sqlite3
import time
from sqlite3 import Connection, Cursor
from swingmusic import settings
from swingmusic.models import Album, Playlist, Track
def tuple_to_track(track: tuple):
"""
Takes a tuple and returns a Track object
"""
return Track(*track[1:]) # rowid is removed from the tuple
def tuples_to_tracks(tracks: list[tuple]):
"""
Takes a list of tuples and returns a generator that yields a Track object for each tuple
"""
for track in tracks:
yield tuple_to_track(track)
def tuple_to_album(album: tuple):
"""
Takes a tuple and returns an Album object
"""
return Album(*album[1:]) # rowid is removed from the tuple
def tuples_to_albums(albums: list[tuple]):
"""
Takes a list of tuples and returns a generator that yields an album object for each tuple
"""
for album in albums:
yield tuple_to_album(album)
def tuple_to_playlist(playlist: tuple):
"""
Takes a tuple and returns a Playlist object
"""
return Playlist(*playlist)
def tuples_to_playlists(playlists: list[tuple]):
"""
Takes a list of tuples and returns a list of Playlist objects
"""
for playlist in playlists:
yield tuple_to_playlist(playlist)
class SQLiteManager:
"""
This is a context manager that handles the connection and cursor
for you. It also commits and closes the connection when you're done.
"""
def __init__(
self,
conn: Connection | None = None,
userdata_db=False,
test_db_path: str = None,
) -> None:
"""
When a connection is passed in, don't close the connection, because it's
a connection to the search database [in memory db].
"""
self.conn = conn
self.CLOSE_CONN = True
self.userdata_db = userdata_db
self.test_db_path = test_db_path
if conn:
self.conn = conn
self.CLOSE_CONN = False
def __enter__(self) -> Cursor:
if self.conn is not None:
cur = self.conn.cursor()
cur.execute("PRAGMA foreign_keys = ON")
return cur
db_path = self.test_db_path or settings.Paths().app_db_path
if self.userdata_db:
db_path = settings.Paths().userdata_db_path
self.conn = sqlite3.connect(
db_path,
timeout=15,
)
cur = self.conn.cursor()
cur.execute("PRAGMA foreign_keys = ON")
return cur
def __exit__(self, exc_type, exc_value, exc_traceback):
trial_count = 0
while trial_count < 10:
try:
self.conn.commit()
if self.CLOSE_CONN:
self.conn.close()
return
except sqlite3.OperationalError:
trial_count += 1
time.sleep(3)
self.conn.close()
+835
View File
@@ -0,0 +1,835 @@
import datetime
from collections.abc import Iterable
from dataclasses import asdict
from typing import Any, Literal
from sqlalchemy import (
JSON,
Boolean,
ForeignKey,
Integer,
String,
and_,
delete,
func,
insert,
select,
update,
)
from sqlalchemy.orm import Mapped, mapped_column
from swingmusic.db import Base
from swingmusic.db.engine import DbEngine
from swingmusic.db.utils import (
favorite_to_dataclass,
favorites_to_dataclass,
playlist_to_dataclass,
plugin_to_dataclass,
similar_artist_to_dataclass,
tracklog_to_dataclass,
user_to_dataclass,
)
from swingmusic.models.mix import Mix
from swingmusic.utils.auth import get_current_userid, hash_password
class UserTable(Base):
__tablename__ = "user"
id: Mapped[int] = mapped_column(primary_key=True)
image: Mapped[str] = mapped_column(String(), nullable=True)
password: Mapped[str] = mapped_column(String())
username: Mapped[str] = mapped_column(String(), index=True)
roles: Mapped[list[str]] = mapped_column(JSON(), default_factory=lambda: [])
password_change_required: Mapped[bool] = mapped_column(Boolean(), default=False)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def get_all(cls):
result = cls.execute(select(cls))
for i in next(result).scalars():
yield user_to_dataclass(i)
@classmethod
def insert_default_user(cls):
user = {
"username": "admin",
"password": hash_password("admin"),
"roles": ["admin"],
"password_change_required": True,
}
return cls.insert_one(user)
@classmethod
def insert_guest_user(cls):
user = {
"username": "guest",
"password": hash_password("guest"),
"roles": ["guest"],
"password_change_required": True,
}
return cls.insert_one(user)
@classmethod
def get_by_id(cls, id: int):
result = cls.execute(select(cls).where(cls.id == id))
res = next(result).scalar()
if res:
return user_to_dataclass(res)
@classmethod
def get_by_username(cls, username: str):
res = cls.execute(select(cls).where(cls.username == username))
res = next(res).scalar()
if res:
return user_to_dataclass(res)
@classmethod
def update_one(cls, user: dict[str, Any]):
return next(
cls.execute(
update(cls).where(cls.id == user["id"]).values(user), commit=True
)
)
@classmethod
def remove_by_username(cls, username: str):
return next(
cls.execute(delete(cls).where(cls.username == username), commit=True)
)
class PluginTable(Base):
__tablename__ = "plugin"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(), unique=True)
active: Mapped[bool] = mapped_column(Boolean())
settings: Mapped[dict[str, Any]] = mapped_column(JSON())
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), nullable=True)
@classmethod
def get_all(cls):
result = cls.execute(select(cls))
for i in next(result).scalars():
yield plugin_to_dataclass(i)
@classmethod
def activate(cls, name: str, value: bool):
return next(
cls.execute(
update(cls).where(cls.name == name).values(active=value), commit=True
)
)
@classmethod
def get_by_name(cls, name: str):
result = cls.execute(select(cls).where(cls.name == name))
res = next(result).scalar()
if res:
return plugin_to_dataclass(res)
@classmethod
def update_settings(cls, name: str, settings: dict[str, Any]):
return next(
cls.execute(
update(cls).where(cls.name == name).values(settings=settings),
commit=True,
)
)
class SimilarArtistTable(Base):
__tablename__ = "notlastfm_similar_artists"
id: Mapped[int] = mapped_column(Integer(), primary_key=True)
artisthash: Mapped[str] = mapped_column(String(), index=True)
similar_artists: Mapped[dict[str, str]] = mapped_column(JSON())
@classmethod
def get_all(cls):
result = cls.execute(select(cls).execution_options(yield_per=100))
for i in next(result).scalars():
yield similar_artist_to_dataclass(i)
@classmethod
def exists(cls, artisthash: str):
"""
Check whether an artisthash exists in the database.
"""
with DbEngine.manager() as conn:
result = conn.execute(
select(cls.artisthash)
.where(cls.artisthash == artisthash)
.execution_options(yield_per=100)
)
return len(result.scalars().all()) > 0
@classmethod
def get_by_hash(cls, artisthash: str):
"""
Get a single artist by hash.
"""
result = cls.execute(select(cls).where(cls.artisthash == artisthash))
res = next(result).scalar()
if res:
return similar_artist_to_dataclass(res)
class FavoritesTable(Base):
__tablename__ = "favorite"
id: Mapped[int] = mapped_column(primary_key=True)
hash: Mapped[str] = mapped_column(String(), unique=True)
type: Mapped[str] = mapped_column(String(), index=True)
timestamp: Mapped[int] = mapped_column(Integer(), index=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), default=1, index=True
)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def _normalize_item_hash(cls, raw_hash: str, item_type: str) -> str:
"""
Normalize legacy and scoped favorite hash formats to plain item hash.
Accepted formats:
- <hash>
- <type>_<hash>
- u<userid>:<type>_<hash>
"""
normalized = str(raw_hash or "").strip()
item_type = str(item_type or "").strip()
if normalized.startswith("u") and ":" in normalized:
user_prefix, remainder = normalized.split(":", 1)
if user_prefix[1:].isdigit():
normalized = remainder
type_prefix = f"{item_type}_"
if item_type and normalized.startswith(type_prefix):
normalized = normalized[len(type_prefix) :]
return normalized
@classmethod
def _hash_candidates(
cls,
*,
hash_value: str,
item_type: str,
userid: int | None = None,
) -> set[str]:
canonical = cls._normalize_item_hash(hash_value, item_type)
candidates = {canonical}
if item_type:
candidates.add(f"{item_type}_{canonical}")
if userid is not None:
candidates.add(f"u{int(userid)}:{item_type}_{canonical}")
return {candidate for candidate in candidates if candidate}
@classmethod
def get_all(cls, with_user: bool = True):
with DbEngine.manager() as conn:
if with_user:
result = conn.execute(
select(cls).where(cls.userid == get_current_userid())
)
else:
result = conn.execute(select(cls))
for i in result.scalars():
yield favorite_to_dataclass(i)
@classmethod
def insert_item(cls, item: dict[str, Any]):
item_type = str(item.get("type") or "").strip()
canonical_hash = cls._normalize_item_hash(item.get("hash", ""), item_type)
userid = int(item.get("userid") or get_current_userid())
if cls.check_exists(canonical_hash, item_type, userid=userid):
return None
# Scope favorites per user while keeping backward compatibility
# with legacy `type_hash` entries.
item["hash"] = f"u{userid}:{item_type}_{canonical_hash}"
if item.get("timestamp") is None:
item["timestamp"] = int(datetime.datetime.now().timestamp())
item["userid"] = userid
return next(cls.execute(insert(cls).values(item), commit=True))
@classmethod
def remove_item(cls, item: dict[str, Any]):
userid = int(item.get("userid") or get_current_userid())
candidates = cls._hash_candidates(
hash_value=str(item.get("hash") or ""),
item_type=str(item.get("type") or ""),
userid=userid,
)
return next(
cls.execute(
delete(cls).where(and_(cls.userid == userid, cls.hash.in_(candidates))),
commit=True,
)
)
@classmethod
def check_exists(cls, hash: str, type: str, userid: int | None = None):
userid = int(userid or get_current_userid())
candidates = cls._hash_candidates(
hash_value=hash,
item_type=type,
userid=userid,
)
result = cls.execute(
select(cls).where(and_(cls.userid == userid, cls.hash.in_(candidates)))
)
return next(result).scalar() is not None
@classmethod
def get_by_hash(cls, hash: str, type: str, userid: int | None = None):
userid = int(userid or get_current_userid())
candidates = cls._hash_candidates(
hash_value=hash,
item_type=type,
userid=userid,
)
result = cls.execute(
select(cls).where(and_(cls.userid == userid, cls.hash.in_(candidates)))
)
return next(result).scalars().all()
@classmethod
def get_all_of_type(cls, type: str, start: int, limit: int):
result = cls.execute(
select(cls)
# .select_from(join(table, cls, field == cls.hash))
.where(and_(cls.type == type, cls.userid == get_current_userid()))
.order_by(cls.timestamp.desc())
.offset(start)
# INFO: If start is 0, fetch all so we can get the total count
.limit(limit if start != 0 else None)
)
res = next(result).scalars().all()
if start == 0:
# if limit == -1, return all
if limit == -1:
limit = len(res)
return res[:limit], len(res)
return res, -1
@classmethod
def get_fav_tracks(cls, start: int, limit: int):
result, total = cls.get_all_of_type("track", start, limit)
return favorites_to_dataclass(result), total
@classmethod
def get_fav_albums(cls, start: int, limit: int):
result, total = cls.get_all_of_type("album", start, limit)
return favorites_to_dataclass(result), total
@classmethod
def get_fav_artists(cls, start: int, limit: int):
result, total = cls.get_all_of_type("artist", start, limit)
return favorites_to_dataclass(result), total
@classmethod
def count_favs_in_period(cls, start_time: int, end_time: int):
result = cls.execute(
select(func.count(cls.id))
.where(cls.userid == get_current_userid())
.where(and_(cls.timestamp >= start_time, cls.timestamp <= end_time))
)
res = next(result).scalar()
if res:
return res
return 0
@classmethod
def count_tracks(cls):
result = cls.execute(
select(func.count(cls.id)).where(
and_(cls.type == "track", cls.userid == get_current_userid())
)
)
return next(result).scalar()
@classmethod
def get_last_trackhash(cls):
result = cls.execute(
select(cls.hash)
.where(and_(cls.type == "track", cls.userid == get_current_userid()))
.order_by(cls.timestamp.desc())
)
db_hash = next(result).scalar()
if not db_hash:
return None
return cls._normalize_item_hash(db_hash, "track")
class ScrobbleTable(Base):
__tablename__ = "scrobble"
id: Mapped[int] = mapped_column(primary_key=True)
trackhash: Mapped[str] = mapped_column(String(), index=True)
duration: Mapped[int] = mapped_column(Integer())
timestamp: Mapped[int] = mapped_column(Integer())
source: Mapped[str] = mapped_column(String())
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def add(cls, item: dict[str, Any]):
if item.get("userid") is None:
item["userid"] = get_current_userid()
return cls.insert_one(item)
@classmethod
def get_all(cls, start: int, limit: int | None = None, userid: int | None = None):
result = cls.execute(
select(cls)
.where(cls.userid == (userid if userid else get_current_userid()))
.order_by(cls.timestamp.desc())
.offset(start)
.limit(limit)
.execution_options(yield_per=100)
)
for i in next(result).scalars():
yield tracklog_to_dataclass(i)
@classmethod
def get_all_in_period(cls, start_time: int, end_time: int, userid: int | None):
# UserId will be None if function is called from the API
# In that case, we use the request userid
if userid is None:
userid = get_current_userid()
result = cls.execute(
select(cls)
.where(cls.userid == userid)
.where(and_(cls.timestamp >= start_time, cls.timestamp <= end_time))
.order_by(cls.timestamp.desc())
.execution_options(yield_per=100)
)
for i in next(result).scalars():
yield tracklog_to_dataclass(i)
@classmethod
def get_last_entry(cls, userid: int):
result = cls.execute(
select(cls).where(cls.userid == userid).order_by(cls.timestamp.desc())
)
res = next(result).scalar()
if res:
return tracklog_to_dataclass(res)
class PlaylistTable(Base):
__tablename__ = "playlist"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(), index=True)
last_updated: Mapped[int] = mapped_column(Integer())
image: Mapped[str] = mapped_column(String(), nullable=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade")
)
settings: Mapped[dict[str, Any]] = mapped_column(JSON())
trackhashes: Mapped[list[str]] = mapped_column(JSON(), default_factory=list)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def get_all(cls, current_user: bool = True):
if current_user:
result = cls.execute(
select(cls)
.where(cls.userid == get_current_userid())
.execution_options(yield_per=100)
)
else:
result = cls.execute(select(cls).execution_options(yield_per=100))
for i in next(result).scalars():
yield playlist_to_dataclass(i)
@classmethod
def add_one(cls, playlist: dict[str, Any]):
playlist["userid"] = get_current_userid()
result = cls.insert_one(playlist)
return result.lastrowid
@classmethod
def check_exists_by_name(cls, name: str):
result = cls.execute(
select(cls).where((cls.name == name) & (cls.userid == get_current_userid()))
)
return next(result).scalar() is not None
@classmethod
def append_to_playlist(cls, id: int, trackhashes: list[str]):
dbtrackhashes = cls.get_trackhashes(id) or []
trackhashes = list(set(dbtrackhashes).union(set(trackhashes)))
return next(
cls.execute(
update(cls)
.where((cls.id == id) & (cls.userid == get_current_userid()))
.values(trackhashes=trackhashes),
commit=True,
)
)
@classmethod
def get_trackhashes(cls, id: int):
result = cls.execute(
select(cls.trackhashes).where(
(cls.id == id) & (cls.userid == get_current_userid())
)
)
return next(result).scalar()
@classmethod
def remove_from_playlist(cls, id: int, trackhashes: list[dict[str, Any]]):
# INFO: Get db trackhashes
dbtrackhashes = cls.get_trackhashes(id)
if dbtrackhashes:
for item in trackhashes:
if dbtrackhashes.index(item["trackhash"]) == item["index"]:
dbtrackhashes.remove(item["trackhash"])
return next(
cls.execute(
update(cls)
.where((cls.id == id) & (cls.userid == get_current_userid()))
.values(trackhashes=dbtrackhashes),
commit=True,
)
)
@classmethod
def get_by_id(cls, id: int):
result = cls.execute(
select(cls).where((cls.id == id) & (cls.userid == get_current_userid()))
)
result = next(result).scalar()
if result:
return playlist_to_dataclass(result)
@classmethod
def update_one(cls, id: int, playlist: dict[str, Any]):
return next(
cls.execute(
update(cls)
.where((cls.id == id) & (cls.userid == get_current_userid()))
.values(playlist),
commit=True,
)
)
@classmethod
def update_settings(cls, id: int, settings: dict[str, Any]):
return next(
cls.execute(
update(cls)
.where((cls.id == id) & (cls.userid == get_current_userid()))
.values(settings=settings),
commit=True,
)
)
@classmethod
def remove_image(cls, id: int):
return next(
cls.execute(
update(cls)
.where((cls.id == id) & (cls.userid == get_current_userid()))
.values(image=None),
commit=True,
)
)
class LibDataTable(Base):
__tablename__ = "artistdata"
id: Mapped[int] = mapped_column(primary_key=True)
itemhash: Mapped[str] = mapped_column(String(), unique=True, index=True)
itemtype: Mapped[str] = mapped_column(String())
color: Mapped[str] = mapped_column(String(), nullable=True)
bio: Mapped[str] = mapped_column(String(), nullable=True)
info: Mapped[dict[str, Any]] = mapped_column(JSON(), nullable=True)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def update_one(cls, hash: str, data: dict[str, Any]):
return next(
cls.execute(
update(cls).where(cls.itemhash == hash).values(data), commit=True
)
)
@classmethod
def find_one(cls, hash: str, type: Literal["album", "artist"]):
result = cls.execute(
select(cls).where((cls.itemhash == type + hash) & (cls.itemtype == type))
)
return next(result).scalar()
@classmethod
def get_all_colors(cls, type: str) -> Iterable[dict[str, str]]:
result = cls.execute(select(cls).where(cls.itemtype == type))
for i in next(result).scalars():
yield {"itemhash": i.itemhash.replace(type, ""), "color": i.color}
class MixTable(Base):
__tablename__ = "mix"
id: Mapped[int] = mapped_column(primary_key=True)
mixid: Mapped[str] = mapped_column(String(), index=True)
title: Mapped[str] = mapped_column(String())
description: Mapped[str] = mapped_column(String())
timestamp: Mapped[int] = mapped_column(Integer())
sourcehash: Mapped[str] = mapped_column(String(), unique=True, index=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
)
saved: Mapped[bool] = mapped_column(Boolean(), default=False)
tracks: Mapped[list[str]] = mapped_column(JSON(), default_factory=list)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def get_all(cls, with_userid: bool = False):
if with_userid:
result = cls.execute(
select(cls)
.where(cls.userid == get_current_userid())
.order_by(cls.timestamp.desc())
)
else:
result = cls.execute(select(cls).order_by(cls.timestamp.desc()))
for i in next(result).scalars():
yield Mix.mix_to_dataclass(i)
@classmethod
def get_by_sourcehash(cls, sourcehash: str):
result = cls.execute(
select(cls).where(
and_(cls.sourcehash == sourcehash, cls.userid == get_current_userid())
)
)
res = next(result).scalar()
if res:
return Mix.mix_to_dataclass(res)
@classmethod
def get_by_mixid(cls, mixid: str):
result = cls.execute(
select(cls).where(
and_(cls.mixid == mixid, cls.userid == get_current_userid())
)
)
res = next(result).scalar()
if res:
return Mix.mix_to_dataclass(res)
@classmethod
def insert_one(cls, mix: Mix):
mixdict = asdict(mix)
mixdict["mixid"] = mix.id
del mixdict["id"]
return next(cls.execute(insert(cls).values(mixdict), commit=True))
@classmethod
def update_one(cls, mixid: str, mix: Mix):
mixdict = asdict(mix)
mixdict["mixid"] = mix.id
del mixdict["id"]
return next(
cls.execute(
update(cls)
.where(
and_(
cls.mixid == mixid,
cls.sourcehash == mix.sourcehash,
cls.userid == get_current_userid(),
)
)
.values(mixdict),
commit=True,
)
)
@classmethod
def save_artist_mix(cls, sourcehash: str):
"""
Toggles the saved status of an artist mix.
"""
mix = cls.get_by_sourcehash(sourcehash)
if not mix:
return False
mix.saved = not mix.saved
cls.update_one(mix.id, mix)
return mix.saved
@classmethod
def get_saved_track_mixes(cls):
"""
Return all mixes that have the extra.trackmix_saved set to True.
"""
result = cls.execute(
select(cls).where(
and_(cls.extra.c.trackmix_saved, cls.userid == get_current_userid())
)
)
# return Mix.mixes_to_dataclasses(result.fetchall())
for i in next(result).scalars():
yield Mix.mix_to_dataclass(i)
@classmethod
def save_track_mix(cls, sourcehash: str):
"""
Toggles the property extra.trackmix_saved to True.
"""
mix = cls.get_by_sourcehash(sourcehash)
if not mix:
return False
mix.extra["trackmix_saved"] = not mix.extra.get("trackmix_saved", False)
cls.update_one(mix.id, mix)
return mix.extra["trackmix_saved"]
class CollectionTable(Base):
# INFO: table name was kept as page to avoid breaking existing data
__tablename__ = "page"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(), index=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
)
items: Mapped[list[dict[str, Any]]] = mapped_column(JSON(), default_factory=list)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def to_dict(cls, entry: Any) -> dict[str, Any]:
d = entry.__dict__
del d["_sa_instance_state"]
return d
@classmethod
def get_all(cls):
result = cls.execute(select(cls).where(cls.userid == get_current_userid()))
for i in next(result).scalars():
yield cls.to_dict(i)
@classmethod
def get_by_id(cls, id: int):
result = cls.execute(
select(cls).where(and_(cls.id == id, cls.userid == get_current_userid()))
)
res = next(result).scalar()
if res:
return cls.to_dict(res)
@classmethod
def delete_by_id(cls, id: int):
return next(
cls.execute(
delete(cls).where(
and_(cls.id == id, cls.userid == get_current_userid())
),
commit=True,
)
)
@classmethod
def update_items(cls, id: int, items: list[dict[str, Any]]):
return next(
cls.execute(
update(cls)
.where(and_(cls.id == id, cls.userid == get_current_userid()))
.values(items=items),
commit=True,
)
)
@classmethod
def update_one(cls, payload: dict[str, Any]):
return next(
cls.execute(
update(cls)
.where(
and_(cls.id == payload["id"], cls.userid == get_current_userid())
)
.values(payload),
commit=True,
)
)
+100
View File
@@ -0,0 +1,100 @@
from typing import Any
from swingmusic.config import UserConfig
from swingmusic.models import Album as AlbumModel
from swingmusic.models import Artist as ArtistModel
from swingmusic.models import Track as TrackModel
from swingmusic.models.favorite import Favorite
from swingmusic.models.lastfm import SimilarArtist
from swingmusic.models.logger import TrackLog
from swingmusic.models.playlist import Playlist
from swingmusic.models.plugins import Plugin
from swingmusic.models.user import User
def row_to_dict(row: Any):
d = row.__dict__
del d["_sa_instance_state"]
return d
def track_to_dataclass(track: dict, config: UserConfig):
return TrackModel(**track, config=config)
def tracks_to_dataclasses(tracks: Any):
return [track_to_dataclass(track, UserConfig()) for track in tracks]
def album_to_dataclass(album: Any):
return AlbumModel(**album._asdict())
def albums_to_dataclasses(albums: Any):
return [album_to_dataclass(album) for album in albums]
def artist_to_dataclass(artist: Any):
return ArtistModel(**artist._asdict())
def artists_to_dataclasses(artists: Any):
return [artist_to_dataclass(artist) for artist in artists]
# SECTION: User data helpers
def similar_artist_to_dataclass(entry: Any):
entry_dict = row_to_dict(entry)
del entry_dict["id"]
return SimilarArtist(**entry_dict)
def similar_artists_to_dataclass(entries: Any):
return [similar_artist_to_dataclass(entry) for entry in entries]
def favorite_to_dataclass(entry: Any):
entry_dict = row_to_dict(entry)
del entry_dict["id"]
return Favorite(**entry_dict)
def favorites_to_dataclass(entries: Any):
return [favorite_to_dataclass(entry) for entry in entries]
def user_to_dataclass(entry: Any):
return User(**row_to_dict(entry))
# def user_to_dataclasses(entries: Any):
# return [user_to_dataclass(entry) for entry in entries]
def plugin_to_dataclass(entry: Any):
entry_dict = row_to_dict(entry)
del entry_dict["id"]
return Plugin(**entry_dict)
def plugin_to_dataclasses(entries: Any):
return [plugin_to_dataclass(entry) for entry in entries]
def tracklog_to_dataclass(entry: Any):
return TrackLog(**row_to_dict(entry))
def tracklog_to_dataclasses(entries: Any):
return [tracklog_to_dataclass(entry) for entry in entries]
def playlist_to_dataclass(entry: Any):
entry_dict = row_to_dict(entry)
return Playlist(**entry_dict)
def playlists_to_dataclasses(entries: Any):
return [playlist_to_dataclass(entry) for entry in entries]