supplement mixes using other remote similar albums and artist data

This commit is contained in:
cwilvx
2024-11-01 12:23:41 +03:00
parent eb4c65de83
commit 1fdd5ba4d1
8 changed files with 171 additions and 31 deletions
+1 -1
View File
@@ -20,7 +20,7 @@ def get_track_mix():
""" """
mixes = MixesPlugin() mixes = MixesPlugin()
track = TrackStore.trackhashmap["9eeee292264ad01b"].get_best() track = TrackStore.trackhashmap["9eeee292264ad01b"].get_best()
tracks = mixes.get_track_mix(track) tracks = mixes.get_track_mix([track])
return { return {
"total": len(tracks), "total": len(tracks),
+6
View File
@@ -7,6 +7,7 @@ class Mixes(CronJob):
""" """
This cron job creates mixes displayed on the homepage. This cron job creates mixes displayed on the homepage.
""" """
def __init__(self): def __init__(self):
super().__init__("mixes", 5) super().__init__("mixes", 5)
@@ -16,6 +17,11 @@ class Mixes(CronJob):
""" """
print("⭐⭐⭐⭐ Mixes cron job running") print("⭐⭐⭐⭐ Mixes cron job running")
mixes = MixesPlugin() mixes = MixesPlugin()
if not mixes.enabled:
return
artist_mixes = mixes.create_artist_mixes() artist_mixes = mixes.create_artist_mixes()
if artist_mixes:
HomepageStore.set_artist_mixes(artist_mixes) HomepageStore.set_artist_mixes(artist_mixes)
+4
View File
@@ -36,6 +36,7 @@ class Album:
image: str = "" image: str = ""
versions: list[str] = dataclasses.field(default_factory=list) versions: list[str] = dataclasses.field(default_factory=list)
fav_userids: list[int] = dataclasses.field(default_factory=list) fav_userids: list[int] = dataclasses.field(default_factory=list)
weakhash: str = ""
@property @property
def is_favorite(self): def is_favorite(self):
@@ -54,6 +55,9 @@ class Album:
def __post_init__(self): def __post_init__(self):
self.image = self.albumhash + ".webp" self.image = self.albumhash + ".webp"
self.populate_versions() self.populate_versions()
self.weakhash = create_hash(
self.og_title, ",".join(a["name"] for a in self.albumartists)
)
def populate_versions(self): def populate_versions(self):
_, self.versions = get_base_title_and_versions(self.og_title, get_versions=True) _, self.versions = get_base_title_and_versions(self.og_title, get_versions=True)
+110 -20
View File
@@ -1,5 +1,8 @@
import datetime
import json import json
import random
import string import string
import time
import requests import requests
from urllib.parse import quote from urllib.parse import quote
from PIL import Image from PIL import Image
@@ -11,9 +14,10 @@ from app.models.mix import Mix
from app.models.track import Track from app.models.track import Track
from app.plugins import Plugin, plugin_method from app.plugins import Plugin, plugin_method
from app.settings import Paths from app.settings import Paths
from app.store.albums import AlbumStore
from app.store.artists import ArtistStore from app.store.artists import ArtistStore
from app.store.tracks import TrackStore from app.store.tracks import TrackStore
from app.utils.dates import get_date_range from app.utils.dates import get_date_range, get_duration_ago
from app.utils.mixes import balance_mix from app.utils.mixes import balance_mix
from app.utils.remove_duplicates import remove_duplicates from app.utils.remove_duplicates import remove_duplicates
from app.utils.stats import get_artists_in_period from app.utils.stats import get_artists_in_period
@@ -31,7 +35,18 @@ class MixesPlugin(Plugin):
def __init__(self): def __init__(self):
super().__init__("mixes", "Mixes") super().__init__("mixes", "Mixes")
self.server = "https://smcloud.mungaist.com" self.server = "https://smcloud.mungaist.com"
self.set_active(True)
server_online = self.ping_server()
self.set_active(server_online)
def ping_server(self):
try:
requests.get(self.server, timeout=10)
except requests.exceptions.ConnectionError:
print("Failed to connect to the recommendation server")
return False
return True
@plugin_method @plugin_method
def get_track_mix(self, tracks: list[Track], with_help: bool = False): def get_track_mix(self, tracks: list[Track], with_help: bool = False):
@@ -57,13 +72,20 @@ class MixesPlugin(Plugin):
for track in tracks for track in tracks
] ]
response = requests.post( try:
f"{self.server}/radio", response = requests.post(f"{self.server}/radio", json=queries, timeout=10)
json=queries, except requests.exceptions.ConnectionError:
) print("Failed to connect to recommendation server")
return []
try:
results = response.json() results = response.json()
except json.JSONDecodeError:
print("Failed to decode JSON response from recommendation server")
return []
# artisthashes = results["artists"] # artisthashes = results["artists"]
# albumhashes = results["albums"]
trackhashes: list[str] = results["tracks"] trackhashes: list[str] = results["tracks"]
trackmatches = TrackStore.get_flat_list() trackmatches = TrackStore.get_flat_list()
@@ -82,6 +104,16 @@ class MixesPlugin(Plugin):
# sort by trackhash order # sort by trackhash order
trackmatches = sorted(trackmatches, key=lambda x: trackhashes.index(x.weakhash)) trackmatches = sorted(trackmatches, key=lambda x: trackhashes.index(x.weakhash))
if len(trackmatches) < self.TRACK_MIX_LENGTH:
filler_tracks = self.fallback_create_artist_mix(
artist=tracks[0].artists[0],
similar_artists=results["artists"],
similar_albums=results["albums"],
omit_trackhashes={t.weakhash for t in trackmatches},
limit=self.TRACK_MIX_LENGTH - len(trackmatches),
)
trackmatches.extend(filler_tracks)
# try to balance the mix # try to balance the mix
trackmatches = balance_mix(trackmatches) trackmatches = balance_mix(trackmatches)
return trackmatches return trackmatches
@@ -104,9 +136,9 @@ class MixesPlugin(Plugin):
indexed = set() indexed = set()
today_start, today_end = get_date_range(duration="day") today_start, today_end = get_date_range(duration="day")
last_2_days_start, last_2_days_end = get_date_range(duration="day", units_ago=2) last_2_days_start = get_duration_ago("day", 2)
last_7_days_start, last_7_days_end = get_date_range(duration="week") last_7_days_start = get_duration_ago("week")
last_1_month_start, last_1_month_end = get_date_range(duration="month") last_1_month_start = get_duration_ago("month")
artists = { artists = {
"today": { "today": {
@@ -116,17 +148,17 @@ class MixesPlugin(Plugin):
}, },
"last_2_days": { "last_2_days": {
"max": 2, "max": 2,
"artists": get_artists_in_period(last_2_days_start, last_2_days_end), "artists": get_artists_in_period(last_2_days_start, time.time()),
"created": 0, "created": 0,
}, },
"last_7_days": { "last_7_days": {
"max": 3, "max": 3,
"artists": get_artists_in_period(last_7_days_start, last_7_days_end), "artists": get_artists_in_period(last_7_days_start, time.time()),
"created": 0, "created": 0,
}, },
"last_1_month": { "last_1_month": {
"max": 2, "max": 2,
"artists": get_artists_in_period(last_1_month_start, last_1_month_end), "artists": get_artists_in_period(last_1_month_start, time.time()),
"created": 0, "created": 0,
}, },
} }
@@ -219,7 +251,10 @@ class MixesPlugin(Plugin):
) )
def download_artist_image(self, artist: Artist): def download_artist_image(self, artist: Artist):
try:
res = requests.get(f"{self.server}/image?artist={artist.name}") res = requests.get(f"{self.server}/image?artist={artist.name}")
except requests.exceptions.ConnectionError:
return None
if res.status_code == 200: if res.status_code == 200:
# save to file # save to file
@@ -247,8 +282,9 @@ class MixesPlugin(Plugin):
def fallback_create_artist_mix( def fallback_create_artist_mix(
self, self,
artist: dict[str, str], artist: dict[str, str],
similar_albums: list[str],
similar_artists: list[str], similar_artists: list[str],
trackhashes: set[str], omit_trackhashes: set[str],
limit: int, limit: int,
): ):
""" """
@@ -264,16 +300,70 @@ class MixesPlugin(Plugin):
:param trackhashes: A set of trackhashes to omit from the new tracklist. :param trackhashes: A set of trackhashes to omit from the new tracklist.
:param limit: The maximum number of tracks to select. :param limit: The maximum number of tracks to select.
""" """
artists = similar_artists
if len(similar_artists) == 0: mixtracks = []
local_similar_artists = SimilarArtistTable.get_by_hash(artist["artisthash"]) albummatches = (
a
for a in AlbumStore.albummap.values()
if a.album.weakhash in similar_albums
)
if local_similar_artists: for match in albummatches:
artists = [a.artisthash for a in local_similar_artists.similar_artists] if len(mixtracks) >= limit:
print(f"Filled up to {limit} tracks with album tracks")
return mixtracks
if len(artists) == 0: albumtracks = [
return [] t
for t in TrackStore.get_tracks_by_trackhashes(match.trackhashes)
if t.weakhash not in omit_trackhashes
]
if len(albumtracks) == 0:
continue
sample = random.sample(albumtracks, k=1)
mixtracks.extend(sample)
print(
f"Supplement: album track {sample[0].title} from ALBUM: {match.album.og_title}"
)
artistmatches = (
a
for a in ArtistStore.artistmap.values()
if a.artist.artisthash in similar_artists
)
for match in artistmatches:
if len(mixtracks) >= limit:
print(f"Filled up to {limit} tracks with artist tracks")
return mixtracks
artisttracks = [
t
for t in TrackStore.get_tracks_by_trackhashes(match.trackhashes)
if t.weakhash not in omit_trackhashes
]
if len(artisttracks) == 0:
continue
sample = random.sample(artisttracks, k=1)
mixtracks.extend(sample)
print(
f"Supplement: track {sample[0].title} from ARTIST: {match.artist.name}"
)
return mixtracks
# if len(similar_artists) == 0:
# local_similar_artists = SimilarArtistTable.get_by_hash(artist["artisthash"])
# if local_similar_artists:
# artists = [a.artisthash for a in local_similar_artists.similar_artists]
# if len(artists) == 0:
# return []
# CHECKPOINT: I'M TIRED AF AND I NEED TO SLEEP # CHECKPOINT: I'M TIRED AF AND I NEED TO SLEEP
# The plan: # The plan:
+1
View File
@@ -38,6 +38,7 @@ def serialize_for_card(album: Album):
"extra", "extra",
"id", "id",
"lastplayed", "lastplayed",
"weakhash",
} }
return album_serializer(album, props_to_remove) return album_serializer(album, props_to_remove)
+40 -1
View File
@@ -71,6 +71,9 @@ def get_date_range(duration: str, units_ago: int = 0):
Returns a tuple of dates representing the start and end of a given duration. Returns a tuple of dates representing the start and end of a given duration.
""" """
date_range = None date_range = None
seconds_ago = 0
if duration != "alltime":
seconds_ago = ( seconds_ago = (
pendulum.now() - pendulum.now().subtract().start_of(duration) pendulum.now() - pendulum.now().subtract().start_of(duration)
).total_seconds() * units_ago ).total_seconds() * units_ago
@@ -83,7 +86,9 @@ def get_date_range(duration: str, units_ago: int = 0):
.subtract(seconds=seconds_ago) .subtract(seconds=seconds_ago)
.start_of(duration) .start_of(duration)
.timestamp(), .timestamp(),
pendulum.now().end_of(duration).timestamp(), pendulum.now()
# .end_of(duration)
.timestamp(),
) )
case "alltime": case "alltime":
date_range = (0, pendulum.now().timestamp()) date_range = (0, pendulum.now().timestamp())
@@ -93,6 +98,40 @@ def get_date_range(duration: str, units_ago: int = 0):
return (int(date_range[0]), int(date_range[1])) return (int(date_range[0]), int(date_range[1]))
def get_duration_ago(duration: str, units_ago: int = 1) -> int:
"""
Returns the start of the last duration.
"""
seconds_in_day = 24 * 60 * 60
now = pendulum.now()
match duration:
case "day":
return int(
now.subtract(seconds=seconds_in_day * units_ago).timestamp()
)
case "week":
return int(
now
.subtract(seconds=seconds_in_day * 7 * units_ago)
.timestamp()
)
case "month":
return int(
now
.subtract(seconds=seconds_in_day * 30 * units_ago)
.timestamp()
)
case "year":
return int(
now
.subtract(seconds=seconds_in_day * 365 * units_ago)
.timestamp()
)
case _:
raise ValueError(f"Invalid duration: {duration}")
def get_duration_in_seconds(duration: str) -> int: def get_duration_in_seconds(duration: str) -> int:
""" """
Returns the number of seconds in a given duration. Returns the number of seconds in a given duration.
+1 -1
View File
@@ -11,7 +11,7 @@ from app.store.tracks import TrackStore
from app.utils.dates import seconds_to_time_string from app.utils.dates import seconds_to_time_string
def get_artists_in_period(start_time: int, end_time: int): def get_artists_in_period(start_time: int | float, end_time: int | float):
scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time) scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time)
artists: Any = defaultdict(lambda: {"playcount": 0, "playduration": 0}) artists: Any = defaultdict(lambda: {"playcount": 0, "playduration": 0})
+1 -1
View File
@@ -102,7 +102,7 @@ whitelisted_routes = {
"/auth/refresh", "/auth/refresh",
"/docs", "/docs",
} }
blacklist_extensions = {".webp"}.union(getClientFilesExtensions()) blacklist_extensions = {".webp", ".jpg"}.union(getClientFilesExtensions())
def skipAuthAction(): def skipAuthAction():