try: rewrite some parts with process pools

This commit is contained in:
cwilvx
2025-02-12 21:28:53 +03:00
parent beec5bc7d3
commit fa7c781610
9 changed files with 90 additions and 43 deletions
+2 -1
View File
@@ -1,5 +1,4 @@
from contextlib import contextmanager
import gc
from sqlalchemy import Engine, event
@@ -43,3 +42,5 @@ class DbEngine:
raise e
finally:
conn.close()
del conn
cls.engine.clear_compiled_cache()
+10 -4
View File
@@ -143,8 +143,10 @@ class SimilarArtistTable(Base):
@classmethod
def get_all(cls):
with DbEngine.manager() as conn:
result = conn.execute(select(cls))
return similar_artists_to_dataclass(result.fetchall())
result = conn.execute(
select(cls.artisthash), execution_options={"stream_results": True}
)
return result.scalars().all()
@classmethod
def exists(cls, artisthash: str):
@@ -156,7 +158,8 @@ class SimilarArtistTable(Base):
result = conn.execute(
select(cls.artisthash).where(cls.artisthash == artisthash)
)
return result.fetchone() is not None
return len(result.scalars().all()) > 0
@classmethod
def get_by_hash(cls, artisthash: str):
@@ -474,7 +477,10 @@ class LibDataTable(Base):
result = cls.execute(
select(cls.itemhash, cls.color).where(cls.itemtype == type)
)
return [{"itemhash": r[0].replace(type, ''), "color": r[1]} for r in result.fetchall()]
return [
{"itemhash": r[0].replace(type, ""), "color": r[1]}
for r in result.fetchall()
]
class MixTable(Base):
+13 -6
View File
@@ -1,11 +1,13 @@
import os
import time
import urllib
from concurrent.futures import ThreadPoolExecutor
import requests
import multiprocessing
from io import BytesIO
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor
import requests
from PIL import Image, PngImagePlugin, UnidentifiedImageError
from requests.exceptions import ConnectionError as RequestConnectionError
from requests.exceptions import ReadTimeout
@@ -21,6 +23,7 @@ from app.store.artists import ArtistStore
from app.utils.hashing import create_hash
from app.utils.progressbar import tqdm
CHECK_ARTIST_IMAGES_KEY = ""
LARGE_ENOUGH_NUMBER = 100
@@ -29,6 +32,7 @@ PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
import random
def get_artist_image_link(artist: str):
"""
Returns an artist image url.
@@ -43,7 +47,7 @@ def get_artist_image_link(artist: str):
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Safari/605.1.15",
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:90.0) Gecko/20100101 Firefox/90.0",
"Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1"
"Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1",
]
headers = {
"User-Agent": random.choice(user_agents),
@@ -138,7 +142,7 @@ class DownloadImage:
img.save(path, format="webp")
continue
img.resize((size, int(size / ratio)), Image.ANTIALIAS).save(
img.resize((size, int(size / ratio)), Image.LANCZOS).save(
path, format="webp"
)
@@ -150,7 +154,7 @@ class CheckArtistImages:
# read all files in the artist image folder
path = settings.Paths.get_sm_artist_img_path()
processed = [path.replace(".webp", "") for path in os.listdir(path)]
processed = set(i.replace(".webp", "") for i in os.listdir(path))
unprocessed = [
a for a in ArtistStore.get_flat_list() if a.artisthash not in processed
@@ -158,7 +162,10 @@ class CheckArtistImages:
key_artist_map = ((instance_key, artist) for artist in unprocessed)
with ThreadPoolExecutor(max_workers=14) as executor:
# Use number of CPU cores minus 1 to leave one core free for system processes
num_workers = max(1, multiprocessing.cpu_count() - 1)
with ProcessPoolExecutor(max_workers=num_workers) as executor:
res = list(
tqdm(
executor.map(self.download_image, key_artist_map),
+7 -5
View File
@@ -1,6 +1,6 @@
from dataclasses import asdict
import os
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ProcessPoolExecutor
from requests import ConnectionError as RequestConnectionError
from requests import ReadTimeout
@@ -111,7 +111,7 @@ class ProcessTrackThumbnails:
path = settings.Paths.get_sm_thumb_path()
# read all the files in the thumbnail directory
processed = "".join(os.listdir(path)).replace("webp", "")
processed = set(i.replace(".webp", "") for i in os.listdir(path))
# filter out albums that already have thumbnails
albums = filter(
@@ -119,10 +119,12 @@ class ProcessTrackThumbnails:
)
albums = list(albums)
print("length of albums", len(albums))
# process the rest
key_album_map = ((instance_key, album) for album in albums)
with ThreadPoolExecutor(max_workers=get_cpu_count()) as executor:
with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
results = list(
tqdm(
executor.map(get_image, key_album_map),
@@ -168,7 +170,7 @@ class FetchSimilarArtistsLastFM:
def __init__(self, instance_key: str) -> None:
# read all artists from db
processed = SimilarArtistTable.get_all()
processed = ".".join(a.artisthash for a in processed)
processed = ".".join(a for a in processed)
# filter out artists that already have similar artists
artists = filter(
@@ -179,7 +181,7 @@ class FetchSimilarArtistsLastFM:
# process the rest
key_artist_map = ((instance_key, artist) for artist in artists)
with ThreadPoolExecutor(max_workers=get_cpu_count()) as executor:
with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
try:
print("Processing similar artists")
results = list(
+38 -11
View File
@@ -1,4 +1,7 @@
import os
from functools import partial
from multiprocessing import Pool, cpu_count
from app import settings
from app.config import UserConfig
from app.db.libdata import TrackTable
@@ -17,7 +20,9 @@ from app.utils.progressbar import tqdm
from app.logger import log
from app.utils.remove_duplicates import remove_duplicates
POPULATE_KEY: float = 0
class PopulateKey:
key: float = 0
class IndexTracks:
@@ -28,9 +33,7 @@ class IndexTracks:
An instance key is used to prevent multiple instances of the
same class from running at the same time.
"""
global POPULATE_KEY
POPULATE_KEY = instance_key
PopulateKey.key = instance_key
dirs_to_scan = UserConfig().rootDirs
if len(dirs_to_scan) == 0:
@@ -120,22 +123,46 @@ class IndexTracks:
def get_untagged(self):
tracks = TrackTable.get_all()
@staticmethod
def _process_file(file: str, config: UserConfig, key: float) -> dict | None:
"""Worker function to process individual files"""
if PopulateKey.key != key:
return None
try:
return get_tags(file, config=config)
except Exception as e:
log.warning(f"Failed to process file {file}: {e}")
return None
def tag_untagged(self, files: set[str], key: float):
config = UserConfig()
for file in tqdm(files, desc="Reading files"):
if POPULATE_KEY != key:
# Create process pool with worker function
with Pool(processes=cpu_count()) as pool:
worker = partial(self._process_file, config=config, key=key)
# Process files and track progress
results = []
for result in tqdm(
pool.imap_unordered(worker, files),
total=len(files),
desc="Reading files",
):
if PopulateKey.key != key:
log.warning("'Populate.tag_untagged': Populate key changed")
pool.terminate()
return
tags = get_tags(file, config=config)
if result is not None:
results.append(result)
if tags is not None:
# Bulk insert results
for tags in results:
TrackTable.insert_one(tags)
FolderStore.filepaths.add(tags["filepath"])
del tags
print(f"{len(files)} new files indexed")
print(f"{len(results)} new files indexed")
print("Done")
+8 -5
View File
@@ -23,11 +23,12 @@ def parse_album_art(filepath: str):
"""
Returns the album art for a given audio file.
"""
try:
tags = TinyTag.get(filepath, image=True)
return tags.get_image()
except: # pylint: disable=bare-except
image = tags.images.any
if image:
return image.data
return None
@@ -53,7 +54,9 @@ def extract_thumb(filepath: str, webp_path: str, overwrite=False) -> bool:
ratio = width / height
for path, size in images:
img.resize((size, int(size / ratio)), Image.ANTIALIAS).save(path, "webp")
img.resize((size, int(size / ratio)), Image.LANCZOS).save(path, "webp")
del img
if not overwrite and os.path.exists(sm_img_path):
img_size = os.path.getsize(sm_img_path)
+4 -4
View File
@@ -126,8 +126,8 @@ class MixesPlugin(Plugin):
# if len(trackmatches) < self.TRACK_MIX_LENGTH:
if True:
filler_tracks = self.fallback_create_artist_mix(
similar_artists=results["artists"],
similar_albums=results["albums"],
similar_artists=results.get("artists", []),
similar_albums=results.get("albums", []),
omit_trackhashes={t.weakhash for t in trackmatches},
# limit=self.TRACK_MIX_LENGTH - len(trackmatches),
)
@@ -135,9 +135,9 @@ class MixesPlugin(Plugin):
# try to balance the mix
trackmatches = balance_mix(trackmatches)
return trackmatches, results["albums"], results["artists"]
return trackmatches, results.get("albums", []), results.get("artists", [])
@plugin_method
# @plugin_method
# def get_artist_mix(self, artisthash: str):
# """
# Given an artisthash, creates an artist mix using the
-1
View File
@@ -8,7 +8,6 @@ def has_connection(host="google.it", port=80, timeout=3):
OpenPort: 53/tcp
Service: domain (DNS/TCP)
"""
try:
Socket.setdefaulttimeout(timeout)
Socket.socket(Socket.AF_INET, Socket.SOCK_STREAM).connect((host, port))
+4 -2
View File
@@ -173,7 +173,9 @@ def serve_client_files(path: str):
# INFO: Safari doesn't support gzip encoding
# See issue: https://github.com/swingmx/swingmusic/issues/155
is_safari = user_agent.find("Safari") >= 0 and user_agent.find("Chrome") < 0
is_safari = (
user_agent and user_agent.find("Safari") >= 0 and user_agent.find("Chrome") < 0
)
if is_safari:
return app.send_static_file(path)
@@ -181,7 +183,7 @@ def serve_client_files(path: str):
accepts_gzip = request.headers.get("Accept-Encoding", "").find("gzip") >= 0
if accepts_gzip:
if os.path.exists(os.path.join(app.static_folder, gzipped_path)):
if os.path.exists(os.path.join(app.static_folder or "", gzipped_path)):
response = app.make_response(app.send_static_file(gzipped_path))
response.headers["Content-Encoding"] = "gzip"
return response