port search to stores

+ fix favorites
This commit is contained in:
cwilvx
2024-07-27 21:44:33 +03:00
parent 5d32536758
commit b0e904c84f
25 changed files with 428 additions and 666 deletions
+44 -63
View File
@@ -18,6 +18,7 @@ from sqlalchemy import (
from sqlalchemy.orm import Mapped, mapped_column
from app.db.engine import DbEngine
from app.db.utils import (
albums_to_dataclasses,
artists_to_dataclasses,
@@ -27,14 +28,13 @@ from app.db.utils import (
plugin_to_dataclasses,
similar_artist_to_dataclass,
similar_artists_to_dataclass,
tracklog_to_dataclass,
tracklog_to_dataclasses,
tracks_to_dataclasses,
user_to_dataclass,
user_to_dataclasses,
)
from app.db import Base, DbManager
from app.db import Base
from app.utils.auth import get_current_userid, hash_password
@@ -77,7 +77,7 @@ class UserTable(Base):
@classmethod
def get_by_id(cls, id: int):
with DbManager() as conn:
with DbEngine.manager() as conn:
result = conn.execute(select(cls).where(cls.id == id))
res = result.fetchone()
@@ -86,7 +86,7 @@ class UserTable(Base):
@classmethod
def get_by_username(cls, username: str):
with DbManager() as conn:
with DbEngine.manager() as conn:
result = conn.execute(select(cls).where(cls.username == username))
res = result.fetchone()
@@ -95,7 +95,7 @@ class UserTable(Base):
@classmethod
def update_one(cls, user: dict[str, Any]):
with DbManager(commit=True) as conn:
with DbEngine.manager(commit=True) as conn:
conn.execute(update(cls).where(cls.id == user["id"]).values(user))
@classmethod
@@ -126,7 +126,7 @@ class SimilarArtistTable(Base):
@classmethod
def get_all(cls):
with DbManager() as conn:
with DbEngine.manager() as conn:
result = conn.execute(select(cls))
return similar_artists_to_dataclass(result.fetchall())
@@ -136,7 +136,7 @@ class SimilarArtistTable(Base):
Check whether an artisthash exists in the database.
"""
with DbManager() as conn:
with DbEngine.manager() as conn:
result = conn.execute(
select(cls.artisthash).where(cls.artisthash == artisthash)
)
@@ -148,7 +148,7 @@ class SimilarArtistTable(Base):
Get a single artist by hash.
"""
with DbManager() as conn:
with DbEngine.manager() as conn:
result = conn.execute(select(cls).where(cls.artisthash == artisthash))
result = result.fetchone()
@@ -160,7 +160,7 @@ class FavoritesTable(Base):
__tablename__ = "favorite"
id: Mapped[int] = mapped_column(primary_key=True)
hash: Mapped[str] = mapped_column(String())
hash: Mapped[str] = mapped_column(String(), unique=True)
type: Mapped[str] = mapped_column(String(), index=True)
timestamp: Mapped[int] = mapped_column(Integer(), index=True)
userid: Mapped[int] = mapped_column(
@@ -172,7 +172,7 @@ class FavoritesTable(Base):
@classmethod
def get_all(cls):
with DbManager() as conn:
with DbEngine.manager() as conn:
result = conn.execute(select(cls))
return favorites_to_dataclass(result.fetchall())
@@ -181,12 +181,12 @@ class FavoritesTable(Base):
item["timestamp"] = int(datetime.datetime.now().timestamp())
item["userid"] = get_current_userid()
with DbManager(commit=True) as conn:
with DbEngine.manager(commit=True) as conn:
conn.execute(insert(cls).values(item))
@classmethod
def remove_item(cls, item: dict[str, Any]):
with DbManager(commit=True) as conn:
with DbEngine.manager(commit=True) as conn:
conn.execute(
delete(cls).where(
(cls.hash == item["hash"]) & (cls.type == item["type"])
@@ -199,12 +199,13 @@ class FavoritesTable(Base):
return result.fetchone() is not None
@classmethod
def get_all_of_type(cls, table: Any, field: Any, type: str, start: int, limit: int):
def get_all_of_type(cls, type: str, start: int, limit: int):
result = cls.execute(
select(table)
.select_from(join(table, cls, field == cls.hash))
.where(and_(cls.type == type, cls.userid == get_current_userid()))
.offset(start)
select(cls)
# .select_from(join(table, cls, field == cls.hash))
.where(and_(cls.type == type, cls.userid == get_current_userid())).offset(
start
)
# INFO: If start is 0, fetch all so we can get the total count
.limit(limit if start != 0 else None)
)
@@ -218,30 +219,18 @@ class FavoritesTable(Base):
@classmethod
def get_fav_tracks(cls, start: int, limit: int):
from .libdata import TrackTable
result, total = cls.get_all_of_type(
TrackTable, TrackTable.trackhash, "track", start, limit
)
return tracks_to_dataclasses(result), total
result, total = cls.get_all_of_type("track", start, limit)
return favorites_to_dataclass(result), total
@classmethod
def get_fav_albums(cls, start: int, limit: int):
from .libdata import AlbumTable
result, total = cls.get_all_of_type(
AlbumTable, AlbumTable.albumhash, "album", start, limit
)
return albums_to_dataclasses(result), total
result, total = cls.get_all_of_type("album", start, limit)
return favorites_to_dataclass(result), total
@classmethod
def get_fav_artists(cls, start: int, limit: int):
from .libdata import ArtistTable
result, total = cls.get_all_of_type(
ArtistTable, ArtistTable.artisthash, "artist", start, limit
)
return artists_to_dataclasses(result), total
result, total = cls.get_all_of_type("artist", start, limit)
return favorites_to_dataclass(result), total
class ScrobbleTable(Base):
@@ -265,7 +254,7 @@ class ScrobbleTable(Base):
return cls.insert_one(item)
@classmethod
def get_all(cls, start: int, limit: int):
def get_all(cls, start: int, limit: int | None):
result = cls.execute(
select(cls)
.where(cls.userid == get_current_userid())
@@ -325,7 +314,7 @@ class PlaylistTable(Base):
)
@classmethod
def get_trackhashes(cls, id: int) -> list[str]:
def get_trackhashes(cls, id: int):
result = cls.execute(
select(cls.trackhashes).where(
(cls.id == id) & (cls.userid == get_current_userid())
@@ -388,32 +377,24 @@ class PlaylistTable(Base):
)
# class PlaylistTrackTable(Base):
# __tablename__ = "playlisttrack"
class ArtistData(Base):
__tablename__ = "artistdata"
# id: Mapped[int] = mapped_column(primary_key=True)
# trackhash: Mapped[str] = mapped_column(String(), index=True)
# playlistid: Mapped[int] = mapped_column(
# Integer(), ForeignKey("playlist.id", ondelete="cascade")
# )
# index: Mapped[int] = mapped_column(Integer())
# userid: Mapped[int] = mapped_column(
# Integer(), ForeignKey("user.id", ondelete="cascade")
# )
id: Mapped[int] = mapped_column(primary_key=True)
artisthash: Mapped[str] = mapped_column(String(), index=True)
color: Mapped[str] = mapped_column(String(), nullable=True)
bio: Mapped[str] = mapped_column(String(), nullable=True)
info: Mapped[dict[str, Any]] = mapped_column(JSON(), nullable=True)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
# @classmethod
# def count_by_playlist()
@classmethod
def find_one(cls, artisthash: str):
result = cls.execute(select(cls).where(cls.artisthash == artisthash))
return result.fetchone()
# @classmethod
# def insert_many(cls, playlistid: int, trackhashes: list[str]):
# userid = get_current_userid()
# items = [
# {
# "index": index,
# "userid": userid,
# "trackhash": trackhash,
# "playlistid": playlistid,
# }
# for index, trackhash in enumerate(trackhashes)
# ]
# return cls.execute(insert(cls).values(items), commit=True)
@classmethod
def get_all_colors(cls) -> dict[str, str]:
result = cls.execute(select(cls.artisthash, cls.color))
return dict(result.fetchall())