attempt fix: downloading artist images

This commit is contained in:
cwilvx
2024-08-31 13:08:14 +03:00
parent 41db162869
commit bbc905c585
4 changed files with 364 additions and 333 deletions
+296 -306
View File
@@ -1,107 +1,97 @@
from app.db import ( from app.db import Base
Base as MasterBase, from app.db.utils import tracks_to_dataclasses
)
from app.db.utils import (
album_to_dataclass,
albums_to_dataclasses,
artist_to_dataclass,
artists_to_dataclasses,
track_to_dataclass,
tracks_to_dataclasses,
)
from app.models import Album as AlbumModel
from app.utils.remove_duplicates import remove_duplicates
from app.db.engine import DbEngine from app.db.engine import DbEngine
from sqlalchemy import JSON, Integer, String, delete, select, update from sqlalchemy import JSON, Integer, String, delete, select
from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase from sqlalchemy.orm import Mapped, mapped_column
from typing import Any, Iterable, Optional from typing import Any, Optional
def create_all(): # def create_all():
""" # """
Create all the tables defined in this file. # Create all the tables defined in this file.
NOTE: We need this function because the MasterBase does not collect # NOTE: We need this function because the MasterBase does not collect
the tables defined here (as they are grand-children of the MasterBase) # the tables defined here (as they are grand-children of the MasterBase)
""" # """
Base.metadata.create_all(DbEngine.engine) # Base.metadata.create_all(DbEngine.engine)
class Base(MasterBase, DeclarativeBase): # class Base(MasterBase, DeclarativeBase):
@classmethod # pass
def get_all_hashes(cls, create_date: int | None = None): # @classmethod
with DbEngine.manager() as conn: # def get_all_hashes(cls, create_date: int | None = None):
if create_date: # with DbEngine.manager() as conn:
if cls.__tablename__ == "track": # if create_date:
stmt = select(TrackTable.trackhash).where( # if cls.__tablename__ == "track":
cls.last_mod < create_date # stmt = select(TrackTable.trackhash).where(
) # cls.last_mod < create_date
elif cls.__tablename__ == "album": # )
stmt = select(AlbumTable.albumhash).where( # elif cls.__tablename__ == "album":
cls.created_date < create_date # stmt = select(AlbumTable.albumhash).where(
) # cls.created_date < create_date
elif cls.__tablename__ == "artist": # )
stmt = select(ArtistTable.artisthash).where( # elif cls.__tablename__ == "artist":
cls.created_date < create_date # stmt = select(ArtistTable.artisthash).where(
) # cls.created_date < create_date
else: # )
if cls.__tablename__ == "track": # else:
stmt = select(TrackTable.trackhash) # if cls.__tablename__ == "track":
elif cls.__tablename__ == "album": # stmt = select(TrackTable.trackhash)
stmt = select(AlbumTable.albumhash) # elif cls.__tablename__ == "album":
elif cls.__tablename__ == "artist": # stmt = select(AlbumTable.albumhash)
stmt = select(ArtistTable.artisthash) # elif cls.__tablename__ == "artist":
# stmt = select(ArtistTable.artisthash)
result = conn.execute(stmt) # result = conn.execute(stmt)
return {row[0] for row in result.fetchall()} # return {row[0] for row in result.fetchall()}
@classmethod # @classmethod
def set_is_favorite(cls, hash: str, is_favorite: bool): # def set_is_favorite(cls, hash: str, is_favorite: bool):
""" # """
Set the 'is_favorite' flag for a specific hash. # Set the 'is_favorite' flag for a specific hash.
Args: # Args:
hash (str): The hash value. # hash (str): The hash value.
is_favorite (bool): The value of the 'is_favorite' flag. # is_favorite (bool): The value of the 'is_favorite' flag.
""" # """
with DbEngine.manager(commit=True) as conn: # with DbEngine.manager(commit=True) as conn:
if cls.__tablename__ == "track": # if cls.__tablename__ == "track":
stmt = ( # stmt = (
update(cls) # update(cls)
.where(TrackTable.trackhash == hash) # .where(TrackTable.trackhash == hash)
.values(is_favorite=is_favorite) # .values(is_favorite=is_favorite)
) # )
elif cls.__tablename__ == "album": # elif cls.__tablename__ == "album":
stmt = ( # stmt = (
update(cls) # update(cls)
.where(AlbumTable.albumhash == hash) # .where(AlbumTable.albumhash == hash)
.values(is_favorite=is_favorite) # .values(is_favorite=is_favorite)
) # )
elif cls.__tablename__ == "artist": # elif cls.__tablename__ == "artist":
stmt = ( # stmt = (
update(cls) # update(cls)
.where(ArtistTable.artisthash == hash) # .where(ArtistTable.artisthash == hash)
.values(is_favorite=is_favorite) # .values(is_favorite=is_favorite)
) # )
conn.execute(stmt) # conn.execute(stmt)
@classmethod # @classmethod
def increment_scrobblecount( # def increment_scrobblecount(
cls, table: Any, field: Any, hash: str, duration: int, timestamp: int # cls, table: Any, field: Any, hash: str, duration: int, timestamp: int
): # ):
cls.execute( # cls.execute(
update(table) # update(table)
.where(field == hash) # .where(field == hash)
.values( # .values(
playcount=table.playcount + 1, # playcount=table.playcount + 1,
playduration=table.playduration + duration, # playduration=table.playduration + duration,
lastplayed=timestamp, # lastplayed=timestamp,
), # ),
commit=True, # commit=True,
) # )
class TrackTable(Base): class TrackTable(Base):
@@ -151,44 +141,44 @@ class TrackTable(Base):
) )
return tracks_to_dataclasses(result.fetchall()) return tracks_to_dataclasses(result.fetchall())
@classmethod # @classmethod
def get_tracks_by_albumhash(cls, albumhash: str): # def get_tracks_by_albumhash(cls, albumhash: str):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute( # result = conn.execute(
select(TrackTable).where(TrackTable.albumhash == albumhash) # select(TrackTable).where(TrackTable.albumhash == albumhash)
) # )
tracks = tracks_to_dataclasses(result.fetchall()) # tracks = tracks_to_dataclasses(result.fetchall())
return remove_duplicates(tracks, is_album_tracks=True) # return remove_duplicates(tracks, is_album_tracks=True)
@classmethod # @classmethod
def get_track_by_trackhash(cls, hash: str, filepath: str = ""): # def get_track_by_trackhash(cls, hash: str, filepath: str = ""):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
if filepath: # if filepath:
result = conn.execute( # result = conn.execute(
select(TrackTable) # select(TrackTable)
.where( # .where(
(TrackTable.trackhash == hash) # (TrackTable.trackhash == hash)
& (TrackTable.filepath == filepath), # & (TrackTable.filepath == filepath),
) # )
.order_by(TrackTable.bitrate.desc()) # .order_by(TrackTable.bitrate.desc())
) # )
else: # else:
result = conn.execute( # result = conn.execute(
select(TrackTable).where(TrackTable.trackhash == hash) # select(TrackTable).where(TrackTable.trackhash == hash)
) # )
track = result.fetchone() # track = result.fetchone()
if track: # if track:
return track_to_dataclass(track) # return track_to_dataclass(track)
@classmethod # @classmethod
def get_tracks_by_artisthash(cls, artisthash: str): # def get_tracks_by_artisthash(cls, artisthash: str):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute( # result = conn.execute(
select(TrackTable).where(TrackTable.artists.contains(artisthash)) # select(TrackTable).where(TrackTable.artists.contains(artisthash))
) # )
return tracks_to_dataclasses(result.fetchall()) # return tracks_to_dataclasses(result.fetchall())
@classmethod @classmethod
def get_tracks_in_path(cls, path: str): def get_tracks_in_path(cls, path: str):
@@ -200,55 +190,55 @@ class TrackTable(Base):
) )
return tracks_to_dataclasses(result.fetchall()) return tracks_to_dataclasses(result.fetchall())
@classmethod # @classmethod
def get_tracks_by_trackhashes(cls, hashes: Iterable[str], limit: int | None = None): # def get_tracks_by_trackhashes(cls, hashes: Iterable[str], limit: int | None = None):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute( # result = conn.execute(
select(TrackTable) # select(TrackTable)
.where(TrackTable.trackhash.in_(hashes)) # .where(TrackTable.trackhash.in_(hashes))
.group_by(TrackTable.trackhash) # .group_by(TrackTable.trackhash)
.limit(limit) # .limit(limit)
) # )
tracks = tracks_to_dataclasses(result.fetchall()) # tracks = tracks_to_dataclasses(result.fetchall())
# order the tracks in the same order as the hashes # # order the tracks in the same order as the hashes
if type(hashes) == list: # if type(hashes) == list:
return sorted(tracks, key=lambda x: hashes.index(x.trackhash)) # return sorted(tracks, key=lambda x: hashes.index(x.trackhash))
return tracks # return tracks
# @classmethod
# def get_recently_added(cls, start: int, limit: int):
# with DbEngine.manager() as conn:
# result = conn.execute(
# select(TrackTable)
# .order_by(TrackTable.last_mod.desc())
# .offset(start)
# .limit(limit)
# )
# return tracks_to_dataclasses(result.fetchall())
@classmethod @classmethod
def get_recently_added(cls, start: int, limit: int): # def get_recently_played(cls, limit: int):
with DbEngine.manager() as conn: # result = cls.execute(
result = conn.execute( # select(cls)
select(TrackTable) # .group_by(cls.trackhash)
.order_by(TrackTable.last_mod.desc()) # .order_by(cls.lastplayed.desc())
.offset(start) # .limit(limit)
.limit(limit) # )
) # return tracks_to_dataclasses(result.fetchall())
return tracks_to_dataclasses(result.fetchall())
@classmethod
def get_recently_played(cls, limit: int):
result = cls.execute(
select(cls)
.group_by(cls.trackhash)
.order_by(cls.lastplayed.desc())
.limit(limit)
)
return tracks_to_dataclasses(result.fetchall())
@classmethod @classmethod
def remove_tracks_by_filepaths(cls, filepaths: set[str]): def remove_tracks_by_filepaths(cls, filepaths: set[str]):
with DbEngine.manager(commit=True) as conn: with DbEngine.manager(commit=True) as conn:
conn.execute(delete(TrackTable).where(TrackTable.filepath.in_(filepaths))) conn.execute(delete(TrackTable).where(TrackTable.filepath.in_(filepaths)))
@classmethod # @classmethod
def increment_playcount(cls, trackhash: str, duration: int, timestamp: int): # def increment_playcount(cls, trackhash: str, duration: int, timestamp: int):
cls.increment_scrobblecount( # cls.increment_scrobblecount(
TrackTable, TrackTable.trackhash, trackhash, duration, timestamp # TrackTable, TrackTable.trackhash, trackhash, duration, timestamp
) # )
# @classmethod # @classmethod
# def update_artist_separators(cls, separators: set[str]): # def update_artist_separators(cls, separators: set[str]):
@@ -264,166 +254,166 @@ class TrackTable(Base):
# ) # )
class AlbumTable(Base): # class AlbumTable(Base):
__tablename__ = "album" # __tablename__ = "album"
id: Mapped[int] = mapped_column(primary_key=True) # id: Mapped[int] = mapped_column(primary_key=True)
albumartists: Mapped[list[dict[str, str]]] = mapped_column(JSON(), index=True) # albumartists: Mapped[list[dict[str, str]]] = mapped_column(JSON(), index=True)
artisthashes: Mapped[list[str]] = mapped_column(JSON(), index=True) # artisthashes: Mapped[list[str]] = mapped_column(JSON(), index=True)
albumhash: Mapped[str] = mapped_column(String(), unique=True, index=True) # albumhash: Mapped[str] = mapped_column(String(), unique=True, index=True)
base_title: Mapped[str] = mapped_column(String()) # base_title: Mapped[str] = mapped_column(String())
color: Mapped[Optional[str]] = mapped_column(String()) # color: Mapped[Optional[str]] = mapped_column(String())
created_date: Mapped[int] = mapped_column(Integer()) # created_date: Mapped[int] = mapped_column(Integer())
date: Mapped[int] = mapped_column(Integer()) # date: Mapped[int] = mapped_column(Integer())
duration: Mapped[int] = mapped_column(Integer()) # duration: Mapped[int] = mapped_column(Integer())
genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True) # genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True)
genres: Mapped[str] = mapped_column(JSON()) # genres: Mapped[str] = mapped_column(JSON())
og_title: Mapped[str] = mapped_column(String()) # og_title: Mapped[str] = mapped_column(String())
title: Mapped[str] = mapped_column(String()) # title: Mapped[str] = mapped_column(String())
trackcount: Mapped[int] = mapped_column(Integer()) # trackcount: Mapped[int] = mapped_column(Integer())
lastplayed: Mapped[int] = mapped_column(Integer(), default=0) # lastplayed: Mapped[int] = mapped_column(Integer(), default=0)
playcount: Mapped[int] = mapped_column(Integer(), default=0) # playcount: Mapped[int] = mapped_column(Integer(), default=0)
playduration: Mapped[int] = mapped_column(Integer(), default=0) # playduration: Mapped[int] = mapped_column(Integer(), default=0)
extra: Mapped[Optional[dict[str, Any]]] = mapped_column( # extra: Mapped[Optional[dict[str, Any]]] = mapped_column(
JSON(), default_factory=dict # JSON(), default_factory=dict
) # )
@classmethod # @classmethod
def get_all(cls): # def get_all(cls):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute(select(AlbumTable)) # result = conn.execute(select(AlbumTable))
all = result.fetchall() # all = result.fetchall()
return albums_to_dataclasses(all) # return albums_to_dataclasses(all)
@classmethod # @classmethod
def get_album_by_albumhash(cls, hash: str): # def get_album_by_albumhash(cls, hash: str):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute( # result = conn.execute(
select(AlbumTable).where(AlbumTable.albumhash == hash) # select(AlbumTable).where(AlbumTable.albumhash == hash)
) # )
album = result.fetchone() # album = result.fetchone()
if album: # if album:
return album_to_dataclass(album) # return album_to_dataclass(album)
@classmethod # @classmethod
def get_albums_by_albumhashes(cls, hashes: Iterable[str], limit: int | None = None): # def get_albums_by_albumhashes(cls, hashes: Iterable[str], limit: int | None = None):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute( # result = conn.execute(
select(AlbumTable).where(AlbumTable.albumhash.in_(hashes)).limit(limit) # select(AlbumTable).where(AlbumTable.albumhash.in_(hashes)).limit(limit)
) # )
albums = albums_to_dataclasses(result.fetchall()) # albums = albums_to_dataclasses(result.fetchall())
# order the albums in the same order as the hashes # # order the albums in the same order as the hashes
if type(hashes) == list: # if type(hashes) == list:
return sorted(albums, key=lambda x: hashes.index(x.albumhash)) # return sorted(albums, key=lambda x: hashes.index(x.albumhash))
return albums # return albums
@classmethod # @classmethod
def get_albums_by_artisthashes(cls, artisthashes: list[str]): # def get_albums_by_artisthashes(cls, artisthashes: list[str]):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
albums: dict[str, list[AlbumModel]] = {} # albums: dict[str, list[AlbumModel]] = {}
for artist in artisthashes: # for artist in artisthashes:
result = conn.execute( # result = conn.execute(
select(AlbumTable).where(AlbumTable.artisthashes.contains(artist)) # select(AlbumTable).where(AlbumTable.artisthashes.contains(artist))
) # )
albums[artist] = albums_to_dataclasses(result.fetchall()) # albums[artist] = albums_to_dataclasses(result.fetchall())
return albums # return albums
@classmethod # @classmethod
def get_albums_by_base_title(cls, base_title: str): # def get_albums_by_base_title(cls, base_title: str):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute( # result = conn.execute(
select(AlbumTable).where(AlbumTable.base_title == base_title) # select(AlbumTable).where(AlbumTable.base_title == base_title)
) # )
return albums_to_dataclasses(result.fetchall()) # return albums_to_dataclasses(result.fetchall())
@classmethod # @classmethod
def get_albums_by_artisthash(cls, artisthash: str): # def get_albums_by_artisthash(cls, artisthash: str):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute( # result = conn.execute(
select(AlbumTable).where(AlbumTable.artisthashes.contains(artisthash)) # select(AlbumTable).where(AlbumTable.artisthashes.contains(artisthash))
) # )
return albums_to_dataclasses(result.all()) # return albums_to_dataclasses(result.all())
@classmethod # @classmethod
def increment_playcount(cls, albumhash: str, duration: int, timestamp: int): # def increment_playcount(cls, albumhash: str, duration: int, timestamp: int):
return cls.increment_scrobblecount( # return cls.increment_scrobblecount(
AlbumTable, AlbumTable.albumhash, albumhash, duration, timestamp # AlbumTable, AlbumTable.albumhash, albumhash, duration, timestamp
) # )
class ArtistTable(Base): # class ArtistTable(Base):
__tablename__ = "artist" # __tablename__ = "artist"
id: Mapped[int] = mapped_column(primary_key=True) # id: Mapped[int] = mapped_column(primary_key=True)
albumcount: Mapped[int] = mapped_column(Integer()) # albumcount: Mapped[int] = mapped_column(Integer())
artisthash: Mapped[str] = mapped_column(String(), unique=True, index=True) # artisthash: Mapped[str] = mapped_column(String(), unique=True, index=True)
created_date: Mapped[int] = mapped_column(Integer()) # created_date: Mapped[int] = mapped_column(Integer())
date: Mapped[int] = mapped_column(Integer()) # date: Mapped[int] = mapped_column(Integer())
duration: Mapped[int] = mapped_column(Integer()) # duration: Mapped[int] = mapped_column(Integer())
genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True) # genrehashes: Mapped[list[str]] = mapped_column(JSON(), nullable=True, index=True)
genres: Mapped[str] = mapped_column(JSON()) # genres: Mapped[str] = mapped_column(JSON())
name: Mapped[str] = mapped_column(String(), index=True) # name: Mapped[str] = mapped_column(String(), index=True)
trackcount: Mapped[int] = mapped_column(Integer()) # trackcount: Mapped[int] = mapped_column(Integer())
lastplayed: Mapped[int] = mapped_column(Integer(), default=0) # lastplayed: Mapped[int] = mapped_column(Integer(), default=0)
playcount: Mapped[int] = mapped_column(Integer(), default=0) # playcount: Mapped[int] = mapped_column(Integer(), default=0)
playduration: Mapped[int] = mapped_column(Integer(), default=0) # playduration: Mapped[int] = mapped_column(Integer(), default=0)
extra: Mapped[Optional[dict[str, Any]]] = mapped_column( # extra: Mapped[Optional[dict[str, Any]]] = mapped_column(
JSON(), default_factory=dict # JSON(), default_factory=dict
) # )
@classmethod # @classmethod
def get_all(cls): # def get_all(cls):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute(select(cls)) # result = conn.execute(select(cls))
all = result.fetchall() # all = result.fetchall()
return artists_to_dataclasses(all) # return artists_to_dataclasses(all)
@classmethod # @classmethod
def get_artist_by_hash(cls, artisthash: str): # def get_artist_by_hash(cls, artisthash: str):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute( # result = conn.execute(
select(ArtistTable).where(ArtistTable.artisthash == artisthash) # select(ArtistTable).where(ArtistTable.artisthash == artisthash)
) # )
return artist_to_dataclass(result.fetchone()) # return artist_to_dataclass(result.fetchone())
@classmethod # @classmethod
def get_artisthashes_not_in(cls, artisthashes: list[str]): # def get_artisthashes_not_in(cls, artisthashes: list[str]):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute( # result = conn.execute(
select(ArtistTable.artisthash, ArtistTable.name).where( # select(ArtistTable.artisthash, ArtistTable.name).where(
~ArtistTable.artisthash.in_(artisthashes) # ~ArtistTable.artisthash.in_(artisthashes)
) # )
) # )
return [{"artisthash": row[0], "name": row[1]} for row in result.fetchall()] # return [{"artisthash": row[0], "name": row[1]} for row in result.fetchall()]
@classmethod # @classmethod
def get_artists_by_artisthashes( # def get_artists_by_artisthashes(
cls, hashes: Iterable[str], limit: int | None = None # cls, hashes: Iterable[str], limit: int | None = None
): # ):
with DbEngine.manager() as conn: # with DbEngine.manager() as conn:
result = conn.execute( # result = conn.execute(
select(ArtistTable) # select(ArtistTable)
.where(ArtistTable.artisthash.in_(hashes)) # .where(ArtistTable.artisthash.in_(hashes))
.limit(limit) # .limit(limit)
) # )
return artists_to_dataclasses(result.fetchall()) # return artists_to_dataclasses(result.fetchall())
@classmethod # @classmethod
def increment_playcount( # def increment_playcount(
cls, artisthashes: list[str], duration: int, timestamp: int # cls, artisthashes: list[str], duration: int, timestamp: int
): # ):
cls.execute( # cls.execute(
update(cls) # update(cls)
.where(ArtistTable.artisthash.in_(artisthashes)) # .where(ArtistTable.artisthash.in_(artisthashes))
.values( # .values(
playcount=ArtistTable.playcount + 1, # playcount=ArtistTable.playcount + 1,
playduration=ArtistTable.playduration + duration, # playduration=ArtistTable.playduration + duration,
lastplayed=timestamp, # lastplayed=timestamp,
), # ),
commit=True, # commit=True,
) # )
+65 -24
View File
@@ -1,4 +1,5 @@
import os import os
import time
import urllib import urllib
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from io import BytesIO from io import BytesIO
@@ -10,7 +11,10 @@ from requests.exceptions import ConnectionError as RequestConnectionError
from requests.exceptions import ReadTimeout from requests.exceptions import ReadTimeout
from app import settings from app import settings
from app.db.libdata import ArtistTable from app.models.artist import Artist
from app.store.artists import ArtistStore
# from app.db.libdata import ArtistTable
# from app.store import artists as artist_store # from app.store import artists as artist_store
# from app.store.tracks import TrackStore # from app.store.tracks import TrackStore
@@ -28,27 +32,51 @@ def get_artist_image_link(artist: str):
""" """
Returns an artist image url. Returns an artist image url.
""" """
response: requests.Response | None = None
try: def make_request():
query = urllib.parse.quote(artist) # type: ignore query = urllib.parse.quote(artist) # type: ignore
url = f"https://api.deezer.com/search/artist?q={query}" url = f"https://api.deezer.com/search/artist?q={query}"
response = requests.get(url, timeout=30) headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Accept": "application/json, text/plain, */*",
"Accept-Language": "en-US,en;q=0.9",
"Referer": "https://www.deezer.com/",
"Origin": "https://www.deezer.com",
}
return requests.get(url, headers=headers, timeout=30)
for attempt in range(5):
try: try:
data = response.json() response = make_request()
except requests.exceptions.JSONDecodeError: try:
data = response.json()
except requests.exceptions.JSONDecodeError:
return None
for res in data["data"]:
res_hash = create_hash(res["name"], decode=True)
artist_hash = create_hash(artist, decode=True)
if res_hash == artist_hash:
return str(res["picture_big"])
return None return None
except (RequestConnectionError, ReadTimeout, IndexError, KeyError):
if attempt == 4:
print("Failed to get artist image link ")
for res in data["data"]: if attempt <= 4:
res_hash = create_hash(res["name"], decode=True) time.sleep(10)
artist_hash = create_hash(artist, decode=True) else:
return None
if res_hash == artist_hash: # except (IndexError, KeyError):
return str(res["picture_big"]) # print(f"Encountered index/key error in attempt {attempt}")
# if response is not None:
# print(response.headers)
return None # return None
except (RequestConnectionError, ReadTimeout, IndexError, KeyError):
return None
# TODO: Move network calls to utils/network.py # TODO: Move network calls to utils/network.py
@@ -75,11 +103,19 @@ class DownloadImage:
def download(url: str) -> Image.Image | None: def download(url: str) -> Image.Image | None:
""" """
Downloads the image from the url. Downloads the image from the url.
Retries after 10 seconds on a connection error.
""" """
try: for attempt in range(2):
return Image.open(BytesIO(requests.get(url, timeout=10).content)) try:
except UnidentifiedImageError: response = requests.get(url, timeout=10)
return None return Image.open(BytesIO(response.content))
except (RequestConnectionError, requests.Timeout, ReadTimeout):
if attempt == 0:
time.sleep(10)
else:
return None
except UnidentifiedImageError:
return None
@staticmethod @staticmethod
def save_img(img: Image.Image, entries: list[tuple[Path, int | None]]): def save_img(img: Image.Image, entries: list[tuple[Path, int | None]]):
@@ -107,7 +143,13 @@ class CheckArtistImages:
# read all files in the artist image folder # read all files in the artist image folder
path = settings.Paths.get_sm_artist_img_path() path = settings.Paths.get_sm_artist_img_path()
processed = [path.replace(".webp", "") for path in os.listdir(path)] processed = [path.replace(".webp", "") for path in os.listdir(path)]
unprocessed = ArtistTable.get_artisthashes_not_in(processed) print(f"Found {len(processed)} processed artist images")
unprocessed = [
a for a in ArtistStore.get_flat_list() if a.artisthash not in processed
]
print(f"Downloading {len(unprocessed)} artist images")
key_artist_map = ((instance_key, artist) for artist in unprocessed) key_artist_map = ((instance_key, artist) for artist in unprocessed)
with ThreadPoolExecutor(max_workers=14) as executor: with ThreadPoolExecutor(max_workers=14) as executor:
@@ -122,7 +164,7 @@ class CheckArtistImages:
list(res) list(res)
@staticmethod @staticmethod
def download_image(_map: tuple[str, dict[str, str]]): def download_image(_map: tuple[str, Artist]):
""" """
Checks if an artist image exists and downloads it if not. Checks if an artist image exists and downloads it if not.
@@ -134,14 +176,13 @@ class CheckArtistImages:
return return
img_path = ( img_path = (
Path(settings.Paths.get_sm_artist_img_path()) Path(settings.Paths.get_sm_artist_img_path()) / f"{artist.artisthash}.webp"
/ f"{artist['artisthash']}.webp"
) )
if img_path.exists(): if img_path.exists():
return return
url = get_artist_image_link(artist["name"]) url = get_artist_image_link(artist.name)
if url is not None: if url is not None:
return DownloadImage(url, name=f"{artist['artisthash']}.webp") return DownloadImage(url, name=f"{artist.artisthash}.webp")
+1 -1
View File
@@ -141,7 +141,7 @@ SUPPORTED_FILES = tuple(f".{file}" for file in FILES)
# ===== SQLite ===== # ===== SQLite =====
class DbPaths: class DbPaths:
APP_DB_NAME = "swing.db" APP_DB_NAME = "swingmusic.db"
USER_DATA_DB_NAME = "userdata.db" USER_DATA_DB_NAME = "userdata.db"
@classmethod @classmethod
+2 -2
View File
@@ -10,7 +10,7 @@ from app.settings import DbPaths
from app.db.engine import DbEngine from app.db.engine import DbEngine
from app.db import create_all_tables from app.db import create_all_tables
from app.db.libdata import create_all as create_user_tables # from app.db.libdata import create_all as create_user_tables
def run_migrations(): def run_migrations():
@@ -32,7 +32,7 @@ def setup_sqlite():
) )
create_all_tables() create_all_tables()
create_user_tables() # create_user_tables()
if not UserTable.get_all(): if not UserTable.get_all():
UserTable.insert_default_user() UserTable.insert_default_user()