Files
swingmusic-extended/swingmusic/db/userdata.py
T
2025-06-17 09:47:41 +02:00

753 lines
22 KiB
Python

from dataclasses import asdict
import datetime
import json
from typing import Any, Iterable, Literal
from sqlalchemy import (
JSON,
Boolean,
ForeignKey,
Integer,
String,
and_,
delete,
func,
insert,
select,
update,
)
from sqlalchemy.orm import Mapped, mapped_column
from swingmusic.db.engine import DbEngine
from swingmusic.db.utils import (
favorite_to_dataclass,
favorites_to_dataclass,
playlist_to_dataclass,
plugin_to_dataclass,
similar_artist_to_dataclass,
tracklog_to_dataclass,
user_to_dataclass,
)
from swingmusic.db import Base
from swingmusic.models.mix import Mix
from swingmusic.utils.auth import get_current_userid, hash_password
class UserTable(Base):
__tablename__ = "user"
id: Mapped[int] = mapped_column(primary_key=True)
image: Mapped[str] = mapped_column(String(), nullable=True)
password: Mapped[str] = mapped_column(String())
username: Mapped[str] = mapped_column(String(), index=True)
roles: Mapped[list[str]] = mapped_column(JSON(), default_factory=lambda: [])
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def get_all(cls):
result = cls.execute(select(cls))
for i in next(result).scalars():
yield user_to_dataclass(i)
@classmethod
def insert_default_user(cls):
user = {
"username": "admin",
"password": hash_password("admin"),
"roles": ["admin"],
}
return cls.insert_one(user)
@classmethod
def insert_guest_user(cls):
user = {
"username": "guest",
"password": hash_password("guest"),
"roles": ["guest"],
}
return cls.insert_one(user)
@classmethod
def get_by_id(cls, id: int):
result = cls.execute(select(cls).where(cls.id == id))
res = next(result).scalar()
if res:
return user_to_dataclass(res)
@classmethod
def get_by_username(cls, username: str):
res = cls.execute(select(cls).where(cls.username == username))
res = next(res).scalar()
if res:
return user_to_dataclass(res)
@classmethod
def update_one(cls, user: dict[str, Any]):
return next(
cls.execute(
update(cls).where(cls.id == user["id"]).values(user), commit=True
)
)
@classmethod
def remove_by_username(cls, username: str):
return next(
cls.execute(delete(cls).where(cls.username == username), commit=True)
)
class PluginTable(Base):
__tablename__ = "plugin"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(), unique=True)
active: Mapped[bool] = mapped_column(Boolean())
settings: Mapped[dict[str, Any]] = mapped_column(JSON())
extra: Mapped[dict[str, Any]] = mapped_column(JSON(), nullable=True)
@classmethod
def get_all(cls):
result = cls.execute(select(cls))
for i in next(result).scalars():
yield plugin_to_dataclass(i)
@classmethod
def activate(cls, name: str, value: bool):
return next(
cls.execute(
update(cls).where(cls.name == name).values(active=value), commit=True
)
)
@classmethod
def get_by_name(cls, name: str):
result = cls.execute(select(cls).where(cls.name == name))
res = next(result).scalar()
if res:
return plugin_to_dataclass(res)
@classmethod
def update_settings(cls, name: str, settings: dict[str, Any]):
return next(
cls.execute(
update(cls).where(cls.name == name).values(settings=settings),
commit=True,
)
)
class SimilarArtistTable(Base):
__tablename__ = "notlastfm_similar_artists"
id: Mapped[int] = mapped_column(Integer(), primary_key=True)
artisthash: Mapped[str] = mapped_column(String(), index=True)
similar_artists: Mapped[dict[str, str]] = mapped_column(JSON())
@classmethod
def get_all(cls):
result = cls.execute(select(cls).execution_options(yield_per=100))
for i in next(result).scalars():
yield similar_artist_to_dataclass(i)
@classmethod
def exists(cls, artisthash: str):
"""
Check whether an artisthash exists in the database.
"""
with DbEngine.manager() as conn:
result = conn.execute(
select(cls.artisthash)
.where(cls.artisthash == artisthash)
.execution_options(yield_per=100)
)
return len(result.scalars().all()) > 0
@classmethod
def get_by_hash(cls, artisthash: str):
"""
Get a single artist by hash.
"""
result = cls.execute(select(cls).where(cls.artisthash == artisthash))
res = next(result).scalar()
if res:
return similar_artist_to_dataclass(res)
class FavoritesTable(Base):
__tablename__ = "favorite"
id: Mapped[int] = mapped_column(primary_key=True)
hash: Mapped[str] = mapped_column(String(), unique=True)
type: Mapped[str] = mapped_column(String(), index=True)
timestamp: Mapped[int] = mapped_column(Integer(), index=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), default=1, index=True
)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def get_all(cls, with_user: bool = False):
with DbEngine.manager() as conn:
if with_user:
result = conn.execute(
select(cls).where(cls.userid == get_current_userid())
)
else:
result = conn.execute(select(cls))
for i in result.scalars():
yield favorite_to_dataclass(i)
@classmethod
def insert_item(cls, item: dict[str, Any]):
# guard against hash collisions for different item types
item["hash"] = f"{item['type']}_{item['hash']}"
if item.get("timestamp") is None:
print("No timestamp found, using current timestamp")
item["timestamp"] = int(datetime.datetime.now().timestamp())
if item.get("userid") is None:
print("No userid found, using current userid")
item["userid"] = get_current_userid()
return next(cls.execute(insert(cls).values(item), commit=True))
@classmethod
def remove_item(cls, item: dict[str, Any]):
return next(
cls.execute(
delete(cls).where(
(cls.hash == item["hash"])
| (cls.hash == f"{item['type']}_{item['hash']}")
),
commit=True,
)
)
@classmethod
def check_exists(cls, hash: str, type: str):
result = cls.execute(
select(cls).where((cls.hash == hash) | (cls.hash == f"{type}_{hash}"))
)
return next(result).scalar() is not None
@classmethod
def get_by_hash(cls, hash: str, type: str):
result = cls.execute(
select(cls).where((cls.hash == hash) | (cls.hash == f"{type}_{hash}"))
)
return next(result).scalars().all()
@classmethod
def get_all_of_type(cls, type: str, start: int, limit: int):
result = cls.execute(
select(cls)
# .select_from(join(table, cls, field == cls.hash))
.where(and_(cls.type == type, cls.userid == get_current_userid()))
.order_by(cls.timestamp.desc())
.offset(start)
# INFO: If start is 0, fetch all so we can get the total count
.limit(limit if start != 0 else None)
)
res = next(result).scalars().all()
if start == 0:
# if limit == -1, return all
if limit == -1:
limit = len(res)
return res[:limit], len(res)
return res, -1
@classmethod
def get_fav_tracks(cls, start: int, limit: int):
result, total = cls.get_all_of_type("track", start, limit)
return favorites_to_dataclass(result), total
@classmethod
def get_fav_albums(cls, start: int, limit: int):
result, total = cls.get_all_of_type("album", start, limit)
return favorites_to_dataclass(result), total
@classmethod
def get_fav_artists(cls, start: int, limit: int):
result, total = cls.get_all_of_type("artist", start, limit)
return favorites_to_dataclass(result), total
@classmethod
def count_favs_in_period(cls, start_time: int, end_time: int):
result = cls.execute(
select(func.count(cls.id))
.where((cls.userid == get_current_userid()))
.where(and_(cls.timestamp >= start_time, cls.timestamp <= end_time))
)
res = next(result).scalar()
if res:
return res
return 0
@classmethod
def count_tracks(cls):
result = cls.execute(select(func.count(cls.id)).where(cls.type == "track"))
return next(result).scalar()
@classmethod
def get_last_trackhash(cls):
result = cls.execute(
select(cls.hash).where(cls.type == "track").order_by(cls.timestamp.desc())
)
return next(result).scalar()
class ScrobbleTable(Base):
__tablename__ = "scrobble"
id: Mapped[int] = mapped_column(primary_key=True)
trackhash: Mapped[str] = mapped_column(String(), index=True)
duration: Mapped[int] = mapped_column(Integer())
timestamp: Mapped[int] = mapped_column(Integer())
source: Mapped[str] = mapped_column(String())
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def add(cls, item: dict[str, Any]):
if item.get("userid") is None:
print("No userid found, using current userid")
item["userid"] = get_current_userid()
return cls.insert_one(item)
@classmethod
def get_all(cls, start: int, limit: int | None = None, userid: int | None = None):
result = cls.execute(
select(cls)
.where(cls.userid == (userid if userid else get_current_userid()))
.order_by(cls.timestamp.desc())
.offset(start)
.limit(limit)
.execution_options(yield_per=100)
)
for i in next(result).scalars():
yield tracklog_to_dataclass(i)
@classmethod
def get_all_in_period(cls, start_time: int, end_time: int, userid: int | None):
# UserId will be None if function is called from the API
# In that case, we use the request userid
if userid is None:
userid = get_current_userid()
result = cls.execute(
select(cls)
.where(cls.userid == userid)
.where(and_(cls.timestamp >= start_time, cls.timestamp <= end_time))
.order_by(cls.timestamp.desc())
.execution_options(yield_per=100)
)
for i in next(result).scalars():
yield tracklog_to_dataclass(i)
@classmethod
def get_last_entry(cls, userid: int):
result = cls.execute(
select(cls).where(cls.userid == userid).order_by(cls.timestamp.desc())
)
res = next(result).scalar()
if res:
return tracklog_to_dataclass(res)
class PlaylistTable(Base):
__tablename__ = "playlist"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(), index=True)
last_updated: Mapped[int] = mapped_column(Integer())
image: Mapped[str] = mapped_column(String(), nullable=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade")
)
settings: Mapped[dict[str, Any]] = mapped_column(JSON())
trackhashes: Mapped[list[str]] = mapped_column(JSON(), default_factory=list)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def get_all(cls, current_user: bool = True):
if current_user:
result = cls.execute(
select(cls)
.where(cls.userid == get_current_userid())
.execution_options(yield_per=100)
)
else:
result = cls.execute(select(cls).execution_options(yield_per=100))
for i in next(result).scalars():
yield playlist_to_dataclass(i)
@classmethod
def add_one(cls, playlist: dict[str, Any]):
playlist["userid"] = get_current_userid()
result = cls.insert_one(playlist)
return result.lastrowid
@classmethod
def check_exists_by_name(cls, name: str):
result = cls.execute(
select(cls).where((cls.name == name) & (cls.userid == get_current_userid()))
)
return next(result).scalar() is not None
@classmethod
def append_to_playlist(cls, id: int, trackhashes: list[str]):
dbtrackhashes = cls.get_trackhashes(id)
if not dbtrackhashes:
dbtrackhashes = []
return next(
cls.execute(
update(cls)
.where((cls.id == id) & (cls.userid == get_current_userid()))
.values(trackhashes=dbtrackhashes + trackhashes),
commit=True,
)
)
@classmethod
def get_trackhashes(cls, id: int):
result = cls.execute(
select(cls.trackhashes).where(
(cls.id == id) & (cls.userid == get_current_userid())
)
)
return next(result).scalar()
@classmethod
def remove_from_playlist(cls, id: int, trackhashes: list[dict[str, Any]]):
# INFO: Get db trackhashes
dbtrackhashes = cls.get_trackhashes(id)
if dbtrackhashes:
for item in trackhashes:
if dbtrackhashes.index(item["trackhash"]) == item["index"]:
dbtrackhashes.remove(item["trackhash"])
return next(
cls.execute(
update(cls)
.where((cls.id == id) & (cls.userid == get_current_userid()))
.values(trackhashes=dbtrackhashes),
commit=True,
)
)
@classmethod
def get_by_id(cls, id: int):
result = cls.execute(
select(cls).where((cls.id == id) & (cls.userid == get_current_userid()))
)
result = next(result).scalar()
if result:
return playlist_to_dataclass(result)
@classmethod
def update_one(cls, id: int, playlist: dict[str, Any]):
return next(
cls.execute(
update(cls)
.where((cls.id == id) & (cls.userid == get_current_userid()))
.values(playlist),
commit=True,
)
)
@classmethod
def update_settings(cls, id: int, settings: dict[str, Any]):
return next(
cls.execute(
update(cls)
.where((cls.id == id) & (cls.userid == get_current_userid()))
.values(settings=settings),
commit=True,
)
)
@classmethod
def remove_image(cls, id: int):
return next(
cls.execute(
update(cls)
.where((cls.id == id) & (cls.userid == get_current_userid()))
.values(image=None),
commit=True,
)
)
class LibDataTable(Base):
__tablename__ = "artistdata"
id: Mapped[int] = mapped_column(primary_key=True)
itemhash: Mapped[str] = mapped_column(String(), unique=True, index=True)
itemtype: Mapped[str] = mapped_column(String())
color: Mapped[str] = mapped_column(String(), nullable=True)
bio: Mapped[str] = mapped_column(String(), nullable=True)
info: Mapped[dict[str, Any]] = mapped_column(JSON(), nullable=True)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def update_one(cls, hash: str, data: dict[str, Any]):
return next(
cls.execute(
update(cls).where(cls.itemhash == hash).values(data), commit=True
)
)
@classmethod
def find_one(cls, hash: str, type: Literal["album", "artist"]):
result = cls.execute(
select(cls).where((cls.itemhash == type + hash) & (cls.itemtype == type))
)
return next(result).scalar()
@classmethod
def get_all_colors(cls, type: str) -> Iterable[dict[str, str]]:
result = cls.execute(select(cls).where(cls.itemtype == type))
for i in next(result).scalars():
yield {"itemhash": i.itemhash.replace(type, ""), "color": i.color}
class MixTable(Base):
__tablename__ = "mix"
id: Mapped[int] = mapped_column(primary_key=True)
mixid: Mapped[str] = mapped_column(String(), index=True)
title: Mapped[str] = mapped_column(String())
description: Mapped[str] = mapped_column(String())
timestamp: Mapped[int] = mapped_column(Integer())
sourcehash: Mapped[str] = mapped_column(String(), unique=True, index=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
)
saved: Mapped[bool] = mapped_column(Boolean(), default=False)
tracks: Mapped[list[str]] = mapped_column(JSON(), default_factory=list)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def get_all(cls, with_userid: bool = False):
if with_userid:
result = cls.execute(
select(cls)
.where(cls.userid == get_current_userid())
.order_by(cls.timestamp.desc())
)
else:
result = cls.execute(select(cls).order_by(cls.timestamp.desc()))
for i in next(result).scalars():
yield Mix.mix_to_dataclass(i)
@classmethod
def get_by_sourcehash(cls, sourcehash: str):
result = cls.execute(select(cls).where(cls.sourcehash == sourcehash))
res = next(result).scalar()
if res:
return Mix.mix_to_dataclass(res)
@classmethod
def get_by_mixid(cls, mixid: str):
result = cls.execute(select(cls).where(cls.mixid == mixid))
res = next(result).scalar()
if res:
return Mix.mix_to_dataclass(res)
@classmethod
def insert_one(cls, mix: Mix):
mixdict = asdict(mix)
mixdict["mixid"] = mix.id
del mixdict["id"]
return next(cls.execute(insert(cls).values(mixdict), commit=True))
@classmethod
def update_one(cls, mixid: str, mix: Mix):
mixdict = asdict(mix)
mixdict["mixid"] = mix.id
del mixdict["id"]
return next(
cls.execute(
update(cls)
.where(
and_(
cls.mixid == mixid,
cls.sourcehash == mix.sourcehash,
cls.userid == get_current_userid(),
)
)
.values(mixdict),
commit=True,
)
)
@classmethod
def save_artist_mix(cls, sourcehash: str):
"""
Toggles the saved status of an artist mix.
"""
mix = cls.get_by_sourcehash(sourcehash)
if not mix:
return False
mix.saved = not mix.saved
cls.update_one(mix.id, mix)
return mix.saved
@classmethod
def get_saved_track_mixes(cls):
"""
Return all mixes that have the extra.trackmix_saved set to True.
"""
result = cls.execute(select(cls).where(cls.extra.c.trackmix_saved == True))
# return Mix.mixes_to_dataclasses(result.fetchall())
for i in next(result).scalars():
yield Mix.mix_to_dataclass(i)
@classmethod
def save_track_mix(cls, sourcehash: str):
"""
Toggles the property extra.trackmix_saved to True.
"""
mix = cls.get_by_sourcehash(sourcehash)
if not mix:
return False
mix.extra["trackmix_saved"] = not mix.extra.get("trackmix_saved", False)
cls.update_one(mix.id, mix)
return mix.extra["trackmix_saved"]
class CollectionTable(Base):
# INFO: table name was kept as page to avoid breaking existing data
__tablename__ = "page"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(), index=True)
userid: Mapped[int] = mapped_column(
Integer(), ForeignKey("user.id", ondelete="cascade"), index=True
)
items: Mapped[list[dict[str, Any]]] = mapped_column(JSON(), default_factory=list)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def to_dict(cls, entry: Any) -> dict[str, Any]:
d = entry.__dict__
del d["_sa_instance_state"]
return d
@classmethod
def get_all(cls):
result = cls.execute(select(cls).where(cls.userid == get_current_userid()))
for i in next(result).scalars():
yield cls.to_dict(i)
@classmethod
def get_by_id(cls, id: int):
result = cls.execute(
select(cls).where(and_(cls.id == id, cls.userid == get_current_userid()))
)
res = next(result).scalar()
if res:
return cls.to_dict(res)
@classmethod
def delete_by_id(cls, id: int):
return next(
cls.execute(
delete(cls).where(
and_(cls.id == id, cls.userid == get_current_userid())
),
commit=True,
)
)
@classmethod
def update_items(cls, id: int, items: list[dict[str, Any]]):
return next(
cls.execute(
update(cls)
.where(and_(cls.id == id, cls.userid == get_current_userid()))
.values(items=items),
commit=True,
)
)
@classmethod
def update_one(cls, payload: dict[str, Any]):
return next(
cls.execute(
update(cls)
.where(
and_(cls.id == payload["id"], cls.userid == get_current_userid())
)
.values(payload),
commit=True,
)
)