diff --git a/app/db/libdata.py b/app/db/libdata.py index c5f76a59..6e68d6d7 100644 --- a/app/db/libdata.py +++ b/app/db/libdata.py @@ -1,107 +1,97 @@ -from app.db import ( - Base as MasterBase, -) -from app.db.utils import ( - album_to_dataclass, - albums_to_dataclasses, - artist_to_dataclass, - artists_to_dataclasses, - track_to_dataclass, - tracks_to_dataclasses, -) -from app.models import Album as AlbumModel -from app.utils.remove_duplicates import remove_duplicates +from app.db import Base +from app.db.utils import tracks_to_dataclasses from app.db.engine import DbEngine -from sqlalchemy import JSON, Integer, String, delete, select, update -from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase +from sqlalchemy import JSON, Integer, String, delete, select +from sqlalchemy.orm import Mapped, mapped_column -from typing import Any, Iterable, Optional +from typing import Any, Optional -def create_all(): - """ - Create all the tables defined in this file. +# def create_all(): +# """ +# Create all the tables defined in this file. - NOTE: We need this function because the MasterBase does not collect - the tables defined here (as they are grand-children of the MasterBase) - """ - Base.metadata.create_all(DbEngine.engine) +# NOTE: We need this function because the MasterBase does not collect +# the tables defined here (as they are grand-children of the MasterBase) +# """ +# Base.metadata.create_all(DbEngine.engine) -class Base(MasterBase, DeclarativeBase): - @classmethod - def get_all_hashes(cls, create_date: int | None = None): - with DbEngine.manager() as conn: - if create_date: - if cls.__tablename__ == "track": - stmt = select(TrackTable.trackhash).where( - cls.last_mod < create_date - ) - elif cls.__tablename__ == "album": - stmt = select(AlbumTable.albumhash).where( - cls.created_date < create_date - ) - elif cls.__tablename__ == "artist": - stmt = select(ArtistTable.artisthash).where( - cls.created_date < create_date - ) - else: - if cls.__tablename__ == "track": - stmt = select(TrackTable.trackhash) - elif cls.__tablename__ == "album": - stmt = select(AlbumTable.albumhash) - elif cls.__tablename__ == "artist": - stmt = select(ArtistTable.artisthash) +# class Base(MasterBase, DeclarativeBase): +# pass +# @classmethod +# def get_all_hashes(cls, create_date: int | None = None): +# with DbEngine.manager() as conn: +# if create_date: +# if cls.__tablename__ == "track": +# stmt = select(TrackTable.trackhash).where( +# cls.last_mod < create_date +# ) +# elif cls.__tablename__ == "album": +# stmt = select(AlbumTable.albumhash).where( +# cls.created_date < create_date +# ) +# elif cls.__tablename__ == "artist": +# stmt = select(ArtistTable.artisthash).where( +# cls.created_date < create_date +# ) +# else: +# if cls.__tablename__ == "track": +# stmt = select(TrackTable.trackhash) +# elif cls.__tablename__ == "album": +# stmt = select(AlbumTable.albumhash) +# elif cls.__tablename__ == "artist": +# stmt = select(ArtistTable.artisthash) - result = conn.execute(stmt) - return {row[0] for row in result.fetchall()} +# result = conn.execute(stmt) +# return {row[0] for row in result.fetchall()} - @classmethod - def set_is_favorite(cls, hash: str, is_favorite: bool): - """ - Set the 'is_favorite' flag for a specific hash. +# @classmethod +# def set_is_favorite(cls, hash: str, is_favorite: bool): +# """ +# Set the 'is_favorite' flag for a specific hash. - Args: - hash (str): The hash value. - is_favorite (bool): The value of the 'is_favorite' flag. - """ - with DbEngine.manager(commit=True) as conn: - if cls.__tablename__ == "track": - stmt = ( - update(cls) - .where(TrackTable.trackhash == hash) - .values(is_favorite=is_favorite) - ) - elif cls.__tablename__ == "album": - stmt = ( - update(cls) - .where(AlbumTable.albumhash == hash) - .values(is_favorite=is_favorite) - ) - elif cls.__tablename__ == "artist": - stmt = ( - update(cls) - .where(ArtistTable.artisthash == hash) - .values(is_favorite=is_favorite) - ) +# Args: +# hash (str): The hash value. +# is_favorite (bool): The value of the 'is_favorite' flag. +# """ +# with DbEngine.manager(commit=True) as conn: +# if cls.__tablename__ == "track": +# stmt = ( +# update(cls) +# .where(TrackTable.trackhash == hash) +# .values(is_favorite=is_favorite) +# ) +# elif cls.__tablename__ == "album": +# stmt = ( +# update(cls) +# .where(AlbumTable.albumhash == hash) +# .values(is_favorite=is_favorite) +# ) +# elif cls.__tablename__ == "artist": +# stmt = ( +# update(cls) +# .where(ArtistTable.artisthash == hash) +# .values(is_favorite=is_favorite) +# ) - conn.execute(stmt) +# conn.execute(stmt) - @classmethod - def increment_scrobblecount( - cls, table: Any, field: Any, hash: str, duration: int, timestamp: int - ): - cls.execute( - update(table) - .where(field == hash) - .values( - playcount=table.playcount + 1, - playduration=table.playduration + duration, - lastplayed=timestamp, - ), - commit=True, - ) +# @classmethod +# def increment_scrobblecount( +# cls, table: Any, field: Any, hash: str, duration: int, timestamp: int +# ): +# cls.execute( +# update(table) +# .where(field == hash) +# .values( +# playcount=table.playcount + 1, +# playduration=table.playduration + duration, +# lastplayed=timestamp, +# ), +# commit=True, +# ) class TrackTable(Base): @@ -151,44 +141,44 @@ class TrackTable(Base): ) return tracks_to_dataclasses(result.fetchall()) - @classmethod - def get_tracks_by_albumhash(cls, albumhash: str): - with DbEngine.manager() as conn: - result = conn.execute( - select(TrackTable).where(TrackTable.albumhash == albumhash) - ) - tracks = tracks_to_dataclasses(result.fetchall()) - return remove_duplicates(tracks, is_album_tracks=True) + # @classmethod + # def get_tracks_by_albumhash(cls, albumhash: str): + # with DbEngine.manager() as conn: + # result = conn.execute( + # select(TrackTable).where(TrackTable.albumhash == albumhash) + # ) + # tracks = tracks_to_dataclasses(result.fetchall()) + # return remove_duplicates(tracks, is_album_tracks=True) - @classmethod - def get_track_by_trackhash(cls, hash: str, filepath: str = ""): - with DbEngine.manager() as conn: - if filepath: - result = conn.execute( - select(TrackTable) - .where( - (TrackTable.trackhash == hash) - & (TrackTable.filepath == filepath), - ) - .order_by(TrackTable.bitrate.desc()) - ) - else: - result = conn.execute( - select(TrackTable).where(TrackTable.trackhash == hash) - ) + # @classmethod + # def get_track_by_trackhash(cls, hash: str, filepath: str = ""): + # with DbEngine.manager() as conn: + # if filepath: + # result = conn.execute( + # select(TrackTable) + # .where( + # (TrackTable.trackhash == hash) + # & (TrackTable.filepath == filepath), + # ) + # .order_by(TrackTable.bitrate.desc()) + # ) + # else: + # result = conn.execute( + # select(TrackTable).where(TrackTable.trackhash == hash) + # ) - track = result.fetchone() + # track = result.fetchone() - if track: - return track_to_dataclass(track) + # if track: + # return track_to_dataclass(track) - @classmethod - def get_tracks_by_artisthash(cls, artisthash: str): - with DbEngine.manager() as conn: - result = conn.execute( - select(TrackTable).where(TrackTable.artists.contains(artisthash)) - ) - return tracks_to_dataclasses(result.fetchall()) + # @classmethod + # def get_tracks_by_artisthash(cls, artisthash: str): + # with DbEngine.manager() as conn: + # result = conn.execute( + # select(TrackTable).where(TrackTable.artists.contains(artisthash)) + # ) + # return tracks_to_dataclasses(result.fetchall()) @classmethod def get_tracks_in_path(cls, path: str): @@ -200,55 +190,55 @@ class TrackTable(Base): ) return tracks_to_dataclasses(result.fetchall()) - @classmethod - def get_tracks_by_trackhashes(cls, hashes: Iterable[str], limit: int | None = None): - with DbEngine.manager() as conn: - result = conn.execute( - select(TrackTable) - .where(TrackTable.trackhash.in_(hashes)) - .group_by(TrackTable.trackhash) - .limit(limit) - ) - tracks = tracks_to_dataclasses(result.fetchall()) + # @classmethod + # def get_tracks_by_trackhashes(cls, hashes: Iterable[str], limit: int | None = None): + # with DbEngine.manager() as conn: + # result = conn.execute( + # select(TrackTable) + # .where(TrackTable.trackhash.in_(hashes)) + # .group_by(TrackTable.trackhash) + # .limit(limit) + # ) + # tracks = tracks_to_dataclasses(result.fetchall()) - # order the tracks in the same order as the hashes - if type(hashes) == list: - return sorted(tracks, key=lambda x: hashes.index(x.trackhash)) + # # order the tracks in the same order as the hashes + # if type(hashes) == list: + # return sorted(tracks, key=lambda x: hashes.index(x.trackhash)) - return tracks + # return tracks + + # @classmethod + # def get_recently_added(cls, start: int, limit: int): + # with DbEngine.manager() as conn: + # result = conn.execute( + # select(TrackTable) + # .order_by(TrackTable.last_mod.desc()) + # .offset(start) + # .limit(limit) + # ) + + # return tracks_to_dataclasses(result.fetchall()) @classmethod - def get_recently_added(cls, start: int, limit: int): - with DbEngine.manager() as conn: - result = conn.execute( - select(TrackTable) - .order_by(TrackTable.last_mod.desc()) - .offset(start) - .limit(limit) - ) - - return tracks_to_dataclasses(result.fetchall()) - - @classmethod - def get_recently_played(cls, limit: int): - result = cls.execute( - select(cls) - .group_by(cls.trackhash) - .order_by(cls.lastplayed.desc()) - .limit(limit) - ) - return tracks_to_dataclasses(result.fetchall()) + # def get_recently_played(cls, limit: int): + # result = cls.execute( + # select(cls) + # .group_by(cls.trackhash) + # .order_by(cls.lastplayed.desc()) + # .limit(limit) + # ) + # return tracks_to_dataclasses(result.fetchall()) @classmethod def remove_tracks_by_filepaths(cls, filepaths: set[str]): with DbEngine.manager(commit=True) as conn: conn.execute(delete(TrackTable).where(TrackTable.filepath.in_(filepaths))) - @classmethod - def increment_playcount(cls, trackhash: str, duration: int, timestamp: int): - cls.increment_scrobblecount( - TrackTable, TrackTable.trackhash, trackhash, duration, timestamp - ) + # @classmethod + # def increment_playcount(cls, trackhash: str, duration: int, timestamp: int): + # cls.increment_scrobblecount( + # TrackTable, TrackTable.trackhash, trackhash, duration, timestamp + # ) # @classmethod # def update_artist_separators(cls, separators: set[str]): @@ -264,166 +254,166 @@ class TrackTable(Base): # ) -class AlbumTable(Base): - __tablename__ = "album" +# class AlbumTable(Base): +# __tablename__ = "album" - id: Mapped[int] = mapped_column(primary_key=True) - albumartists: Mapped[list[dict[str, str]]] = mapped_column(JSON(), index=True) - artisthashes: Mapped[list[str]] = mapped_column(JSON(), index=True) - albumhash: Mapped[str] = mapped_column(String(), unique=True, index=True) - base_title: Mapped[str] = mapped_column(String()) - color: Mapped[Optional[str]] = mapped_column(String()) - created_date: Mapped[int] = mapped_column(Integer()) - date: Mapped[int] = mapped_column(Integer()) - duration: Mapped[int] = mapped_column(Integer()) - genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True) - genres: Mapped[str] = mapped_column(JSON()) - og_title: Mapped[str] = mapped_column(String()) - title: Mapped[str] = mapped_column(String()) - trackcount: Mapped[int] = mapped_column(Integer()) - lastplayed: Mapped[int] = mapped_column(Integer(), default=0) - playcount: Mapped[int] = mapped_column(Integer(), default=0) - playduration: Mapped[int] = mapped_column(Integer(), default=0) - extra: Mapped[Optional[dict[str, Any]]] = mapped_column( - JSON(), default_factory=dict - ) +# id: Mapped[int] = mapped_column(primary_key=True) +# albumartists: Mapped[list[dict[str, str]]] = mapped_column(JSON(), index=True) +# artisthashes: Mapped[list[str]] = mapped_column(JSON(), index=True) +# albumhash: Mapped[str] = mapped_column(String(), unique=True, index=True) +# base_title: Mapped[str] = mapped_column(String()) +# color: Mapped[Optional[str]] = mapped_column(String()) +# created_date: Mapped[int] = mapped_column(Integer()) +# date: Mapped[int] = mapped_column(Integer()) +# duration: Mapped[int] = mapped_column(Integer()) +# genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True) +# genres: Mapped[str] = mapped_column(JSON()) +# og_title: Mapped[str] = mapped_column(String()) +# title: Mapped[str] = mapped_column(String()) +# trackcount: Mapped[int] = mapped_column(Integer()) +# lastplayed: Mapped[int] = mapped_column(Integer(), default=0) +# playcount: Mapped[int] = mapped_column(Integer(), default=0) +# playduration: Mapped[int] = mapped_column(Integer(), default=0) +# extra: Mapped[Optional[dict[str, Any]]] = mapped_column( +# JSON(), default_factory=dict +# ) - @classmethod - def get_all(cls): - with DbEngine.manager() as conn: - result = conn.execute(select(AlbumTable)) - all = result.fetchall() - return albums_to_dataclasses(all) +# @classmethod +# def get_all(cls): +# with DbEngine.manager() as conn: +# result = conn.execute(select(AlbumTable)) +# all = result.fetchall() +# return albums_to_dataclasses(all) - @classmethod - def get_album_by_albumhash(cls, hash: str): - with DbEngine.manager() as conn: - result = conn.execute( - select(AlbumTable).where(AlbumTable.albumhash == hash) - ) - album = result.fetchone() +# @classmethod +# def get_album_by_albumhash(cls, hash: str): +# with DbEngine.manager() as conn: +# result = conn.execute( +# select(AlbumTable).where(AlbumTable.albumhash == hash) +# ) +# album = result.fetchone() - if album: - return album_to_dataclass(album) +# if album: +# return album_to_dataclass(album) - @classmethod - def get_albums_by_albumhashes(cls, hashes: Iterable[str], limit: int | None = None): - with DbEngine.manager() as conn: - result = conn.execute( - select(AlbumTable).where(AlbumTable.albumhash.in_(hashes)).limit(limit) - ) - albums = albums_to_dataclasses(result.fetchall()) +# @classmethod +# def get_albums_by_albumhashes(cls, hashes: Iterable[str], limit: int | None = None): +# with DbEngine.manager() as conn: +# result = conn.execute( +# select(AlbumTable).where(AlbumTable.albumhash.in_(hashes)).limit(limit) +# ) +# albums = albums_to_dataclasses(result.fetchall()) - # order the albums in the same order as the hashes - if type(hashes) == list: - return sorted(albums, key=lambda x: hashes.index(x.albumhash)) +# # order the albums in the same order as the hashes +# if type(hashes) == list: +# return sorted(albums, key=lambda x: hashes.index(x.albumhash)) - return albums +# return albums - @classmethod - def get_albums_by_artisthashes(cls, artisthashes: list[str]): - with DbEngine.manager() as conn: - albums: dict[str, list[AlbumModel]] = {} +# @classmethod +# def get_albums_by_artisthashes(cls, artisthashes: list[str]): +# with DbEngine.manager() as conn: +# albums: dict[str, list[AlbumModel]] = {} - for artist in artisthashes: - result = conn.execute( - select(AlbumTable).where(AlbumTable.artisthashes.contains(artist)) - ) - albums[artist] = albums_to_dataclasses(result.fetchall()) +# for artist in artisthashes: +# result = conn.execute( +# select(AlbumTable).where(AlbumTable.artisthashes.contains(artist)) +# ) +# albums[artist] = albums_to_dataclasses(result.fetchall()) - return albums +# return albums - @classmethod - def get_albums_by_base_title(cls, base_title: str): - with DbEngine.manager() as conn: - result = conn.execute( - select(AlbumTable).where(AlbumTable.base_title == base_title) - ) - return albums_to_dataclasses(result.fetchall()) +# @classmethod +# def get_albums_by_base_title(cls, base_title: str): +# with DbEngine.manager() as conn: +# result = conn.execute( +# select(AlbumTable).where(AlbumTable.base_title == base_title) +# ) +# return albums_to_dataclasses(result.fetchall()) - @classmethod - def get_albums_by_artisthash(cls, artisthash: str): - with DbEngine.manager() as conn: - result = conn.execute( - select(AlbumTable).where(AlbumTable.artisthashes.contains(artisthash)) - ) - return albums_to_dataclasses(result.all()) +# @classmethod +# def get_albums_by_artisthash(cls, artisthash: str): +# with DbEngine.manager() as conn: +# result = conn.execute( +# select(AlbumTable).where(AlbumTable.artisthashes.contains(artisthash)) +# ) +# return albums_to_dataclasses(result.all()) - @classmethod - def increment_playcount(cls, albumhash: str, duration: int, timestamp: int): - return cls.increment_scrobblecount( - AlbumTable, AlbumTable.albumhash, albumhash, duration, timestamp - ) +# @classmethod +# def increment_playcount(cls, albumhash: str, duration: int, timestamp: int): +# return cls.increment_scrobblecount( +# AlbumTable, AlbumTable.albumhash, albumhash, duration, timestamp +# ) -class ArtistTable(Base): - __tablename__ = "artist" +# class ArtistTable(Base): +# __tablename__ = "artist" - id: Mapped[int] = mapped_column(primary_key=True) - albumcount: Mapped[int] = mapped_column(Integer()) - artisthash: Mapped[str] = mapped_column(String(), unique=True, index=True) - created_date: Mapped[int] = mapped_column(Integer()) - date: Mapped[int] = mapped_column(Integer()) - duration: Mapped[int] = mapped_column(Integer()) - genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True) - genres: Mapped[str] = mapped_column(JSON()) - name: Mapped[str] = mapped_column(String(), index=True) - trackcount: Mapped[int] = mapped_column(Integer()) - lastplayed: Mapped[int] = mapped_column(Integer(), default=0) - playcount: Mapped[int] = mapped_column(Integer(), default=0) - playduration: Mapped[int] = mapped_column(Integer(), default=0) - extra: Mapped[Optional[dict[str, Any]]] = mapped_column( - JSON(), default_factory=dict - ) +# id: Mapped[int] = mapped_column(primary_key=True) +# albumcount: Mapped[int] = mapped_column(Integer()) +# artisthash: Mapped[str] = mapped_column(String(), unique=True, index=True) +# created_date: Mapped[int] = mapped_column(Integer()) +# date: Mapped[int] = mapped_column(Integer()) +# duration: Mapped[int] = mapped_column(Integer()) +# genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True) +# genres: Mapped[str] = mapped_column(JSON()) +# name: Mapped[str] = mapped_column(String(), index=True) +# trackcount: Mapped[int] = mapped_column(Integer()) +# lastplayed: Mapped[int] = mapped_column(Integer(), default=0) +# playcount: Mapped[int] = mapped_column(Integer(), default=0) +# playduration: Mapped[int] = mapped_column(Integer(), default=0) +# extra: Mapped[Optional[dict[str, Any]]] = mapped_column( +# JSON(), default_factory=dict +# ) - @classmethod - def get_all(cls): - with DbEngine.manager() as conn: - result = conn.execute(select(cls)) - all = result.fetchall() - return artists_to_dataclasses(all) +# @classmethod +# def get_all(cls): +# with DbEngine.manager() as conn: +# result = conn.execute(select(cls)) +# all = result.fetchall() +# return artists_to_dataclasses(all) - @classmethod - def get_artist_by_hash(cls, artisthash: str): - with DbEngine.manager() as conn: - result = conn.execute( - select(ArtistTable).where(ArtistTable.artisthash == artisthash) - ) - return artist_to_dataclass(result.fetchone()) +# @classmethod +# def get_artist_by_hash(cls, artisthash: str): +# with DbEngine.manager() as conn: +# result = conn.execute( +# select(ArtistTable).where(ArtistTable.artisthash == artisthash) +# ) +# return artist_to_dataclass(result.fetchone()) - @classmethod - def get_artisthashes_not_in(cls, artisthashes: list[str]): - with DbEngine.manager() as conn: - result = conn.execute( - select(ArtistTable.artisthash, ArtistTable.name).where( - ~ArtistTable.artisthash.in_(artisthashes) - ) - ) - return [{"artisthash": row[0], "name": row[1]} for row in result.fetchall()] +# @classmethod +# def get_artisthashes_not_in(cls, artisthashes: list[str]): +# with DbEngine.manager() as conn: +# result = conn.execute( +# select(ArtistTable.artisthash, ArtistTable.name).where( +# ~ArtistTable.artisthash.in_(artisthashes) +# ) +# ) +# return [{"artisthash": row[0], "name": row[1]} for row in result.fetchall()] - @classmethod - def get_artists_by_artisthashes( - cls, hashes: Iterable[str], limit: int | None = None - ): - with DbEngine.manager() as conn: - result = conn.execute( - select(ArtistTable) - .where(ArtistTable.artisthash.in_(hashes)) - .limit(limit) - ) - return artists_to_dataclasses(result.fetchall()) +# @classmethod +# def get_artists_by_artisthashes( +# cls, hashes: Iterable[str], limit: int | None = None +# ): +# with DbEngine.manager() as conn: +# result = conn.execute( +# select(ArtistTable) +# .where(ArtistTable.artisthash.in_(hashes)) +# .limit(limit) +# ) +# return artists_to_dataclasses(result.fetchall()) - @classmethod - def increment_playcount( - cls, artisthashes: list[str], duration: int, timestamp: int - ): - cls.execute( - update(cls) - .where(ArtistTable.artisthash.in_(artisthashes)) - .values( - playcount=ArtistTable.playcount + 1, - playduration=ArtistTable.playduration + duration, - lastplayed=timestamp, - ), - commit=True, - ) +# @classmethod +# def increment_playcount( +# cls, artisthashes: list[str], duration: int, timestamp: int +# ): +# cls.execute( +# update(cls) +# .where(ArtistTable.artisthash.in_(artisthashes)) +# .values( +# playcount=ArtistTable.playcount + 1, +# playduration=ArtistTable.playduration + duration, +# lastplayed=timestamp, +# ), +# commit=True, +# ) diff --git a/app/lib/artistlib.py b/app/lib/artistlib.py index f4664e84..676e983b 100644 --- a/app/lib/artistlib.py +++ b/app/lib/artistlib.py @@ -1,4 +1,5 @@ import os +import time import urllib from concurrent.futures import ThreadPoolExecutor from io import BytesIO @@ -10,7 +11,10 @@ from requests.exceptions import ConnectionError as RequestConnectionError from requests.exceptions import ReadTimeout from app import settings -from app.db.libdata import ArtistTable +from app.models.artist import Artist +from app.store.artists import ArtistStore + +# from app.db.libdata import ArtistTable # from app.store import artists as artist_store # from app.store.tracks import TrackStore @@ -28,27 +32,51 @@ def get_artist_image_link(artist: str): """ Returns an artist image url. """ + response: requests.Response | None = None - try: + def make_request(): query = urllib.parse.quote(artist) # type: ignore - url = f"https://api.deezer.com/search/artist?q={query}" - response = requests.get(url, timeout=30) + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + "Accept": "application/json, text/plain, */*", + "Accept-Language": "en-US,en;q=0.9", + "Referer": "https://www.deezer.com/", + "Origin": "https://www.deezer.com", + } + return requests.get(url, headers=headers, timeout=30) + + for attempt in range(5): try: - data = response.json() - except requests.exceptions.JSONDecodeError: + response = make_request() + try: + data = response.json() + except requests.exceptions.JSONDecodeError: + return None + + for res in data["data"]: + res_hash = create_hash(res["name"], decode=True) + artist_hash = create_hash(artist, decode=True) + + if res_hash == artist_hash: + return str(res["picture_big"]) + return None + except (RequestConnectionError, ReadTimeout, IndexError, KeyError): + if attempt == 4: + print("Failed to get artist image link ") - for res in data["data"]: - res_hash = create_hash(res["name"], decode=True) - artist_hash = create_hash(artist, decode=True) + if attempt <= 4: + time.sleep(10) + else: + return None - if res_hash == artist_hash: - return str(res["picture_big"]) + # except (IndexError, KeyError): + # print(f"Encountered index/key error in attempt {attempt}") + # if response is not None: + # print(response.headers) - return None - except (RequestConnectionError, ReadTimeout, IndexError, KeyError): - return None + # return None # TODO: Move network calls to utils/network.py @@ -75,11 +103,19 @@ class DownloadImage: def download(url: str) -> Image.Image | None: """ Downloads the image from the url. + Retries after 10 seconds on a connection error. """ - try: - return Image.open(BytesIO(requests.get(url, timeout=10).content)) - except UnidentifiedImageError: - return None + for attempt in range(2): + try: + response = requests.get(url, timeout=10) + return Image.open(BytesIO(response.content)) + except (RequestConnectionError, requests.Timeout, ReadTimeout): + if attempt == 0: + time.sleep(10) + else: + return None + except UnidentifiedImageError: + return None @staticmethod def save_img(img: Image.Image, entries: list[tuple[Path, int | None]]): @@ -107,7 +143,13 @@ class CheckArtistImages: # read all files in the artist image folder path = settings.Paths.get_sm_artist_img_path() processed = [path.replace(".webp", "") for path in os.listdir(path)] - unprocessed = ArtistTable.get_artisthashes_not_in(processed) + print(f"Found {len(processed)} processed artist images") + + unprocessed = [ + a for a in ArtistStore.get_flat_list() if a.artisthash not in processed + ] + + print(f"Downloading {len(unprocessed)} artist images") key_artist_map = ((instance_key, artist) for artist in unprocessed) with ThreadPoolExecutor(max_workers=14) as executor: @@ -122,7 +164,7 @@ class CheckArtistImages: list(res) @staticmethod - def download_image(_map: tuple[str, dict[str, str]]): + def download_image(_map: tuple[str, Artist]): """ Checks if an artist image exists and downloads it if not. @@ -134,14 +176,13 @@ class CheckArtistImages: return img_path = ( - Path(settings.Paths.get_sm_artist_img_path()) - / f"{artist['artisthash']}.webp" + Path(settings.Paths.get_sm_artist_img_path()) / f"{artist.artisthash}.webp" ) if img_path.exists(): return - url = get_artist_image_link(artist["name"]) + url = get_artist_image_link(artist.name) if url is not None: - return DownloadImage(url, name=f"{artist['artisthash']}.webp") + return DownloadImage(url, name=f"{artist.artisthash}.webp") diff --git a/app/settings.py b/app/settings.py index 1320454e..3577ca4f 100644 --- a/app/settings.py +++ b/app/settings.py @@ -141,7 +141,7 @@ SUPPORTED_FILES = tuple(f".{file}" for file in FILES) # ===== SQLite ===== class DbPaths: - APP_DB_NAME = "swing.db" + APP_DB_NAME = "swingmusic.db" USER_DATA_DB_NAME = "userdata.db" @classmethod diff --git a/app/setup/sqlite.py b/app/setup/sqlite.py index f3044192..33ef1c89 100644 --- a/app/setup/sqlite.py +++ b/app/setup/sqlite.py @@ -10,7 +10,7 @@ from app.settings import DbPaths from app.db.engine import DbEngine from app.db import create_all_tables -from app.db.libdata import create_all as create_user_tables +# from app.db.libdata import create_all as create_user_tables def run_migrations(): @@ -32,7 +32,7 @@ def setup_sqlite(): ) create_all_tables() - create_user_tables() + # create_user_tables() if not UserTable.get_all(): UserTable.insert_default_user()