start: rewrite the database layer using a freaking ORM

+ start ditching in-mem stores
+ move main db table to a new name
+ experiments!
This commit is contained in:
cwilvx
2024-06-24 00:26:47 +03:00
parent c3472a865a
commit c42ec4dcde
27 changed files with 1399 additions and 397 deletions
+232
View File
@@ -0,0 +1,232 @@
import json
from pprint import pprint
from typing import Any, Optional
from sqlalchemy import (
JSON,
Boolean,
Integer,
Row,
String,
Tuple,
create_engine,
insert,
select,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
DeclarativeBase,
MappedAsDataclass,
sessionmaker,
)
from app.models import Track as TrackModel
from app.models import Album as AlbumModel
from app.utils.remove_duplicates import remove_duplicates
fullpath = "/home/cwilvx/temp/swingmusic/swing.db"
engine = create_engine(f"sqlite+pysqlite:///{fullpath}", echo=False)
def todict(track: Any):
return track._asdict()
def todicts(tracks: list[Any]):
return [todict(track) for track in tracks]
class DbManager:
def __init__(self):
self.engine = create_engine(f"sqlite+pysqlite:///{fullpath}", echo=True)
self.conn = self.engine.connect()
def __enter__(self):
return self.conn.execution_options(preserve_rowcount=True)
def __exit__(self, exc_type, exc_val, exc_tb):
self.conn.commit()
self.conn.close()
class Base(MappedAsDataclass, DeclarativeBase):
@classmethod
def insert_many(cls, items: list[dict[str, Any]]):
"""
Inserts multiple items into the database.
"""
with DbManager() as conn:
conn.execute(insert(cls).values(items))
@classmethod
def insert_one(cls, item: dict[str, Any]):
"""
Inserts a single item into the database.
"""
return cls.insert_many([item])
@classmethod
def get_all(cls):
"""
Returns all the items from the database.
"""
with DbManager() as conn:
result = conn.execute(select(cls))
return result.fetchall()
class ArtistTable(Base):
__tablename__ = "artist"
id: Mapped[int] = mapped_column(primary_key=True)
albumcount: Mapped[int] = mapped_column(Integer())
artisthash: Mapped[str] = mapped_column(String(), unique=True, index=True)
created_date: Mapped[int] = mapped_column(Integer())
date: Mapped[int] = mapped_column(Integer())
duration: Mapped[int] = mapped_column(Integer())
genres: Mapped[str] = mapped_column(JSON())
name: Mapped[str] = mapped_column(String(), index=True)
trackcount: Mapped[int] = mapped_column(Integer())
is_favorite: Mapped[Optional[bool]] = mapped_column(Boolean())
@classmethod
def get_all(cls, start: int, limit: int):
with DbManager() as conn:
result = conn.execute(select(cls).offset(start).limit(limit))
return albums_to_dataclasses(result.fetchall())
class AlbumTable(Base):
__tablename__ = "album"
id: Mapped[int] = mapped_column(primary_key=True)
albumartists: Mapped[list[dict[str, str]]] = mapped_column(JSON(), index=True)
albumhash: Mapped[str] = mapped_column(String(), unique=True, index=True)
base_title: Mapped[str] = mapped_column(String())
color: Mapped[Optional[str]] = mapped_column(String())
created_date: Mapped[int] = mapped_column(Integer())
date: Mapped[int] = mapped_column(Integer())
duration: Mapped[int] = mapped_column(Integer())
genres: Mapped[str] = mapped_column(JSON())
og_title: Mapped[str] = mapped_column(String())
title: Mapped[str] = mapped_column(String())
trackcount: Mapped[int] = mapped_column(Integer())
@classmethod
def get_album_by_albumhash(cls, hash: str):
with DbManager() as conn:
result = conn.execute(
select(AlbumTable).where(AlbumTable.albumhash == hash)
)
album = result.fetchone()
if album:
return album_to_dataclass(album)
@classmethod
def get_all(cls, start: int, limit: int):
with DbManager() as conn:
result = conn.execute(select(AlbumTable).offset(start).limit(limit))
return albums_to_dataclasses(result.fetchall())
@classmethod
def get_albums_by_artisthashes(cls, artisthashes: list[dict[str, str]]):
with DbManager() as conn:
albums: list[AlbumModel] = []
for artist in artisthashes:
result = conn.execute(
# NOTE: The artist dict keys need to in the same order they appear in the db for this to work!
select(AlbumTable).where(AlbumTable.albumartists.contains(artist))
)
albums.extend(albums_to_dataclasses(result.fetchall()))
print(albums)
return albums
@classmethod
def get_albums_by_base_title(cls, base_title: str):
with DbManager() as conn:
result = conn.execute(
select(AlbumTable).where(AlbumTable.base_title == base_title)
)
return albums_to_dataclasses(result.fetchall())
class TrackTable(Base):
__tablename__ = "track"
id: Mapped[int] = mapped_column(init=False, primary_key=True)
album: Mapped[str] = mapped_column(String())
albumartists: Mapped[list[dict[str, str]]] = mapped_column(JSON())
albumhash: Mapped[str] = mapped_column(String(), index=True)
artists: Mapped[list[dict[str, str]]] = mapped_column(JSON(), index=True)
bitrate: Mapped[int] = mapped_column(Integer())
copyright: Mapped[Optional[str]] = mapped_column(String())
date: Mapped[int] = mapped_column(Integer())
disc: Mapped[int] = mapped_column(Integer())
duration: Mapped[int] = mapped_column(Integer())
filepath: Mapped[str] = mapped_column(String(), unique=True)
folder: Mapped[str] = mapped_column(String(), index=True)
genre: Mapped[Optional[list[dict[str, str]]]] = mapped_column(JSON())
last_mod: Mapped[float] = mapped_column(Integer())
og_album: Mapped[str] = mapped_column(String())
og_title: Mapped[str] = mapped_column(String())
title: Mapped[str] = mapped_column(String())
track: Mapped[int] = mapped_column(Integer())
trackhash: Mapped[str] = mapped_column(String(), index=True)
@classmethod
def get_tracks_by_filepaths(cls, filepaths: list[str]):
print(filepaths[0])
with DbManager() as conn:
result = conn.execute(
select(TrackTable).where(TrackTable.filepath.in_(filepaths))
)
return [dict(r) for r in result.mappings().fetchall()]
@classmethod
def count_tracks_containing_paths(cls, paths: list[str]):
results: list[dict[str, int | str]] = []
with DbManager() as conn:
for path in paths:
result = conn.execute(
select(TrackTable).where(TrackTable.filepath.contains(path))
)
results.append({"path": path, "trackcount": result.all().__len__()})
return results
@classmethod
def get_tracks_by_albumhash(cls, albumhash: str):
with DbManager() as conn:
result = conn.execute(
select(TrackTable).where(TrackTable.albumhash == albumhash)
)
tracks = tracks_to_dataclasses(result.fetchall())
return remove_duplicates(tracks, is_album_tracks=True)
# SECTION: HELPER FUNCTIONS
def album_to_dataclass(album: Row[AlbumTable]):
return AlbumModel(**album._asdict())
def albums_to_dataclasses(albums: list[Row[AlbumTable]]):
return [album_to_dataclass(album) for album in albums]
def track_to_dataclass(track: Row[TrackTable]):
return TrackModel(**track._asdict())
def tracks_to_dataclasses(tracks: list[Row[TrackTable]]):
return [track_to_dataclass(track) for track in tracks]
Base().metadata.create_all(engine)
-2
View File
@@ -75,7 +75,6 @@ class SQLiteAuthMethods:
{', '.join([f"{key} = :{key}" for key in keys if key != 'id'])}
WHERE id = :id
"""
print(sql, user)
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql, user)
@@ -140,7 +139,6 @@ class SQLiteAuthMethods:
Delete a user by username.
"""
sql = "DELETE FROM users WHERE id = ?"
print("deleting user: ", username)
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql, (3,))
cur.close()
+5 -2
View File
@@ -1,6 +1,7 @@
from flask_jwt_extended import current_user
from app.db.sqlite.utils import SQLiteManager
from app.models.logger import TrackLog as TrackLog
from app.utils.auth import get_current_userid
class SQLiteTrackLogger:
@@ -10,6 +11,7 @@ class SQLiteTrackLogger:
Inserts a track play record into the database
"""
userid = get_current_userid()
with SQLiteManager(userdata_db=True) as cur:
sql = """INSERT OR REPLACE INTO track_logger(
trackhash,
@@ -21,7 +23,7 @@ class SQLiteTrackLogger:
"""
cur.execute(
sql, (trackhash, duration, timestamp, source, current_user["id"])
sql, (trackhash, duration, timestamp, source, userid)
)
lastrowid = cur.lastrowid
@@ -34,7 +36,8 @@ class SQLiteTrackLogger:
"""
with SQLiteManager(userdata_db=True) as cur:
sql = f"""SELECT * FROM track_logger WHERE userid = {current_user['id']} ORDER BY timestamp DESC"""
userid = get_current_userid()
sql = f"""SELECT * FROM track_logger WHERE userid = {userid} ORDER BY timestamp DESC"""
cur.execute(sql)
rows = cur.fetchall()
+31 -5
View File
@@ -60,7 +60,15 @@ class SQLitePlaylistMethods:
@staticmethod
def get_all_playlists():
with SQLiteManager(userdata_db=True) as cur:
cur.execute(f"SELECT * FROM playlists WHERE userid = {current_user['id']}")
userid = 1
try:
userid = current_user["id"]
except RuntimeError:
# Catch this error raised during migration execution
pass
cur.execute(f"SELECT * FROM playlists WHERE userid = {userid}")
playlists = cur.fetchall()
cur.close()
@@ -92,7 +100,15 @@ class SQLitePlaylistMethods:
Adds a string item to a json dumped list using a playlist id and field name.
Takes the playlist ID, a field name, an item to add to the field.
"""
sql = f"SELECT {field} FROM playlists WHERE id = ? and userid = {current_user['id']}"
userid = 1
try:
userid = current_user["id"]
except RuntimeError:
# Catch this error raised during migration execution
pass
sql = f"SELECT {field} FROM playlists WHERE id = ? and userid = {userid}"
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql, (playlist_id,))
@@ -173,10 +189,17 @@ class SQLitePlaylistMethods:
"""
sql = """UPDATE playlists SET trackhashes = ? WHERE id = ?"""
userid = 1
try:
userid = current_user["id"]
except RuntimeError:
# Catch this error raised during migration execution
pass
with SQLiteManager(userdata_db=True) as cur:
cur.execute(
f"SELECT trackhashes FROM playlists WHERE id = ? and userid = {current_user['id']}",
f"SELECT trackhashes FROM playlists WHERE id = ? and userid = {userid}",
(playlistid,),
)
data = cur.fetchone()
@@ -185,17 +208,20 @@ class SQLitePlaylistMethods:
return
trackhashes: list[str] = json.loads(data[0])
to_remove = []
for track in tracks:
# {
# trackhash: str;
# index: int;
# }
index = trackhashes.index(track["trackhash"])
if index == track["index"]:
trackhashes.remove(track["trackhash"])
to_remove.append(track["trackhash"])
for trackhash in to_remove:
trackhashes.remove(trackhash)
cur.execute(sql, (json.dumps(trackhashes), playlistid))
+14
View File
@@ -98,6 +98,20 @@ class SQLiteTrackMethods:
return tuple_to_track(row)
return None
@staticmethod
def get_track_by_albumhash(albumhash: str):
"""
Gets a track using its albumhash. Returns a Track object or None.
"""
with SQLiteManager() as cur:
cur.execute("SELECT * FROM tracks WHERE albumhash=?", (albumhash,))
row = cur.fetchone()
if row is not None:
return tuple_to_track(row)
return None
@staticmethod
def remove_tracks_by_filepaths(filepaths: str | set[str]):