mirror of
https://github.com/Dvorinka/swingmusic-extended.git
synced 2026-06-03 20:13:02 +00:00
rewrite color extraction with multiprocessing 💀
This commit is contained in:
+208
-56
@@ -2,20 +2,24 @@
|
||||
Contains everything that deals with image color extraction.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
import colorgram
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator
|
||||
from swingmusic.utils.progressbar import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
from swingmusic import settings
|
||||
|
||||
from swingmusic.db.userdata import LibDataTable
|
||||
from swingmusic.logger import log
|
||||
from swingmusic.store.albums import AlbumStore
|
||||
from swingmusic.db.userdata import LibDataTable
|
||||
from swingmusic.store.artists import ArtistStore
|
||||
from swingmusic.utils.progressbar import tqdm
|
||||
|
||||
|
||||
def get_image_colors(image: str, count=1) -> list[str]:
|
||||
"""Extracts n number of the most dominant colors from an image."""
|
||||
"""
|
||||
Extracts n number of the most dominant colors from an image.
|
||||
"""
|
||||
try:
|
||||
colors = sorted(colorgram.extract(image, count), key=lambda c: c.hsl.h)
|
||||
except OSError:
|
||||
@@ -44,76 +48,224 @@ def process_color(item_hash: str, is_album=True):
|
||||
return get_image_colors(str(path))
|
||||
|
||||
|
||||
class ProcessAlbumColors:
|
||||
def extract_color_worker(item_data: dict) -> dict:
|
||||
"""
|
||||
Extracts the most dominant color from the album art and saves it to the database.
|
||||
Generic worker function for extracting colors in parallel.
|
||||
Returns data to main process for batch database operations.
|
||||
Works for both albums and artists based on item_data configuration.
|
||||
"""
|
||||
hash_field: str = item_data["hash_field"]
|
||||
path_func: Callable = item_data["path_func"]
|
||||
item_hash: str = item_data[hash_field]
|
||||
|
||||
path = Path(path_func()) / (item_hash + ".webp")
|
||||
|
||||
if not path.exists():
|
||||
return {hash_field: item_hash, "color": None, "error": "Image not found"}
|
||||
|
||||
colors = get_image_colors(str(path))
|
||||
|
||||
if not colors:
|
||||
return {
|
||||
hash_field: item_hash,
|
||||
"color": None,
|
||||
"error": "Color extraction failed",
|
||||
}
|
||||
|
||||
return {hash_field: item_hash, "color": colors[0], "error": None}
|
||||
|
||||
|
||||
class ColorProcessor:
|
||||
"""
|
||||
Generic color processor for extracting dominant colors from images.
|
||||
Uses multiprocessing for parallel color extraction and batch database operations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
albums = [a for a in AlbumStore.get_flat_list() if not a.color]
|
||||
def __init__(
|
||||
self,
|
||||
item_type: str,
|
||||
store: AlbumStore | ArtistStore,
|
||||
path_func: Callable,
|
||||
hash_field: str,
|
||||
):
|
||||
"""
|
||||
Initialize the color processor.
|
||||
|
||||
for album in tqdm(albums, desc="Processing missing album colors"):
|
||||
albumhash = album.albumhash
|
||||
Args:
|
||||
item_type: Type of item ("album" or "artist")
|
||||
store: Store object (AlbumStore or ArtistStore)
|
||||
path_func: Function to get the image path
|
||||
hash_field: Name of the hash field ("albumhash" or "artisthash")
|
||||
"""
|
||||
self.item_type = item_type
|
||||
self.store = store
|
||||
self.path_func = path_func
|
||||
self.hash_field = hash_field
|
||||
|
||||
albumrecord = LibDataTable.find_one(albumhash, type="album")
|
||||
if albumrecord is not None and albumrecord.color is not None:
|
||||
# Read existing colors from database to filter out already processed items
|
||||
existing_colors = set()
|
||||
for color_data in LibDataTable.get_all_colors(item_type):
|
||||
if color_data["color"]:
|
||||
existing_colors.add(color_data["itemhash"])
|
||||
|
||||
# Filter items that need color processing
|
||||
items_needing_colors = self._get_items_needing_colors(existing_colors)
|
||||
|
||||
if not items_needing_colors:
|
||||
return
|
||||
|
||||
self._process_colors_parallel(items_needing_colors)
|
||||
|
||||
def _get_items_needing_colors(
|
||||
self, existing_colors: set
|
||||
) -> Generator[dict, None, None]:
|
||||
"""
|
||||
Generator that yields items needing color processing.
|
||||
"""
|
||||
for item in self.store.get_flat_list():
|
||||
# Skip if item already has color in memory store
|
||||
if item.color:
|
||||
continue
|
||||
|
||||
colors = process_color(albumhash)
|
||||
|
||||
if colors is None:
|
||||
# Skip if item already has color in database
|
||||
item_hash = getattr(item, self.hash_field)
|
||||
if item_hash in existing_colors:
|
||||
continue
|
||||
|
||||
album = AlbumStore.albummap.get(albumhash)
|
||||
yield {
|
||||
self.hash_field: item_hash,
|
||||
"item_type": self.item_type,
|
||||
"path_func": self.path_func,
|
||||
"hash_field": self.hash_field,
|
||||
}
|
||||
|
||||
if album:
|
||||
album.set_color(colors[0])
|
||||
def _process_colors_parallel(self, items: Generator[dict, None, None]) -> None:
|
||||
"""
|
||||
Process colors using multiprocessing and batch database operations.
|
||||
"""
|
||||
items_list = list(items)
|
||||
|
||||
# INFO: Write to the database.
|
||||
if albumrecord is None:
|
||||
LibDataTable.insert_one(
|
||||
if not items_list:
|
||||
return
|
||||
|
||||
cpus = max(1, (os.cpu_count() or 1) // 2)
|
||||
batch_size = 20 # Process results in batches
|
||||
|
||||
with ProcessPoolExecutor(max_workers=cpus) as executor:
|
||||
# Submit all jobs
|
||||
future_to_item = {
|
||||
executor.submit(extract_color_worker, item): item for item in items_list
|
||||
}
|
||||
|
||||
batch = []
|
||||
processed_count = 0
|
||||
|
||||
# Process results as they complete
|
||||
progress_bar = tqdm(
|
||||
as_completed(future_to_item),
|
||||
total=len(items_list),
|
||||
desc=f"Processing {self.item_type} colors",
|
||||
)
|
||||
|
||||
for future in progress_bar:
|
||||
try:
|
||||
result = future.result()
|
||||
|
||||
if result["color"] is not None:
|
||||
batch.append(result)
|
||||
|
||||
# Process batch when it reaches batch_size or we're done
|
||||
if len(batch) >= batch_size or processed_count + 1 >= len(
|
||||
items_list
|
||||
):
|
||||
if batch:
|
||||
self._process_batch(batch)
|
||||
batch = []
|
||||
|
||||
processed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
item_data = future_to_item[future]
|
||||
item_hash = item_data[self.hash_field]
|
||||
log.error(f"Error processing {self.item_type} {item_hash}: {e}")
|
||||
|
||||
def _process_batch(self, batch: list[dict]) -> None:
|
||||
"""
|
||||
Process a batch of color results - update database and memory stores.
|
||||
"""
|
||||
if not batch:
|
||||
return
|
||||
|
||||
# Prepare database records
|
||||
db_inserts = []
|
||||
db_updates = []
|
||||
|
||||
for result in batch:
|
||||
item_hash = result[self.hash_field]
|
||||
color = result["color"]
|
||||
|
||||
# Check if record exists in database
|
||||
existing_record = LibDataTable.find_one(item_hash, type=self.item_type)
|
||||
|
||||
if existing_record is None:
|
||||
db_inserts.append(
|
||||
{
|
||||
"itemhash": "album" + albumhash,
|
||||
"color": colors[0],
|
||||
"itemtype": "album",
|
||||
"itemhash": self.item_type + item_hash,
|
||||
"color": color,
|
||||
"itemtype": self.item_type,
|
||||
}
|
||||
)
|
||||
else:
|
||||
LibDataTable.update_one(albumhash, {"color": colors[0]})
|
||||
db_updates.append(
|
||||
{"itemhash": self.item_type + item_hash, "color": color}
|
||||
)
|
||||
|
||||
# Batch database operations
|
||||
if db_inserts:
|
||||
LibDataTable.insert_many(db_inserts)
|
||||
|
||||
if db_updates:
|
||||
for update_data in db_updates:
|
||||
clean_hash = update_data["itemhash"].replace(self.item_type, "")
|
||||
LibDataTable.update_one(clean_hash, {"color": update_data["color"]})
|
||||
|
||||
# Update in-memory store
|
||||
store_map = getattr(self.store, f"{self.item_type}map")
|
||||
|
||||
for result in batch:
|
||||
item_hash = result[self.hash_field]
|
||||
color = result["color"]
|
||||
|
||||
item = store_map.get(item_hash)
|
||||
if item:
|
||||
item.set_color(color)
|
||||
|
||||
|
||||
class ProcessAlbumColors:
|
||||
"""
|
||||
Extracts the most dominant color from the album art and saves it to the database.
|
||||
Uses multiprocessing for parallel color extraction and batch database operations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
ColorProcessor(
|
||||
item_type="album",
|
||||
store=AlbumStore,
|
||||
path_func=settings.Paths.get_sm_thumb_path,
|
||||
hash_field="albumhash",
|
||||
)
|
||||
|
||||
|
||||
class ProcessArtistColors:
|
||||
"""
|
||||
Extracts the most dominant color from the artist art and saves it to the database.
|
||||
Uses multiprocessing for parallel color extraction and batch database operations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
all_artists = [a for a in ArtistStore.get_flat_list() if not a.color]
|
||||
|
||||
for artist in tqdm(all_artists, desc="Processing missing artist colors"):
|
||||
artisthash = artist.artisthash
|
||||
|
||||
record = LibDataTable.find_one(artisthash, "artist")
|
||||
if (record is not None) and (record.color is not None):
|
||||
continue
|
||||
|
||||
colors = process_color(artisthash, is_album=False)
|
||||
|
||||
if colors is None:
|
||||
continue
|
||||
|
||||
artist = ArtistStore.artistmap.get(artisthash)
|
||||
|
||||
if artist:
|
||||
artist.set_color(colors[0])
|
||||
|
||||
if record is None:
|
||||
LibDataTable.insert_one(
|
||||
{
|
||||
"itemhash": "artist" + artisthash,
|
||||
"color": colors[0],
|
||||
"itemtype": "artist",
|
||||
}
|
||||
)
|
||||
else:
|
||||
LibDataTable.update_one("artist" + artisthash, {"color": colors[0]})
|
||||
ColorProcessor(
|
||||
item_type="artist",
|
||||
store=ArtistStore,
|
||||
path_func=settings.Paths.get_sm_artist_img_path,
|
||||
hash_field="artisthash",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user