mirror of
https://github.com/Dvorinka/swingmusic-extended.git
synced 2026-06-05 04:53:01 +00:00
try: rewrite some parts with process pools
This commit is contained in:
+2
-1
@@ -1,5 +1,4 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import gc
|
|
||||||
from sqlalchemy import Engine, event
|
from sqlalchemy import Engine, event
|
||||||
|
|
||||||
|
|
||||||
@@ -43,3 +42,5 @@ class DbEngine:
|
|||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
del conn
|
||||||
|
cls.engine.clear_compiled_cache()
|
||||||
|
|||||||
+10
-4
@@ -143,8 +143,10 @@ class SimilarArtistTable(Base):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_all(cls):
|
def get_all(cls):
|
||||||
with DbEngine.manager() as conn:
|
with DbEngine.manager() as conn:
|
||||||
result = conn.execute(select(cls))
|
result = conn.execute(
|
||||||
return similar_artists_to_dataclass(result.fetchall())
|
select(cls.artisthash), execution_options={"stream_results": True}
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def exists(cls, artisthash: str):
|
def exists(cls, artisthash: str):
|
||||||
@@ -156,7 +158,8 @@ class SimilarArtistTable(Base):
|
|||||||
result = conn.execute(
|
result = conn.execute(
|
||||||
select(cls.artisthash).where(cls.artisthash == artisthash)
|
select(cls.artisthash).where(cls.artisthash == artisthash)
|
||||||
)
|
)
|
||||||
return result.fetchone() is not None
|
|
||||||
|
return len(result.scalars().all()) > 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_by_hash(cls, artisthash: str):
|
def get_by_hash(cls, artisthash: str):
|
||||||
@@ -474,7 +477,10 @@ class LibDataTable(Base):
|
|||||||
result = cls.execute(
|
result = cls.execute(
|
||||||
select(cls.itemhash, cls.color).where(cls.itemtype == type)
|
select(cls.itemhash, cls.color).where(cls.itemtype == type)
|
||||||
)
|
)
|
||||||
return [{"itemhash": r[0].replace(type, ''), "color": r[1]} for r in result.fetchall()]
|
return [
|
||||||
|
{"itemhash": r[0].replace(type, ""), "color": r[1]}
|
||||||
|
for r in result.fetchall()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class MixTable(Base):
|
class MixTable(Base):
|
||||||
|
|||||||
+13
-6
@@ -1,11 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
import requests
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
import requests
|
|
||||||
from PIL import Image, PngImagePlugin, UnidentifiedImageError
|
from PIL import Image, PngImagePlugin, UnidentifiedImageError
|
||||||
from requests.exceptions import ConnectionError as RequestConnectionError
|
from requests.exceptions import ConnectionError as RequestConnectionError
|
||||||
from requests.exceptions import ReadTimeout
|
from requests.exceptions import ReadTimeout
|
||||||
@@ -21,6 +23,7 @@ from app.store.artists import ArtistStore
|
|||||||
from app.utils.hashing import create_hash
|
from app.utils.hashing import create_hash
|
||||||
from app.utils.progressbar import tqdm
|
from app.utils.progressbar import tqdm
|
||||||
|
|
||||||
|
|
||||||
CHECK_ARTIST_IMAGES_KEY = ""
|
CHECK_ARTIST_IMAGES_KEY = ""
|
||||||
|
|
||||||
LARGE_ENOUGH_NUMBER = 100
|
LARGE_ENOUGH_NUMBER = 100
|
||||||
@@ -29,6 +32,7 @@ PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
def get_artist_image_link(artist: str):
|
def get_artist_image_link(artist: str):
|
||||||
"""
|
"""
|
||||||
Returns an artist image url.
|
Returns an artist image url.
|
||||||
@@ -43,7 +47,7 @@ def get_artist_image_link(artist: str):
|
|||||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Safari/605.1.15",
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Safari/605.1.15",
|
||||||
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36",
|
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36",
|
||||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:90.0) Gecko/20100101 Firefox/90.0",
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:90.0) Gecko/20100101 Firefox/90.0",
|
||||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1"
|
"Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1",
|
||||||
]
|
]
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": random.choice(user_agents),
|
"User-Agent": random.choice(user_agents),
|
||||||
@@ -138,7 +142,7 @@ class DownloadImage:
|
|||||||
img.save(path, format="webp")
|
img.save(path, format="webp")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
img.resize((size, int(size / ratio)), Image.ANTIALIAS).save(
|
img.resize((size, int(size / ratio)), Image.LANCZOS).save(
|
||||||
path, format="webp"
|
path, format="webp"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -150,7 +154,7 @@ class CheckArtistImages:
|
|||||||
|
|
||||||
# read all files in the artist image folder
|
# read all files in the artist image folder
|
||||||
path = settings.Paths.get_sm_artist_img_path()
|
path = settings.Paths.get_sm_artist_img_path()
|
||||||
processed = [path.replace(".webp", "") for path in os.listdir(path)]
|
processed = set(i.replace(".webp", "") for i in os.listdir(path))
|
||||||
|
|
||||||
unprocessed = [
|
unprocessed = [
|
||||||
a for a in ArtistStore.get_flat_list() if a.artisthash not in processed
|
a for a in ArtistStore.get_flat_list() if a.artisthash not in processed
|
||||||
@@ -158,7 +162,10 @@ class CheckArtistImages:
|
|||||||
|
|
||||||
key_artist_map = ((instance_key, artist) for artist in unprocessed)
|
key_artist_map = ((instance_key, artist) for artist in unprocessed)
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=14) as executor:
|
# Use number of CPU cores minus 1 to leave one core free for system processes
|
||||||
|
num_workers = max(1, multiprocessing.cpu_count() - 1)
|
||||||
|
|
||||||
|
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
||||||
res = list(
|
res = list(
|
||||||
tqdm(
|
tqdm(
|
||||||
executor.map(self.download_image, key_artist_map),
|
executor.map(self.download_image, key_artist_map),
|
||||||
|
|||||||
+7
-5
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
from requests import ConnectionError as RequestConnectionError
|
from requests import ConnectionError as RequestConnectionError
|
||||||
from requests import ReadTimeout
|
from requests import ReadTimeout
|
||||||
@@ -111,7 +111,7 @@ class ProcessTrackThumbnails:
|
|||||||
path = settings.Paths.get_sm_thumb_path()
|
path = settings.Paths.get_sm_thumb_path()
|
||||||
|
|
||||||
# read all the files in the thumbnail directory
|
# read all the files in the thumbnail directory
|
||||||
processed = "".join(os.listdir(path)).replace("webp", "")
|
processed = set(i.replace(".webp", "") for i in os.listdir(path))
|
||||||
|
|
||||||
# filter out albums that already have thumbnails
|
# filter out albums that already have thumbnails
|
||||||
albums = filter(
|
albums = filter(
|
||||||
@@ -119,10 +119,12 @@ class ProcessTrackThumbnails:
|
|||||||
)
|
)
|
||||||
albums = list(albums)
|
albums = list(albums)
|
||||||
|
|
||||||
|
print("length of albums", len(albums))
|
||||||
|
|
||||||
# process the rest
|
# process the rest
|
||||||
key_album_map = ((instance_key, album) for album in albums)
|
key_album_map = ((instance_key, album) for album in albums)
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=get_cpu_count()) as executor:
|
with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
|
||||||
results = list(
|
results = list(
|
||||||
tqdm(
|
tqdm(
|
||||||
executor.map(get_image, key_album_map),
|
executor.map(get_image, key_album_map),
|
||||||
@@ -168,7 +170,7 @@ class FetchSimilarArtistsLastFM:
|
|||||||
def __init__(self, instance_key: str) -> None:
|
def __init__(self, instance_key: str) -> None:
|
||||||
# read all artists from db
|
# read all artists from db
|
||||||
processed = SimilarArtistTable.get_all()
|
processed = SimilarArtistTable.get_all()
|
||||||
processed = ".".join(a.artisthash for a in processed)
|
processed = ".".join(a for a in processed)
|
||||||
|
|
||||||
# filter out artists that already have similar artists
|
# filter out artists that already have similar artists
|
||||||
artists = filter(
|
artists = filter(
|
||||||
@@ -179,7 +181,7 @@ class FetchSimilarArtistsLastFM:
|
|||||||
# process the rest
|
# process the rest
|
||||||
key_artist_map = ((instance_key, artist) for artist in artists)
|
key_artist_map = ((instance_key, artist) for artist in artists)
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=get_cpu_count()) as executor:
|
with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
|
||||||
try:
|
try:
|
||||||
print("Processing similar artists")
|
print("Processing similar artists")
|
||||||
results = list(
|
results = list(
|
||||||
|
|||||||
+41
-14
@@ -1,4 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from multiprocessing import Pool, cpu_count
|
||||||
|
|
||||||
from app import settings
|
from app import settings
|
||||||
from app.config import UserConfig
|
from app.config import UserConfig
|
||||||
from app.db.libdata import TrackTable
|
from app.db.libdata import TrackTable
|
||||||
@@ -17,7 +20,9 @@ from app.utils.progressbar import tqdm
|
|||||||
from app.logger import log
|
from app.logger import log
|
||||||
from app.utils.remove_duplicates import remove_duplicates
|
from app.utils.remove_duplicates import remove_duplicates
|
||||||
|
|
||||||
POPULATE_KEY: float = 0
|
|
||||||
|
class PopulateKey:
|
||||||
|
key: float = 0
|
||||||
|
|
||||||
|
|
||||||
class IndexTracks:
|
class IndexTracks:
|
||||||
@@ -28,9 +33,7 @@ class IndexTracks:
|
|||||||
An instance key is used to prevent multiple instances of the
|
An instance key is used to prevent multiple instances of the
|
||||||
same class from running at the same time.
|
same class from running at the same time.
|
||||||
"""
|
"""
|
||||||
global POPULATE_KEY
|
PopulateKey.key = instance_key
|
||||||
POPULATE_KEY = instance_key
|
|
||||||
|
|
||||||
dirs_to_scan = UserConfig().rootDirs
|
dirs_to_scan = UserConfig().rootDirs
|
||||||
|
|
||||||
if len(dirs_to_scan) == 0:
|
if len(dirs_to_scan) == 0:
|
||||||
@@ -120,22 +123,46 @@ class IndexTracks:
|
|||||||
def get_untagged(self):
|
def get_untagged(self):
|
||||||
tracks = TrackTable.get_all()
|
tracks = TrackTable.get_all()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_file(file: str, config: UserConfig, key: float) -> dict | None:
|
||||||
|
"""Worker function to process individual files"""
|
||||||
|
if PopulateKey.key != key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return get_tags(file, config=config)
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Failed to process file {file}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def tag_untagged(self, files: set[str], key: float):
|
def tag_untagged(self, files: set[str], key: float):
|
||||||
config = UserConfig()
|
config = UserConfig()
|
||||||
for file in tqdm(files, desc="Reading files"):
|
|
||||||
if POPULATE_KEY != key:
|
|
||||||
log.warning("'Populate.tag_untagged': Populate key changed")
|
|
||||||
return
|
|
||||||
|
|
||||||
tags = get_tags(file, config=config)
|
# Create process pool with worker function
|
||||||
|
with Pool(processes=cpu_count()) as pool:
|
||||||
|
worker = partial(self._process_file, config=config, key=key)
|
||||||
|
|
||||||
if tags is not None:
|
# Process files and track progress
|
||||||
TrackTable.insert_one(tags)
|
results = []
|
||||||
FolderStore.filepaths.add(tags["filepath"])
|
for result in tqdm(
|
||||||
|
pool.imap_unordered(worker, files),
|
||||||
|
total=len(files),
|
||||||
|
desc="Reading files",
|
||||||
|
):
|
||||||
|
if PopulateKey.key != key:
|
||||||
|
log.warning("'Populate.tag_untagged': Populate key changed")
|
||||||
|
pool.terminate()
|
||||||
|
return
|
||||||
|
|
||||||
del tags
|
if result is not None:
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
print(f"{len(files)} new files indexed")
|
# Bulk insert results
|
||||||
|
for tags in results:
|
||||||
|
TrackTable.insert_one(tags)
|
||||||
|
FolderStore.filepaths.add(tags["filepath"])
|
||||||
|
|
||||||
|
print(f"{len(results)} new files indexed")
|
||||||
print("Done")
|
print("Done")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+9
-6
@@ -23,12 +23,13 @@ def parse_album_art(filepath: str):
|
|||||||
"""
|
"""
|
||||||
Returns the album art for a given audio file.
|
Returns the album art for a given audio file.
|
||||||
"""
|
"""
|
||||||
|
tags = TinyTag.get(filepath, image=True)
|
||||||
|
image = tags.images.any
|
||||||
|
|
||||||
try:
|
if image:
|
||||||
tags = TinyTag.get(filepath, image=True)
|
return image.data
|
||||||
return tags.get_image()
|
|
||||||
except: # pylint: disable=bare-except
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def extract_thumb(filepath: str, webp_path: str, overwrite=False) -> bool:
|
def extract_thumb(filepath: str, webp_path: str, overwrite=False) -> bool:
|
||||||
@@ -53,7 +54,9 @@ def extract_thumb(filepath: str, webp_path: str, overwrite=False) -> bool:
|
|||||||
ratio = width / height
|
ratio = width / height
|
||||||
|
|
||||||
for path, size in images:
|
for path, size in images:
|
||||||
img.resize((size, int(size / ratio)), Image.ANTIALIAS).save(path, "webp")
|
img.resize((size, int(size / ratio)), Image.LANCZOS).save(path, "webp")
|
||||||
|
|
||||||
|
del img
|
||||||
|
|
||||||
if not overwrite and os.path.exists(sm_img_path):
|
if not overwrite and os.path.exists(sm_img_path):
|
||||||
img_size = os.path.getsize(sm_img_path)
|
img_size = os.path.getsize(sm_img_path)
|
||||||
|
|||||||
@@ -126,8 +126,8 @@ class MixesPlugin(Plugin):
|
|||||||
# if len(trackmatches) < self.TRACK_MIX_LENGTH:
|
# if len(trackmatches) < self.TRACK_MIX_LENGTH:
|
||||||
if True:
|
if True:
|
||||||
filler_tracks = self.fallback_create_artist_mix(
|
filler_tracks = self.fallback_create_artist_mix(
|
||||||
similar_artists=results["artists"],
|
similar_artists=results.get("artists", []),
|
||||||
similar_albums=results["albums"],
|
similar_albums=results.get("albums", []),
|
||||||
omit_trackhashes={t.weakhash for t in trackmatches},
|
omit_trackhashes={t.weakhash for t in trackmatches},
|
||||||
# limit=self.TRACK_MIX_LENGTH - len(trackmatches),
|
# limit=self.TRACK_MIX_LENGTH - len(trackmatches),
|
||||||
)
|
)
|
||||||
@@ -135,9 +135,9 @@ class MixesPlugin(Plugin):
|
|||||||
|
|
||||||
# try to balance the mix
|
# try to balance the mix
|
||||||
trackmatches = balance_mix(trackmatches)
|
trackmatches = balance_mix(trackmatches)
|
||||||
return trackmatches, results["albums"], results["artists"]
|
return trackmatches, results.get("albums", []), results.get("artists", [])
|
||||||
|
|
||||||
@plugin_method
|
# @plugin_method
|
||||||
# def get_artist_mix(self, artisthash: str):
|
# def get_artist_mix(self, artisthash: str):
|
||||||
# """
|
# """
|
||||||
# Given an artisthash, creates an artist mix using the
|
# Given an artisthash, creates an artist mix using the
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ def has_connection(host="google.it", port=80, timeout=3):
|
|||||||
OpenPort: 53/tcp
|
OpenPort: 53/tcp
|
||||||
Service: domain (DNS/TCP)
|
Service: domain (DNS/TCP)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
Socket.setdefaulttimeout(timeout)
|
Socket.setdefaulttimeout(timeout)
|
||||||
Socket.socket(Socket.AF_INET, Socket.SOCK_STREAM).connect((host, port))
|
Socket.socket(Socket.AF_INET, Socket.SOCK_STREAM).connect((host, port))
|
||||||
|
|||||||
@@ -173,7 +173,9 @@ def serve_client_files(path: str):
|
|||||||
|
|
||||||
# INFO: Safari doesn't support gzip encoding
|
# INFO: Safari doesn't support gzip encoding
|
||||||
# See issue: https://github.com/swingmx/swingmusic/issues/155
|
# See issue: https://github.com/swingmx/swingmusic/issues/155
|
||||||
is_safari = user_agent.find("Safari") >= 0 and user_agent.find("Chrome") < 0
|
is_safari = (
|
||||||
|
user_agent and user_agent.find("Safari") >= 0 and user_agent.find("Chrome") < 0
|
||||||
|
)
|
||||||
|
|
||||||
if is_safari:
|
if is_safari:
|
||||||
return app.send_static_file(path)
|
return app.send_static_file(path)
|
||||||
@@ -181,7 +183,7 @@ def serve_client_files(path: str):
|
|||||||
accepts_gzip = request.headers.get("Accept-Encoding", "").find("gzip") >= 0
|
accepts_gzip = request.headers.get("Accept-Encoding", "").find("gzip") >= 0
|
||||||
|
|
||||||
if accepts_gzip:
|
if accepts_gzip:
|
||||||
if os.path.exists(os.path.join(app.static_folder, gzipped_path)):
|
if os.path.exists(os.path.join(app.static_folder or "", gzipped_path)):
|
||||||
response = app.make_response(app.send_static_file(gzipped_path))
|
response = app.make_response(app.send_static_file(gzipped_path))
|
||||||
response.headers["Content-Encoding"] = "gzip"
|
response.headers["Content-Encoding"] = "gzip"
|
||||||
return response
|
return response
|
||||||
|
|||||||
Reference in New Issue
Block a user