migrate homepage items to homepage routine

+ add Mix db model
This commit is contained in:
cwilvx
2024-11-10 19:38:51 +03:00
parent 38d08f07bb
commit 498d0688b0
11 changed files with 292 additions and 442 deletions
+9 -8
View File
@@ -1,6 +1,6 @@
from flask_jwt_extended import current_user
from flask_openapi3 import Tag
from flask_openapi3 import APIBlueprint
from pydantic import BaseModel, Field
from app.api.apischemas import GenericLimitSchema
from app.lib.home.recentlyadded import get_recently_added_items
@@ -27,13 +27,14 @@ def get_recent_plays(query: GenericLimitSchema):
return {"items": get_recently_played(query.limit)}
class HomepageItem(BaseModel):
limit: int = Field(
default=9, description="The max number of items per group to return"
)
@api.get("/")
def homepage_items():
def homepage_items(query: HomepageItem):
return {
"artist_mixes": {
"title": "Artist mixes for you",
"description": "Based on artists you have been listening to",
"items": HomepageStore.get_artist_mixes(),
"extra": {},
},
"artist_mixes": HomepageStore.get_mixes("artist_mixes", limit=query.limit),
}
+9 -6
View File
@@ -1,4 +1,5 @@
from app.crons.cron import CronJob
from app.lib.recipes import ArtistMixes
from app.plugins.mixes import MixesPlugin
from app.store.homepage import HomepageStore
@@ -16,12 +17,14 @@ class Mixes(CronJob):
Creates the artist mixes
"""
print("⭐⭐⭐⭐ Mixes cron job running")
mixes = MixesPlugin()
ArtistMixes().run()
# mixes = MixesPlugin()
if not mixes.enabled:
return
# if not mixes.enabled:
# return
artist_mixes = mixes.create_artist_mixes()
if artist_mixes:
HomepageStore.set_artist_mixes(artist_mixes)
# artist_mixes = mixes.create_artist_mixes()
# if artist_mixes:
# HomepageStore.set_artist_mixes(artist_mixes)
-352
View File
@@ -8,92 +8,6 @@ from sqlalchemy.orm import Mapped, mapped_column
from typing import Any, Optional
# def create_all():
# """
# Create all the tables defined in this file.
# NOTE: We need this function because the MasterBase does not collect
# the tables defined here (as they are grand-children of the MasterBase)
# """
# Base.metadata.create_all(DbEngine.engine)
# class Base(MasterBase, DeclarativeBase):
# pass
# @classmethod
# def get_all_hashes(cls, create_date: int | None = None):
# with DbEngine.manager() as conn:
# if create_date:
# if cls.__tablename__ == "track":
# stmt = select(TrackTable.trackhash).where(
# cls.last_mod < create_date
# )
# elif cls.__tablename__ == "album":
# stmt = select(AlbumTable.albumhash).where(
# cls.created_date < create_date
# )
# elif cls.__tablename__ == "artist":
# stmt = select(ArtistTable.artisthash).where(
# cls.created_date < create_date
# )
# else:
# if cls.__tablename__ == "track":
# stmt = select(TrackTable.trackhash)
# elif cls.__tablename__ == "album":
# stmt = select(AlbumTable.albumhash)
# elif cls.__tablename__ == "artist":
# stmt = select(ArtistTable.artisthash)
# result = conn.execute(stmt)
# return {row[0] for row in result.fetchall()}
# @classmethod
# def set_is_favorite(cls, hash: str, is_favorite: bool):
# """
# Set the 'is_favorite' flag for a specific hash.
# Args:
# hash (str): The hash value.
# is_favorite (bool): The value of the 'is_favorite' flag.
# """
# with DbEngine.manager(commit=True) as conn:
# if cls.__tablename__ == "track":
# stmt = (
# update(cls)
# .where(TrackTable.trackhash == hash)
# .values(is_favorite=is_favorite)
# )
# elif cls.__tablename__ == "album":
# stmt = (
# update(cls)
# .where(AlbumTable.albumhash == hash)
# .values(is_favorite=is_favorite)
# )
# elif cls.__tablename__ == "artist":
# stmt = (
# update(cls)
# .where(ArtistTable.artisthash == hash)
# .values(is_favorite=is_favorite)
# )
# conn.execute(stmt)
# @classmethod
# def increment_scrobblecount(
# cls, table: Any, field: Any, hash: str, duration: int, timestamp: int
# ):
# cls.execute(
# update(table)
# .where(field == hash)
# .values(
# playcount=table.playcount + 1,
# playduration=table.playduration + duration,
# lastplayed=timestamp,
# ),
# commit=True,
# )
class TrackTable(Base):
__tablename__ = "track"
@@ -101,7 +15,6 @@ class TrackTable(Base):
album: Mapped[str] = mapped_column(String())
albumartists: Mapped[str] = mapped_column(String())
albumhash: Mapped[str] = mapped_column(String(), index=True)
# artisthashes: Mapped[list[str]] = mapped_column(JSON(), index=True)
artists: Mapped[str] = mapped_column(String())
bitrate: Mapped[int] = mapped_column(Integer())
copyright: Mapped[Optional[str]] = mapped_column(String())
@@ -110,11 +23,8 @@ class TrackTable(Base):
duration: Mapped[int] = mapped_column(Integer())
filepath: Mapped[str] = mapped_column(String(), index=True, unique=True)
folder: Mapped[str] = mapped_column(String(), index=True)
# genrehashes: Mapped[list[str]] = mapped_column(JSON(), index=True)
genres: Mapped[Optional[str]] = mapped_column(String())
last_mod: Mapped[float] = mapped_column(Integer())
# og_album: Mapped[str] = mapped_column(String())
# og_title: Mapped[str] = mapped_column(String())
title: Mapped[str] = mapped_column(String())
track: Mapped[int] = mapped_column(Integer())
trackhash: Mapped[str] = mapped_column(String(), index=True)
@@ -141,45 +51,6 @@ class TrackTable(Base):
)
return tracks_to_dataclasses(result.fetchall())
# @classmethod
# def get_tracks_by_albumhash(cls, albumhash: str):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(TrackTable).where(TrackTable.albumhash == albumhash)
# )
# tracks = tracks_to_dataclasses(result.fetchall())
# return remove_duplicates(tracks, is_album_tracks=True)
# @classmethod
# def get_track_by_trackhash(cls, hash: str, filepath: str = ""):
# with DbEngine.manager() as conn:
# if filepath:
# result = conn.execute(
# select(TrackTable)
# .where(
# (TrackTable.trackhash == hash)
# & (TrackTable.filepath == filepath),
# )
# .order_by(TrackTable.bitrate.desc())
# )
# else:
# result = conn.execute(
# select(TrackTable).where(TrackTable.trackhash == hash)
# )
# track = result.fetchone()
# if track:
# return track_to_dataclass(track)
# @classmethod
# def get_tracks_by_artisthash(cls, artisthash: str):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(TrackTable).where(TrackTable.artists.contains(artisthash))
# )
# return tracks_to_dataclasses(result.fetchall())
@classmethod
def get_tracks_in_path(cls, path: str):
with DbEngine.manager() as conn:
@@ -190,230 +61,7 @@ class TrackTable(Base):
)
return tracks_to_dataclasses(result.fetchall())
# @classmethod
# def get_tracks_by_trackhashes(cls, hashes: Iterable[str], limit: int | None = None):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(TrackTable)
# .where(TrackTable.trackhash.in_(hashes))
# .group_by(TrackTable.trackhash)
# .limit(limit)
# )
# tracks = tracks_to_dataclasses(result.fetchall())
# # order the tracks in the same order as the hashes
# if type(hashes) == list:
# return sorted(tracks, key=lambda x: hashes.index(x.trackhash))
# return tracks
# @classmethod
# def get_recently_added(cls, start: int, limit: int):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(TrackTable)
# .order_by(TrackTable.last_mod.desc())
# .offset(start)
# .limit(limit)
# )
# return tracks_to_dataclasses(result.fetchall())
@classmethod
# def get_recently_played(cls, limit: int):
# result = cls.execute(
# select(cls)
# .group_by(cls.trackhash)
# .order_by(cls.lastplayed.desc())
# .limit(limit)
# )
# return tracks_to_dataclasses(result.fetchall())
@classmethod
def remove_tracks_by_filepaths(cls, filepaths: set[str]):
with DbEngine.manager(commit=True) as conn:
conn.execute(delete(TrackTable).where(TrackTable.filepath.in_(filepaths)))
# @classmethod
# def increment_playcount(cls, trackhash: str, duration: int, timestamp: int):
# cls.increment_scrobblecount(
# TrackTable, TrackTable.trackhash, trackhash, duration, timestamp
# )
# @classmethod
# def update_artist_separators(cls, separators: set[str]):
# tracks = cls.get_all()
# with DbEngine.manager(commit=True) as conn:
# for track in tracks:
# track.split_artists(separators)
# conn.execute(
# update(cls)
# .where(cls.trackhash == track.trackhash)
# .values(artists=track.artists, artisthashes=track.artisthashes)
# )
# class AlbumTable(Base):
# __tablename__ = "album"
# id: Mapped[int] = mapped_column(primary_key=True)
# albumartists: Mapped[list[dict[str, str]]] = mapped_column(JSON(), index=True)
# artisthashes: Mapped[list[str]] = mapped_column(JSON(), index=True)
# albumhash: Mapped[str] = mapped_column(String(), unique=True, index=True)
# base_title: Mapped[str] = mapped_column(String())
# color: Mapped[Optional[str]] = mapped_column(String())
# created_date: Mapped[int] = mapped_column(Integer())
# date: Mapped[int] = mapped_column(Integer())
# duration: Mapped[int] = mapped_column(Integer())
# genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True)
# genres: Mapped[str] = mapped_column(JSON())
# og_title: Mapped[str] = mapped_column(String())
# title: Mapped[str] = mapped_column(String())
# trackcount: Mapped[int] = mapped_column(Integer())
# lastplayed: Mapped[int] = mapped_column(Integer(), default=0)
# playcount: Mapped[int] = mapped_column(Integer(), default=0)
# playduration: Mapped[int] = mapped_column(Integer(), default=0)
# extra: Mapped[Optional[dict[str, Any]]] = mapped_column(
# JSON(), default_factory=dict
# )
# @classmethod
# def get_all(cls):
# with DbEngine.manager() as conn:
# result = conn.execute(select(AlbumTable))
# all = result.fetchall()
# return albums_to_dataclasses(all)
# @classmethod
# def get_album_by_albumhash(cls, hash: str):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(AlbumTable).where(AlbumTable.albumhash == hash)
# )
# album = result.fetchone()
# if album:
# return album_to_dataclass(album)
# @classmethod
# def get_albums_by_albumhashes(cls, hashes: Iterable[str], limit: int | None = None):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(AlbumTable).where(AlbumTable.albumhash.in_(hashes)).limit(limit)
# )
# albums = albums_to_dataclasses(result.fetchall())
# # order the albums in the same order as the hashes
# if type(hashes) == list:
# return sorted(albums, key=lambda x: hashes.index(x.albumhash))
# return albums
# @classmethod
# def get_albums_by_artisthashes(cls, artisthashes: list[str]):
# with DbEngine.manager() as conn:
# albums: dict[str, list[AlbumModel]] = {}
# for artist in artisthashes:
# result = conn.execute(
# select(AlbumTable).where(AlbumTable.artisthashes.contains(artist))
# )
# albums[artist] = albums_to_dataclasses(result.fetchall())
# return albums
# @classmethod
# def get_albums_by_base_title(cls, base_title: str):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(AlbumTable).where(AlbumTable.base_title == base_title)
# )
# return albums_to_dataclasses(result.fetchall())
# @classmethod
# def get_albums_by_artisthash(cls, artisthash: str):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(AlbumTable).where(AlbumTable.artisthashes.contains(artisthash))
# )
# return albums_to_dataclasses(result.all())
# @classmethod
# def increment_playcount(cls, albumhash: str, duration: int, timestamp: int):
# return cls.increment_scrobblecount(
# AlbumTable, AlbumTable.albumhash, albumhash, duration, timestamp
# )
# class ArtistTable(Base):
# __tablename__ = "artist"
# id: Mapped[int] = mapped_column(primary_key=True)
# albumcount: Mapped[int] = mapped_column(Integer())
# artisthash: Mapped[str] = mapped_column(String(), unique=True, index=True)
# created_date: Mapped[int] = mapped_column(Integer())
# date: Mapped[int] = mapped_column(Integer())
# duration: Mapped[int] = mapped_column(Integer())
# genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True)
# genres: Mapped[str] = mapped_column(JSON())
# name: Mapped[str] = mapped_column(String(), index=True)
# trackcount: Mapped[int] = mapped_column(Integer())
# lastplayed: Mapped[int] = mapped_column(Integer(), default=0)
# playcount: Mapped[int] = mapped_column(Integer(), default=0)
# playduration: Mapped[int] = mapped_column(Integer(), default=0)
# extra: Mapped[Optional[dict[str, Any]]] = mapped_column(
# JSON(), default_factory=dict
# )
# @classmethod
# def get_all(cls):
# with DbEngine.manager() as conn:
# result = conn.execute(select(cls))
# all = result.fetchall()
# return artists_to_dataclasses(all)
# @classmethod
# def get_artist_by_hash(cls, artisthash: str):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(ArtistTable).where(ArtistTable.artisthash == artisthash)
# )
# return artist_to_dataclass(result.fetchone())
# @classmethod
# def get_artisthashes_not_in(cls, artisthashes: list[str]):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(ArtistTable.artisthash, ArtistTable.name).where(
# ~ArtistTable.artisthash.in_(artisthashes)
# )
# )
# return [{"artisthash": row[0], "name": row[1]} for row in result.fetchall()]
# @classmethod
# def get_artists_by_artisthashes(
# cls, hashes: Iterable[str], limit: int | None = None
# ):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(ArtistTable)
# .where(ArtistTable.artisthash.in_(hashes))
# .limit(limit)
# )
# return artists_to_dataclasses(result.fetchall())
# @classmethod
# def increment_playcount(
# cls, artisthashes: list[str], duration: int, timestamp: int
# ):
# cls.execute(
# update(cls)
# .where(ArtistTable.artisthash.in_(artisthashes))
# .values(
# playcount=ArtistTable.playcount + 1,
# playduration=ArtistTable.playduration + duration,
# lastplayed=timestamp,
# ),
# commit=True,
# )
+43 -5
View File
@@ -1,3 +1,4 @@
from dataclasses import asdict
import datetime
from typing import Any, Literal
from sqlalchemy import (
@@ -31,6 +32,7 @@ from app.db.utils import (
)
from app.db import Base
from app.models.mix import Mix
from app.utils.auth import get_current_userid, hash_password
@@ -223,9 +225,7 @@ class FavoritesTable(Base):
# .select_from(join(table, cls, field == cls.hash))
.where(and_(cls.type == type, cls.userid == get_current_userid()))
.order_by(cls.timestamp.desc())
.offset(
start
)
.offset(start)
# INFO: If start is 0, fetch all so we can get the total count
.limit(limit if start != 0 else None)
)
@@ -305,10 +305,15 @@ class ScrobbleTable(Base):
return tracklog_to_dataclasses(result.fetchall())
@classmethod
def get_all_in_period(cls, start_time: int, end_time: int):
def get_all_in_period(cls, start_time: int, end_time: int, userid: int | None):
# UserId will be None if function is called from the API
# In that case, we use the request userid
if userid is None:
userid = get_current_userid()
result = cls.execute(
select(cls)
.where(cls.userid == get_current_userid())
.where(cls.userid == userid)
.where(and_(cls.timestamp >= start_time, cls.timestamp <= end_time))
.order_by(cls.timestamp.desc())
)
@@ -458,3 +463,36 @@ class LibDataTable(Base):
select(cls.itemhash, cls.color).where(cls.itemtype == type)
)
return [{"itemhash": r[0], "color": r[1]} for r in result.fetchall()]
class MixTable(Base):
__tablename__ = "mix"
id: Mapped[int] = mapped_column(primary_key=True)
mixid: Mapped[str] = mapped_column(String(), index=True)
title: Mapped[str] = mapped_column(String())
description: Mapped[str] = mapped_column(String())
timestamp: Mapped[int] = mapped_column(Integer())
sourcehash: Mapped[str] = mapped_column(String(), unique=True, index=True)
tracks: Mapped[list[str]] = mapped_column(JSON(), default_factory=list)
extra: Mapped[dict[str, Any]] = mapped_column(
JSON(), nullable=True, default_factory=dict
)
@classmethod
def get_all(cls):
result = cls.execute(select(cls))
return Mix.mixes_to_dataclasses(result.fetchall())
@classmethod
def get_by_sourcehash(cls, sourcehash: str):
result = cls.execute(select(cls).where(cls.sourcehash == sourcehash))
return Mix.mix_to_dataclass(result.fetchone())
@classmethod
def insert_one(cls, mix: Mix):
mixdict = asdict(mix)
mixdict["mixid"] = mix.id
del mixdict["id"]
return cls.execute(insert(cls).values(mixdict), commit=True)
-1
View File
@@ -10,7 +10,6 @@ from typing import Any
from PIL import Image, ImageSequence
from app import settings
from app.db.libdata import TrackTable
from app.models.track import Track
from app.store.albums import AlbumStore
from app.store.tracks import TrackStore
+66
View File
@@ -0,0 +1,66 @@
"""
Recipes are a way to create mixes.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from app.db.userdata import UserTable
from app.models.mix import Mix
from app.plugins.mixes import MixesPlugin
from app.store.homepage import HomepageStore
class HomepageRoutine(ABC):
"""
A routine creates a row of homepage items.
"""
title: str
description: str
items: List[Mix]
extra: Dict[str, Any]
@property
@abstractmethod
def is_valid(self) -> bool: ...
def __init__(self) -> None:
if not self.is_valid:
return
self.items = self.run()
@abstractmethod
def run(self) -> List[Mix]:
"""
Creates the homepage items and saves them to the
homepage store if self.is_valid is true.
"""
...
class ArtistMixes(HomepageRoutine):
items: List[Mix] = []
extra: Dict[str, Any] = {}
store_key = "artist_mixes"
@property
def is_valid(self):
return MixesPlugin().enabled
def run(self):
users = UserTable.get_all()
for user in users:
mix = MixesPlugin()
mixes = mix.create_artist_mixes(user.id)
if not mixes:
continue
HomepageStore.set_mixes(mixes, mixkey=self.store_key, userid=user.id)
def __init__(self) -> None:
super().__init__()
+20 -2
View File
@@ -1,5 +1,6 @@
from dataclasses import asdict, dataclass, field
import time
from dataclasses import asdict, dataclass, field
from typing import Any
from app.lib.playlistlib import get_first_4_images
from app.serializers.track import serialize_tracks
@@ -13,12 +14,17 @@ class Mix:
title: str
description: str
tracks: list[str]
sourcehash: str
"""
A hash of the tracks used to generate the mix.
"""
timestamp: int = field(default_factory=lambda: int(time.time()))
extra: dict = field(default_factory=dict)
saved: bool = False
def to_full_dict(self):
# Limit track mix to 30 tracks
tracks = TrackStore.get_tracks_by_trackhashes(self.tracks)
serialized_tracks = serialize_tracks(tracks)
@@ -34,7 +40,19 @@ class Mix:
return _dict
def to_dict(self):
item = self.to_full_dict()
item = asdict(self)
del item["tracks"]
return item
@classmethod
def mix_to_dataclass(cls, entry: Any):
entry_dict = entry._asdict()
entry_dict["id"] = entry_dict["mixid"]
del entry_dict["mixid"]
return Mix(**entry_dict)
@classmethod
def mixes_to_dataclasses(cls, entries: Any):
return [cls.mix_to_dataclass(entry) for entry in entries]
+45 -38
View File
@@ -1,5 +1,6 @@
import datetime
import json
from pprint import pprint
import random
import string
import time
@@ -18,6 +19,7 @@ from app.store.albums import AlbumStore
from app.store.artists import ArtistStore
from app.store.tracks import TrackStore
from app.utils.dates import get_date_range, get_duration_ago
from app.utils.hashing import create_hash
from app.utils.mixes import balance_mix
from app.utils.remove_duplicates import remove_duplicates
from app.utils.stats import get_artists_in_period
@@ -25,7 +27,7 @@ from app.utils.stats import get_artists_in_period
class MixesPlugin(Plugin):
MAX_TRACKS_TO_FETCH = 5
TRACK_MIX_LENGTH = 50
TRACK_MIX_LENGTH = 30
MIN_TRACK_MIX_LENGTH = 15
MIN_DAY_LISTEN_DURATION = 3 * 60 # 3 minutes
@@ -60,8 +62,6 @@ class MixesPlugin(Plugin):
:param with_help: Whether to include the help flag in the query.
The flag tells the server to find more data using other tracks from the same album.
"""
queries = [
{
@@ -84,8 +84,6 @@ class MixesPlugin(Plugin):
print("Failed to decode JSON response from recommendation server")
return []
# artisthashes = results["artists"]
# albumhashes = results["albums"]
trackhashes: list[str] = results["tracks"]
trackmatches = TrackStore.get_flat_list()
@@ -104,13 +102,18 @@ class MixesPlugin(Plugin):
# sort by trackhash order
trackmatches = sorted(trackmatches, key=lambda x: trackhashes.index(x.weakhash))
if len(trackmatches) < self.TRACK_MIX_LENGTH:
# if the mix is short, try to fill it up with tracks
# from album and artist data from the cloud!
# Create as many filler tracks as possible
# Then the mix length will be controlled in the Mix model
# if len(trackmatches) < self.TRACK_MIX_LENGTH:
if True:
filler_tracks = self.fallback_create_artist_mix(
artist=tracks[0].artists[0],
similar_artists=results["artists"],
similar_albums=results["albums"],
omit_trackhashes={t.weakhash for t in trackmatches},
limit=self.TRACK_MIX_LENGTH - len(trackmatches),
# limit=self.TRACK_MIX_LENGTH - len(trackmatches),
)
trackmatches.extend(filler_tracks)
@@ -123,15 +126,26 @@ class MixesPlugin(Plugin):
"""
Given an artisthash, creates an artist mix using the
self.MAX_TRACKS_TO_FETCH most listened to tracks.
Returns a tuple of the mix and the sourcehash.
"""
artist = ArtistStore.artistmap[artisthash]
tracks = TrackStore.get_tracks_by_trackhashes(artist.trackhashes)
tracks = sorted(tracks, key=lambda x: x.playduration, reverse=True)
return self.get_track_mix(tracks[: self.MAX_TRACKS_TO_FETCH])
sourcetracks = tracks[: self.MAX_TRACKS_TO_FETCH]
sourcehash = create_hash(*[t.trackhash for t in sourcetracks])
# TODO: Check if we already have this mix in the
# database and return that instead
return (self.get_track_mix(tracks[: self.MAX_TRACKS_TO_FETCH]), sourcehash)
@plugin_method
def create_artist_mixes(self):
def create_artist_mixes(self, userid: int):
"""
Creates artist mixes for a given userid.
"""
mixes: list[Mix] = []
indexed = set()
@@ -143,22 +157,28 @@ class MixesPlugin(Plugin):
artists = {
"today": {
"max": 3,
"artists": get_artists_in_period(today_start, today_end),
"artists": get_artists_in_period(today_start, today_end, userid),
"created": 0,
},
"last_2_days": {
"max": 2,
"artists": get_artists_in_period(last_2_days_start, time.time()),
"artists": get_artists_in_period(
last_2_days_start, time.time(), userid
),
"created": 0,
},
"last_7_days": {
"max": 3,
"artists": get_artists_in_period(last_7_days_start, time.time()),
"artists": get_artists_in_period(
last_7_days_start, time.time(), userid
),
"created": 0,
},
"last_1_month": {
"max": 2,
"artists": get_artists_in_period(last_1_month_start, time.time()),
"artists": get_artists_in_period(
last_1_month_start, time.time(), userid
),
"created": 0,
},
}
@@ -224,7 +244,7 @@ class MixesPlugin(Plugin):
if not _artist:
return None
mix_tracks = self.get_artist_mix(artist["artisthash"])
mix_tracks, sourcehash = self.get_artist_mix(artist["artisthash"])
if len(mix_tracks) < self.MIN_TRACK_MIX_LENGTH:
return None
@@ -243,11 +263,13 @@ class MixesPlugin(Plugin):
title=artist["artist"] + " Radio",
description=self.get_mix_description(mix_tracks, artist["artisthash"]),
tracks=[t.trackhash for t in mix_tracks],
sourcehash=sourcehash,
extra={
"type": "artist",
"artisthash": artist["artisthash"],
"image": mix_image,
},
timestamp=int(time.time()),
)
def download_artist_image(self, artist: Artist):
@@ -281,23 +303,23 @@ class MixesPlugin(Plugin):
def fallback_create_artist_mix(
self,
artist: dict[str, str],
# artist: dict[str, str],
similar_albums: list[str],
similar_artists: list[str],
omit_trackhashes: set[str],
limit: int,
limit: int = 99,
):
"""
Creates an artist mix by selecting random tracks from similar artists.
Creates an artist mix by selecting random tracks from similar albums and artists.
This is used when:
- The Swing Music recommendation server is down.
- The artist has less than self.MIN_TRACK_MIX_LENGTH tracks from the cloud mix.
- When we need to dilute the mix to balance the artist distribution.
:param artist: The artist to create a mix for.
:param similar_artists: A list of similar artists to select tracks from. If not provided, we try reading from the local database. When we exhaust the passed list, we also try reading from the local database.
:param trackhashes: A set of trackhashes to omit from the new tracklist.
:param similar_albums: A list of similar album weakhashes to select tracks from.
:param similar_artists: A list of similar artist hashes to select tracks from.
:param omit_trackhashes: A set of trackhashes to omit from the new tracklist.
:param limit: The maximum number of tracks to select.
"""
@@ -356,20 +378,5 @@ class MixesPlugin(Plugin):
return mixtracks
# if len(similar_artists) == 0:
# local_similar_artists = SimilarArtistTable.get_by_hash(artist["artisthash"])
# if local_similar_artists:
# artists = [a.artisthash for a in local_similar_artists.similar_artists]
# if len(artists) == 0:
# return []
# CHECKPOINT: I'M TIRED AF AND I NEED TO SLEEP
# The plan:
# Figure out which artists we should skip for the new tracklist
# these would be artists with a large number of tracks in the mix already
# Since the artisthashes are ordered by similarity score, we iterate from the start
# and go forward collecting tracks that aren't in the mix yet.
#
def get_mix_from_lastfm_data(self, artisthash: str, limit: int):
pass
-6
View File
@@ -1,20 +1,14 @@
from itertools import groupby
import json
from pprint import pprint
import random
from typing import Iterable
from app.lib.tagger import create_albums
from app.models import Album, Track
from app.store.artists import ArtistStore
from app.utils import flatten
from app.utils.auth import get_current_userid
from app.utils.customlist import CustomList
from app.utils.remove_duplicates import remove_duplicates
from ..utils.hashing import create_hash
from .tracks import TrackStore
from app.utils.progressbar import tqdm
ALBUM_LOAD_KEY = ""
+88 -16
View File
@@ -1,30 +1,102 @@
from abc import ABC
from dataclasses import dataclass
from typing import Any
from app.models.mix import Mix
from app.store.tracks import TrackStore
from app.utils.auth import get_current_userid
@dataclass
class HomepageEntry(ABC):
"""
Base class for all homepage entries.
items is a dict of userid to a dict of stuff.
"""
title: str
description: str
items: dict[int, dict[str, Any]]
def __init__(self, title: str, description: str):
self.title = title
self.description = description
def get_items(self, userid: int):
"""
Return usable items for the homepage.
"""
...
@dataclass
class MixHomepageEntry(HomepageEntry):
"""
A homepage entry for mixes.
self.items is a dict of userid to a dict of mixid to mix.
"""
items: dict[int, dict[str, Mix]]
def __init__(self, title: str, description: str):
super().__init__(title, description)
self.items = {}
def get_items(self, userid: int, limit: int | None = None):
items = []
for mix in self.items.get(userid, {}).values():
if limit and len(items) >= limit:
break
items.append(
{
"type": "mix",
"item": mix.to_dict(),
}
)
return {
"title": self.title,
"description": self.description,
"items": items,
}
class HomepageStore:
"""
Stores the homepage items.
"""
entries = {
"artist_mixes": {},
"artist_mixes": MixHomepageEntry(
title="Artist mixes for you",
description="Based on artists you have been listening to",
),
}
@classmethod
def set_artist_mixes(cls, mixes: list[Mix], userid: int = 1):
def set_mixes(cls, mixes: list[Mix], mixkey: str, userid: int | None = None):
idmap = {mix.id[1:]: mix for mix in mixes}
cls.entries["artist_mixes"][userid] = idmap
cls.entries[mixkey].items[userid or get_current_userid()] = idmap
@classmethod
def get_artist_mixes(cls):
return [
{
"type": "mix",
"item": mix.to_dict(),
}
for mix in cls.entries["artist_mixes"]
.get(get_current_userid(), {})
.values()
]
def get_mixes(cls, mixkey: str, limit: int | None = 9):
return cls.entries[mixkey].get_items(get_current_userid(), limit)
@classmethod
def get_mix(cls, mixtype: str, mixid: str):
return cls.entries[mixtype].get(get_current_userid(), {}).get(mixid).to_full_dict()
def get_mix(cls, mixkey: str, mixid: str):
mix = cls.entries[mixkey].items.get(get_current_userid(), {}).get(mixid)
return mix.to_full_dict() if mix else None
@classmethod
def get_mix_by_sourcehash(cls, sourcehash: str):
return next(
(
mix
for mix in cls.entries["artist_mixes"]
.items.get(get_current_userid(), {})
.values()
if mix.sourcehash == sourcehash
),
None,
)
+12 -8
View File
@@ -11,8 +11,10 @@ from app.store.tracks import TrackStore
from app.utils.dates import seconds_to_time_string
def get_artists_in_period(start_time: int | float, end_time: int | float):
scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time)
def get_artists_in_period(
start_time: int | float, end_time: int | float, userid: int | None = None
):
scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time, userid)
artists: Any = defaultdict(lambda: {"playcount": 0, "playduration": 0})
for scrobble in scrobbles:
@@ -33,8 +35,8 @@ def get_artists_in_period(start_time: int | float, end_time: int | float):
return sorted(artists, key=lambda x: x["playduration"], reverse=True)
def get_albums_in_period(start_time: int, end_time: int):
scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time)
def get_albums_in_period(start_time: int, end_time: int, userid: int | None = None):
scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time, userid)
albums: dict[str, Album] = {}
for scrobble in scrobbles:
@@ -60,8 +62,8 @@ def get_albums_in_period(start_time: int, end_time: int):
return list(albums.values())
def get_tracks_in_period(start_time: int, end_time: int):
scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time)
def get_tracks_in_period(start_time: int, end_time: int, userid: int | None = None):
scrobbles = ScrobbleTable.get_all_in_period(start_time, end_time, userid)
tracks: dict[str, Track] = {}
duration = 0
@@ -160,12 +162,14 @@ def calculate_scrobble_trend(current_scrobbles: int, previous_scrobbles: int) ->
)
def calculate_new_artists(current_artists: List[dict[str, Any]], timestamp: int):
def calculate_new_artists(
current_artists: List[dict[str, Any]], timestamp: int, userid: int | None = None
):
"""
Calculate the number of new artists based on the current and all previous scrobbles.
"""
current_artists_set = set(artist["artisthash"] for artist in current_artists)
all_records = ScrobbleTable.get_all_in_period(0, timestamp)
all_records = ScrobbleTable.get_all_in_period(0, timestamp, userid)
trackhashes = set(record.trackhash for record in all_records)
previous_artists_set = set()