implement playlist store

This commit is contained in:
cwilvx
2024-07-15 20:11:18 +03:00
parent 88a72763df
commit c8c21dc01a
5 changed files with 127 additions and 81 deletions
+3
View File
@@ -61,3 +61,6 @@
- Paginate the following endpoints:
1. Folder tracks
2. Playlist tracks
- When you update a playlist, update the store as well!
+71 -80
View File
@@ -12,6 +12,7 @@ from flask_openapi3 import Tag
from flask_openapi3 import APIBlueprint, FileStorage
from app import models
from app.api.apischemas import GenericLimitSchema
from app.db.libdata import TrackTable
from app.db.userdata import PlaylistTable
from app.lib import playlistlib
@@ -20,6 +21,9 @@ from app.lib.home.recentlyadded import get_recently_added_playlist
from app.lib.home.recentlyplayed import get_recently_played_playlist
from app.models.playlist import Playlist
from app.serializers.playlist import serialize_for_card
from app.serializers.track import serialize_tracks
from app.store.playlists import PlaylistStore
from app.store.tracks import TrackStore
from app.utils.dates import create_new_date, date_string_to_time_passed
from app.utils.remove_duplicates import remove_duplicates
from app.settings import Paths
@@ -28,35 +32,6 @@ tag = Tag(name="Playlists", description="Get and manage playlists")
api = APIBlueprint("playlists", __name__, url_prefix="/playlists", abp_tags=[tag])
class SendAllPlaylistsQuery(BaseModel):
no_images: bool = Field(False, description="Whether to include images")
@api.get("")
def send_all_playlists(query: SendAllPlaylistsQuery):
"""
Gets all the playlists.
"""
playlists = PlaylistTable.get_all()
playlists = list(playlists)
for playlist in playlists:
if not query.no_images:
playlist.images = playlistlib.get_first_4_images(
trackhashes=playlist.trackhashes
)
playlist.images = [img["image"] for img in playlist.images]
playlist.clear_lists()
playlists.sort(
key=lambda p: datetime.strptime(p.last_updated, "%Y-%m-%d %H:%M:%S"),
reverse=True,
)
return {"data": playlists}
def insert_playlist(name: str, image: str = None):
playlist = {
"image": image,
@@ -79,31 +54,6 @@ def insert_playlist(name: str, image: str = None):
return None
class CreatePlaylistBody(BaseModel):
name: str = Field(..., description="The name of the playlist")
@api.post("/new")
def create_playlist(body: CreatePlaylistBody):
"""
New playlist
Creates a new playlist. Accepts POST method with a JSON body.
"""
# existing_playlist_count = PL.count_playlist_by_name(body.name)
exists = PlaylistTable.check_exists_by_name(body.name)
if exists:
return {"error": "Playlist already exists"}, 409
playlist = insert_playlist(body.name)
if playlist is None:
return {"error": "Playlist could not be created"}, 500
return {"playlist": playlist}, 201
def get_path_trackhashes(path: str):
"""
Returns a list of trackhashes in a folder.
@@ -130,9 +80,63 @@ def get_artist_trackhashes(artisthash: str):
return [t.trackhash for t in tracks]
def format_custom_playlist(playlist: models.Playlist, tracks: list[models.Track]):
duration = sum(t.duration for t in tracks)
playlist.duration = duration
return {
"info": serialize_for_card(playlist),
"tracks": tracks,
}
class SendAllPlaylistsQuery(BaseModel):
no_images: bool = Field(False, description="Whether to include images")
@api.get("")
def send_all_playlists(query: SendAllPlaylistsQuery):
"""
Gets all the playlists.
"""
playlists = PlaylistStore.get_flat_list()
playlists.sort(
key=lambda p: datetime.strptime(p.last_updated, "%Y-%m-%d %H:%M:%S"),
reverse=True,
)
return {"data": playlists}
class CreatePlaylistBody(BaseModel):
name: str = Field(..., description="The name of the playlist")
@api.post("/new")
def create_playlist(body: CreatePlaylistBody):
"""
New playlist
Creates a new playlist. Accepts POST method with a JSON body.
"""
exists = PlaylistTable.check_exists_by_name(body.name)
if exists:
return {"error": "Playlist already exists"}, 409
playlist = insert_playlist(body.name)
if playlist is None:
return {"error": "Playlist could not be created"}, 500
PlaylistStore.add_playlist(playlist)
return {"playlist": playlist}, 201
class PlaylistIDPath(BaseModel):
# INFO: playlistid string examples: "recentlyadded"
playlistid: int | str = Field(..., description="The ID of the playlist")
playlistid: str = Field(..., description="The ID of the playlist")
class AddItemToPlaylistBody(BaseModel):
@@ -170,19 +174,9 @@ def add_item_to_playlist(path: PlaylistIDPath, body: AddItemToPlaylistBody):
return {"msg": "Done"}, 200
class GetPlaylistQuery(BaseModel):
class GetPlaylistQuery(GenericLimitSchema):
no_tracks: bool = Field(False, description="Whether to include tracks")
def format_custom_playlist(playlist: models.Playlist, tracks: list[models.Track]):
duration = sum(t.duration for t in tracks)
playlist.duration = duration
return {
"info": serialize_for_card(playlist),
"tracks": tracks,
}
start: int = Field(0, description="The start index of the tracks")
@api.get("/<playlistid>")
@@ -206,24 +200,22 @@ def get_playlist(path: PlaylistIDPath, query: GetPlaylistQuery):
playlist, tracks = handler()
return format_custom_playlist(playlist, tracks)
playlist = PlaylistTable.get_by_id(playlistid)
entry = PlaylistStore.playlistmap.get(playlistid)
if playlist is None:
if entry is None:
return {"msg": "Playlist not found"}, 404
tracks = TrackTable.get_tracks_by_trackhashes(playlist.trackhashes)
playlist = entry.playlist
tracks = PlaylistStore.get_playlist_tracks(playlistid, query.start, query.limit)
duration = sum(t.duration for t in tracks)
playlist.last_updated = date_string_to_time_passed(playlist.last_updated)
playlist._last_updated = date_string_to_time_passed(playlist.last_updated)
playlist.duration = duration
if not playlist.has_image:
playlist.images = playlistlib.get_first_4_images(tracks)
playlist.clear_lists()
return {"info": playlist, "tracks": tracks if not no_tracks else []}
return {
"info": playlist,
"tracks": serialize_tracks(tracks) if not no_tracks else [],
}
class UpdatePlaylistForm(BaseModel):
@@ -340,7 +332,7 @@ def remove_playlist(path: PlaylistIDPath):
Delete playlist
"""
PlaylistTable.remove_one(path.playlistid)
PlaylistStore.playlistmap.pop(path.playlistid, None)
return {"msg": "Done"}, 200
@@ -426,7 +418,6 @@ def save_item_as_playlist(body: SavePlaylistAsItemBody):
img, str(playlist.id), "image/webp", filename=filename
)
# PL.add_tracks_to_playlist(playlist.id, trackhashes)
PlaylistTable.append_to_playlist(playlist.id, trackhashes)
playlist.count = len(trackhashes)
+1 -1
View File
@@ -1,5 +1,4 @@
import dataclasses
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@@ -20,6 +19,7 @@ class Playlist:
trackhashes: list[str] = dataclasses.field(default_factory=list)
extra: dict[str, Any] = dataclasses.field(default_factory=dict)
_last_updated: str = ""
userid: int | None = None
thumb: str = ""
count: int = 0
+2
View File
@@ -8,6 +8,7 @@ from app.setup.sqlite import run_migrations, setup_sqlite
from app.store.albums import AlbumStore
from app.store.artists import ArtistStore
from app.store.folder import FolderStore
from app.store.playlists import PlaylistStore
from app.store.tracks import TrackStore
from app.utils.generators import get_random_str
from app.config import UserConfig
@@ -49,4 +50,5 @@ def load_into_mem():
TrackStore.load_all_tracks(get_random_str())
AlbumStore.load_albums('a')
ArtistStore.load_artists('a')
PlaylistStore.load_playlists()
FolderStore.load_filepaths()
+50
View File
@@ -0,0 +1,50 @@
from app.db.userdata import PlaylistTable
from app.lib.playlistlib import get_first_4_images
from app.models.playlist import Playlist
from app.store.tracks import TrackStore
class PlaylistEntry:
def __init__(self, playlist: Playlist) -> None:
self.playlist = playlist
self.trackhashes: list[str] = playlist.trackhashes
self.playlist.clear_lists()
if not playlist.has_image:
self.playlist.images = get_first_4_images(
TrackStore.get_tracks_by_trackhashes(self.trackhashes)
)
class PlaylistStore:
playlistmap: dict[str, PlaylistEntry] = {}
@classmethod
def load_playlists(cls):
"""
Loads all playlists into the store.
"""
cls.playlistmap = {str(p.id): PlaylistEntry(p) for p in PlaylistTable.get_all()}
print(cls.playlistmap)
@classmethod
def get_playlist_tracks(cls, playlist_id: str, start: int, limit: int):
"""
Returns the trackhashes for a playlist.
"""
entry = cls.playlistmap.get(playlist_id)
if entry is None:
return []
return TrackStore.get_tracks_by_trackhashes(
entry.trackhashes[start : start + limit]
)
@classmethod
def get_flat_list(cls):
return [p.playlist for p in cls.playlistmap.values()]
@classmethod
def add_playlist(cls, playlist: Playlist):
cls.playlistmap[str(playlist.id)] = PlaylistEntry(playlist)