mirror of
https://github.com/Dvorinka/swingmusic-extended.git
synced 2026-06-04 20:43:04 +00:00
start: rewrite the database layer using a freaking ORM
+ start ditching in-mem stores + move main db table to a new name + experiments!
This commit is contained in:
@@ -0,0 +1,232 @@
|
||||
import json
|
||||
from pprint import pprint
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Integer,
|
||||
Row,
|
||||
String,
|
||||
Tuple,
|
||||
create_engine,
|
||||
insert,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import (
|
||||
Mapped,
|
||||
mapped_column,
|
||||
DeclarativeBase,
|
||||
MappedAsDataclass,
|
||||
sessionmaker,
|
||||
)
|
||||
|
||||
from app.models import Track as TrackModel
|
||||
from app.models import Album as AlbumModel
|
||||
from app.utils.remove_duplicates import remove_duplicates
|
||||
|
||||
|
||||
fullpath = "/home/cwilvx/temp/swingmusic/swing.db"
|
||||
engine = create_engine(f"sqlite+pysqlite:///{fullpath}", echo=False)
|
||||
|
||||
|
||||
def todict(track: Any):
|
||||
return track._asdict()
|
||||
|
||||
|
||||
def todicts(tracks: list[Any]):
|
||||
return [todict(track) for track in tracks]
|
||||
|
||||
|
||||
class DbManager:
|
||||
def __init__(self):
|
||||
self.engine = create_engine(f"sqlite+pysqlite:///{fullpath}", echo=True)
|
||||
self.conn = self.engine.connect()
|
||||
|
||||
def __enter__(self):
|
||||
return self.conn.execution_options(preserve_rowcount=True)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.conn.commit()
|
||||
self.conn.close()
|
||||
|
||||
|
||||
class Base(MappedAsDataclass, DeclarativeBase):
|
||||
@classmethod
|
||||
def insert_many(cls, items: list[dict[str, Any]]):
|
||||
"""
|
||||
Inserts multiple items into the database.
|
||||
"""
|
||||
with DbManager() as conn:
|
||||
conn.execute(insert(cls).values(items))
|
||||
|
||||
@classmethod
|
||||
def insert_one(cls, item: dict[str, Any]):
|
||||
"""
|
||||
Inserts a single item into the database.
|
||||
"""
|
||||
return cls.insert_many([item])
|
||||
|
||||
@classmethod
|
||||
def get_all(cls):
|
||||
"""
|
||||
Returns all the items from the database.
|
||||
"""
|
||||
with DbManager() as conn:
|
||||
result = conn.execute(select(cls))
|
||||
return result.fetchall()
|
||||
|
||||
|
||||
class ArtistTable(Base):
|
||||
__tablename__ = "artist"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
albumcount: Mapped[int] = mapped_column(Integer())
|
||||
artisthash: Mapped[str] = mapped_column(String(), unique=True, index=True)
|
||||
created_date: Mapped[int] = mapped_column(Integer())
|
||||
date: Mapped[int] = mapped_column(Integer())
|
||||
duration: Mapped[int] = mapped_column(Integer())
|
||||
genres: Mapped[str] = mapped_column(JSON())
|
||||
name: Mapped[str] = mapped_column(String(), index=True)
|
||||
trackcount: Mapped[int] = mapped_column(Integer())
|
||||
is_favorite: Mapped[Optional[bool]] = mapped_column(Boolean())
|
||||
|
||||
@classmethod
|
||||
def get_all(cls, start: int, limit: int):
|
||||
with DbManager() as conn:
|
||||
result = conn.execute(select(cls).offset(start).limit(limit))
|
||||
return albums_to_dataclasses(result.fetchall())
|
||||
|
||||
|
||||
class AlbumTable(Base):
|
||||
__tablename__ = "album"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
albumartists: Mapped[list[dict[str, str]]] = mapped_column(JSON(), index=True)
|
||||
albumhash: Mapped[str] = mapped_column(String(), unique=True, index=True)
|
||||
base_title: Mapped[str] = mapped_column(String())
|
||||
color: Mapped[Optional[str]] = mapped_column(String())
|
||||
created_date: Mapped[int] = mapped_column(Integer())
|
||||
date: Mapped[int] = mapped_column(Integer())
|
||||
duration: Mapped[int] = mapped_column(Integer())
|
||||
genres: Mapped[str] = mapped_column(JSON())
|
||||
og_title: Mapped[str] = mapped_column(String())
|
||||
title: Mapped[str] = mapped_column(String())
|
||||
trackcount: Mapped[int] = mapped_column(Integer())
|
||||
|
||||
@classmethod
|
||||
def get_album_by_albumhash(cls, hash: str):
|
||||
with DbManager() as conn:
|
||||
result = conn.execute(
|
||||
select(AlbumTable).where(AlbumTable.albumhash == hash)
|
||||
)
|
||||
album = result.fetchone()
|
||||
|
||||
if album:
|
||||
return album_to_dataclass(album)
|
||||
|
||||
@classmethod
|
||||
def get_all(cls, start: int, limit: int):
|
||||
with DbManager() as conn:
|
||||
result = conn.execute(select(AlbumTable).offset(start).limit(limit))
|
||||
return albums_to_dataclasses(result.fetchall())
|
||||
|
||||
@classmethod
|
||||
def get_albums_by_artisthashes(cls, artisthashes: list[dict[str, str]]):
|
||||
with DbManager() as conn:
|
||||
albums: list[AlbumModel] = []
|
||||
|
||||
for artist in artisthashes:
|
||||
result = conn.execute(
|
||||
# NOTE: The artist dict keys need to in the same order they appear in the db for this to work!
|
||||
select(AlbumTable).where(AlbumTable.albumartists.contains(artist))
|
||||
)
|
||||
albums.extend(albums_to_dataclasses(result.fetchall()))
|
||||
|
||||
print(albums)
|
||||
return albums
|
||||
|
||||
@classmethod
|
||||
def get_albums_by_base_title(cls, base_title: str):
|
||||
with DbManager() as conn:
|
||||
result = conn.execute(
|
||||
select(AlbumTable).where(AlbumTable.base_title == base_title)
|
||||
)
|
||||
return albums_to_dataclasses(result.fetchall())
|
||||
|
||||
|
||||
class TrackTable(Base):
|
||||
__tablename__ = "track"
|
||||
|
||||
id: Mapped[int] = mapped_column(init=False, primary_key=True)
|
||||
album: Mapped[str] = mapped_column(String())
|
||||
albumartists: Mapped[list[dict[str, str]]] = mapped_column(JSON())
|
||||
albumhash: Mapped[str] = mapped_column(String(), index=True)
|
||||
artists: Mapped[list[dict[str, str]]] = mapped_column(JSON(), index=True)
|
||||
bitrate: Mapped[int] = mapped_column(Integer())
|
||||
copyright: Mapped[Optional[str]] = mapped_column(String())
|
||||
date: Mapped[int] = mapped_column(Integer())
|
||||
disc: Mapped[int] = mapped_column(Integer())
|
||||
duration: Mapped[int] = mapped_column(Integer())
|
||||
filepath: Mapped[str] = mapped_column(String(), unique=True)
|
||||
folder: Mapped[str] = mapped_column(String(), index=True)
|
||||
genre: Mapped[Optional[list[dict[str, str]]]] = mapped_column(JSON())
|
||||
last_mod: Mapped[float] = mapped_column(Integer())
|
||||
og_album: Mapped[str] = mapped_column(String())
|
||||
og_title: Mapped[str] = mapped_column(String())
|
||||
title: Mapped[str] = mapped_column(String())
|
||||
track: Mapped[int] = mapped_column(Integer())
|
||||
trackhash: Mapped[str] = mapped_column(String(), index=True)
|
||||
|
||||
@classmethod
|
||||
def get_tracks_by_filepaths(cls, filepaths: list[str]):
|
||||
print(filepaths[0])
|
||||
with DbManager() as conn:
|
||||
result = conn.execute(
|
||||
select(TrackTable).where(TrackTable.filepath.in_(filepaths))
|
||||
)
|
||||
return [dict(r) for r in result.mappings().fetchall()]
|
||||
|
||||
@classmethod
|
||||
def count_tracks_containing_paths(cls, paths: list[str]):
|
||||
results: list[dict[str, int | str]] = []
|
||||
|
||||
with DbManager() as conn:
|
||||
for path in paths:
|
||||
result = conn.execute(
|
||||
select(TrackTable).where(TrackTable.filepath.contains(path))
|
||||
)
|
||||
results.append({"path": path, "trackcount": result.all().__len__()})
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def get_tracks_by_albumhash(cls, albumhash: str):
|
||||
with DbManager() as conn:
|
||||
result = conn.execute(
|
||||
select(TrackTable).where(TrackTable.albumhash == albumhash)
|
||||
)
|
||||
tracks = tracks_to_dataclasses(result.fetchall())
|
||||
return remove_duplicates(tracks, is_album_tracks=True)
|
||||
|
||||
|
||||
# SECTION: HELPER FUNCTIONS
|
||||
|
||||
|
||||
def album_to_dataclass(album: Row[AlbumTable]):
|
||||
return AlbumModel(**album._asdict())
|
||||
|
||||
|
||||
def albums_to_dataclasses(albums: list[Row[AlbumTable]]):
|
||||
return [album_to_dataclass(album) for album in albums]
|
||||
|
||||
|
||||
def track_to_dataclass(track: Row[TrackTable]):
|
||||
return TrackModel(**track._asdict())
|
||||
|
||||
|
||||
def tracks_to_dataclasses(tracks: list[Row[TrackTable]]):
|
||||
return [track_to_dataclass(track) for track in tracks]
|
||||
|
||||
|
||||
Base().metadata.create_all(engine)
|
||||
|
||||
@@ -75,7 +75,6 @@ class SQLiteAuthMethods:
|
||||
{', '.join([f"{key} = :{key}" for key in keys if key != 'id'])}
|
||||
WHERE id = :id
|
||||
"""
|
||||
print(sql, user)
|
||||
|
||||
with SQLiteManager(userdata_db=True) as cur:
|
||||
cur.execute(sql, user)
|
||||
@@ -140,7 +139,6 @@ class SQLiteAuthMethods:
|
||||
Delete a user by username.
|
||||
"""
|
||||
sql = "DELETE FROM users WHERE id = ?"
|
||||
print("deleting user: ", username)
|
||||
with SQLiteManager(userdata_db=True) as cur:
|
||||
cur.execute(sql, (3,))
|
||||
cur.close()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from flask_jwt_extended import current_user
|
||||
from app.db.sqlite.utils import SQLiteManager
|
||||
from app.models.logger import TrackLog as TrackLog
|
||||
from app.utils.auth import get_current_userid
|
||||
|
||||
|
||||
class SQLiteTrackLogger:
|
||||
@@ -10,6 +11,7 @@ class SQLiteTrackLogger:
|
||||
Inserts a track play record into the database
|
||||
"""
|
||||
|
||||
userid = get_current_userid()
|
||||
with SQLiteManager(userdata_db=True) as cur:
|
||||
sql = """INSERT OR REPLACE INTO track_logger(
|
||||
trackhash,
|
||||
@@ -21,7 +23,7 @@ class SQLiteTrackLogger:
|
||||
"""
|
||||
|
||||
cur.execute(
|
||||
sql, (trackhash, duration, timestamp, source, current_user["id"])
|
||||
sql, (trackhash, duration, timestamp, source, userid)
|
||||
)
|
||||
lastrowid = cur.lastrowid
|
||||
|
||||
@@ -34,7 +36,8 @@ class SQLiteTrackLogger:
|
||||
"""
|
||||
|
||||
with SQLiteManager(userdata_db=True) as cur:
|
||||
sql = f"""SELECT * FROM track_logger WHERE userid = {current_user['id']} ORDER BY timestamp DESC"""
|
||||
userid = get_current_userid()
|
||||
sql = f"""SELECT * FROM track_logger WHERE userid = {userid} ORDER BY timestamp DESC"""
|
||||
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
|
||||
@@ -60,7 +60,15 @@ class SQLitePlaylistMethods:
|
||||
@staticmethod
|
||||
def get_all_playlists():
|
||||
with SQLiteManager(userdata_db=True) as cur:
|
||||
cur.execute(f"SELECT * FROM playlists WHERE userid = {current_user['id']}")
|
||||
userid = 1
|
||||
|
||||
try:
|
||||
userid = current_user["id"]
|
||||
except RuntimeError:
|
||||
# Catch this error raised during migration execution
|
||||
pass
|
||||
|
||||
cur.execute(f"SELECT * FROM playlists WHERE userid = {userid}")
|
||||
playlists = cur.fetchall()
|
||||
cur.close()
|
||||
|
||||
@@ -92,7 +100,15 @@ class SQLitePlaylistMethods:
|
||||
Adds a string item to a json dumped list using a playlist id and field name.
|
||||
Takes the playlist ID, a field name, an item to add to the field.
|
||||
"""
|
||||
sql = f"SELECT {field} FROM playlists WHERE id = ? and userid = {current_user['id']}"
|
||||
userid = 1
|
||||
|
||||
try:
|
||||
userid = current_user["id"]
|
||||
except RuntimeError:
|
||||
# Catch this error raised during migration execution
|
||||
pass
|
||||
|
||||
sql = f"SELECT {field} FROM playlists WHERE id = ? and userid = {userid}"
|
||||
|
||||
with SQLiteManager(userdata_db=True) as cur:
|
||||
cur.execute(sql, (playlist_id,))
|
||||
@@ -173,10 +189,17 @@ class SQLitePlaylistMethods:
|
||||
"""
|
||||
|
||||
sql = """UPDATE playlists SET trackhashes = ? WHERE id = ?"""
|
||||
userid = 1
|
||||
|
||||
try:
|
||||
userid = current_user["id"]
|
||||
except RuntimeError:
|
||||
# Catch this error raised during migration execution
|
||||
pass
|
||||
|
||||
with SQLiteManager(userdata_db=True) as cur:
|
||||
cur.execute(
|
||||
f"SELECT trackhashes FROM playlists WHERE id = ? and userid = {current_user['id']}",
|
||||
f"SELECT trackhashes FROM playlists WHERE id = ? and userid = {userid}",
|
||||
(playlistid,),
|
||||
)
|
||||
data = cur.fetchone()
|
||||
@@ -185,17 +208,20 @@ class SQLitePlaylistMethods:
|
||||
return
|
||||
|
||||
trackhashes: list[str] = json.loads(data[0])
|
||||
to_remove = []
|
||||
|
||||
for track in tracks:
|
||||
# {
|
||||
# trackhash: str;
|
||||
# index: int;
|
||||
# }
|
||||
|
||||
index = trackhashes.index(track["trackhash"])
|
||||
|
||||
if index == track["index"]:
|
||||
trackhashes.remove(track["trackhash"])
|
||||
to_remove.append(track["trackhash"])
|
||||
|
||||
for trackhash in to_remove:
|
||||
trackhashes.remove(trackhash)
|
||||
|
||||
cur.execute(sql, (json.dumps(trackhashes), playlistid))
|
||||
|
||||
|
||||
@@ -98,6 +98,20 @@ class SQLiteTrackMethods:
|
||||
return tuple_to_track(row)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_track_by_albumhash(albumhash: str):
|
||||
"""
|
||||
Gets a track using its albumhash. Returns a Track object or None.
|
||||
"""
|
||||
with SQLiteManager() as cur:
|
||||
cur.execute("SELECT * FROM tracks WHERE albumhash=?", (albumhash,))
|
||||
row = cur.fetchone()
|
||||
|
||||
if row is not None:
|
||||
return tuple_to_track(row)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def remove_tracks_by_filepaths(filepaths: str | set[str]):
|
||||
|
||||
Reference in New Issue
Block a user