mirror of
https://github.com/Dvorinka/swingmusic-extended.git
synced 2026-06-04 04:23:01 +00:00
try: rewrite some parts with process pools
This commit is contained in:
+2
-1
@@ -1,5 +1,4 @@
|
||||
from contextlib import contextmanager
|
||||
import gc
|
||||
from sqlalchemy import Engine, event
|
||||
|
||||
|
||||
@@ -43,3 +42,5 @@ class DbEngine:
|
||||
raise e
|
||||
finally:
|
||||
conn.close()
|
||||
del conn
|
||||
cls.engine.clear_compiled_cache()
|
||||
|
||||
+10
-4
@@ -143,8 +143,10 @@ class SimilarArtistTable(Base):
|
||||
@classmethod
|
||||
def get_all(cls):
|
||||
with DbEngine.manager() as conn:
|
||||
result = conn.execute(select(cls))
|
||||
return similar_artists_to_dataclass(result.fetchall())
|
||||
result = conn.execute(
|
||||
select(cls.artisthash), execution_options={"stream_results": True}
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
def exists(cls, artisthash: str):
|
||||
@@ -156,7 +158,8 @@ class SimilarArtistTable(Base):
|
||||
result = conn.execute(
|
||||
select(cls.artisthash).where(cls.artisthash == artisthash)
|
||||
)
|
||||
return result.fetchone() is not None
|
||||
|
||||
return len(result.scalars().all()) > 0
|
||||
|
||||
@classmethod
|
||||
def get_by_hash(cls, artisthash: str):
|
||||
@@ -474,7 +477,10 @@ class LibDataTable(Base):
|
||||
result = cls.execute(
|
||||
select(cls.itemhash, cls.color).where(cls.itemtype == type)
|
||||
)
|
||||
return [{"itemhash": r[0].replace(type, ''), "color": r[1]} for r in result.fetchall()]
|
||||
return [
|
||||
{"itemhash": r[0].replace(type, ""), "color": r[1]}
|
||||
for r in result.fetchall()
|
||||
]
|
||||
|
||||
|
||||
class MixTable(Base):
|
||||
|
||||
+13
-6
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
import time
|
||||
import urllib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import requests
|
||||
import multiprocessing
|
||||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import requests
|
||||
from PIL import Image, PngImagePlugin, UnidentifiedImageError
|
||||
from requests.exceptions import ConnectionError as RequestConnectionError
|
||||
from requests.exceptions import ReadTimeout
|
||||
@@ -21,6 +23,7 @@ from app.store.artists import ArtistStore
|
||||
from app.utils.hashing import create_hash
|
||||
from app.utils.progressbar import tqdm
|
||||
|
||||
|
||||
CHECK_ARTIST_IMAGES_KEY = ""
|
||||
|
||||
LARGE_ENOUGH_NUMBER = 100
|
||||
@@ -29,6 +32,7 @@ PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
||||
|
||||
import random
|
||||
|
||||
|
||||
def get_artist_image_link(artist: str):
|
||||
"""
|
||||
Returns an artist image url.
|
||||
@@ -43,7 +47,7 @@ def get_artist_image_link(artist: str):
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Safari/605.1.15",
|
||||
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36",
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:90.0) Gecko/20100101 Firefox/90.0",
|
||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1"
|
||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1",
|
||||
]
|
||||
headers = {
|
||||
"User-Agent": random.choice(user_agents),
|
||||
@@ -138,7 +142,7 @@ class DownloadImage:
|
||||
img.save(path, format="webp")
|
||||
continue
|
||||
|
||||
img.resize((size, int(size / ratio)), Image.ANTIALIAS).save(
|
||||
img.resize((size, int(size / ratio)), Image.LANCZOS).save(
|
||||
path, format="webp"
|
||||
)
|
||||
|
||||
@@ -150,7 +154,7 @@ class CheckArtistImages:
|
||||
|
||||
# read all files in the artist image folder
|
||||
path = settings.Paths.get_sm_artist_img_path()
|
||||
processed = [path.replace(".webp", "") for path in os.listdir(path)]
|
||||
processed = set(i.replace(".webp", "") for i in os.listdir(path))
|
||||
|
||||
unprocessed = [
|
||||
a for a in ArtistStore.get_flat_list() if a.artisthash not in processed
|
||||
@@ -158,7 +162,10 @@ class CheckArtistImages:
|
||||
|
||||
key_artist_map = ((instance_key, artist) for artist in unprocessed)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=14) as executor:
|
||||
# Use number of CPU cores minus 1 to leave one core free for system processes
|
||||
num_workers = max(1, multiprocessing.cpu_count() - 1)
|
||||
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
||||
res = list(
|
||||
tqdm(
|
||||
executor.map(self.download_image, key_artist_map),
|
||||
|
||||
+7
-5
@@ -1,6 +1,6 @@
|
||||
from dataclasses import asdict
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
from requests import ConnectionError as RequestConnectionError
|
||||
from requests import ReadTimeout
|
||||
@@ -111,7 +111,7 @@ class ProcessTrackThumbnails:
|
||||
path = settings.Paths.get_sm_thumb_path()
|
||||
|
||||
# read all the files in the thumbnail directory
|
||||
processed = "".join(os.listdir(path)).replace("webp", "")
|
||||
processed = set(i.replace(".webp", "") for i in os.listdir(path))
|
||||
|
||||
# filter out albums that already have thumbnails
|
||||
albums = filter(
|
||||
@@ -119,10 +119,12 @@ class ProcessTrackThumbnails:
|
||||
)
|
||||
albums = list(albums)
|
||||
|
||||
print("length of albums", len(albums))
|
||||
|
||||
# process the rest
|
||||
key_album_map = ((instance_key, album) for album in albums)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=get_cpu_count()) as executor:
|
||||
with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
|
||||
results = list(
|
||||
tqdm(
|
||||
executor.map(get_image, key_album_map),
|
||||
@@ -168,7 +170,7 @@ class FetchSimilarArtistsLastFM:
|
||||
def __init__(self, instance_key: str) -> None:
|
||||
# read all artists from db
|
||||
processed = SimilarArtistTable.get_all()
|
||||
processed = ".".join(a.artisthash for a in processed)
|
||||
processed = ".".join(a for a in processed)
|
||||
|
||||
# filter out artists that already have similar artists
|
||||
artists = filter(
|
||||
@@ -179,7 +181,7 @@ class FetchSimilarArtistsLastFM:
|
||||
# process the rest
|
||||
key_artist_map = ((instance_key, artist) for artist in artists)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=get_cpu_count()) as executor:
|
||||
with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
|
||||
try:
|
||||
print("Processing similar artists")
|
||||
results = list(
|
||||
|
||||
+41
-14
@@ -1,4 +1,7 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
from app import settings
|
||||
from app.config import UserConfig
|
||||
from app.db.libdata import TrackTable
|
||||
@@ -17,7 +20,9 @@ from app.utils.progressbar import tqdm
|
||||
from app.logger import log
|
||||
from app.utils.remove_duplicates import remove_duplicates
|
||||
|
||||
POPULATE_KEY: float = 0
|
||||
|
||||
class PopulateKey:
|
||||
key: float = 0
|
||||
|
||||
|
||||
class IndexTracks:
|
||||
@@ -28,9 +33,7 @@ class IndexTracks:
|
||||
An instance key is used to prevent multiple instances of the
|
||||
same class from running at the same time.
|
||||
"""
|
||||
global POPULATE_KEY
|
||||
POPULATE_KEY = instance_key
|
||||
|
||||
PopulateKey.key = instance_key
|
||||
dirs_to_scan = UserConfig().rootDirs
|
||||
|
||||
if len(dirs_to_scan) == 0:
|
||||
@@ -120,22 +123,46 @@ class IndexTracks:
|
||||
def get_untagged(self):
|
||||
tracks = TrackTable.get_all()
|
||||
|
||||
@staticmethod
|
||||
def _process_file(file: str, config: UserConfig, key: float) -> dict | None:
|
||||
"""Worker function to process individual files"""
|
||||
if PopulateKey.key != key:
|
||||
return None
|
||||
|
||||
try:
|
||||
return get_tags(file, config=config)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to process file {file}: {e}")
|
||||
return None
|
||||
|
||||
def tag_untagged(self, files: set[str], key: float):
|
||||
config = UserConfig()
|
||||
for file in tqdm(files, desc="Reading files"):
|
||||
if POPULATE_KEY != key:
|
||||
log.warning("'Populate.tag_untagged': Populate key changed")
|
||||
return
|
||||
|
||||
tags = get_tags(file, config=config)
|
||||
# Create process pool with worker function
|
||||
with Pool(processes=cpu_count()) as pool:
|
||||
worker = partial(self._process_file, config=config, key=key)
|
||||
|
||||
if tags is not None:
|
||||
TrackTable.insert_one(tags)
|
||||
FolderStore.filepaths.add(tags["filepath"])
|
||||
# Process files and track progress
|
||||
results = []
|
||||
for result in tqdm(
|
||||
pool.imap_unordered(worker, files),
|
||||
total=len(files),
|
||||
desc="Reading files",
|
||||
):
|
||||
if PopulateKey.key != key:
|
||||
log.warning("'Populate.tag_untagged': Populate key changed")
|
||||
pool.terminate()
|
||||
return
|
||||
|
||||
del tags
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
|
||||
print(f"{len(files)} new files indexed")
|
||||
# Bulk insert results
|
||||
for tags in results:
|
||||
TrackTable.insert_one(tags)
|
||||
FolderStore.filepaths.add(tags["filepath"])
|
||||
|
||||
print(f"{len(results)} new files indexed")
|
||||
print("Done")
|
||||
|
||||
|
||||
|
||||
+9
-6
@@ -23,12 +23,13 @@ def parse_album_art(filepath: str):
|
||||
"""
|
||||
Returns the album art for a given audio file.
|
||||
"""
|
||||
tags = TinyTag.get(filepath, image=True)
|
||||
image = tags.images.any
|
||||
|
||||
try:
|
||||
tags = TinyTag.get(filepath, image=True)
|
||||
return tags.get_image()
|
||||
except: # pylint: disable=bare-except
|
||||
return None
|
||||
if image:
|
||||
return image.data
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_thumb(filepath: str, webp_path: str, overwrite=False) -> bool:
|
||||
@@ -53,7 +54,9 @@ def extract_thumb(filepath: str, webp_path: str, overwrite=False) -> bool:
|
||||
ratio = width / height
|
||||
|
||||
for path, size in images:
|
||||
img.resize((size, int(size / ratio)), Image.ANTIALIAS).save(path, "webp")
|
||||
img.resize((size, int(size / ratio)), Image.LANCZOS).save(path, "webp")
|
||||
|
||||
del img
|
||||
|
||||
if not overwrite and os.path.exists(sm_img_path):
|
||||
img_size = os.path.getsize(sm_img_path)
|
||||
|
||||
@@ -126,8 +126,8 @@ class MixesPlugin(Plugin):
|
||||
# if len(trackmatches) < self.TRACK_MIX_LENGTH:
|
||||
if True:
|
||||
filler_tracks = self.fallback_create_artist_mix(
|
||||
similar_artists=results["artists"],
|
||||
similar_albums=results["albums"],
|
||||
similar_artists=results.get("artists", []),
|
||||
similar_albums=results.get("albums", []),
|
||||
omit_trackhashes={t.weakhash for t in trackmatches},
|
||||
# limit=self.TRACK_MIX_LENGTH - len(trackmatches),
|
||||
)
|
||||
@@ -135,9 +135,9 @@ class MixesPlugin(Plugin):
|
||||
|
||||
# try to balance the mix
|
||||
trackmatches = balance_mix(trackmatches)
|
||||
return trackmatches, results["albums"], results["artists"]
|
||||
return trackmatches, results.get("albums", []), results.get("artists", [])
|
||||
|
||||
@plugin_method
|
||||
# @plugin_method
|
||||
# def get_artist_mix(self, artisthash: str):
|
||||
# """
|
||||
# Given an artisthash, creates an artist mix using the
|
||||
|
||||
@@ -8,7 +8,6 @@ def has_connection(host="google.it", port=80, timeout=3):
|
||||
OpenPort: 53/tcp
|
||||
Service: domain (DNS/TCP)
|
||||
"""
|
||||
|
||||
try:
|
||||
Socket.setdefaulttimeout(timeout)
|
||||
Socket.socket(Socket.AF_INET, Socket.SOCK_STREAM).connect((host, port))
|
||||
|
||||
Reference in New Issue
Block a user