mirror of
https://github.com/Dvorinka/swingmusic-extended.git
synced 2026-06-05 04:53:01 +00:00
fix multiprocessing
+ remove global locks
This commit is contained in:
+6
-16
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import urllib
|
||||
@@ -24,8 +25,6 @@ from app.utils.hashing import create_hash
|
||||
from app.utils.progressbar import tqdm
|
||||
|
||||
|
||||
CHECK_ARTIST_IMAGES_KEY = ""
|
||||
|
||||
LARGE_ENOUGH_NUMBER = 100
|
||||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
||||
# https://stackoverflow.com/a/61466412
|
||||
@@ -148,10 +147,7 @@ class DownloadImage:
|
||||
|
||||
|
||||
class CheckArtistImages:
|
||||
def __init__(self, instance_key: str):
|
||||
global CHECK_ARTIST_IMAGES_KEY
|
||||
CHECK_ARTIST_IMAGES_KEY = instance_key
|
||||
|
||||
def __init__(self):
|
||||
# read all files in the artist image folder
|
||||
path = settings.Paths.get_sm_artist_img_path()
|
||||
processed = set(i.replace(".webp", "") for i in os.listdir(path))
|
||||
@@ -160,15 +156,14 @@ class CheckArtistImages:
|
||||
a for a in ArtistStore.get_flat_list() if a.artisthash not in processed
|
||||
]
|
||||
|
||||
key_artist_map = ((instance_key, artist) for artist in unprocessed)
|
||||
|
||||
# Use number of CPU cores minus 1 to leave one core free for system processes
|
||||
num_workers = max(1, multiprocessing.cpu_count() - 1)
|
||||
num_workers = max(1, math.floor(multiprocessing.cpu_count() / 2))
|
||||
print("num_workers", num_workers)
|
||||
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
||||
res = list(
|
||||
tqdm(
|
||||
executor.map(self.download_image, key_artist_map),
|
||||
executor.map(self.download_image, unprocessed),
|
||||
total=len(unprocessed),
|
||||
desc="Downloading missing artist images",
|
||||
)
|
||||
@@ -177,17 +172,12 @@ class CheckArtistImages:
|
||||
list(res)
|
||||
|
||||
@staticmethod
|
||||
def download_image(_map: tuple[str, Artist]):
|
||||
def download_image(artist: Artist):
|
||||
"""
|
||||
Checks if an artist image exists and downloads it if not.
|
||||
|
||||
:param artist: The artist name
|
||||
"""
|
||||
instance_key, artist = _map
|
||||
|
||||
if CHECK_ARTIST_IMAGES_KEY != instance_key:
|
||||
return
|
||||
|
||||
img_path = (
|
||||
Path(settings.Paths.get_sm_artist_img_path()) / f"{artist.artisthash}.webp"
|
||||
)
|
||||
|
||||
+2
-22
@@ -10,15 +10,10 @@ from app import settings
|
||||
|
||||
from app.db.userdata import LibDataTable
|
||||
from app.logger import log
|
||||
from app.lib.errors import PopulateCancelledError
|
||||
from app.store.albums import AlbumStore
|
||||
from app.store.artists import ArtistStore
|
||||
from app.utils.progressbar import tqdm
|
||||
|
||||
PROCESS_ALBUM_COLORS_KEY = ""
|
||||
PROCESS_ARTIST_COLORS_KEY = ""
|
||||
|
||||
|
||||
def get_image_colors(image: str, count=1) -> list[str]:
|
||||
"""Extracts n number of the most dominant colors from an image."""
|
||||
try:
|
||||
@@ -54,18 +49,11 @@ class ProcessAlbumColors:
|
||||
Extracts the most dominant color from the album art and saves it to the database.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_key: str) -> None:
|
||||
global PROCESS_ALBUM_COLORS_KEY
|
||||
PROCESS_ALBUM_COLORS_KEY = instance_key
|
||||
|
||||
def __init__(self) -> None:
|
||||
albums = [a for a in AlbumStore.get_flat_list() if not a.color]
|
||||
|
||||
for album in tqdm(albums, desc="Processing missing album colors"):
|
||||
albumhash = album.albumhash
|
||||
if PROCESS_ALBUM_COLORS_KEY != instance_key:
|
||||
raise PopulateCancelledError(
|
||||
"A newer 'ProcessAlbumColors' instance is running. Stopping this one."
|
||||
)
|
||||
|
||||
albumrecord = LibDataTable.find_one(albumhash, type="album")
|
||||
if albumrecord is not None and albumrecord.color is not None:
|
||||
@@ -99,21 +87,13 @@ class ProcessArtistColors:
|
||||
Extracts the most dominant color from the artist art and saves it to the database.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_key: str) -> None:
|
||||
def __init__(self) -> None:
|
||||
all_artists = [a for a in ArtistStore.get_flat_list() if not a.color]
|
||||
|
||||
global PROCESS_ARTIST_COLORS_KEY
|
||||
PROCESS_ARTIST_COLORS_KEY = instance_key
|
||||
|
||||
for artist in tqdm(all_artists, desc="Processing missing artist colors"):
|
||||
artisthash = artist.artisthash
|
||||
if PROCESS_ARTIST_COLORS_KEY != instance_key:
|
||||
raise PopulateCancelledError(
|
||||
"A newer 'ProcessArtistColors' instance is running. Stopping this one."
|
||||
)
|
||||
|
||||
record = LibDataTable.find_one(artisthash, "artist")
|
||||
print(record)
|
||||
if (record is not None) and (record.color is not None):
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
class PopulateCancelledError(Exception):
|
||||
"""
|
||||
Raised when the instance key of a looping function called
|
||||
inside Populate is changed.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@ from app.utils.threading import background
|
||||
|
||||
class IndexEverything:
|
||||
def __init__(self) -> None:
|
||||
IndexTracks(instance_key=time())
|
||||
IndexTracks()
|
||||
|
||||
key = str(time())
|
||||
TrackStore.load_all_tracks(key)
|
||||
|
||||
+47
-71
@@ -1,6 +1,8 @@
|
||||
from dataclasses import asdict
|
||||
import math
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import platform
|
||||
|
||||
from requests import ConnectionError as RequestConnectionError
|
||||
from requests import ReadTimeout
|
||||
@@ -8,7 +10,6 @@ from requests import ReadTimeout
|
||||
from app import settings
|
||||
from app.lib.artistlib import CheckArtistImages
|
||||
from app.lib.colorlib import ProcessAlbumColors, ProcessArtistColors
|
||||
from app.lib.errors import PopulateCancelledError
|
||||
from app.lib.taglib import extract_thumb
|
||||
from app.logger import log
|
||||
from app.models import Album, Artist
|
||||
@@ -21,33 +22,22 @@ from app.utils.progressbar import tqdm
|
||||
|
||||
from app.db.userdata import SimilarArtistTable
|
||||
|
||||
|
||||
POPULATE_KEY = ""
|
||||
|
||||
|
||||
class CordinateMedia:
|
||||
"""
|
||||
Cordinates the extracting of thumbnails
|
||||
"""
|
||||
|
||||
def __init__(self, instance_key: str):
|
||||
global POPULATE_KEY
|
||||
POPULATE_KEY = instance_key
|
||||
|
||||
try:
|
||||
ProcessTrackThumbnails(instance_key)
|
||||
ProcessAlbumColors(instance_key)
|
||||
ProcessArtistColors(instance_key)
|
||||
except PopulateCancelledError as e:
|
||||
log.warn(e)
|
||||
return
|
||||
ProcessTrackThumbnails()
|
||||
ProcessAlbumColors()
|
||||
ProcessArtistColors()
|
||||
|
||||
tried_to_download_new_images = False
|
||||
|
||||
if has_connection():
|
||||
tried_to_download_new_images = True
|
||||
try:
|
||||
CheckArtistImages(instance_key)
|
||||
CheckArtistImages()
|
||||
except (RequestConnectionError, ReadTimeout) as e:
|
||||
log.error(
|
||||
"Internet connection lost. Downloading artist images suspended."
|
||||
@@ -58,18 +48,14 @@ class CordinateMedia:
|
||||
|
||||
# Re-process the new artist images.
|
||||
if tried_to_download_new_images:
|
||||
ProcessArtistColors(instance_key=instance_key)
|
||||
ProcessArtistColors()
|
||||
|
||||
if has_connection():
|
||||
try:
|
||||
print("Attempting to download similar artists...")
|
||||
FetchSimilarArtistsLastFM(instance_key)
|
||||
except PopulateCancelledError as e:
|
||||
log.warn(e)
|
||||
return
|
||||
print("Attempting to download similar artists...")
|
||||
FetchSimilarArtistsLastFM()
|
||||
|
||||
|
||||
def get_image(_map: tuple[str, Album]):
|
||||
def get_image(album: Album):
|
||||
"""
|
||||
The function retrieves an image from an album by iterating through its tracks and extracting the thumbnail from the first track that has one.
|
||||
|
||||
@@ -77,17 +63,13 @@ def get_image(_map: tuple[str, Album]):
|
||||
:type album: Album
|
||||
:return: None
|
||||
"""
|
||||
|
||||
instance_key, album = _map
|
||||
|
||||
if POPULATE_KEY != instance_key:
|
||||
raise PopulateCancelledError("'ProcessTrackThumbnails': Populate key changed")
|
||||
|
||||
matching_tracks = AlbumStore.get_album_tracks(album.albumhash)
|
||||
|
||||
|
||||
for track in matching_tracks:
|
||||
if extract_thumb(track.filepath, track.albumhash + ".webp"):
|
||||
break
|
||||
extracted = extract_thumb(track.filepath, track.albumhash + ".webp")
|
||||
|
||||
if extracted:
|
||||
return
|
||||
|
||||
|
||||
def get_cpu_count():
|
||||
@@ -103,7 +85,31 @@ class ProcessTrackThumbnails:
|
||||
Extracts the album art from all albums in album store.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_key: str) -> None:
|
||||
def extract(self, albums: list[Album]):
|
||||
"""
|
||||
Extracts the album art with platform specific logic.
|
||||
"""
|
||||
|
||||
if platform.system() == "Linux":
|
||||
# INFO: Processess are forked with access to global stores
|
||||
# It's "safe" to use a process pool
|
||||
cpus = math.floor(get_cpu_count() / 2)
|
||||
with ProcessPoolExecutor(max_workers=cpus) as executor:
|
||||
results = list(
|
||||
tqdm(
|
||||
executor.map(get_image, albums),
|
||||
total=len(albums),
|
||||
desc="Extracting track images",
|
||||
)
|
||||
)
|
||||
|
||||
list(results)
|
||||
else:
|
||||
# INFO: Use a for loop for windows (and others I guess)
|
||||
for album in tqdm(albums, desc="Extracting track images"):
|
||||
get_image(album)
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Filters out albums that already have thumbnails and
|
||||
extracts the thumbnail for the other albums.
|
||||
@@ -112,43 +118,20 @@ class ProcessTrackThumbnails:
|
||||
|
||||
# read all the files in the thumbnail directory
|
||||
processed = set(i.replace(".webp", "") for i in os.listdir(path))
|
||||
|
||||
# filter out albums that already have thumbnails
|
||||
albums = filter(
|
||||
lambda album: album.albumhash not in processed, AlbumStore.get_flat_list()
|
||||
lambda album: album.albumhash not in processed,
|
||||
AlbumStore.get_flat_list(),
|
||||
)
|
||||
|
||||
albums = list(albums)
|
||||
|
||||
print("length of albums", len(albums))
|
||||
|
||||
# process the rest
|
||||
key_album_map = ((instance_key, album) for album in albums)
|
||||
|
||||
with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
|
||||
results = list(
|
||||
tqdm(
|
||||
executor.map(get_image, key_album_map),
|
||||
total=len(albums),
|
||||
desc="Extracting track images",
|
||||
)
|
||||
)
|
||||
|
||||
list(results)
|
||||
self.extract(albums)
|
||||
|
||||
|
||||
def save_similar_artists(_map: tuple[str, Artist]):
|
||||
def save_similar_artists(artist: Artist):
|
||||
"""
|
||||
Downloads and saves similar artists to the database.
|
||||
"""
|
||||
|
||||
instance_key, artist = _map
|
||||
|
||||
if POPULATE_KEY != instance_key:
|
||||
print("Warning: Populate key changed")
|
||||
raise PopulateCancelledError(
|
||||
"'FetchSimilarArtistsLastFM': Populate key changed"
|
||||
)
|
||||
|
||||
if SimilarArtistTable.exists(artist.artisthash):
|
||||
return
|
||||
|
||||
@@ -167,7 +150,7 @@ class FetchSimilarArtistsLastFM:
|
||||
Fetches similar artists from LastFM using a thread pool.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_key: str) -> None:
|
||||
def __init__(self) -> None:
|
||||
# read all artists from db
|
||||
processed = set(a.artisthash for a in SimilarArtistTable.get_all())
|
||||
|
||||
@@ -177,25 +160,18 @@ class FetchSimilarArtistsLastFM:
|
||||
)
|
||||
artists = list(artists)
|
||||
|
||||
# process the rest
|
||||
key_artist_map = ((instance_key, artist) for artist in artists)
|
||||
|
||||
with ProcessPoolExecutor(max_workers=get_cpu_count()) as executor:
|
||||
try:
|
||||
print("Processing similar artists")
|
||||
results = list(
|
||||
tqdm(
|
||||
executor.map(save_similar_artists, key_artist_map),
|
||||
executor.map(save_similar_artists, artists),
|
||||
total=len(artists),
|
||||
desc="Fetching similar artists",
|
||||
)
|
||||
)
|
||||
|
||||
list(results)
|
||||
|
||||
except PopulateCancelledError as e:
|
||||
raise e
|
||||
|
||||
# any exception that can be raised by the pool
|
||||
except Exception as e:
|
||||
log.warn(e)
|
||||
|
||||
+5
-18
@@ -22,19 +22,14 @@ from app.logger import log
|
||||
from app.utils.remove_duplicates import remove_duplicates
|
||||
|
||||
|
||||
class PopulateKey:
|
||||
key: float = 0
|
||||
|
||||
|
||||
class IndexTracks:
|
||||
def __init__(self, instance_key: float) -> None:
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Indexes all tracks in the database.
|
||||
|
||||
An instance key is used to prevent multiple instances of the
|
||||
same class from running at the same time.
|
||||
"""
|
||||
PopulateKey.key = instance_key
|
||||
dirs_to_scan = UserConfig().rootDirs
|
||||
|
||||
if len(dirs_to_scan) == 0:
|
||||
@@ -60,7 +55,7 @@ class IndexTracks:
|
||||
unmodified, modified_tracks = self.filter_modded()
|
||||
untagged = files - unmodified
|
||||
|
||||
self.tag_untagged(untagged, instance_key)
|
||||
self.tag_untagged(untagged)
|
||||
self.extract_thumb_with_overwrite(modified_tracks)
|
||||
|
||||
@staticmethod
|
||||
@@ -125,23 +120,20 @@ class IndexTracks:
|
||||
tracks = TrackTable.get_all()
|
||||
|
||||
@staticmethod
|
||||
def _process_file(file: str, config: UserConfig, key: float) -> dict | None:
|
||||
def _process_file(file: str, config: UserConfig) -> dict | None:
|
||||
"""Worker function to process individual files"""
|
||||
if PopulateKey.key != key:
|
||||
return None
|
||||
|
||||
try:
|
||||
return get_tags(file, config=config)
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to process file {file}: {e}")
|
||||
return None
|
||||
|
||||
def tag_untagged(self, files: set[str], key: float):
|
||||
def tag_untagged(self, files: set[str]):
|
||||
config = UserConfig()
|
||||
|
||||
# Create process pool with worker function
|
||||
with Pool(processes=math.floor(cpu_count() / 2)) as pool:
|
||||
worker = partial(self._process_file, config=config, key=key)
|
||||
worker = partial(self._process_file, config=config)
|
||||
|
||||
# Process files and track progress
|
||||
results = []
|
||||
@@ -150,11 +142,6 @@ class IndexTracks:
|
||||
total=len(files),
|
||||
desc="Reading files",
|
||||
):
|
||||
if PopulateKey.key != key:
|
||||
log.warning("'Populate.tag_untagged': Populate key changed")
|
||||
pool.terminate()
|
||||
return
|
||||
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user