save mixes to the db

This commit is contained in:
cwilvx
2024-11-17 21:38:51 +03:00
parent ef4ecc2499
commit dd2bb16a8c
5 changed files with 79 additions and 34 deletions
+1 -1
View File
@@ -31,7 +31,7 @@ def get_track_mix():
@api.post("/artist") @api.post("/artist")
def get_artist_mix(): def get_artist_mix():
mixes = MixesPlugin() mixes = MixesPlugin()
return mixes.create_artist_mixes() # return mixes.create_artist_mixes()
# tracks = mixes.get_artist_mix("09306be8039b98ad") # tracks = mixes.get_artist_mix("09306be8039b98ad")
# return { # return {
+1 -1
View File
@@ -8,7 +8,7 @@ class Mixes(CronJob):
""" """
name: str = "mixes" name: str = "mixes"
hours: int = 1 hours: int = 6
def __init__(self): def __init__(self):
super().__init__() super().__init__()
+5 -1
View File
@@ -483,6 +483,7 @@ class MixTable(Base):
timestamp: Mapped[int] = mapped_column(Integer()) timestamp: Mapped[int] = mapped_column(Integer())
sourcehash: Mapped[str] = mapped_column(String(), unique=True, index=True) sourcehash: Mapped[str] = mapped_column(String(), unique=True, index=True)
tracks: Mapped[list[str]] = mapped_column(JSON(), default_factory=list) tracks: Mapped[list[str]] = mapped_column(JSON(), default_factory=list)
saved: Mapped[bool] = mapped_column(Boolean(), default=False)
extra: Mapped[dict[str, Any]] = mapped_column( extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict JSON(), nullable=True, default_factory=dict
) )
@@ -495,7 +496,10 @@ class MixTable(Base):
@classmethod @classmethod
def get_by_sourcehash(cls, sourcehash: str): def get_by_sourcehash(cls, sourcehash: str):
result = cls.execute(select(cls).where(cls.sourcehash == sourcehash)) result = cls.execute(select(cls).where(cls.sourcehash == sourcehash))
return Mix.mix_to_dataclass(result.fetchone())
res = result.fetchone()
if res:
return Mix.mix_to_dataclass(res)
@classmethod @classmethod
def insert_one(cls, mix: Mix): def insert_one(cls, mix: Mix):
+7 -2
View File
@@ -24,8 +24,7 @@ class Mix:
saved: bool = False saved: bool = False
def to_full_dict(self): def to_full_dict(self):
# Limit track mix to 30 tracks tracks = TrackStore.get_tracks_by_trackhashes(self.tracks)[:40]
tracks = TrackStore.get_tracks_by_trackhashes(self.tracks)
serialized_tracks = serialize_tracks(tracks) serialized_tracks = serialize_tracks(tracks)
_dict = asdict(self) _dict = asdict(self)
@@ -37,12 +36,18 @@ class Mix:
_dict["duration"] = seconds_to_time_string(sum(t.duration for t in tracks)) _dict["duration"] = seconds_to_time_string(sum(t.duration for t in tracks))
_dict["trackcount"] = len(tracks) _dict["trackcount"] = len(tracks)
del _dict["extra"]["albums"]
del _dict["extra"]["artists"]
return _dict return _dict
def to_dict(self): def to_dict(self):
item = asdict(self) item = asdict(self)
del item["tracks"] del item["tracks"]
del item["extra"]["albums"]
del item["extra"]["artists"]
return item return item
@classmethod @classmethod
+65 -29
View File
@@ -8,7 +8,7 @@ import requests
from urllib.parse import quote from urllib.parse import quote
from PIL import Image from PIL import Image
from app.db.userdata import SimilarArtistTable from app.db.userdata import MixTable, SimilarArtistTable
from app.lib.colorlib import get_image_colors from app.lib.colorlib import get_image_colors
from app.models.artist import Artist from app.models.artist import Artist
from app.models.mix import Mix from app.models.mix import Mix
@@ -25,6 +25,14 @@ from app.utils.remove_duplicates import remove_duplicates
from app.utils.stats import get_artists_in_period from app.utils.stats import get_artists_in_period
class MixAlreadyExists(Exception):
"""
Raised when a mix with the same sourcehash already exists.
"""
pass
class MixesPlugin(Plugin): class MixesPlugin(Plugin):
MAX_TRACKS_TO_FETCH = 5 MAX_TRACKS_TO_FETCH = 5
TRACK_MIX_LENGTH = 30 TRACK_MIX_LENGTH = 30
@@ -42,13 +50,22 @@ class MixesPlugin(Plugin):
self.set_active(server_online) self.set_active(server_online)
def ping_server(self): def ping_server(self):
try: max_retries = 3
requests.get(self.server, timeout=10) retry_delay = 2 # seconds
except requests.exceptions.ConnectionError:
print("Failed to connect to the recommendation server")
return False
return True for attempt in range(max_retries):
try:
requests.get(self.server, timeout=10)
return True
except Exception as e:
print(
f"Failed to connect to the recommendation server (attempt {attempt + 1}/{max_retries})"
)
if attempt < max_retries - 1:
time.sleep(retry_delay)
continue
return False
@plugin_method @plugin_method
def get_track_mix(self, tracks: list[Track], with_help: bool = False): def get_track_mix(self, tracks: list[Track], with_help: bool = False):
@@ -73,16 +90,16 @@ class MixesPlugin(Plugin):
] ]
try: try:
response = requests.post(f"{self.server}/radio", json=queries, timeout=10) response = requests.post(f"{self.server}/radio", json=queries, timeout=30)
except requests.exceptions.ConnectionError: except (requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout):
print("Failed to connect to recommendation server") print("Failed to connect to recommendation server")
return [] return [], [], []
try: try:
results = response.json() results = response.json()
except json.JSONDecodeError: except json.JSONDecodeError:
print("Failed to decode JSON response from recommendation server") print("Failed to decode JSON response from recommendation server")
return [] return [], [], []
trackhashes: list[str] = results["tracks"] trackhashes: list[str] = results["tracks"]
@@ -119,7 +136,7 @@ class MixesPlugin(Plugin):
# try to balance the mix # try to balance the mix
trackmatches = balance_mix(trackmatches) trackmatches = balance_mix(trackmatches)
return trackmatches return trackmatches, results["albums"], results["albums"]
@plugin_method @plugin_method
def get_artist_mix(self, artisthash: str): def get_artist_mix(self, artisthash: str):
@@ -136,10 +153,11 @@ class MixesPlugin(Plugin):
sourcetracks = tracks[: self.MAX_TRACKS_TO_FETCH] sourcetracks = tracks[: self.MAX_TRACKS_TO_FETCH]
sourcehash = create_hash(*[t.trackhash for t in sourcetracks]) sourcehash = create_hash(*[t.trackhash for t in sourcetracks])
# TODO: Check if we already have this mix in the if MixTable.get_by_sourcehash(sourcehash):
# database and return that instead raise MixAlreadyExists()
return (self.get_track_mix(tracks[: self.MAX_TRACKS_TO_FETCH]), sourcehash) tracks, albums, artists = self.get_track_mix(tracks[: self.MAX_TRACKS_TO_FETCH])
return (tracks, albums, artists, sourcehash)
@plugin_method @plugin_method
def create_artist_mixes(self, userid: int): def create_artist_mixes(self, userid: int):
@@ -239,25 +257,36 @@ class MixesPlugin(Plugin):
""" """
Given an artist dict, creates an artist mix. Given an artist dict, creates an artist mix.
""" """
_artist = ArtistStore.get_artist_by_hash(artist["artisthash"]) _artist = ArtistStore.artistmap.get(artist["artisthash"])
if not _artist: if not _artist:
return None return None
mix_tracks, sourcehash = self.get_artist_mix(artist["artisthash"]) tracks = TrackStore.get_tracks_by_trackhashes(_artist.trackhashes)
tracks = sorted(tracks, key=lambda x: x.playduration, reverse=True)
sourcetracks = tracks[: self.MAX_TRACKS_TO_FETCH]
sourcehash = create_hash(*[t.trackhash for t in sourcetracks])
db_mix = MixTable.get_by_sourcehash(sourcehash)
if db_mix:
print(f"🔍 Found existing mix for {_artist.artist.name}")
print(db_mix.title)
return db_mix
mix_tracks, albums, artists = self.get_track_mix(sourcetracks)
if len(mix_tracks) < self.MIN_TRACK_MIX_LENGTH: if len(mix_tracks) < self.MIN_TRACK_MIX_LENGTH:
return None return None
# try downloading artist image # try downloading artist image
mix_image = {"image": _artist.image, "color": _artist.color} mix_image = {"image": _artist.artist.image, "color": _artist.artist.color}
downloaded_img_color = self.download_artist_image(_artist) downloaded_img_color = self.download_artist_image(_artist.artist)
if downloaded_img_color: if downloaded_img_color:
mix_image["image"] = f"{_artist.artisthash}.jpg" mix_image["image"] = f"{_artist.artist.artisthash}.jpg"
mix_image["color"] = downloaded_img_color[0] mix_image["color"] = downloaded_img_color[0]
return Mix( mix = Mix(
# the a prefix indicates that this is an artist mix # the a prefix indicates that this is an artist mix
id=f"a{artist['artisthash']}", id=f"a{artist['artisthash']}",
title=artist["artist"] + " Radio", title=artist["artist"] + " Radio",
@@ -268,10 +297,18 @@ class MixesPlugin(Plugin):
"type": "artist", "type": "artist",
"artisthash": artist["artisthash"], "artisthash": artist["artisthash"],
"image": mix_image, "image": mix_image,
# NOTE: Save the similar albums and artists
# Related to the source tracks that were used to create the mix
# Will be useful when generating other homepage entries
"albums": albums,
"artists": artists,
}, },
timestamp=int(time.time()), timestamp=int(time.time()),
) )
MixTable.insert_one(mix)
return mix
def download_artist_image(self, artist: Artist): def download_artist_image(self, artist: Artist):
try: try:
res = requests.get(f"{self.server}/image?artist={artist.name}") res = requests.get(f"{self.server}/image?artist={artist.name}")
@@ -332,7 +369,6 @@ class MixesPlugin(Plugin):
for match in albummatches: for match in albummatches:
if len(mixtracks) >= limit: if len(mixtracks) >= limit:
print(f"Filled up to {limit} tracks with album tracks")
return mixtracks return mixtracks
albumtracks = [ albumtracks = [
@@ -346,9 +382,6 @@ class MixesPlugin(Plugin):
sample = random.sample(albumtracks, k=1) sample = random.sample(albumtracks, k=1)
mixtracks.extend(sample) mixtracks.extend(sample)
print(
f"Supplement: album track {sample[0].title} from ALBUM: {match.album.og_title}"
)
artistmatches = ( artistmatches = (
a a
@@ -358,7 +391,6 @@ class MixesPlugin(Plugin):
for match in artistmatches: for match in artistmatches:
if len(mixtracks) >= limit: if len(mixtracks) >= limit:
print(f"Filled up to {limit} tracks with artist tracks")
return mixtracks return mixtracks
artisttracks = [ artisttracks = [
@@ -372,11 +404,15 @@ class MixesPlugin(Plugin):
sample = random.sample(artisttracks, k=1) sample = random.sample(artisttracks, k=1)
mixtracks.extend(sample) mixtracks.extend(sample)
print(
f"Supplement: track {sample[0].title} from ARTIST: {match.artist.name}"
)
return mixtracks return mixtracks
def get_mix_from_lastfm_data(self, artisthash: str, limit: int): def get_mix_from_lastfm_data(self, artisthash: str, limit: int):
"""
Creates a mix from the locally available lastfm similar artists data.
The resulting mix is definitely expected to be of low quality.
TODO: Implement this!
"""
pass pass