diff --git a/swingmusic/api/backup_and_restore.py b/swingmusic/api/backup_and_restore.py index 3fe3fea8..bf701fea 100644 --- a/swingmusic/api/backup_and_restore.py +++ b/swingmusic/api/backup_and_restore.py @@ -10,7 +10,7 @@ from flask_openapi3 import APIBlueprint import sqlalchemy.exc from swingmusic.api.auth import admin_required -from swingmusic.db.userdata import FavoritesTable, PlaylistTable, ScrobbleTable +from swingmusic.db.userdata import FavoritesTable, PlaylistTable, ScrobbleTable, CollectionTable from swingmusic.lib.index import index_everything from swingmusic.settings import Paths from datetime import datetime @@ -29,7 +29,7 @@ api = APIBlueprint( @admin_required() def backup(): """ - Create a backup file of your favorites, playlists and scrobble data. + Create a backup file of your favorites, playlists, scrobble data, and collections. """ backup_name = f"backup.{int(time())}" backup_dir = Path("~").expanduser() / "swingmusic.backup" / backup_name @@ -78,10 +78,23 @@ def backup(): shutil.copy(img_path, img_folder / playlist["image"]) # !SECTION + + # SECTION: Collections + collections_list = list(CollectionTable.get_all()) + collections_dicts = [] + + for collection in collections_list: + # Remove auto-generated id field + collection_copy = collection.copy() + if "id" in collection_copy: + del collection_copy["id"] + collections_dicts.append(collection_copy) + # !SECTION data = { "favorites": favorites, "scrobbles": scrobbles, "playlists": playlist_dicts, + "collections": collections_dicts, } with open(backup_file, "w") as f: @@ -93,11 +106,12 @@ def backup(): "scrobbles": len(scrobbles), "favorites": len(favorites), "playlists": len(playlist_dicts), + "collections": len(collections_dicts), }, 200 class RestoreBackup: - # TODO: BACKUP AND RESTORE COLLECTIONS & MIXES! + # TODO: BACKUP AND RESTORE MIXES! # TODO: IMPROVE UX WHEN WAITING FOR RESTORE TO COMPLETE! def __init__(self, backup_dir: Path): @@ -109,6 +123,7 @@ class RestoreBackup: self.restore_favorites(self.data["favorites"]) self.restore_playlists(self.data["playlists"]) self.restore_scrobbles(self.data["scrobbles"]) + self.restore_collections(self.data.get("collections", [])) def restore(self): pass @@ -161,6 +176,25 @@ class RestoreBackup: print("Integrity error, skipping scrobble:") print(scrobble) + def restore_collections(self, collections: list[dict]): + existing_collections = list(CollectionTable.get_all()) + existing_names = set(collection["name"] for collection in existing_collections) + new_collections = [ + collection for collection in collections if collection["name"] not in existing_names + ] + + for collection in new_collections: + try: + # Ensure userid is set for the collection + if collection.get("userid") is None: + from swingmusic.utils.auth import get_current_userid + collection["userid"] = get_current_userid() + + CollectionTable.insert_one(collection) + except sqlalchemy.exc.IntegrityError: + print("Integrity error, skipping collection:") + print(collection) + class RestoreBackupBody(BaseModel): @@ -175,7 +209,7 @@ class RestoreBackupBody(BaseModel): @admin_required() def restore(body: RestoreBackupBody): """ - Restore your favorites, playlists and scrobble data from a specified backup or all backups. + Restore your favorites, playlists, scrobble data, and collections from a specified backup or all backups. """ backup_base_dir = Path("~").expanduser() / "swingmusic.backup" backups = [] @@ -247,10 +281,12 @@ def list_backups(): backup_info["scrobbles"] = len(data.get("scrobbles", [])) backup_info["favorites"] = len(data.get("favorites", [])) backup_info["playlists"] = len(data.get("playlists", [])) + backup_info["collections"] = len(data.get("collections", [])) else: backup_info["scrobbles"] = 0 backup_info["favorites"] = 0 backup_info["playlists"] = 0 + backup_info["collections"] = 0 backups.append(backup_info)