diff --git a/app/api/plugins/mixes.py b/app/api/plugins/mixes.py index e513ac6b..f79e3762 100644 --- a/app/api/plugins/mixes.py +++ b/app/api/plugins/mixes.py @@ -31,7 +31,7 @@ def get_track_mix(): @api.post("/artist") def get_artist_mix(): mixes = MixesPlugin() - return mixes.create_artist_mixes() + # return mixes.create_artist_mixes() # tracks = mixes.get_artist_mix("09306be8039b98ad") # return { diff --git a/app/crons/mixes.py b/app/crons/mixes.py index db8c2773..08274896 100644 --- a/app/crons/mixes.py +++ b/app/crons/mixes.py @@ -8,7 +8,7 @@ class Mixes(CronJob): """ name: str = "mixes" - hours: int = 1 + hours: int = 6 def __init__(self): super().__init__() diff --git a/app/db/userdata.py b/app/db/userdata.py index f529413d..8bb35d31 100644 --- a/app/db/userdata.py +++ b/app/db/userdata.py @@ -483,6 +483,7 @@ class MixTable(Base): timestamp: Mapped[int] = mapped_column(Integer()) sourcehash: Mapped[str] = mapped_column(String(), unique=True, index=True) tracks: Mapped[list[str]] = mapped_column(JSON(), default_factory=list) + saved: Mapped[bool] = mapped_column(Boolean(), default=False) extra: Mapped[dict[str, Any]] = mapped_column( JSON(), nullable=True, default_factory=dict ) @@ -495,7 +496,10 @@ class MixTable(Base): @classmethod def get_by_sourcehash(cls, sourcehash: str): result = cls.execute(select(cls).where(cls.sourcehash == sourcehash)) - return Mix.mix_to_dataclass(result.fetchone()) + + res = result.fetchone() + if res: + return Mix.mix_to_dataclass(res) @classmethod def insert_one(cls, mix: Mix): diff --git a/app/models/mix.py b/app/models/mix.py index 5255b5cd..091d4616 100644 --- a/app/models/mix.py +++ b/app/models/mix.py @@ -24,8 +24,7 @@ class Mix: saved: bool = False def to_full_dict(self): - # Limit track mix to 30 tracks - tracks = TrackStore.get_tracks_by_trackhashes(self.tracks) + tracks = TrackStore.get_tracks_by_trackhashes(self.tracks)[:40] serialized_tracks = serialize_tracks(tracks) _dict = asdict(self) @@ -37,12 +36,18 @@ class Mix: _dict["duration"] = seconds_to_time_string(sum(t.duration for t in tracks)) _dict["trackcount"] = len(tracks) + del _dict["extra"]["albums"] + del _dict["extra"]["artists"] + return _dict def to_dict(self): item = asdict(self) del item["tracks"] + del item["extra"]["albums"] + del item["extra"]["artists"] + return item @classmethod diff --git a/app/plugins/mixes.py b/app/plugins/mixes.py index d0b57e38..69d60e03 100644 --- a/app/plugins/mixes.py +++ b/app/plugins/mixes.py @@ -8,7 +8,7 @@ import requests from urllib.parse import quote from PIL import Image -from app.db.userdata import SimilarArtistTable +from app.db.userdata import MixTable, SimilarArtistTable from app.lib.colorlib import get_image_colors from app.models.artist import Artist from app.models.mix import Mix @@ -25,6 +25,14 @@ from app.utils.remove_duplicates import remove_duplicates from app.utils.stats import get_artists_in_period +class MixAlreadyExists(Exception): + """ + Raised when a mix with the same sourcehash already exists. + """ + + pass + + class MixesPlugin(Plugin): MAX_TRACKS_TO_FETCH = 5 TRACK_MIX_LENGTH = 30 @@ -42,13 +50,22 @@ class MixesPlugin(Plugin): self.set_active(server_online) def ping_server(self): - try: - requests.get(self.server, timeout=10) - except requests.exceptions.ConnectionError: - print("Failed to connect to the recommendation server") - return False + max_retries = 3 + retry_delay = 2 # seconds - return True + for attempt in range(max_retries): + try: + requests.get(self.server, timeout=10) + return True + except Exception as e: + print( + f"Failed to connect to the recommendation server (attempt {attempt + 1}/{max_retries})" + ) + if attempt < max_retries - 1: + time.sleep(retry_delay) + continue + + return False @plugin_method def get_track_mix(self, tracks: list[Track], with_help: bool = False): @@ -73,16 +90,16 @@ class MixesPlugin(Plugin): ] try: - response = requests.post(f"{self.server}/radio", json=queries, timeout=10) - except requests.exceptions.ConnectionError: + response = requests.post(f"{self.server}/radio", json=queries, timeout=30) + except (requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout): print("Failed to connect to recommendation server") - return [] + return [], [], [] try: results = response.json() except json.JSONDecodeError: print("Failed to decode JSON response from recommendation server") - return [] + return [], [], [] trackhashes: list[str] = results["tracks"] @@ -119,7 +136,7 @@ class MixesPlugin(Plugin): # try to balance the mix trackmatches = balance_mix(trackmatches) - return trackmatches + return trackmatches, results["albums"], results["albums"] @plugin_method def get_artist_mix(self, artisthash: str): @@ -136,10 +153,11 @@ class MixesPlugin(Plugin): sourcetracks = tracks[: self.MAX_TRACKS_TO_FETCH] sourcehash = create_hash(*[t.trackhash for t in sourcetracks]) - # TODO: Check if we already have this mix in the - # database and return that instead + if MixTable.get_by_sourcehash(sourcehash): + raise MixAlreadyExists() - return (self.get_track_mix(tracks[: self.MAX_TRACKS_TO_FETCH]), sourcehash) + tracks, albums, artists = self.get_track_mix(tracks[: self.MAX_TRACKS_TO_FETCH]) + return (tracks, albums, artists, sourcehash) @plugin_method def create_artist_mixes(self, userid: int): @@ -239,25 +257,36 @@ class MixesPlugin(Plugin): """ Given an artist dict, creates an artist mix. """ - _artist = ArtistStore.get_artist_by_hash(artist["artisthash"]) + _artist = ArtistStore.artistmap.get(artist["artisthash"]) if not _artist: return None - mix_tracks, sourcehash = self.get_artist_mix(artist["artisthash"]) + tracks = TrackStore.get_tracks_by_trackhashes(_artist.trackhashes) + tracks = sorted(tracks, key=lambda x: x.playduration, reverse=True) + sourcetracks = tracks[: self.MAX_TRACKS_TO_FETCH] + sourcehash = create_hash(*[t.trackhash for t in sourcetracks]) + + db_mix = MixTable.get_by_sourcehash(sourcehash) + if db_mix: + print(f"🔍 Found existing mix for {_artist.artist.name}") + print(db_mix.title) + return db_mix + + mix_tracks, albums, artists = self.get_track_mix(sourcetracks) if len(mix_tracks) < self.MIN_TRACK_MIX_LENGTH: return None # try downloading artist image - mix_image = {"image": _artist.image, "color": _artist.color} - downloaded_img_color = self.download_artist_image(_artist) + mix_image = {"image": _artist.artist.image, "color": _artist.artist.color} + downloaded_img_color = self.download_artist_image(_artist.artist) if downloaded_img_color: - mix_image["image"] = f"{_artist.artisthash}.jpg" + mix_image["image"] = f"{_artist.artist.artisthash}.jpg" mix_image["color"] = downloaded_img_color[0] - return Mix( + mix = Mix( # the a prefix indicates that this is an artist mix id=f"a{artist['artisthash']}", title=artist["artist"] + " Radio", @@ -268,10 +297,18 @@ class MixesPlugin(Plugin): "type": "artist", "artisthash": artist["artisthash"], "image": mix_image, + # NOTE: Save the similar albums and artists + # Related to the source tracks that were used to create the mix + # Will be useful when generating other homepage entries + "albums": albums, + "artists": artists, }, timestamp=int(time.time()), ) + MixTable.insert_one(mix) + return mix + def download_artist_image(self, artist: Artist): try: res = requests.get(f"{self.server}/image?artist={artist.name}") @@ -332,7 +369,6 @@ class MixesPlugin(Plugin): for match in albummatches: if len(mixtracks) >= limit: - print(f"Filled up to {limit} tracks with album tracks") return mixtracks albumtracks = [ @@ -346,9 +382,6 @@ class MixesPlugin(Plugin): sample = random.sample(albumtracks, k=1) mixtracks.extend(sample) - print( - f"Supplement: album track {sample[0].title} from ALBUM: {match.album.og_title}" - ) artistmatches = ( a @@ -358,7 +391,6 @@ class MixesPlugin(Plugin): for match in artistmatches: if len(mixtracks) >= limit: - print(f"Filled up to {limit} tracks with artist tracks") return mixtracks artisttracks = [ @@ -372,11 +404,15 @@ class MixesPlugin(Plugin): sample = random.sample(artisttracks, k=1) mixtracks.extend(sample) - print( - f"Supplement: track {sample[0].title} from ARTIST: {match.artist.name}" - ) return mixtracks def get_mix_from_lastfm_data(self, artisthash: str, limit: int): + """ + Creates a mix from the locally available lastfm similar artists data. + + The resulting mix is definitely expected to be of low quality. + + TODO: Implement this! + """ pass