try: rewrite some parts with process pools

This commit is contained in:
cwilvx
2025-02-12 21:28:53 +03:00
parent beec5bc7d3
commit fa7c781610
9 changed files with 90 additions and 43 deletions
+2 -1
View File
@@ -1,5 +1,4 @@
from contextlib import contextmanager from contextlib import contextmanager
import gc
from sqlalchemy import Engine, event from sqlalchemy import Engine, event
@@ -43,3 +42,5 @@ class DbEngine:
raise e raise e
finally: finally:
conn.close() conn.close()
del conn
cls.engine.clear_compiled_cache()
+10 -4
View File
@@ -143,8 +143,10 @@ class SimilarArtistTable(Base):
@classmethod @classmethod
def get_all(cls): def get_all(cls):
with DbEngine.manager() as conn: with DbEngine.manager() as conn:
result = conn.execute(select(cls)) result = conn.execute(
return similar_artists_to_dataclass(result.fetchall()) select(cls.artisthash), execution_options={"stream_results": True}
)
return result.scalars().all()
@classmethod @classmethod
def exists(cls, artisthash: str): def exists(cls, artisthash: str):
@@ -156,7 +158,8 @@ class SimilarArtistTable(Base):
result = conn.execute( result = conn.execute(
select(cls.artisthash).where(cls.artisthash == artisthash) select(cls.artisthash).where(cls.artisthash == artisthash)
) )
return result.fetchone() is not None
return len(result.scalars().all()) > 0
@classmethod @classmethod
def get_by_hash(cls, artisthash: str): def get_by_hash(cls, artisthash: str):
@@ -474,7 +477,10 @@ class LibDataTable(Base):
result = cls.execute( result = cls.execute(
select(cls.itemhash, cls.color).where(cls.itemtype == type) select(cls.itemhash, cls.color).where(cls.itemtype == type)
) )
return [{"itemhash": r[0].replace(type, ''), "color": r[1]} for r in result.fetchall()] return [
{"itemhash": r[0].replace(type, ""), "color": r[1]}
for r in result.fetchall()
]
class MixTable(Base): class MixTable(Base):
+13 -6
View File
@@ -1,11 +1,13 @@
import os import os
import time import time
import urllib import urllib
from concurrent.futures import ThreadPoolExecutor import requests
import multiprocessing
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from concurrent.futures import ProcessPoolExecutor
import requests
from PIL import Image, PngImagePlugin, UnidentifiedImageError from PIL import Image, PngImagePlugin, UnidentifiedImageError
from requests.exceptions import ConnectionError as RequestConnectionError from requests.exceptions import ConnectionError as RequestConnectionError
from requests.exceptions import ReadTimeout from requests.exceptions import ReadTimeout
@@ -21,6 +23,7 @@ from app.store.artists import ArtistStore
from app.utils.hashing import create_hash from app.utils.hashing import create_hash
from app.utils.progressbar import tqdm from app.utils.progressbar import tqdm
CHECK_ARTIST_IMAGES_KEY = "" CHECK_ARTIST_IMAGES_KEY = ""
LARGE_ENOUGH_NUMBER = 100 LARGE_ENOUGH_NUMBER = 100
@@ -29,6 +32,7 @@ PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
import random import random
def get_artist_image_link(artist: str): def get_artist_image_link(artist: str):
""" """
Returns an artist image url. Returns an artist image url.
@@ -43,7 +47,7 @@ def get_artist_image_link(artist: str):
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Safari/605.1.15", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Safari/605.1.15",
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:90.0) Gecko/20100101 Firefox/90.0", "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:90.0) Gecko/20100101 Firefox/90.0",
"Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1" "Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1",
] ]
headers = { headers = {
"User-Agent": random.choice(user_agents), "User-Agent": random.choice(user_agents),
@@ -138,7 +142,7 @@ class DownloadImage:
img.save(path, format="webp") img.save(path, format="webp")
continue continue
img.resize((size, int(size / ratio)), Image.ANTIALIAS).save( img.resize((size, int(size / ratio)), Image.LANCZOS).save(
path, format="webp" path, format="webp"
) )
@@ -150,7 +154,7 @@ class CheckArtistImages:
# read all files in the artist image folder # read all files in the artist image folder
path = settings.Paths.get_sm_artist_img_path() path = settings.Paths.get_sm_artist_img_path()
processed = [path.replace(".webp", "") for path in os.listdir(path)] processed = set(i.replace(".webp", "") for i in os.listdir(path))
unprocessed = [ unprocessed = [
a for a in ArtistStore.get_flat_list() if a.artisthash not in processed a for a in ArtistStore.get_flat_list() if a.artisthash not in processed
@@ -158,7 +162,10 @@ class CheckArtistImages:
key_artist_map = ((instance_key, artist) for artist in unprocessed) key_artist_map = ((instance_key, artist) for artist in unprocessed)
with ThreadPoolExecutor(max_workers=14) as executor: # Use number of CPU cores minus 1 to leave one core free for system processes
num_workers = max(1, multiprocessing.cpu_count() - 1)
with ProcessPoolExecutor(max_workers=num_workers) as executor:
res = list( res = list(
tqdm( tqdm(
executor.map(self.download_image, key_artist_map), executor.map(self.download_image, key_artist_map),
+7 -5
View File
@@ -1,6 +1,6 @@
from dataclasses import asdict from dataclasses import asdict
import os import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ProcessPoolExecutor
from requests import ConnectionError as RequestConnectionError from requests import ConnectionError as RequestConnectionError
from requests import ReadTimeout from requests import ReadTimeout
@@ -111,7 +111,7 @@ class ProcessTrackThumbnails:
path = settings.Paths.get_sm_thumb_path() path = settings.Paths.get_sm_thumb_path()
# read all the files in the thumbnail directory # read all the files in the thumbnail directory
processed = "".join(os.listdir(path)).replace("webp", "") processed = set(i.replace(".webp", "") for i in os.listdir(path))
# filter out albums that already have thumbnails # filter out albums that already have thumbnails
albums = filter( albums = filter(
@@ -119,10 +119,12 @@ class ProcessTrackThumbnails:
) )
albums = list(albums) albums = list(albums)
print("length of albums", len(albums))
# process the rest # process the rest
key_album_map = ((instance_key, album) for album in albums) key_album_map = ((instance_key, album) for album in albums)
with ThreadPoolExecutor(max_workers=get_cpu_count()) as executor: with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
results = list( results = list(
tqdm( tqdm(
executor.map(get_image, key_album_map), executor.map(get_image, key_album_map),
@@ -168,7 +170,7 @@ class FetchSimilarArtistsLastFM:
def __init__(self, instance_key: str) -> None: def __init__(self, instance_key: str) -> None:
# read all artists from db # read all artists from db
processed = SimilarArtistTable.get_all() processed = SimilarArtistTable.get_all()
processed = ".".join(a.artisthash for a in processed) processed = ".".join(a for a in processed)
# filter out artists that already have similar artists # filter out artists that already have similar artists
artists = filter( artists = filter(
@@ -179,7 +181,7 @@ class FetchSimilarArtistsLastFM:
# process the rest # process the rest
key_artist_map = ((instance_key, artist) for artist in artists) key_artist_map = ((instance_key, artist) for artist in artists)
with ThreadPoolExecutor(max_workers=get_cpu_count()) as executor: with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
try: try:
print("Processing similar artists") print("Processing similar artists")
results = list( results = list(
+38 -11
View File
@@ -1,4 +1,7 @@
import os import os
from functools import partial
from multiprocessing import Pool, cpu_count
from app import settings from app import settings
from app.config import UserConfig from app.config import UserConfig
from app.db.libdata import TrackTable from app.db.libdata import TrackTable
@@ -17,7 +20,9 @@ from app.utils.progressbar import tqdm
from app.logger import log from app.logger import log
from app.utils.remove_duplicates import remove_duplicates from app.utils.remove_duplicates import remove_duplicates
POPULATE_KEY: float = 0
class PopulateKey:
key: float = 0
class IndexTracks: class IndexTracks:
@@ -28,9 +33,7 @@ class IndexTracks:
An instance key is used to prevent multiple instances of the An instance key is used to prevent multiple instances of the
same class from running at the same time. same class from running at the same time.
""" """
global POPULATE_KEY PopulateKey.key = instance_key
POPULATE_KEY = instance_key
dirs_to_scan = UserConfig().rootDirs dirs_to_scan = UserConfig().rootDirs
if len(dirs_to_scan) == 0: if len(dirs_to_scan) == 0:
@@ -120,22 +123,46 @@ class IndexTracks:
def get_untagged(self): def get_untagged(self):
tracks = TrackTable.get_all() tracks = TrackTable.get_all()
@staticmethod
def _process_file(file: str, config: UserConfig, key: float) -> dict | None:
"""Worker function to process individual files"""
if PopulateKey.key != key:
return None
try:
return get_tags(file, config=config)
except Exception as e:
log.warning(f"Failed to process file {file}: {e}")
return None
def tag_untagged(self, files: set[str], key: float): def tag_untagged(self, files: set[str], key: float):
config = UserConfig() config = UserConfig()
for file in tqdm(files, desc="Reading files"):
if POPULATE_KEY != key: # Create process pool with worker function
with Pool(processes=cpu_count()) as pool:
worker = partial(self._process_file, config=config, key=key)
# Process files and track progress
results = []
for result in tqdm(
pool.imap_unordered(worker, files),
total=len(files),
desc="Reading files",
):
if PopulateKey.key != key:
log.warning("'Populate.tag_untagged': Populate key changed") log.warning("'Populate.tag_untagged': Populate key changed")
pool.terminate()
return return
tags = get_tags(file, config=config) if result is not None:
results.append(result)
if tags is not None: # Bulk insert results
for tags in results:
TrackTable.insert_one(tags) TrackTable.insert_one(tags)
FolderStore.filepaths.add(tags["filepath"]) FolderStore.filepaths.add(tags["filepath"])
del tags print(f"{len(results)} new files indexed")
print(f"{len(files)} new files indexed")
print("Done") print("Done")
+8 -5
View File
@@ -23,11 +23,12 @@ def parse_album_art(filepath: str):
""" """
Returns the album art for a given audio file. Returns the album art for a given audio file.
""" """
try:
tags = TinyTag.get(filepath, image=True) tags = TinyTag.get(filepath, image=True)
return tags.get_image() image = tags.images.any
except: # pylint: disable=bare-except
if image:
return image.data
return None return None
@@ -53,7 +54,9 @@ def extract_thumb(filepath: str, webp_path: str, overwrite=False) -> bool:
ratio = width / height ratio = width / height
for path, size in images: for path, size in images:
img.resize((size, int(size / ratio)), Image.ANTIALIAS).save(path, "webp") img.resize((size, int(size / ratio)), Image.LANCZOS).save(path, "webp")
del img
if not overwrite and os.path.exists(sm_img_path): if not overwrite and os.path.exists(sm_img_path):
img_size = os.path.getsize(sm_img_path) img_size = os.path.getsize(sm_img_path)
+4 -4
View File
@@ -126,8 +126,8 @@ class MixesPlugin(Plugin):
# if len(trackmatches) < self.TRACK_MIX_LENGTH: # if len(trackmatches) < self.TRACK_MIX_LENGTH:
if True: if True:
filler_tracks = self.fallback_create_artist_mix( filler_tracks = self.fallback_create_artist_mix(
similar_artists=results["artists"], similar_artists=results.get("artists", []),
similar_albums=results["albums"], similar_albums=results.get("albums", []),
omit_trackhashes={t.weakhash for t in trackmatches}, omit_trackhashes={t.weakhash for t in trackmatches},
# limit=self.TRACK_MIX_LENGTH - len(trackmatches), # limit=self.TRACK_MIX_LENGTH - len(trackmatches),
) )
@@ -135,9 +135,9 @@ class MixesPlugin(Plugin):
# try to balance the mix # try to balance the mix
trackmatches = balance_mix(trackmatches) trackmatches = balance_mix(trackmatches)
return trackmatches, results["albums"], results["artists"] return trackmatches, results.get("albums", []), results.get("artists", [])
@plugin_method # @plugin_method
# def get_artist_mix(self, artisthash: str): # def get_artist_mix(self, artisthash: str):
# """ # """
# Given an artisthash, creates an artist mix using the # Given an artisthash, creates an artist mix using the
-1
View File
@@ -8,7 +8,6 @@ def has_connection(host="google.it", port=80, timeout=3):
OpenPort: 53/tcp OpenPort: 53/tcp
Service: domain (DNS/TCP) Service: domain (DNS/TCP)
""" """
try: try:
Socket.setdefaulttimeout(timeout) Socket.setdefaulttimeout(timeout)
Socket.socket(Socket.AF_INET, Socket.SOCK_STREAM).connect((host, port)) Socket.socket(Socket.AF_INET, Socket.SOCK_STREAM).connect((host, port))
+4 -2
View File
@@ -173,7 +173,9 @@ def serve_client_files(path: str):
# INFO: Safari doesn't support gzip encoding # INFO: Safari doesn't support gzip encoding
# See issue: https://github.com/swingmx/swingmusic/issues/155 # See issue: https://github.com/swingmx/swingmusic/issues/155
is_safari = user_agent.find("Safari") >= 0 and user_agent.find("Chrome") < 0 is_safari = (
user_agent and user_agent.find("Safari") >= 0 and user_agent.find("Chrome") < 0
)
if is_safari: if is_safari:
return app.send_static_file(path) return app.send_static_file(path)
@@ -181,7 +183,7 @@ def serve_client_files(path: str):
accepts_gzip = request.headers.get("Accept-Encoding", "").find("gzip") >= 0 accepts_gzip = request.headers.get("Accept-Encoding", "").find("gzip") >= 0
if accepts_gzip: if accepts_gzip:
if os.path.exists(os.path.join(app.static_folder, gzipped_path)): if os.path.exists(os.path.join(app.static_folder or "", gzipped_path)):
response = app.make_response(app.send_static_file(gzipped_path)) response = app.make_response(app.send_static_file(gzipped_path))
response.headers["Content-Encoding"] = "gzip" response.headers["Content-Encoding"] = "gzip"
return response return response