diff --git a/app/api/home/__init__.py b/app/api/home/__init__.py index 4c5a865e..1179c555 100644 --- a/app/api/home/__init__.py +++ b/app/api/home/__init__.py @@ -1,6 +1,6 @@ -from flask_jwt_extended import current_user from flask_openapi3 import Tag from flask_openapi3 import APIBlueprint +from pydantic import BaseModel, Field from app.api.apischemas import GenericLimitSchema from app.lib.home.recentlyadded import get_recently_added_items @@ -27,13 +27,14 @@ def get_recent_plays(query: GenericLimitSchema): return {"items": get_recently_played(query.limit)} +class HomepageItem(BaseModel): + limit: int = Field( + default=9, description="The max number of items per group to return" + ) + + @api.get("/") -def homepage_items(): +def homepage_items(query: HomepageItem): return { - "artist_mixes": { - "title": "Artist mixes for you", - "description": "Based on artists you have been listening to", - "items": HomepageStore.get_artist_mixes(), - "extra": {}, - }, + "artist_mixes": HomepageStore.get_mixes("artist_mixes", limit=query.limit), } diff --git a/app/crons/mixes.py b/app/crons/mixes.py index e897775f..884cfa87 100644 --- a/app/crons/mixes.py +++ b/app/crons/mixes.py @@ -1,4 +1,5 @@ from app.crons.cron import CronJob +from app.lib.recipes import ArtistMixes from app.plugins.mixes import MixesPlugin from app.store.homepage import HomepageStore @@ -16,12 +17,14 @@ class Mixes(CronJob): Creates the artist mixes """ print("⭐⭐⭐⭐ Mixes cron job running") - mixes = MixesPlugin() + ArtistMixes().run() + # mixes = MixesPlugin() - if not mixes.enabled: - return + # if not mixes.enabled: + # return - artist_mixes = mixes.create_artist_mixes() - if artist_mixes: - HomepageStore.set_artist_mixes(artist_mixes) + # artist_mixes = mixes.create_artist_mixes() + + # if artist_mixes: + # HomepageStore.set_artist_mixes(artist_mixes) diff --git a/app/db/libdata.py b/app/db/libdata.py index 6e68d6d7..3d40c73f 100644 --- a/app/db/libdata.py +++ b/app/db/libdata.py @@ -8,92 +8,6 @@ from sqlalchemy.orm import Mapped, mapped_column from typing import Any, Optional -# def create_all(): -# """ -# Create all the tables defined in this file. - -# NOTE: We need this function because the MasterBase does not collect -# the tables defined here (as they are grand-children of the MasterBase) -# """ -# Base.metadata.create_all(DbEngine.engine) - - -# class Base(MasterBase, DeclarativeBase): -# pass -# @classmethod -# def get_all_hashes(cls, create_date: int | None = None): -# with DbEngine.manager() as conn: -# if create_date: -# if cls.__tablename__ == "track": -# stmt = select(TrackTable.trackhash).where( -# cls.last_mod < create_date -# ) -# elif cls.__tablename__ == "album": -# stmt = select(AlbumTable.albumhash).where( -# cls.created_date < create_date -# ) -# elif cls.__tablename__ == "artist": -# stmt = select(ArtistTable.artisthash).where( -# cls.created_date < create_date -# ) -# else: -# if cls.__tablename__ == "track": -# stmt = select(TrackTable.trackhash) -# elif cls.__tablename__ == "album": -# stmt = select(AlbumTable.albumhash) -# elif cls.__tablename__ == "artist": -# stmt = select(ArtistTable.artisthash) - -# result = conn.execute(stmt) -# return {row[0] for row in result.fetchall()} - -# @classmethod -# def set_is_favorite(cls, hash: str, is_favorite: bool): -# """ -# Set the 'is_favorite' flag for a specific hash. - -# Args: -# hash (str): The hash value. -# is_favorite (bool): The value of the 'is_favorite' flag. -# """ -# with DbEngine.manager(commit=True) as conn: -# if cls.__tablename__ == "track": -# stmt = ( -# update(cls) -# .where(TrackTable.trackhash == hash) -# .values(is_favorite=is_favorite) -# ) -# elif cls.__tablename__ == "album": -# stmt = ( -# update(cls) -# .where(AlbumTable.albumhash == hash) -# .values(is_favorite=is_favorite) -# ) -# elif cls.__tablename__ == "artist": -# stmt = ( -# update(cls) -# .where(ArtistTable.artisthash == hash) -# .values(is_favorite=is_favorite) -# ) - -# conn.execute(stmt) - -# @classmethod -# def increment_scrobblecount( -# cls, table: Any, field: Any, hash: str, duration: int, timestamp: int -# ): -# cls.execute( -# update(table) -# .where(field == hash) -# .values( -# playcount=table.playcount + 1, -# playduration=table.playduration + duration, -# lastplayed=timestamp, -# ), -# commit=True, -# ) - - class TrackTable(Base): __tablename__ = "track" @@ -101,7 +15,6 @@ class TrackTable(Base): album: Mapped[str] = mapped_column(String()) albumartists: Mapped[str] = mapped_column(String()) albumhash: Mapped[str] = mapped_column(String(), index=True) - # artisthashes: Mapped[list[str]] = mapped_column(JSON(), index=True) artists: Mapped[str] = mapped_column(String()) bitrate: Mapped[int] = mapped_column(Integer()) copyright: Mapped[Optional[str]] = mapped_column(String()) @@ -110,11 +23,8 @@ class TrackTable(Base): duration: Mapped[int] = mapped_column(Integer()) filepath: Mapped[str] = mapped_column(String(), index=True, unique=True) folder: Mapped[str] = mapped_column(String(), index=True) - # genrehashes: Mapped[list[str]] = mapped_column(JSON(), index=True) genres: Mapped[Optional[str]] = mapped_column(String()) last_mod: Mapped[float] = mapped_column(Integer()) - # og_album: Mapped[str] = mapped_column(String()) - # og_title: Mapped[str] = mapped_column(String()) title: Mapped[str] = mapped_column(String()) track: Mapped[int] = mapped_column(Integer()) trackhash: Mapped[str] = mapped_column(String(), index=True) @@ -141,45 +51,6 @@ class TrackTable(Base): ) return tracks_to_dataclasses(result.fetchall()) - # @classmethod - # def get_tracks_by_albumhash(cls, albumhash: str): - # with DbEngine.manager() as conn: - # result = conn.execute( - # select(TrackTable).where(TrackTable.albumhash == albumhash) - # ) - # tracks = tracks_to_dataclasses(result.fetchall()) - # return remove_duplicates(tracks, is_album_tracks=True) - - # @classmethod - # def get_track_by_trackhash(cls, hash: str, filepath: str = ""): - # with DbEngine.manager() as conn: - # if filepath: - # result = conn.execute( - # select(TrackTable) - # .where( - # (TrackTable.trackhash == hash) - # & (TrackTable.filepath == filepath), - # ) - # .order_by(TrackTable.bitrate.desc()) - # ) - # else: - # result = conn.execute( - # select(TrackTable).where(TrackTable.trackhash == hash) - # ) - - # track = result.fetchone() - - # if track: - # return track_to_dataclass(track) - - # @classmethod - # def get_tracks_by_artisthash(cls, artisthash: str): - # with DbEngine.manager() as conn: - # result = conn.execute( - # select(TrackTable).where(TrackTable.artists.contains(artisthash)) - # ) - # return tracks_to_dataclasses(result.fetchall()) - @classmethod def get_tracks_in_path(cls, path: str): with DbEngine.manager() as conn: @@ -190,230 +61,7 @@ class TrackTable(Base): ) return tracks_to_dataclasses(result.fetchall()) - # @classmethod - # def get_tracks_by_trackhashes(cls, hashes: Iterable[str], limit: int | None = None): - # with DbEngine.manager() as conn: - # result = conn.execute( - # select(TrackTable) - # .where(TrackTable.trackhash.in_(hashes)) - # .group_by(TrackTable.trackhash) - # .limit(limit) - # ) - # tracks = tracks_to_dataclasses(result.fetchall()) - - # # order the tracks in the same order as the hashes - # if type(hashes) == list: - # return sorted(tracks, key=lambda x: hashes.index(x.trackhash)) - - # return tracks - - # @classmethod - # def get_recently_added(cls, start: int, limit: int): - # with DbEngine.manager() as conn: - # result = conn.execute( - # select(TrackTable) - # .order_by(TrackTable.last_mod.desc()) - # .offset(start) - # .limit(limit) - # ) - - # return tracks_to_dataclasses(result.fetchall()) - - @classmethod - # def get_recently_played(cls, limit: int): - # result = cls.execute( - # select(cls) - # .group_by(cls.trackhash) - # .order_by(cls.lastplayed.desc()) - # .limit(limit) - # ) - # return tracks_to_dataclasses(result.fetchall()) - @classmethod def remove_tracks_by_filepaths(cls, filepaths: set[str]): with DbEngine.manager(commit=True) as conn: conn.execute(delete(TrackTable).where(TrackTable.filepath.in_(filepaths))) - - # @classmethod - # def increment_playcount(cls, trackhash: str, duration: int, timestamp: int): - # cls.increment_scrobblecount( - # TrackTable, TrackTable.trackhash, trackhash, duration, timestamp - # ) - - # @classmethod - # def update_artist_separators(cls, separators: set[str]): - # tracks = cls.get_all() - - # with DbEngine.manager(commit=True) as conn: - # for track in tracks: - # track.split_artists(separators) - # conn.execute( - # update(cls) - # .where(cls.trackhash == track.trackhash) - # .values(artists=track.artists, artisthashes=track.artisthashes) - # ) - - -# class AlbumTable(Base): -# __tablename__ = "album" - -# id: Mapped[int] = mapped_column(primary_key=True) -# albumartists: Mapped[list[dict[str, str]]] = mapped_column(JSON(), index=True) -# artisthashes: Mapped[list[str]] = mapped_column(JSON(), index=True) -# albumhash: Mapped[str] = mapped_column(String(), unique=True, index=True) -# base_title: Mapped[str] = mapped_column(String()) -# color: Mapped[Optional[str]] = mapped_column(String()) -# created_date: Mapped[int] = mapped_column(Integer()) -# date: Mapped[int] = mapped_column(Integer()) -# duration: Mapped[int] = mapped_column(Integer()) -# genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True) -# genres: Mapped[str] = mapped_column(JSON()) -# og_title: Mapped[str] = mapped_column(String()) -# title: Mapped[str] = mapped_column(String()) -# trackcount: Mapped[int] = mapped_column(Integer()) -# lastplayed: Mapped[int] = mapped_column(Integer(), default=0) -# playcount: Mapped[int] = mapped_column(Integer(), default=0) -# playduration: Mapped[int] = mapped_column(Integer(), default=0) -# extra: Mapped[Optional[dict[str, Any]]] = mapped_column( -# JSON(), default_factory=dict -# ) - -# @classmethod -# def get_all(cls): -# with DbEngine.manager() as conn: -# result = conn.execute(select(AlbumTable)) -# all = result.fetchall() -# return albums_to_dataclasses(all) - -# @classmethod -# def get_album_by_albumhash(cls, hash: str): -# with DbEngine.manager() as conn: -# result = conn.execute( -# select(AlbumTable).where(AlbumTable.albumhash == hash) -# ) -# album = result.fetchone() - -# if album: -# return album_to_dataclass(album) - -# @classmethod -# def get_albums_by_albumhashes(cls, hashes: Iterable[str], limit: int | None = None): -# with DbEngine.manager() as conn: -# result = conn.execute( -# select(AlbumTable).where(AlbumTable.albumhash.in_(hashes)).limit(limit) -# ) -# albums = albums_to_dataclasses(result.fetchall()) - -# # order the albums in the same order as the hashes -# if type(hashes) == list: -# return sorted(albums, key=lambda x: hashes.index(x.albumhash)) - -# return albums - -# @classmethod -# def get_albums_by_artisthashes(cls, artisthashes: list[str]): -# with DbEngine.manager() as conn: -# albums: dict[str, list[AlbumModel]] = {} - -# for artist in artisthashes: -# result = conn.execute( -# select(AlbumTable).where(AlbumTable.artisthashes.contains(artist)) -# ) -# albums[artist] = albums_to_dataclasses(result.fetchall()) - -# return albums - -# @classmethod -# def get_albums_by_base_title(cls, base_title: str): -# with DbEngine.manager() as conn: -# result = conn.execute( -# select(AlbumTable).where(AlbumTable.base_title == base_title) -# ) -# return albums_to_dataclasses(result.fetchall()) - -# @classmethod -# def get_albums_by_artisthash(cls, artisthash: str): -# with DbEngine.manager() as conn: -# result = conn.execute( -# select(AlbumTable).where(AlbumTable.artisthashes.contains(artisthash)) -# ) -# return albums_to_dataclasses(result.all()) - -# @classmethod -# def increment_playcount(cls, albumhash: str, duration: int, timestamp: int): -# return cls.increment_scrobblecount( -# AlbumTable, AlbumTable.albumhash, albumhash, duration, timestamp -# ) - - -# class ArtistTable(Base): -# __tablename__ = "artist" - -# id: Mapped[int] = mapped_column(primary_key=True) -# albumcount: Mapped[int] = mapped_column(Integer()) -# artisthash: Mapped[str] = mapped_column(String(), unique=True, index=True) -# created_date: Mapped[int] = mapped_column(Integer()) -# date: Mapped[int] = mapped_column(Integer()) -# duration: Mapped[int] = mapped_column(Integer()) -# genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True) -# genres: Mapped[str] = mapped_column(JSON()) -# name: Mapped[str] = mapped_column(String(), index=True) -# trackcount: Mapped[int] = mapped_column(Integer()) -# lastplayed: Mapped[int] = mapped_column(Integer(), default=0) -# playcount: Mapped[int] = mapped_column(Integer(), default=0) -# playduration: Mapped[int] = mapped_column(Integer(), default=0) -# extra: Mapped[Optional[dict[str, Any]]] = mapped_column( -# JSON(), default_factory=dict -# ) - -# @classmethod -# def get_all(cls): -# with DbEngine.manager() as conn: -# result = conn.execute(select(cls)) -# all = result.fetchall() -# return artists_to_dataclasses(all) - -# @classmethod -# def get_artist_by_hash(cls, artisthash: str): -# with DbEngine.manager() as conn: -# result = conn.execute( -# select(ArtistTable).where(ArtistTable.artisthash == artisthash) -# ) -# return artist_to_dataclass(result.fetchone()) - -# @classmethod -# def get_artisthashes_not_in(cls, artisthashes: list[str]): -# with DbEngine.manager() as conn: -# result = conn.execute( -# select(ArtistTable.artisthash, ArtistTable.name).where( -# ~ArtistTable.artisthash.in_(artisthashes) -# ) -# ) -# return [{"artisthash": row[0], "name": row[1]} for row in result.fetchall()] - -# @classmethod -# def get_artists_by_artisthashes( -# cls, hashes: Iterable[str], limit: int | None = None -# ): -# with DbEngine.manager() as conn: -# result = conn.execute( -# select(ArtistTable) -# .where(ArtistTable.artisthash.in_(hashes)) -# .limit(limit) -# ) -# return artists_to_dataclasses(result.fetchall()) - -# @classmethod -# def increment_playcount( -# cls, artisthashes: list[str], duration: int, timestamp: int -# ): -# cls.execute( -# update(cls) -# .where(ArtistTable.artisthash.in_(artisthashes)) -# .values( -# playcount=ArtistTable.playcount + 1, -# playduration=ArtistTable.playduration + duration, -# lastplayed=timestamp, -# ), -# commit=True, -# ) diff --git a/app/db/userdata.py b/app/db/userdata.py index a4a3f48c..2ec438d8 100644 --- a/app/db/userdata.py +++ b/app/db/userdata.py @@ -1,3 +1,4 @@ +from dataclasses import asdict import datetime from typing import Any, Literal from sqlalchemy import ( @@ -31,6 +32,7 @@ from app.db.utils import ( ) from app.db import Base +from app.models.mix import Mix from app.utils.auth import get_current_userid, hash_password @@ -223,9 +225,7 @@ class FavoritesTable(Base): # .select_from(join(table, cls, field == cls.hash)) .where(and_(cls.type == type, cls.userid == get_current_userid())) .order_by(cls.timestamp.desc()) - .offset( - start - ) + .offset(start) # INFO: If start is 0, fetch all so we can get the total count .limit(limit if start != 0 else None) ) @@ -305,10 +305,15 @@ class ScrobbleTable(Base): return tracklog_to_dataclasses(result.fetchall()) @classmethod - def get_all_in_period(cls, start_time: int, end_time: int): + def get_all_in_period(cls, start_time: int, end_time: int, userid: int | None): + # UserId will be None if function is called from the API + # In that case, we use the request userid + if userid is None: + userid = get_current_userid() + result = cls.execute( select(cls) - .where(cls.userid == get_current_userid()) + .where(cls.userid == userid) .where(and_(cls.timestamp >= start_time, cls.timestamp <= end_time)) .order_by(cls.timestamp.desc()) ) @@ -458,3 +463,36 @@ class LibDataTable(Base): select(cls.itemhash, cls.color).where(cls.itemtype == type) ) return [{"itemhash": r[0], "color": r[1]} for r in result.fetchall()] + + +class MixTable(Base): + __tablename__ = "mix" + + id: Mapped[int] = mapped_column(primary_key=True) + mixid: Mapped[str] = mapped_column(String(), index=True) + title: Mapped[str] = mapped_column(String()) + description: Mapped[str] = mapped_column(String()) + timestamp: Mapped[int] = mapped_column(Integer()) + sourcehash: Mapped[str] = mapped_column(String(), unique=True, index=True) + tracks: Mapped[list[str]] = mapped_column(JSON(), default_factory=list) + extra: Mapped[dict[str, Any]] = mapped_column( + JSON(), nullable=True, default_factory=dict + ) + + @classmethod + def get_all(cls): + result = cls.execute(select(cls)) + return Mix.mixes_to_dataclasses(result.fetchall()) + + @classmethod + def get_by_sourcehash(cls, sourcehash: str): + result = cls.execute(select(cls).where(cls.sourcehash == sourcehash)) + return Mix.mix_to_dataclass(result.fetchone()) + + @classmethod + def insert_one(cls, mix: Mix): + mixdict = asdict(mix) + mixdict["mixid"] = mix.id + del mixdict["id"] + + return cls.execute(insert(cls).values(mixdict), commit=True) diff --git a/app/lib/playlistlib.py b/app/lib/playlistlib.py index 4757b323..e0767adc 100644 --- a/app/lib/playlistlib.py +++ b/app/lib/playlistlib.py @@ -10,7 +10,6 @@ from typing import Any from PIL import Image, ImageSequence from app import settings -from app.db.libdata import TrackTable from app.models.track import Track from app.store.albums import AlbumStore from app.store.tracks import TrackStore diff --git a/app/lib/recipes/__init__.py b/app/lib/recipes/__init__.py new file mode 100644 index 00000000..9065eb2e --- /dev/null +++ b/app/lib/recipes/__init__.py @@ -0,0 +1,66 @@ +""" +Recipes are a way to create mixes. +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from app.db.userdata import UserTable +from app.models.mix import Mix +from app.plugins.mixes import MixesPlugin +from app.store.homepage import HomepageStore + + +class HomepageRoutine(ABC): + """ + A routine creates a row of homepage items. + """ + + title: str + description: str + + items: List[Mix] + extra: Dict[str, Any] + + @property + @abstractmethod + def is_valid(self) -> bool: ... + + def __init__(self) -> None: + if not self.is_valid: + return + + self.items = self.run() + + @abstractmethod + def run(self) -> List[Mix]: + """ + Creates the homepage items and saves them to the + homepage store if self.is_valid is true. + """ + ... + + +class ArtistMixes(HomepageRoutine): + items: List[Mix] = [] + extra: Dict[str, Any] = {} + store_key = "artist_mixes" + + @property + def is_valid(self): + return MixesPlugin().enabled + + def run(self): + users = UserTable.get_all() + + for user in users: + mix = MixesPlugin() + mixes = mix.create_artist_mixes(user.id) + + if not mixes: + continue + + HomepageStore.set_mixes(mixes, mixkey=self.store_key, userid=user.id) + + def __init__(self) -> None: + super().__init__() diff --git a/app/models/mix.py b/app/models/mix.py index 0eb4efd2..5255b5cd 100644 --- a/app/models/mix.py +++ b/app/models/mix.py @@ -1,5 +1,6 @@ -from dataclasses import asdict, dataclass, field import time +from dataclasses import asdict, dataclass, field +from typing import Any from app.lib.playlistlib import get_first_4_images from app.serializers.track import serialize_tracks @@ -13,12 +14,17 @@ class Mix: title: str description: str tracks: list[str] + sourcehash: str + """ + A hash of the tracks used to generate the mix. + """ timestamp: int = field(default_factory=lambda: int(time.time())) extra: dict = field(default_factory=dict) saved: bool = False def to_full_dict(self): + # Limit track mix to 30 tracks tracks = TrackStore.get_tracks_by_trackhashes(self.tracks) serialized_tracks = serialize_tracks(tracks) @@ -34,7 +40,19 @@ class Mix: return _dict def to_dict(self): - item = self.to_full_dict() + item = asdict(self) del item["tracks"] return item + + @classmethod + def mix_to_dataclass(cls, entry: Any): + entry_dict = entry._asdict() + entry_dict["id"] = entry_dict["mixid"] + del entry_dict["mixid"] + + return Mix(**entry_dict) + + @classmethod + def mixes_to_dataclasses(cls, entries: Any): + return [cls.mix_to_dataclass(entry) for entry in entries] diff --git a/app/plugins/mixes.py b/app/plugins/mixes.py index 03e37121..d0b57e38 100644 --- a/app/plugins/mixes.py +++ b/app/plugins/mixes.py @@ -1,5 +1,6 @@ import datetime import json +from pprint import pprint import random import string import time @@ -18,6 +19,7 @@ from app.store.albums import AlbumStore from app.store.artists import ArtistStore from app.store.tracks import TrackStore from app.utils.dates import get_date_range, get_duration_ago +from app.utils.hashing import create_hash from app.utils.mixes import balance_mix from app.utils.remove_duplicates import remove_duplicates from app.utils.stats import get_artists_in_period @@ -25,7 +27,7 @@ from app.utils.stats import get_artists_in_period class MixesPlugin(Plugin): MAX_TRACKS_TO_FETCH = 5 - TRACK_MIX_LENGTH = 50 + TRACK_MIX_LENGTH = 30 MIN_TRACK_MIX_LENGTH = 15 MIN_DAY_LISTEN_DURATION = 3 * 60 # 3 minutes @@ -60,8 +62,6 @@ class MixesPlugin(Plugin): :param with_help: Whether to include the help flag in the query. The flag tells the server to find more data using other tracks from the same album. - - """ queries = [ { @@ -84,8 +84,6 @@ class MixesPlugin(Plugin): print("Failed to decode JSON response from recommendation server") return [] - # artisthashes = results["artists"] - # albumhashes = results["albums"] trackhashes: list[str] = results["tracks"] trackmatches = TrackStore.get_flat_list() @@ -104,13 +102,18 @@ class MixesPlugin(Plugin): # sort by trackhash order trackmatches = sorted(trackmatches, key=lambda x: trackhashes.index(x.weakhash)) - if len(trackmatches) < self.TRACK_MIX_LENGTH: + # if the mix is short, try to fill it up with tracks + # from album and artist data from the cloud! + + # Create as many filler tracks as possible + # Then the mix length will be controlled in the Mix model + # if len(trackmatches) < self.TRACK_MIX_LENGTH: + if True: filler_tracks = self.fallback_create_artist_mix( - artist=tracks[0].artists[0], similar_artists=results["artists"], similar_albums=results["albums"], omit_trackhashes={t.weakhash for t in trackmatches}, - limit=self.TRACK_MIX_LENGTH - len(trackmatches), + # limit=self.TRACK_MIX_LENGTH - len(trackmatches), ) trackmatches.extend(filler_tracks) @@ -123,15 +126,26 @@ class MixesPlugin(Plugin): """ Given an artisthash, creates an artist mix using the self.MAX_TRACKS_TO_FETCH most listened to tracks. + + Returns a tuple of the mix and the sourcehash. """ artist = ArtistStore.artistmap[artisthash] tracks = TrackStore.get_tracks_by_trackhashes(artist.trackhashes) tracks = sorted(tracks, key=lambda x: x.playduration, reverse=True) - return self.get_track_mix(tracks[: self.MAX_TRACKS_TO_FETCH]) + sourcetracks = tracks[: self.MAX_TRACKS_TO_FETCH] + sourcehash = create_hash(*[t.trackhash for t in sourcetracks]) + + # TODO: Check if we already have this mix in the + # database and return that instead + + return (self.get_track_mix(tracks[: self.MAX_TRACKS_TO_FETCH]), sourcehash) @plugin_method - def create_artist_mixes(self): + def create_artist_mixes(self, userid: int): + """ + Creates artist mixes for a given userid. + """ mixes: list[Mix] = [] indexed = set() @@ -143,22 +157,28 @@ class MixesPlugin(Plugin): artists = { "today": { "max": 3, - "artists": get_artists_in_period(today_start, today_end), + "artists": get_artists_in_period(today_start, today_end, userid), "created": 0, }, "last_2_days": { "max": 2, - "artists": get_artists_in_period(last_2_days_start, time.time()), + "artists": get_artists_in_period( + last_2_days_start, time.time(), userid + ), "created": 0, }, "last_7_days": { "max": 3, - "artists": get_artists_in_period(last_7_days_start, time.time()), + "artists": get_artists_in_period( + last_7_days_start, time.time(), userid + ), "created": 0, }, "last_1_month": { "max": 2, - "artists": get_artists_in_period(last_1_month_start, time.time()), + "artists": get_artists_in_period( + last_1_month_start, time.time(), userid + ), "created": 0, }, } @@ -224,7 +244,7 @@ class MixesPlugin(Plugin): if not _artist: return None - mix_tracks = self.get_artist_mix(artist["artisthash"]) + mix_tracks, sourcehash = self.get_artist_mix(artist["artisthash"]) if len(mix_tracks) < self.MIN_TRACK_MIX_LENGTH: return None @@ -243,11 +263,13 @@ class MixesPlugin(Plugin): title=artist["artist"] + " Radio", description=self.get_mix_description(mix_tracks, artist["artisthash"]), tracks=[t.trackhash for t in mix_tracks], + sourcehash=sourcehash, extra={ "type": "artist", "artisthash": artist["artisthash"], "image": mix_image, }, + timestamp=int(time.time()), ) def download_artist_image(self, artist: Artist): @@ -281,23 +303,23 @@ class MixesPlugin(Plugin): def fallback_create_artist_mix( self, - artist: dict[str, str], + # artist: dict[str, str], similar_albums: list[str], similar_artists: list[str], omit_trackhashes: set[str], - limit: int, + limit: int = 99, ): """ - Creates an artist mix by selecting random tracks from similar artists. + Creates an artist mix by selecting random tracks from similar albums and artists. This is used when: - The Swing Music recommendation server is down. - The artist has less than self.MIN_TRACK_MIX_LENGTH tracks from the cloud mix. - When we need to dilute the mix to balance the artist distribution. - :param artist: The artist to create a mix for. - :param similar_artists: A list of similar artists to select tracks from. If not provided, we try reading from the local database. When we exhaust the passed list, we also try reading from the local database. - :param trackhashes: A set of trackhashes to omit from the new tracklist. + :param similar_albums: A list of similar album weakhashes to select tracks from. + :param similar_artists: A list of similar artist hashes to select tracks from. + :param omit_trackhashes: A set of trackhashes to omit from the new tracklist. :param limit: The maximum number of tracks to select. """ @@ -356,20 +378,5 @@ class MixesPlugin(Plugin): return mixtracks - # if len(similar_artists) == 0: - # local_similar_artists = SimilarArtistTable.get_by_hash(artist["artisthash"]) - - # if local_similar_artists: - # artists = [a.artisthash for a in local_similar_artists.similar_artists] - - # if len(artists) == 0: - # return [] - - # CHECKPOINT: I'M TIRED AF AND I NEED TO SLEEP - # The plan: - # Figure out which artists we should skip for the new tracklist - # these would be artists with a large number of tracks in the mix already - - # Since the artisthashes are ordered by similarity score, we iterate from the start - # and go forward collecting tracks that aren't in the mix yet. - # + def get_mix_from_lastfm_data(self, artisthash: str, limit: int): + pass diff --git a/app/store/albums.py b/app/store/albums.py index ee4572ac..ed1860f1 100644 --- a/app/store/albums.py +++ b/app/store/albums.py @@ -1,20 +1,14 @@ -from itertools import groupby -import json -from pprint import pprint import random from typing import Iterable from app.lib.tagger import create_albums from app.models import Album, Track from app.store.artists import ArtistStore -from app.utils import flatten from app.utils.auth import get_current_userid from app.utils.customlist import CustomList -from app.utils.remove_duplicates import remove_duplicates from ..utils.hashing import create_hash from .tracks import TrackStore -from app.utils.progressbar import tqdm ALBUM_LOAD_KEY = "" diff --git a/app/store/homepage.py b/app/store/homepage.py index 49fbf032..02aa9ed3 100644 --- a/app/store/homepage.py +++ b/app/store/homepage.py @@ -1,30 +1,102 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Any from app.models.mix import Mix -from app.store.tracks import TrackStore from app.utils.auth import get_current_userid +@dataclass +class HomepageEntry(ABC): + """ + Base class for all homepage entries. + + items is a dict of userid to a dict of stuff. + """ + + title: str + description: str + items: dict[int, dict[str, Any]] + + def __init__(self, title: str, description: str): + self.title = title + self.description = description + + def get_items(self, userid: int): + """ + Return usable items for the homepage. + """ + ... + + +@dataclass +class MixHomepageEntry(HomepageEntry): + """ + A homepage entry for mixes. + self.items is a dict of userid to a dict of mixid to mix. + """ + + items: dict[int, dict[str, Mix]] + + def __init__(self, title: str, description: str): + super().__init__(title, description) + self.items = {} + + def get_items(self, userid: int, limit: int | None = None): + items = [] + + for mix in self.items.get(userid, {}).values(): + if limit and len(items) >= limit: + break + + items.append( + { + "type": "mix", + "item": mix.to_dict(), + } + ) + + return { + "title": self.title, + "description": self.description, + "items": items, + } + + class HomepageStore: + """ + Stores the homepage items. + """ + entries = { - "artist_mixes": {}, + "artist_mixes": MixHomepageEntry( + title="Artist mixes for you", + description="Based on artists you have been listening to", + ), } @classmethod - def set_artist_mixes(cls, mixes: list[Mix], userid: int = 1): + def set_mixes(cls, mixes: list[Mix], mixkey: str, userid: int | None = None): idmap = {mix.id[1:]: mix for mix in mixes} - cls.entries["artist_mixes"][userid] = idmap + cls.entries[mixkey].items[userid or get_current_userid()] = idmap @classmethod - def get_artist_mixes(cls): - return [ - { - "type": "mix", - "item": mix.to_dict(), - } - for mix in cls.entries["artist_mixes"] - .get(get_current_userid(), {}) - .values() - ] + def get_mixes(cls, mixkey: str, limit: int | None = 9): + return cls.entries[mixkey].get_items(get_current_userid(), limit) @classmethod - def get_mix(cls, mixtype: str, mixid: str): - return cls.entries[mixtype].get(get_current_userid(), {}).get(mixid).to_full_dict() + def get_mix(cls, mixkey: str, mixid: str): + mix = cls.entries[mixkey].items.get(get_current_userid(), {}).get(mixid) + return mix.to_full_dict() if mix else None + + @classmethod + def get_mix_by_sourcehash(cls, sourcehash: str): + return next( + ( + mix + for mix in cls.entries["artist_mixes"] + .items.get(get_current_userid(), {}) + .values() + if mix.sourcehash == sourcehash + ), + None, + ) diff --git a/app/utils/stats.py b/app/utils/stats.py index 86df3eb4..0ed38eb1 100644 --- a/app/utils/stats.py +++ b/app/utils/stats.py @@ -11,8 +11,10 @@ from app.store.tracks import TrackStore from app.utils.dates import seconds_to_time_string -def get_artists_in_period(start_time: int | float, end_time: int | float): - scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time) +def get_artists_in_period( + start_time: int | float, end_time: int | float, userid: int | None = None +): + scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time, userid) artists: Any = defaultdict(lambda: {"playcount": 0, "playduration": 0}) for scrobble in scrobbles: @@ -33,8 +35,8 @@ def get_artists_in_period(start_time: int | float, end_time: int | float): return sorted(artists, key=lambda x: x["playduration"], reverse=True) -def get_albums_in_period(start_time: int, end_time: int): - scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time) +def get_albums_in_period(start_time: int, end_time: int, userid: int | None = None): + scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time, userid) albums: dict[str, Album] = {} for scrobble in scrobbles: @@ -60,8 +62,8 @@ def get_albums_in_period(start_time: int, end_time: int): return list(albums.values()) -def get_tracks_in_period(start_time: int, end_time: int): - scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time) +def get_tracks_in_period(start_time: int, end_time: int, userid: int | None = None): + scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time, userid) tracks: dict[str, Track] = {} duration = 0 @@ -160,12 +162,14 @@ def calculate_scrobble_trend(current_scrobbles: int, previous_scrobbles: int) -> ) -def calculate_new_artists(current_artists: List[dict[str, Any]], timestamp: int): +def calculate_new_artists( + current_artists: List[dict[str, Any]], timestamp: int, userid: int | None = None +): """ Calculate the number of new artists based on the current and all previous scrobbles. """ current_artists_set = set(artist["artisthash"] for artist in current_artists) - all_records = ScrobbleTable.get_all_in_period(0, timestamp) + all_records = ScrobbleTable.get_all_in_period(0, timestamp, userid) trackhashes = set(record.trackhash for record in all_records) previous_artists_set = set()