fix multiprocessing

+ remove global locks
This commit is contained in:
cwilvx
2025-03-10 21:05:28 +03:00
parent 9725fd427b
commit 0eef23880b
12 changed files with 140 additions and 689 deletions
+6 -16
View File
@@ -1,3 +1,4 @@
import math
import os
import time
import urllib
@@ -24,8 +25,6 @@ from app.utils.hashing import create_hash
from app.utils.progressbar import tqdm
CHECK_ARTIST_IMAGES_KEY = ""
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
# https://stackoverflow.com/a/61466412
@@ -148,10 +147,7 @@ class DownloadImage:
class CheckArtistImages:
def __init__(self, instance_key: str):
global CHECK_ARTIST_IMAGES_KEY
CHECK_ARTIST_IMAGES_KEY = instance_key
def __init__(self):
# read all files in the artist image folder
path = settings.Paths.get_sm_artist_img_path()
processed = set(i.replace(".webp", "") for i in os.listdir(path))
@@ -160,15 +156,14 @@ class CheckArtistImages:
a for a in ArtistStore.get_flat_list() if a.artisthash not in processed
]
key_artist_map = ((instance_key, artist) for artist in unprocessed)
# Use number of CPU cores minus 1 to leave one core free for system processes
num_workers = max(1, multiprocessing.cpu_count() - 1)
num_workers = max(1, math.floor(multiprocessing.cpu_count() / 2))
print("num_workers", num_workers)
with ProcessPoolExecutor(max_workers=num_workers) as executor:
res = list(
tqdm(
executor.map(self.download_image, key_artist_map),
executor.map(self.download_image, unprocessed),
total=len(unprocessed),
desc="Downloading missing artist images",
)
@@ -177,17 +172,12 @@ class CheckArtistImages:
list(res)
@staticmethod
def download_image(_map: tuple[str, Artist]):
def download_image(artist: Artist):
"""
Checks if an artist image exists and downloads it if not.
:param artist: The artist name
"""
instance_key, artist = _map
if CHECK_ARTIST_IMAGES_KEY != instance_key:
return
img_path = (
Path(settings.Paths.get_sm_artist_img_path()) / f"{artist.artisthash}.webp"
)
+2 -22
View File
@@ -10,15 +10,10 @@ from app import settings
from app.db.userdata import LibDataTable
from app.logger import log
from app.lib.errors import PopulateCancelledError
from app.store.albums import AlbumStore
from app.store.artists import ArtistStore
from app.utils.progressbar import tqdm
PROCESS_ALBUM_COLORS_KEY = ""
PROCESS_ARTIST_COLORS_KEY = ""
def get_image_colors(image: str, count=1) -> list[str]:
"""Extracts n number of the most dominant colors from an image."""
try:
@@ -54,18 +49,11 @@ class ProcessAlbumColors:
Extracts the most dominant color from the album art and saves it to the database.
"""
def __init__(self, instance_key: str) -> None:
global PROCESS_ALBUM_COLORS_KEY
PROCESS_ALBUM_COLORS_KEY = instance_key
def __init__(self) -> None:
albums = [a for a in AlbumStore.get_flat_list() if not a.color]
for album in tqdm(albums, desc="Processing missing album colors"):
albumhash = album.albumhash
if PROCESS_ALBUM_COLORS_KEY != instance_key:
raise PopulateCancelledError(
"A newer 'ProcessAlbumColors' instance is running. Stopping this one."
)
albumrecord = LibDataTable.find_one(albumhash, type="album")
if albumrecord is not None and albumrecord.color is not None:
@@ -99,21 +87,13 @@ class ProcessArtistColors:
Extracts the most dominant color from the artist art and saves it to the database.
"""
def __init__(self, instance_key: str) -> None:
def __init__(self) -> None:
all_artists = [a for a in ArtistStore.get_flat_list() if not a.color]
global PROCESS_ARTIST_COLORS_KEY
PROCESS_ARTIST_COLORS_KEY = instance_key
for artist in tqdm(all_artists, desc="Processing missing artist colors"):
artisthash = artist.artisthash
if PROCESS_ARTIST_COLORS_KEY != instance_key:
raise PopulateCancelledError(
"A newer 'ProcessArtistColors' instance is running. Stopping this one."
)
record = LibDataTable.find_one(artisthash, "artist")
print(record)
if (record is not None) and (record.color is not None):
continue
-7
View File
@@ -1,7 +0,0 @@
class PopulateCancelledError(Exception):
"""
Raised when the instance key of a looping function called
inside Populate is changed.
"""
pass
+1 -1
View File
@@ -18,7 +18,7 @@ from app.utils.threading import background
class IndexEverything:
def __init__(self) -> None:
IndexTracks(instance_key=time())
IndexTracks()
key = str(time())
TrackStore.load_all_tracks(key)
+47 -71
View File
@@ -1,6 +1,8 @@
from dataclasses import asdict
import math
import os
from concurrent.futures import ProcessPoolExecutor
import platform
from requests import ConnectionError as RequestConnectionError
from requests import ReadTimeout
@@ -8,7 +10,6 @@ from requests import ReadTimeout
from app import settings
from app.lib.artistlib import CheckArtistImages
from app.lib.colorlib import ProcessAlbumColors, ProcessArtistColors
from app.lib.errors import PopulateCancelledError
from app.lib.taglib import extract_thumb
from app.logger import log
from app.models import Album, Artist
@@ -21,33 +22,22 @@ from app.utils.progressbar import tqdm
from app.db.userdata import SimilarArtistTable
POPULATE_KEY = ""
class CordinateMedia:
"""
Cordinates the extracting of thumbnails
"""
def __init__(self, instance_key: str):
global POPULATE_KEY
POPULATE_KEY = instance_key
try:
ProcessTrackThumbnails(instance_key)
ProcessAlbumColors(instance_key)
ProcessArtistColors(instance_key)
except PopulateCancelledError as e:
log.warn(e)
return
ProcessTrackThumbnails()
ProcessAlbumColors()
ProcessArtistColors()
tried_to_download_new_images = False
if has_connection():
tried_to_download_new_images = True
try:
CheckArtistImages(instance_key)
CheckArtistImages()
except (RequestConnectionError, ReadTimeout) as e:
log.error(
"Internet connection lost. Downloading artist images suspended."
@@ -58,18 +48,14 @@ class CordinateMedia:
# Re-process the new artist images.
if tried_to_download_new_images:
ProcessArtistColors(instance_key=instance_key)
ProcessArtistColors()
if has_connection():
try:
print("Attempting to download similar artists...")
FetchSimilarArtistsLastFM(instance_key)
except PopulateCancelledError as e:
log.warn(e)
return
print("Attempting to download similar artists...")
FetchSimilarArtistsLastFM()
def get_image(_map: tuple[str, Album]):
def get_image(album: Album):
"""
The function retrieves an image from an album by iterating through its tracks and extracting the thumbnail from the first track that has one.
@@ -77,17 +63,13 @@ def get_image(_map: tuple[str, Album]):
:type album: Album
:return: None
"""
instance_key, album = _map
if POPULATE_KEY != instance_key:
raise PopulateCancelledError("'ProcessTrackThumbnails': Populate key changed")
matching_tracks = AlbumStore.get_album_tracks(album.albumhash)
for track in matching_tracks:
if extract_thumb(track.filepath, track.albumhash + ".webp"):
break
extracted = extract_thumb(track.filepath, track.albumhash + ".webp")
if extracted:
return
def get_cpu_count():
@@ -103,7 +85,31 @@ class ProcessTrackThumbnails:
Extracts the album art from all albums in album store.
"""
def __init__(self, instance_key: str) -> None:
def extract(self, albums: list[Album]):
"""
Extracts the album art with platform specific logic.
"""
if platform.system() == "Linux":
# INFO: Processess are forked with access to global stores
# It's "safe" to use a process pool
cpus = math.floor(get_cpu_count() / 2)
with ProcessPoolExecutor(max_workers=cpus) as executor:
results = list(
tqdm(
executor.map(get_image, albums),
total=len(albums),
desc="Extracting track images",
)
)
list(results)
else:
# INFO: Use a for loop for windows (and others I guess)
for album in tqdm(albums, desc="Extracting track images"):
get_image(album)
def __init__(self) -> None:
"""
Filters out albums that already have thumbnails and
extracts the thumbnail for the other albums.
@@ -112,43 +118,20 @@ class ProcessTrackThumbnails:
# read all the files in the thumbnail directory
processed = set(i.replace(".webp", "") for i in os.listdir(path))
# filter out albums that already have thumbnails
albums = filter(
lambda album: album.albumhash not in processed, AlbumStore.get_flat_list()
lambda album: album.albumhash not in processed,
AlbumStore.get_flat_list(),
)
albums = list(albums)
print("length of albums", len(albums))
# process the rest
key_album_map = ((instance_key, album) for album in albums)
with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
results = list(
tqdm(
executor.map(get_image, key_album_map),
total=len(albums),
desc="Extracting track images",
)
)
list(results)
self.extract(albums)
def save_similar_artists(_map: tuple[str, Artist]):
def save_similar_artists(artist: Artist):
"""
Downloads and saves similar artists to the database.
"""
instance_key, artist = _map
if POPULATE_KEY != instance_key:
print("Warning: Populate key changed")
raise PopulateCancelledError(
"'FetchSimilarArtistsLastFM': Populate key changed"
)
if SimilarArtistTable.exists(artist.artisthash):
return
@@ -167,7 +150,7 @@ class FetchSimilarArtistsLastFM:
Fetches similar artists from LastFM using a thread pool.
"""
def __init__(self, instance_key: str) -> None:
def __init__(self) -> None:
# read all artists from db
processed = set(a.artisthash for a in SimilarArtistTable.get_all())
@@ -177,25 +160,18 @@ class FetchSimilarArtistsLastFM:
)
artists = list(artists)
# process the rest
key_artist_map = ((instance_key, artist) for artist in artists)
with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
try:
print("Processing similar artists")
results = list(
tqdm(
executor.map(save_similar_artists, key_artist_map),
executor.map(save_similar_artists, artists),
total=len(artists),
desc="Fetching similar artists",
)
)
list(results)
except PopulateCancelledError as e:
raise e
# any exception that can be raised by the pool
except Exception as e:
log.warn(e)
+5 -18
View File
@@ -22,19 +22,14 @@ from app.logger import log
from app.utils.remove_duplicates import remove_duplicates
class PopulateKey:
key: float = 0
class IndexTracks:
def __init__(self, instance_key: float) -> None:
def __init__(self) -> None:
"""
Indexes all tracks in the database.
An instance key is used to prevent multiple instances of the
same class from running at the same time.
"""
PopulateKey.key = instance_key
dirs_to_scan = UserConfig().rootDirs
if len(dirs_to_scan) == 0:
@@ -60,7 +55,7 @@ class IndexTracks:
unmodified, modified_tracks = self.filter_modded()
untagged = files - unmodified
self.tag_untagged(untagged, instance_key)
self.tag_untagged(untagged)
self.extract_thumb_with_overwrite(modified_tracks)
@staticmethod
@@ -125,23 +120,20 @@ class IndexTracks:
tracks = TrackTable.get_all()
@staticmethod
def _process_file(file: str, config: UserConfig, key: float) -> dict | None:
def _process_file(file: str, config: UserConfig) -> dict | None:
"""Worker function to process individual files"""
if PopulateKey.key != key:
return None
try:
return get_tags(file, config=config)
except Exception as e:
log.warning(f"Failed to process file {file}: {e}")
return None
def tag_untagged(self, files: set[str], key: float):
def tag_untagged(self, files: set[str]):
config = UserConfig()
# Create process pool with worker function
with Pool(processes=math.floor(cpu_count() / 2)) as pool:
worker = partial(self._process_file, config=config, key=key)
worker = partial(self._process_file, config=config)
# Process files and track progress
results = []
@@ -150,11 +142,6 @@ class IndexTracks:
total=len(files),
desc="Reading files",
):
if PopulateKey.key != key:
log.warning("'Populate.tag_untagged': Populate key changed")
pool.terminate()
return
if result is not None:
results.append(result)