diff --git a/.github/changelog.md b/.github/changelog.md index 1d00e7f1..4c1511df 100644 --- a/.github/changelog.md +++ b/.github/changelog.md @@ -1,7 +1,6 @@ # What's New? -- Hovering on recent favorite item will show how long ago it was ♥ed -- Recently added playlist returns a max of 100 tracks, but without a cutoff period + +- Auth -# Development -- API documentation on /openapi \ No newline at end of file +## Development diff --git a/.gitignore b/.gitignore index 7cbee9ee..b2412028 100644 --- a/.gitignore +++ b/.gitignore @@ -27,7 +27,7 @@ client logs.txt *.spec -TODO.md +# TODO.md testdata.py test.py nohup.out diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..c78efee4 --- /dev/null +++ b/TODO.md @@ -0,0 +1,2 @@ +- Fix migrations! + - Use total length instead of release version length \ No newline at end of file diff --git a/app/api/__init__.py b/app/api/__init__.py index a74624bc..4319499d 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -8,9 +8,12 @@ from flask_compress import Compress from flask_openapi3 import Info from flask_openapi3 import OpenAPI +from flask_jwt_extended import JWTManager +from app.config import UserConfig -from app.settings import Keys +from app.settings import Info as AppInfo from .plugins import lyrics as lyrics_plugin +from app.db.sqlite.auth import SQLiteAuthMethods as authdb from app.api import ( album, artist, @@ -20,13 +23,14 @@ from app.api import ( imgserver, playlist, search, - send_file, settings, lyrics, plugins, logger, home, getall, + auth, + stream, ) # TODO: Move this description to a separate file @@ -54,23 +58,45 @@ def create_api(): """ api_info = Info( title=f"Swing Music", - version=f"v{Keys.SWINGMUSIC_APP_VERSION}", + version=f"v{AppInfo.SWINGMUSIC_APP_VERSION}", description=open_api_description, ) app = OpenAPI(__name__, info=api_info, doc_prefix="/docs") + # JWT CONFIGS + app.config["JWT_SECRET_KEY"] = UserConfig().userId + app.config["JWT_TOKEN_LOCATION"] = ["cookies"] + app.config["JWT_COOKIE_CSRF_PROTECT"] = False + app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=30) - CORS(app, origins="*") + # CORS + CORS(app, origins="*", supports_credentials=True) + + # RESPONSE COMPRESSION Compress(app) - app.config["COMPRESS_MIMETYPES"] = [ "application/json", ] + # JWT + jwt = JWTManager(app) + + # @jwt.user_identity_loader + # def user_identity_lookup(user): + # return user + + @jwt.user_lookup_loader + def user_lookup_callback(_jwt_header, jwt_data): + identity = jwt_data["sub"] + userid = identity["id"] + user = authdb.get_user_by_id(userid) + return user.todict() + + # Register all the API blueprints with app.app_context(): app.register_api(album.api) app.register_api(artist.api) - app.register_api(send_file.api) + app.register_api(stream.api) app.register_api(search.api) app.register_api(folder.api) app.register_api(playlist.api) @@ -89,8 +115,9 @@ def create_api(): # Home app.register_api(home.api) - - # Flask Restful app.register_api(getall.api) + # Auth + app.register_api(auth.api) + return app diff --git a/app/api/auth.py b/app/api/auth.py new file mode 100644 index 00000000..b8af863a --- /dev/null +++ b/app/api/auth.py @@ -0,0 +1,290 @@ +import json +from dataclasses import asdict +from functools import wraps +import sqlite3 +from flask import jsonify +from flask_jwt_extended import ( + create_access_token, + current_user, + jwt_required, + set_access_cookies, +) +from pydantic import BaseModel, Field +from flask_openapi3 import Tag +from flask_openapi3 import APIBlueprint + +from app.db.sqlite.auth import SQLiteAuthMethods as authdb +from app.utils.auth import check_password, hash_password +from app.config import UserConfig + +bp_tag = Tag(name="Auth", description="Authentication stuff") +api = APIBlueprint("auth", __name__, url_prefix="/auth", abp_tags=[bp_tag]) + + +def admin_required(): + """ + Decorator to require admin role + """ + + def wrapper(fn): + @wraps(fn) + def decorator(*args, **kwargs): + if "admin" not in current_user["roles"]: + return {"msg": "Only admins can do that!"}, 403 + return fn(*args, **kwargs) + + return decorator + + return wrapper + + +class LoginBody(BaseModel): + username: str = Field(description="The username", example="user0") + password: str = Field(description="The password", example="password0") + + +@api.post("/login") +def login(body: LoginBody): + """ + Authenticate using username and password + """ + res = jsonify({"msg": f"Logged in as {body.username}"}) + + user = authdb.get_user_by_username(body.username) + + if user is None: + return {"msg": "User not found"}, 404 + + password_ok = check_password(body.password, user.password) + + if not password_ok: + return {"msg": "Hehe! invalid password"}, 401 + + access_token = create_access_token(identity=user.todict()) + set_access_cookies(res, access_token) + return res + + +class UpdateProfileBody(BaseModel): + id: int = Field(0, description="The user id") + email: str = Field("", description="The email") + username: str = Field("", description="The username", example="user0") + password: str = Field("", description="The password", example="password0") + roles: list[str] = Field(None, description="The roles") + + +@api.put("/profile/update") +def update_profile(body: UpdateProfileBody): + user = { + "id": body.id, + "email": body.email, + "username": body.username, + "password": body.password, + "roles": body.roles, + } + + # prevent updating guest + if current_user["username"] == "guest" or user["username"] == "guest": + return {"msg": "Cannot update guest user"}, 400 + + # if not id, update self + if not user["id"]: + user["id"] = current_user["id"] + + if body.roles is not None: + # only admins can update roles + if "admin" not in current_user["roles"]: + return {"msg": "Only admins can update roles"}, 403 + + all_users = authdb.get_all_users() + if "admin" not in body.roles: + # check if we're removing the last admin + admins = [user for user in all_users if "admin" in user.roles] + + if len(admins) == 1 and admins[0].id == user["id"]: + return {"msg": "Cannot remove the only admin"}, 400 + + # guest roles cannot be updated + _user = [u for u in all_users if u.id == user["id"]][0] + if "guest" in _user.roles: + return {"msg": "Cannot update guest user"}, 400 + + # finally, convert roles to json string + user["roles"] = json.dumps(body.roles) + + if user["password"]: + user["password"] = hash_password(user["password"]) + + # remove empty values + clean_user = {k: v for k, v in user.items() if v} + + try: + return authdb.update_user(clean_user) + except sqlite3.IntegrityError: + return {"msg": "Username already exists"}, 400 + + +@api.post("/profile/create") +@admin_required() +def create_user(body: UpdateProfileBody): + if not body.username or not body.password: + return {"msg": "Username and password are required"}, 400 + + user = { + "username": body.username, + "password": hash_password(body.password), + "roles": json.dumps([]), + } + + # check if user already exists + if authdb.get_user_by_username(user["username"]): + return {"msg": "Username already exists"}, 400 + + userid = authdb.insert_user(user) + return authdb.get_user_by_id(userid).todict() + + +@api.post("/profile/guest/create") +@admin_required() +def create_guest_user(): + """ + Create a guest user + """ + # check if guest user already exists + guest_user = authdb.get_user_by_username("guest") + + if guest_user: + return { + "msg": "Guest user already exists", + }, 400 + + userid = authdb.insert_guest_user() + + if userid: + return { + "msg": "Guest user created", + } + + return { + "msg": "Failed to create guest user", + }, 500 + + +class DeleteUseBody(BaseModel): + username: str = Field("", description="The username") + + +@api.delete("/profile/delete") +@admin_required() +def delete_user(body: DeleteUseBody): + """ + Delete a user by username + """ + # prevent admin from deleting themselves + if body.username == current_user["username"]: + return {"msg": "Sorry! you cannot delete yourselfu"}, 400 + + # prevent deleting the only admin + users = authdb.get_all_users() + admins = [user for user in users if "admin" in user.roles] + if len(admins) == 1 and admins[0].username == body.username: + return {"msg": "Cannot delete the only admin"}, 400 + + authdb.delete_user_by_username(body.username) + return {"msg": f"User {body.username} deleted"} + + +@api.get("/logout") +def logout(): + """ + Log out + """ + res = jsonify({"msg": "Logged out"}) + res.delete_cookie("access_token_cookie") + return res + + +class GetAllUsersQuery(BaseModel): + simplified: bool = Field( + False, description="Whether to return simplified user data" + ) + + +@api.get("/users") +@jwt_required(optional=True) +def get_all_users(query: GetAllUsersQuery): + """ + Get all users (if you're an admin, you will also receive accounts settings) + """ + config = UserConfig() + # config.enableGuest = True + # config.usersOnLogin = True + settings = { + "enableGuest": False, + "usersOnLogin": config.usersOnLogin, + } + + res = { + "settings": {}, + "users": [], + } + + users = authdb.get_all_users() + + is_admin = current_user and "admin" in current_user["roles"] + settings["enableGuest"] = [ + user for user in users if user.username == "guest" + ].__len__() > 0 + + # if user is admin, also return settings + if is_admin: + res = { + "settings": settings, + } + + # if is normal user, return empty response + elif current_user: + return res + + # if not logged in and showing users on login is disabled, return empty response + elif ( + not current_user + and not settings["usersOnLogin"] + and not settings["enableGuest"] + ): + return res + + # remove guest user + # if not settings["enableGuest"]: + # users = [user for user in users if user.username != "guest"] + + if not settings["usersOnLogin"]: + users = [user for user in users if user.username == "guest"] + + # reverse list to show latest users first + users = list(reversed(users)) + + # bring admins to the front + users = sorted(users, key=lambda x: "admin" in x.roles, reverse=True) + # bring current user to index 0 + if current_user: + users = sorted( + users, + key=lambda x: x.username == current_user["username"], + reverse=True, + ) + + if query.simplified: + res["users"] = [user.todict_simplified() for user in users] + + res["users"] = [user.todict() for user in users] + + return res + + +@api.route("/user") +def get_logged_in_user(): + """ + Get logged in user + """ + return dict(current_user) diff --git a/app/api/getall/__init__.py b/app/api/getall/__init__.py index 8a6c7385..b3199ca9 100644 --- a/app/api/getall/__init__.py +++ b/app/api/getall/__init__.py @@ -22,7 +22,7 @@ bp_tag = Tag(name="Get all", description="List all items") api = APIBlueprint("getall", __name__, url_prefix="/getall", abp_tags=[bp_tag]) -class GetAllItemsBody(GenericLimitSchema): +class GetAllItemsQuery(GenericLimitSchema): start: int = Field( description="The start index of the items to return", example=0, @@ -34,10 +34,10 @@ class GetAllItemsBody(GenericLimitSchema): default="created_date", ) - reverse: int = Field( + reverse: str = Field( description="Reverse the sort", example=1, - default=1, + default="1", ) @@ -50,7 +50,7 @@ class GetAllItemsPath(BaseModel): @api.get("/") -def get_all_items(path: GetAllItemsPath, query: GetAllItemsBody): +def get_all_items(path: GetAllItemsPath, query: GetAllItemsQuery): """ Get all items @@ -67,10 +67,7 @@ def get_all_items(path: GetAllItemsPath, query: GetAllItemsBody): start = query.start limit = query.limit sort = query.sortby - reverse = query.reverse == 1 - - # if sort == "": - # sort = "created_date" + reverse = query.reverse == "1" sort_is_count = sort == "count" sort_is_duration = sort == "duration" diff --git a/app/api/imgserver.py b/app/api/imgserver.py index 680abba1..c5e2db49 100644 --- a/app/api/imgserver.py +++ b/app/api/imgserver.py @@ -13,6 +13,9 @@ api = APIBlueprint("imgserver", __name__, url_prefix="/img", abp_tags=[bp_tag]) def send_fallback_img(filename: str = "default.webp"): + """ + Returns the fallback image from the assets folder. + """ folder = Paths.get_assets_path() img = Path(folder) / filename @@ -22,6 +25,18 @@ def send_fallback_img(filename: str = "default.webp"): return send_from_directory(folder, filename) +def send_file_or_fallback(folder: str, filename: str, fallback: str = "default.webp"): + """ + Returns the file from the folder or the fallback image. + """ + fpath = Path(folder) / filename + + if fpath.exists(): + return send_from_directory(folder, filename) + + return send_fallback_img(fallback) + + class ImagePath(BaseModel): imgpath: str = Field( description="The image filename", @@ -43,62 +58,72 @@ class ImagePath(BaseModel): # return send_fallback_img() -@api.get("/t/") +# TRACK THUMBNAILS +@api.get("/thumbnail/") def send_lg_thumbnail(path: ImagePath): """ Get large thumbnail (500 x 500) """ folder = Paths.get_lg_thumb_path() - fpath = Path(folder) / path.imgpath - - if fpath.exists(): - return send_from_directory(folder, path.imgpath) - - return send_fallback_img() + return send_file_or_fallback(folder, path.imgpath) -@api.get("/t/s/") +@api.get("/thumbnail/xsmall/") +def send_xsm_thumbnail(path: ImagePath): + """ + Get extra small thumbnail (64px) + """ + folder = Paths.get_xsm_thumb_path() + return send_file_or_fallback(folder, path.imgpath) + + +@api.get("/thumbnail/small/") def send_sm_thumbnail(path: ImagePath): """ - Get small thumbnail (64 x 64) + Get small thumbnail (96px) """ folder = Paths.get_sm_thumb_path() - fpath = Path(folder) / path.imgpath - - if fpath.exists(): - return send_from_directory(folder, path.imgpath) - - return send_fallback_img() + return send_file_or_fallback(folder, path.imgpath) -@api.get("/a/") +@api.get("/thumbnail/medium/") +def send_md_thumbnail(path: ImagePath): + """ + Get medium thumbnail (256px) + """ + folder = Paths.get_md_thumb_path() + return send_file_or_fallback(folder, path.imgpath) + + +# ARTISTS +@api.get("/artist/") def send_lg_artist_image(path: ImagePath): """ Get large artist image (500 x 500) """ - folder = Paths.get_artist_img_lg_path() - fpath = Path(folder) / path.imgpath - - if fpath.exists(): - return send_from_directory(folder, path.imgpath) - - return send_fallback_img("artist.webp") + folder = Paths.get_lg_artist_img_path() + return send_file_or_fallback(folder, path.imgpath, "artist.webp") -@api.get("/a/s/") +@api.get("/artist/small/") def send_sm_artist_image(path: ImagePath): """ - Get small artist image (64 x 64) + Get small artist image (128) """ - folder = Paths.get_artist_img_sm_path() - fpath = Path(folder) / path.imgpath - - if fpath.exists(): - return send_from_directory(folder, path.imgpath) - - return send_fallback_img("artist.webp") + folder = Paths.get_sm_artist_img_path() + return send_file_or_fallback(folder, path.imgpath, "artist.webp") +@api.get("/artist/medium/") +def send_md_artist_image(path: ImagePath): + """ + Get medium artist image (256px) + """ + folder = Paths.get_md_artist_img_path() + return send_file_or_fallback(folder, path.imgpath, "artist.webp") + + +# PLAYLISTS class PlaylistImagePath(BaseModel): imgpath: str = Field( description="The image path", @@ -106,7 +131,7 @@ class PlaylistImagePath(BaseModel): ) -@api.get("/p/") +@api.get("/playlist/") def send_playlist_image(path: PlaylistImagePath): """ Get playlist image @@ -114,9 +139,4 @@ def send_playlist_image(path: PlaylistImagePath): Images are constructed as '{playlist_id}.webp' """ folder = Paths.get_playlist_img_path() - fpath = Path(folder) / path.imgpath - - if fpath.exists(): - return send_from_directory(folder, path.imgpath) - - return send_fallback_img("playlist.svg") + return send_file_or_fallback(folder, path.imgpath, "playlist.svg") diff --git a/app/api/playlist.py b/app/api/playlist.py index 91461fdf..8ccd1582 100644 --- a/app/api/playlist.py +++ b/app/api/playlist.py @@ -26,6 +26,7 @@ api = APIBlueprint("playlists", __name__, url_prefix="/playlists", abp_tags=[tag PL = SQLitePlaylistMethods + class SendAllPlaylistsQuery(BaseModel): no_images: bool = Field(False, description="Whether to include images") @@ -410,7 +411,7 @@ def save_item_as_playlist(body: SavePlaylistAsItemBody): filename = itemhash + ".webp" base_path = ( - Paths.get_artist_img_lg_path() + Paths.get_lg_artist_img_path() if itemtype == "artist" else Paths.get_lg_thumb_path() ) diff --git a/app/api/plugins/__init__.py b/app/api/plugins/__init__.py index 6fffe4ab..930bf603 100644 --- a/app/api/plugins/__init__.py +++ b/app/api/plugins/__init__.py @@ -3,6 +3,7 @@ from flask import Blueprint, request from flask_openapi3 import Tag from flask_openapi3 import APIBlueprint from pydantic import BaseModel, Field +from app.api.auth import admin_required from app.db.sqlite.plugins import PluginsMethods bp_tag = Tag(name="Plugins", description="Manage plugins") @@ -30,6 +31,7 @@ class PluginActivateBody(PluginBody): @api.post("/setactive") +@admin_required() def activate_deactivate_plugin(body: PluginActivateBody): """ Activate/Deactivate plugin @@ -49,6 +51,7 @@ class PluginSettingsBody(PluginBody): @api.post("/settings") +@admin_required() def update_plugin_settings(body: PluginSettingsBody): """ Update plugin settings diff --git a/app/api/settings.py b/app/api/settings.py index e5d3306b..4931bc52 100644 --- a/app/api/settings.py +++ b/app/api/settings.py @@ -3,18 +3,21 @@ from flask import request from flask_openapi3 import Tag from flask_openapi3 import APIBlueprint from pydantic import BaseModel, Field +from app.api.auth import admin_required from app.db.sqlite.plugins import PluginsMethods as pdb from app.db.sqlite.settings import SettingsSQLMethods as sdb +from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb from app.lib import populate from app.lib.watchdogg import Watcher as WatchDog from app.logger import log -from app.settings import Keys, Paths, SessionVarKeys, set_flag +from app.settings import Info, Paths, SessionVarKeys, set_flag from app.store.albums import AlbumStore from app.store.artists import ArtistStore from app.store.tracks import TrackStore from app.utils.generators import get_random_str from app.utils.threading import background +from app.config import UserConfig bp_tag = Tag(name="Settings", description="Customize stuff") api = APIBlueprint("settings", __name__, url_prefix="/notsettings", abp_tags=[bp_tag]) @@ -49,12 +52,12 @@ def reload_everything(instance_key: str): @background def rebuild_store(db_dirs: list[str]): """ - Restarts the watchdog and rebuilds the music library. + Restarts watchdog and rebuilds the music library. """ instance_key = get_random_str() log.info("Rebuilding library...") - TrackStore.remove_tracks_by_dir_except(db_dirs) + trackdb.remove_tracks_not_in_folders(db_dirs) reload_everything(instance_key) try: @@ -69,7 +72,7 @@ def rebuild_store(db_dirs: list[str]): log.info("Rebuilding library... ✅") -# I freaking don't know what this function does anymore +# I freaking don't know what this function does anymore def finalize(new_: list[str], removed_: list[str], db_dirs_: list[str]): """ Params: @@ -95,6 +98,7 @@ class AddRootDirsBody(BaseModel): @api.post("/add-root-dirs") +@admin_required() def add_root_dirs(body: AddRootDirsBody): """ Add custom root directories to the database. @@ -103,10 +107,10 @@ def add_root_dirs(body: AddRootDirsBody): removed_dirs = body.removed db_dirs = sdb.get_root_dirs() - _h = "$home" + home = "$home" - db_home = any([d == _h for d in db_dirs]) # if $home is in db - incoming_home = any([d == _h for d in new_dirs]) # if $home is in incoming + db_home = any([d == home for d in db_dirs]) # if $home is in db + incoming_home = any([d == home for d in new_dirs]) # if $home is in incoming # handle $home case if db_home and incoming_home: @@ -116,8 +120,8 @@ def add_root_dirs(body: AddRootDirsBody): sdb.remove_root_dirs(db_dirs) if incoming_home: - finalize([_h], [], [Paths.USER_HOME_DIR]) - return {"root_dirs": [_h]} + finalize([home], [], [Paths.USER_HOME_DIR]) + return {"root_dirs": [home]} # --- @@ -132,7 +136,7 @@ def add_root_dirs(body: AddRootDirsBody): pass db_dirs.extend(new_dirs) - db_dirs = [dir_ for dir_ in db_dirs if dir_ != _h] + db_dirs = [dir_ for dir_ in db_dirs if dir_ != home] finalize(new_dirs, removed_dirs, db_dirs) @@ -190,7 +194,7 @@ def get_all_settings(): root_dirs = sdb.get_root_dirs() s["root_dirs"] = root_dirs s["plugins"] = plugins - s["version"] = Keys.SWINGMUSIC_APP_VERSION + s["version"] = Info.SWINGMUSIC_APP_VERSION return { "settings": s, @@ -214,6 +218,7 @@ class SetSettingBody(BaseModel): @api.post("/set") +@admin_required() def set_setting(body: SetSettingBody): """ Set a setting. @@ -264,3 +269,28 @@ def trigger_scan(): run_populate() return {"msg": "Scan triggered!"} + + +class UpdateConfigBody(BaseModel): + key: str = Field( + description="The setting key", + example="usersOnLogin", + ) + value: Any = Field( + description="The setting value", + example=False, + ) + + +@api.put("/update") +@admin_required() +def update_config(body: UpdateConfigBody): + """ + Update the config file + """ + config = UserConfig() + setattr(config, body.key, body.value) + + return { + "msg": "Config updated!", + } diff --git a/app/api/send_file.py b/app/api/stream.py similarity index 78% rename from app/api/send_file.py rename to app/api/stream.py index 72fe8b37..714c3599 100644 --- a/app/api/send_file.py +++ b/app/api/stream.py @@ -3,14 +3,17 @@ Contains all the track routes. """ import os +import time from flask import Blueprint, send_file, request, Response from flask_openapi3 import APIBlueprint, Tag from pydantic import BaseModel, Field from app.api.apischemas import TrackHashSchema +from app.lib.pydub.pydub.audio_segment import AudioSegment from app.lib.trackslib import get_silence_paddings from app.store.tracks import TrackStore +from app.utils.files import guess_mime_type bp_tag = Tag(name="File", description="Audio files") api = APIBlueprint("track", __name__, url_prefix="/file", abp_tags=[bp_tag]) @@ -33,10 +36,6 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery): filepath = query.filepath msg = {"msg": "File Not Found"} - def get_mime(filename: str) -> str: - ext = filename.rsplit(".", maxsplit=1)[-1] - return f"audio/{ext}" - # If filepath is provided, try to send that if filepath is not None: try: @@ -47,7 +46,7 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery): track_exists = track is not None and os.path.exists(track.filepath) if track_exists: - audio_type = get_mime(filepath) + audio_type = guess_mime_type(filepath) return send_file_as_chunks(track.filepath, audio_type) # Else, find file by trackhash @@ -57,7 +56,7 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery): if track is None: return msg, 404 - audio_type = get_mime(track.filepath) + audio_type = guess_mime_type(track.filepath) try: return send_file_as_chunks(track.filepath, audio_type) @@ -68,15 +67,31 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery): def send_file_as_chunks(filepath: str, audio_type: str) -> Response: + """ + Returns a Response object that streams the file in chunks. + """ + # NOTE: +1 makes sure the last byte is included in the range. + # NOTE: -1 is used to convert the end index to a 0-based index. + chunk_size = 1024 * 360 # 360 KB + + # Get file size file_size = os.path.getsize(filepath) start = 0 - end = file_size - 1 + end = chunk_size + # Read range header range_header = request.headers.get("Range") if range_header: - start, end = parse_range_header(range_header, file_size) + start = get_start_range(range_header) - chunk_size = 1024 * 1024 # 1MB chunk size (adjust as needed) + # If start + chunk_size is greater than file_size, + # set end to file_size - 1 + _end = start + chunk_size - 1 + + if _end > file_size: + end = file_size - 1 + else: + end = _end def generate_chunks(): with open(filepath, "rb") as file: @@ -84,8 +99,11 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response: remaining_bytes = end - start + 1 while remaining_bytes > 0: + # Read the chunk size or all the remaining bytes chunk = file.read(min(chunk_size, remaining_bytes)) yield chunk + + # Update the remaining bytes remaining_bytes -= len(chunk) response = Response( @@ -102,15 +120,13 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response: return response -def parse_range_header(range_header: str, file_size: int) -> tuple[int, int]: +def get_start_range(range_header: str): try: range_start, range_end = range_header.strip().split("=")[1].split("-") - start = int(range_start) - end = min(int(range_end), file_size - 1) - except ValueError: - return 0, file_size - 1 + return int(range_start) - return start, end + except ValueError: + return 0 class GetAudioSilenceBody(BaseModel): diff --git a/app/arg_handler.py b/app/arg_handler.py index dec203ac..b74acfe4 100644 --- a/app/arg_handler.py +++ b/app/arg_handler.py @@ -2,6 +2,7 @@ Handles arguments passed to the program. """ +from getpass import getpass import os.path import sys @@ -10,27 +11,37 @@ import PyInstaller.__main__ as bundler from app import settings from app.logger import log from app.print_help import HELP_MESSAGE +from app.utils.auth import hash_password from app.utils.paths import getFlaskOpenApiPath from app.utils.xdg_utils import get_xdg_config_dir from app.utils.wintools import is_windows +from app.db.sqlite.auth import SQLiteAuthMethods as authdb ALLARGS = settings.ALLARGS ARGS = sys.argv[1:] -class HandleArgs: +class ProcessArgs: + """ + Processes the arguments passed to the program. + """ + def __init__(self) -> None: + # resolve config path + self.handle_config_path() # 1 + + # handles that exit + self.handle_password_recovery() self.handle_build() - self.handle_host() - self.handle_port() - self.handle_config_path() - - self.handle_periodic_scan() - self.handle_periodic_scan_interval() - self.handle_help() self.handle_version() + # non-exiting handles + self.handle_host() + self.handle_port() + self.handle_periodic_scan() + self.handle_periodic_scan_interval() + @staticmethod def handle_build(): """ @@ -45,7 +56,7 @@ class HandleArgs: print("https://www.youtube.com/watch?v=wZv62ShoStY") sys.exit(0) - config_keys = [ + info_keys = [ "SWINGMUSIC_APP_VERSION", "GIT_LATEST_COMMIT_HASH", "GIT_CURRENT_BRANCH", @@ -53,16 +64,17 @@ class HandleArgs: lines = [] - for key in config_keys: - value = settings.Keys.get(key) + for key in info_keys: + value = settings.Info.get(key) if not value: - log.error(f"WARNING: {key} not set in environment") + log.error(f"WARNING: {key} not resolved. Exiting ...") sys.exit(0) lines.append(f'{key} = "{value}"\n') try: + # write the info to the config file with open("./app/configs.py", "w", encoding="utf-8") as file: # copy the api keys to the config file file.writelines(lines) @@ -88,7 +100,7 @@ class HandleArgs: finally: # revert and remove the api keys for dev mode with open("./app/configs.py", "w", encoding="utf-8") as file: - lines = [f'{key} = ""\n' for key in config_keys] + lines = [f'{key} = ""\n' for key in info_keys] file.writelines(lines) sys.exit(0) @@ -184,8 +196,43 @@ class HandleArgs: @staticmethod def handle_version(): if any((a in ARGS for a in ALLARGS.version)): - print(f"VERSION: v{settings.Keys.SWINGMUSIC_APP_VERSION}") + print(f"VERSION: v{settings.Info.SWINGMUSIC_APP_VERSION}") print( - f"COMMIT#: {settings.Keys.GIT_CURRENT_BRANCH}/{settings.Keys.GIT_LATEST_COMMIT_HASH}" + f"COMMIT#: {settings.Info.GIT_CURRENT_BRANCH}/{settings.Info.GIT_LATEST_COMMIT_HASH}" ) sys.exit(0) + + @staticmethod + def handle_password_recovery(): + if ALLARGS.pswd in ARGS: + print("SWING MUSIC v2.0.0 ") + print("PASSWORD RECOVERY \n") + + username: str = "" + password: str = "" + + # collect username + try: + username = input("Enter username: ") + except KeyboardInterrupt: + print("\nOperation cancelled! Exiting ...") + sys.exit(0) + + username = username.strip() + user = authdb.get_user_by_username(username) + + if not user: + print(f"User {username} not found") + sys.exit(0) + + # collect password + try: + password = getpass("Enter new password: ") + except KeyboardInterrupt: + print("\nOperation cancelled! Exiting ...") + sys.exit(0) + + password = hash_password(password) + user = authdb.update_user({"id": user.id, "password": password}) + + sys.exit(0) diff --git a/app/config.py b/app/config.py new file mode 100644 index 00000000..00e6cde5 --- /dev/null +++ b/app/config.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass, asdict, field +import json +import os +from typing import Any +from .settings import Paths + +# TODO: Publish this on PyPi + +@dataclass +class UserConfig: + _config_path: str = "" + # NOTE: only auth stuff are used (the others are still reading/writing to db) + # TODO: Move the rest of the settings to the config file + + # auth stuff + # NOTE: Don't expose the userId via the API + userId: str = "" + usersOnLogin: bool = True + + # lists + rootDirs: list[str] = field(default_factory=list) + excludeDirs: list[str] = field(default_factory=list) + artistSeparators: set[str] = field(default_factory=list) + + # tracks + extractFeaturedArtists: bool = True + removeProdBy: bool = True + removeRemasterInfo: bool = True + + # albums + mergeAlbums: bool = False + cleanAlbumTitle: bool = True + showAlbumsAsSingles: bool = False + + def __post_init__(self): + """ + Loads the config file and sets the values to this instance + """ + # set config path locally to avoid writing to file + config_path = Paths.get_config_file_path() + + try: + config = self.load_config(config_path) + except FileNotFoundError: + self._config_path = config_path + return + + # loop through the config file and set the values + for key, value in config.items(): + setattr(self, key, value) + + # finally set the config path + self._config_path = config_path + + def setup_config_file(self) -> None: + """ + Creates the config file with the default settings + if it doesn't exist + """ + # if not exists, create the config file + if not os.path.exists(self._config_path): + self.write_to_file(asdict(self)) + + def load_config(self, path: str) -> dict[str, Any]: + """ + Reads the settings from the config file. + Returns a dictget_root_dirs + """ + with open(path, "r") as f: + settings = json.load(f) + + return settings + + def write_to_file(self, settings: dict[str, Any]): + """ + Writes the settings to the config file + """ + # remove internal attributes + settings = {k: v for k, v in settings.items() if not k.startswith("_")} + + with open(self._config_path, "w") as f: + json.dump(settings, f, indent=4) + + def __setattr__(self, key: str, value: Any) -> None: + """ + Writes to the config file whenever a value is set + """ + super().__setattr__(key, value) + + # if is internal attribute, don't write to file + if key.startswith("_") or not self._config_path: + return + + print(f"writing to file: {key}={value}") + self.write_to_file(asdict(self)) diff --git a/app/db/sqlite/auth.py b/app/db/sqlite/auth.py new file mode 100644 index 00000000..3d9da234 --- /dev/null +++ b/app/db/sqlite/auth.py @@ -0,0 +1,146 @@ +import json +from app.models.user import User +from app.utils.auth import hash_password +from app.db.sqlite.utils import SQLiteManager + + +class SQLiteAuthMethods: + """ + Methods for authenticating users. + """ + + @staticmethod + def insert_user(user: dict[str, str]): + """ + Insert a user into the database. + + :param user: A dict with the username, password and roles. + """ + sql = """INSERT INTO users( + username, + password, + roles + ) VALUES(:username, :password, :roles) + """ + + user_tuple = tuple(user.values()) + + with SQLiteManager(userdata_db=True) as cur: + cur = cur.execute(sql, user_tuple) + userid = cur.lastrowid + return userid + # if userid: + # # sleep + # user = SQLiteAuthMethods.get_user_by_id(userid).todict_simplified() + # cur.close() + # return user + + raise Exception(f"Failed to insert user: {user}") + + @staticmethod + def insert_default_user(): + """ + Inserts the default admin user. + """ + user = { + "username": "admin", + "password": hash_password("admin"), + "roles": json.dumps(["admin"]), + } + return SQLiteAuthMethods.insert_user(user) + + @staticmethod + def insert_guest_user(): + """ + Inserts the default guest user. + """ + user = { + "username": "guest", + "password": hash_password("guest"), + "roles": json.dumps(["guest"]), + } + + return SQLiteAuthMethods.insert_user(user) + + @staticmethod + def update_user(user: dict[str, str]): + """ + Update a user in the database. + + :param user: A dict with the user id and the fields to update. Ommited fields will not be updated. + """ + # get all user dict keys + keys = list(user.keys()) + sql = f"""UPDATE users SET + {', '.join([f"{key} = :{key}" for key in keys if key != 'id'])} + WHERE id = :id + """ + print(sql, user) + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql, user) + cur.close() + + return SQLiteAuthMethods.get_user_by_id(user["id"]).todict() + + @staticmethod + def get_all_users(): + """ + Check if there are any users in the database. + """ + sql = "SELECT * FROM users" + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql) + + data = cur.fetchall() + cur.close() + + return [User(*user) for user in data] + + @staticmethod + def get_user_by_username(username: str): + """ + Get a user by username. + """ + sql = "SELECT * FROM users WHERE username = ?" + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql, (username,)) + + data = cur.fetchone() + cur.close() + + if data is not None: + return User(*data) + + return None + + @staticmethod + def get_user_by_id(userid: int): + """ + Get a user by id. + """ + sql = "SELECT * FROM users WHERE id = ?" + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql, (userid,)) + + data = cur.fetchone() + cur.close() + + if data is not None: + return User(*data) + + return None + + @staticmethod + def delete_user_by_username(username: str): + """ + Delete a user by username. + """ + sql = "DELETE FROM users WHERE username = ?" + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql, (username,)) + cur.close() diff --git a/app/db/sqlite/migrations.py b/app/db/sqlite/migrations.py index 86d36899..550673a0 100644 --- a/app/db/sqlite/migrations.py +++ b/app/db/sqlite/migrations.py @@ -7,9 +7,9 @@ from app.db.sqlite.utils import SQLiteManager class MigrationManager: @staticmethod - def get_version() -> int: + def get_index() -> int: """ - Returns the latest userdata database version. + Returns the latest databases migrations index. """ sql = "SELECT * FROM dbmigrations" with SQLiteManager() as cur: @@ -21,9 +21,9 @@ class MigrationManager: # 👇 Setters 👇 @staticmethod - def set_version(version: int): + def set_index(version: int): """ - Sets the userdata pre-init database version. + Updates the databases migrations index. """ sql = "UPDATE dbmigrations SET version = ? WHERE id = 1" with SQLiteManager() as cur: diff --git a/app/db/sqlite/queries.py b/app/db/sqlite/queries.py index 96b8b5b1..cf084d6a 100644 --- a/app/db/sqlite/queries.py +++ b/app/db/sqlite/queries.py @@ -54,6 +54,17 @@ CREATE TABLE IF NOT EXISTS track_logger ( timestamp integer NOT NULL, source text, userid integer NOT NULL DEFAULT 0 +); + +CREATE TABLE IF NOT EXISTS users ( + id integer PRIMARY KEY, + username text NOT NULL UNIQUE, + firstname text, + lastname text, + password text NOT NULL, + email text, + image text, + roles text NOT NULL DEFAULT '["user"]' ) """ diff --git a/app/db/sqlite/tracks.py b/app/db/sqlite/tracks.py index ab45f888..c8bed28c 100644 --- a/app/db/sqlite/tracks.py +++ b/app/db/sqlite/tracks.py @@ -112,9 +112,10 @@ class SQLiteTrackMethods: cur.execute("DELETE FROM tracks WHERE filepath=?", (filepath,)) @staticmethod - def remove_tracks_by_folders(folders: set[str]): - sql = "DELETE FROM tracks WHERE folder = ?" + def remove_tracks_not_in_folders(folders: set[str]): + sql = "DELETE FROM tracks WHERE folder NOT IN ({})".format( + ",".join("?" * len(folders)) + ) with SQLiteManager() as cur: - for folder in folders: - cur.execute(sql, (folder,)) + cur.execute(sql, tuple(folders)) diff --git a/app/lib/artistlib.py b/app/lib/artistlib.py index f49fcf0c..b40cb56a 100644 --- a/app/lib/artistlib.py +++ b/app/lib/artistlib.py @@ -55,13 +55,22 @@ def get_artist_image_link(artist: str): # TODO: Move network calls to utils/network.py class DownloadImage: def __init__(self, url: str, name: str) -> None: - sm_path = Path(settings.Paths.get_artist_img_sm_path()) / name - lg_path = Path(settings.Paths.get_artist_img_lg_path()) / name - img = self.download(url) - if img is not None: - self.save_img(img, sm_path, lg_path) + if img is None: + return + + sm_path = Path(settings.Paths.get_sm_artist_img_path()) / name + lg_path = Path(settings.Paths.get_lg_artist_img_path()) / name + md_path = Path(settings.Paths.get_md_artist_img_path()) / name + + entries = [ + (lg_path, None), # save in the original size + (sm_path, settings.Defaults.SM_ARTIST_IMG_SIZE), + (md_path, settings.Defaults.MD_ARTIST_IMG_SIZE), + ] + + self.save_img(img, entries) @staticmethod def download(url: str) -> Image.Image | None: @@ -74,14 +83,21 @@ class DownloadImage: return None @staticmethod - def save_img(img: Image.Image, sm_path: Path, lg_path: Path): + def save_img(img: Image.Image, entries: list[tuple[Path, int | None]]): """ Saves the image to the destinations. """ - img.save(lg_path, format="webp") + ratio = img.width / img.height + for entry in entries: + path, size = entry - sm_size = settings.Defaults.SM_ARTIST_IMG_SIZE - img.resize((sm_size, sm_size), Image.ANTIALIAS).save(sm_path, format="webp") + if size is None: + img.save(path, format="webp") + continue + + img.resize((size, int(size / ratio)), Image.ANTIALIAS).save( + path, format="webp" + ) class CheckArtistImages: @@ -90,7 +106,7 @@ class CheckArtistImages: CHECK_ARTIST_IMAGES_KEY = instance_key # read all files in the artist image folder - path = settings.Paths.get_artist_img_sm_path() + path = settings.Paths.get_sm_artist_img_path() processed = "".join(os.listdir(path)).replace("webp", "") # filter out artists that already have an image @@ -126,7 +142,7 @@ class CheckArtistImages: return img_path = ( - Path(settings.Paths.get_artist_img_sm_path()) / f"{artist.artisthash}.webp" + Path(settings.Paths.get_sm_artist_img_path()) / f"{artist.artisthash}.webp" ) if img_path.exists(): diff --git a/app/lib/colorlib.py b/app/lib/colorlib.py index 17dd5c4c..086fcb2b 100644 --- a/app/lib/colorlib.py +++ b/app/lib/colorlib.py @@ -42,7 +42,7 @@ def process_color(item_hash: str, is_album=True): path = ( settings.Paths.get_sm_thumb_path() if is_album - else settings.Paths.get_artist_img_sm_path() + else settings.Paths.get_sm_artist_img_path() ) path = Path(path) / (item_hash + ".webp") diff --git a/app/lib/home/recentlyplayed.py b/app/lib/home/recentlyplayed.py index 6d41c649..99fd1268 100644 --- a/app/lib/home/recentlyplayed.py +++ b/app/lib/home/recentlyplayed.py @@ -82,9 +82,13 @@ def get_recently_played(limit=7): if entry.type == "folder": folder = entry.type_src + if not folder: continue + if not folder.endswith("/"): + folder += "/" + is_home_dir = entry.type_src == "$home" if is_home_dir: @@ -98,7 +102,7 @@ def get_recently_played(limit=7): { "type": "folder", "item": { - "path": entry.type_src, + "path": folder, "count": count, "help_text": "folder", "time": timestamp_to_time_passed(entry.timestamp), diff --git a/app/lib/taglib.py b/app/lib/taglib.py index 2a523144..3d7e9f6d 100644 --- a/app/lib/taglib.py +++ b/app/lib/taglib.py @@ -33,20 +33,22 @@ def extract_thumb(filepath: str, webp_path: str, overwrite=False) -> bool: """ lg_img_path = os.path.join(Paths.get_lg_thumb_path(), webp_path) sm_img_path = os.path.join(Paths.get_sm_thumb_path(), webp_path) + xms_img_path = os.path.join(Paths.get_xsm_thumb_path(), webp_path) + md_img_path = os.path.join(Paths.get_md_thumb_path(), webp_path) - tsize = Defaults.THUMB_SIZE - sm_tsize = Defaults.SM_THUMB_SIZE + images = [ + (lg_img_path, Defaults.LG_THUMB_SIZE), + (sm_img_path, Defaults.SM_THUMB_SIZE), + (xms_img_path, Defaults.XSM_THUMB_SIZE), + (md_img_path, Defaults.MD_THUMB_SIZE), + ] def save_image(img: Image.Image): width, height = img.size ratio = width / height - img.resize((tsize, int(tsize / ratio)), Image.ANTIALIAS).save( - lg_img_path, "webp" - ) - img.resize((sm_tsize, int(sm_tsize / ratio)), Image.ANTIALIAS).save( - sm_img_path, "webp" - ) + for path, size in images: + img.resize((size, int(size / ratio)), Image.ANTIALIAS).save(path, "webp") if not overwrite and os.path.exists(sm_img_path): img_size = os.path.getsize(sm_img_path) diff --git a/app/lib/trackslib.py b/app/lib/trackslib.py index a685c0bd..085dda0c 100644 --- a/app/lib/trackslib.py +++ b/app/lib/trackslib.py @@ -1,12 +1,13 @@ """ This library contains all the functions related to tracks. """ + import os from app.lib.pydub.pydub import AudioSegment from app.lib.pydub.pydub.silence import detect_leading_silence, detect_silence -from app.db.sqlite.tracks import SQLiteTrackMethods as tdb +from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb from app.store.tracks import TrackStore from app.utils.progressbar import tqdm from app.utils.threading import ThreadWithReturnValue @@ -19,7 +20,7 @@ def validate_tracks() -> None: for track in tqdm(TrackStore.tracks, desc="Validating tracks"): if not os.path.exists(track.filepath): TrackStore.remove_track_obj(track) - tdb.remove_tracks_by_filepaths(track.filepath) + trackdb.remove_tracks_by_filepaths(track.filepath) def get_leading_silence_end(filepath: str): diff --git a/app/migrations/__init__.py b/app/migrations/__init__.py index 58efc9f0..e6d5c431 100644 --- a/app/migrations/__init__.py +++ b/app/migrations/__init__.py @@ -2,12 +2,6 @@ Migrations module. Reads and applies the latest database migrations. - -PLEASE NOTE: OLDER MIGRATIONS CAN NEVER BE DELETED. -ONLY MODIFY OLD MIGRATIONS FOR BUG FIXES OR ENHANCEMENTS ONLY -[TRY NOT TO MODIFY BEHAVIOR, UNLESS YOU KNOW WHAT YOU'RE DOING]. - -PS: Fuck that! Do what you want. """ from app.db.sqlite.migrations import MigrationManager @@ -33,21 +27,32 @@ migrations: list[list[Migration]] = [ def apply_migrations(): """ Applies the latest database migrations. + + The length of all the migrations is stored in the database + and used to check for new migrations. When the length of the + migrations list is larger than the number stored in the db, + migrations past that index are applied and the new length + is stored as the new migration index. """ - version = MigrationManager.get_version() + index = MigrationManager.get_index() + all_migrations = [migration for sublist in migrations for migration in sublist] - if version != len(migrations): - # INFO: Apply new migrations - for migration in migrations[version:]: - for m in migration: - try: - m.migrate() - log.info("Applied migration: %s", m.__name__) - except: - log.error("Failed to run migration: %s", m.__name__) - - print("Migrations applied successfully.") - print("Current migration version: ", len(migrations)) - # bump migration version - MigrationManager.set_version(len(migrations)) + to_apply: list[Migration] = [] + + # if index is from old release, + # get migrations from the "migrations" list + if index < 3: + _migrations = migrations[index:] + to_apply = [migration for sublist in _migrations for migration in sublist] + else: + to_apply = all_migrations[index:] + + for migration in to_apply: + try: + migration.migrate() + log.info("Applied migration: %s", migration.__name__) + except: + log.error("Failed to run migration: %s", migration.__name__) + + MigrationManager.set_index(len(all_migrations)) diff --git a/app/migrations/v1_4_9/__init__.py b/app/migrations/v1_4_9/__init__.py index 0b37ee98..6f0de717 100644 --- a/app/migrations/v1_4_9/__init__.py +++ b/app/migrations/v1_4_9/__init__.py @@ -58,3 +58,19 @@ class DeleteOriginalThumbnails(Migration): if os.path.exists(og_imgpath): shutil.rmtree(og_imgpath) + +class DeleteOriginalThumbnailsa(Migration): + """ + Original thumbnails are too large and are not needed. + """ + + # TODO: Implement this migration + + @staticmethod + def migrate(): + imgpath = Paths.get_thumbs_path() + og_imgpath = os.path.join(imgpath, "original") + + if os.path.exists(og_imgpath): + shutil.rmtree(og_imgpath) + diff --git a/app/models/user.py b/app/models/user.py new file mode 100644 index 00000000..b8d05d01 --- /dev/null +++ b/app/models/user.py @@ -0,0 +1,32 @@ +from dataclasses import asdict, field, dataclass +import json + + +@dataclass(slots=True) +class User: + id: int + username: str + firstname: str + lastname: str + password: str + email: str + image: str + + # NOTE: roles: ['admin', 'user', 'curator'] + roles: list[str] = field(default_factory=lambda: ["user"]) + + def __post_init__(self): + self.roles = json.loads(self.roles) + + def todict(self): + this_dict = asdict(self) + del this_dict["password"] + + return this_dict + + def todict_simplified(self): + return { + "id": self.id, + "username": self.username, + "firstname": self.firstname, + } diff --git a/app/periodic_scan.py b/app/periodic_scan.py index bf7d8c79..7a0af7e3 100644 --- a/app/periodic_scan.py +++ b/app/periodic_scan.py @@ -13,6 +13,12 @@ from app.logger import log def run_periodic_scans(): """ Runs periodic scans. + + Periodic scans are checks that run every few minutes + in the background to do stuff like: + - checking for new music + - delete deleted entries + - downloading artist images, and other data. """ # ValidateAlbumThumbs() # ValidatePlaylistThumbs() diff --git a/app/plugins/lyrics.py b/app/plugins/lyrics.py index 330693ec..81c73b59 100644 --- a/app/plugins/lyrics.py +++ b/app/plugins/lyrics.py @@ -8,7 +8,7 @@ import requests from app.db.sqlite.plugins import PluginsMethods from app.plugins import Plugin, plugin_method -from app.settings import Keys, Paths +from app.settings import Paths class LRCProvider: diff --git a/app/print_help.py b/app/print_help.py index 78393340..2ce7825c 100644 --- a/app/print_help.py +++ b/app/print_help.py @@ -1,4 +1,4 @@ -from app.settings import ALLARGS +from app.settings import ALLARGS, Info from tabulate import tabulate args = ALLARGS @@ -10,6 +10,7 @@ help_args_list = [ ["--port", "", "Set the port"], ["--config", "", "Set the config path"], ["--no-periodic-scan", "-nps", "Disable periodic scan"], + ["--pswd", "", "Recover a password"], [ "--scan-interval", "-psi", @@ -23,10 +24,12 @@ help_args_list = [ ] HELP_MESSAGE = f""" -Swing Music is a beautiful, self-hosted music player for your -local audio files. Like a cooler Spotify ... but bring your own music. +Swing Music v{Info.SWINGMUSIC_APP_VERSION} -Usage: swingmusic [options] [args] +A beautiful, self-hosted music player for your local audio files. +Like Spotify ... but bring your own music. -{tabulate(help_args_list, headers=["Option", "Short", "Description"], tablefmt="simple_grid", maxcolwidths=[None, None, 40])} +Usage: ./swingmusic [options] [args] + +{tabulate(help_args_list, headers=["Option", "Alias", "Description"], tablefmt="psql", maxcolwidths=[None, None, 40])} """ diff --git a/app/settings.py b/app/settings.py index 1dcd7c0f..c3c308fd 100644 --- a/app/settings.py +++ b/app/settings.py @@ -45,22 +45,24 @@ class Paths: def get_img_path(cls): return join(cls.get_app_dir(), "images") + # ARTISTS @classmethod def get_artist_img_path(cls): return join(cls.get_img_path(), "artists") @classmethod - def get_artist_img_sm_path(cls): + def get_sm_artist_img_path(cls): return join(cls.get_artist_img_path(), "small") @classmethod - def get_artist_img_lg_path(cls): - return join(cls.get_artist_img_path(), "large") + def get_md_artist_img_path(cls): + return join(cls.get_artist_img_path(), "medium") @classmethod - def get_playlist_img_path(cls): - return join(cls.get_img_path(), "playlists") + def get_lg_artist_img_path(cls): + return join(cls.get_artist_img_path(), "large") + # TRACK THUMBNAILS @classmethod def get_thumbs_path(cls): return join(cls.get_img_path(), "thumbnails") @@ -69,10 +71,23 @@ class Paths: def get_sm_thumb_path(cls): return join(cls.get_thumbs_path(), "small") + @classmethod + def get_xsm_thumb_path(cls): + return join(cls.get_thumbs_path(), "xsmall") + + @classmethod + def get_md_thumb_path(cls): + return join(cls.get_thumbs_path(), "medium") + @classmethod def get_lg_thumb_path(cls): return join(cls.get_thumbs_path(), "large") + # OTHERS + @classmethod + def get_playlist_img_path(cls): + return join(cls.get_img_path(), "playlists") + @classmethod def get_assets_path(cls): return join(Paths.get_app_dir(), "assets") @@ -85,15 +100,32 @@ class Paths: def get_lyrics_plugins_path(cls): return join(Paths.get_plugins_path(), "lyrics") + @classmethod + def get_config_file_path(cls): + return join(cls.get_app_dir(), "settings.json") + # defaults class Defaults: - THUMB_SIZE = 512 - SM_THUMB_SIZE = 128 + """ + Contains default values for various settings. + + XSM_THUMB_SIZE: extra small thumbnail size for web client tracklist + SM_THUMB_SIZE: small thumbnail size for android client tracklist + MD_THUMB_SIZE: medium thumbnail size for web client album cards + LG_THUMB_SIZE: large thumbnail size for web client now playing album art + + NOTE: LG_ARTIST_IMG_SIZE is not defined as the images are saved in the original size (500px) + """ + + XSM_THUMB_SIZE = 64 + SM_THUMB_SIZE = 96 + MD_THUMB_SIZE = 256 + LG_THUMB_SIZE = 512 + SM_ARTIST_IMG_SIZE = 128 - """ - The size of extracted images in pixels - """ + MD_ARTIST_IMG_SIZE = 256 + HASH_LENGTH = 10 API_ALBUMHASH = "bfe300e966" API_ARTISTHASH = "cae59f1fc5" @@ -101,7 +133,6 @@ class Defaults: API_ALBUMNAME = "The Goat" API_ARTISTNAME = "Polo G" API_TRACKNAME = "Martin & Gina" - API_CARD_LIMIT = 6 @@ -158,6 +189,8 @@ class ALLARGS: host = "--host" config = "--config" + pswd = "--pswd" + show_feat = ("--show-feat", "-sf") show_prod = ("--show-prod", "-sp") dont_clean_albums = ("--no-clean-albums", "-nca") @@ -264,7 +297,14 @@ def getCurrentBranch(): return "" -class Keys: +class Info: + """ + Contains information about the app + + NOTE: This class initially written to load keys when running in build mode. + TODO: Remove this class entirely, and implement functionality where needed. + """ + SWINGMUSIC_APP_VERSION = os.environ.get("SWINGMUSIC_APP_VERSION") GIT_LATEST_COMMIT_HASH = "" GIT_CURRENT_BRANCH = "" @@ -279,12 +319,6 @@ class Keys: cls.GIT_LATEST_COMMIT_HASH = getLatestCommitHash() cls.GIT_CURRENT_BRANCH = getCurrentBranch() - cls.verify_keys() - - @classmethod - def verify_keys(cls): - pass - @classmethod def get(cls, key: str): return getattr(cls, key, None) diff --git a/app/setup/__init__.py b/app/setup/__init__.py index 4e840c44..36e7e356 100644 --- a/app/setup/__init__.py +++ b/app/setup/__init__.py @@ -2,6 +2,7 @@ Prepares the server for use. """ +import uuid from app.db.sqlite.settings import load_settings from app.setup.files import create_config_dir from app.setup.sqlite import run_migrations, setup_sqlite @@ -9,10 +10,22 @@ from app.store.albums import AlbumStore from app.store.artists import ArtistStore from app.store.tracks import TrackStore from app.utils.generators import get_random_str +from app.config import UserConfig def run_setup(): + """ + Creates the config directory, runs migrations, and loads settings. + """ create_config_dir() + + # setup config file + config = UserConfig() + config.setup_config_file() + + if not config.userId: + config.userId = str(uuid.uuid4()) + setup_sqlite() run_migrations() @@ -22,6 +35,11 @@ def run_setup(): # settings table is empty pass + +def load_into_mem(): + """ + Load all tracks, albums, and artists into memory. + """ instance_key = get_random_str() # INFO: Load all tracks, albums, and artists into memory diff --git a/app/setup/files.py b/app/setup/files.py index 46f6a001..4e875750 100644 --- a/app/setup/files.py +++ b/app/setup/files.py @@ -51,28 +51,28 @@ def create_config_dir() -> None: """ Creates the config directory if it doesn't exist. """ - thumb_path = os.path.join("images", "thumbnails") - small_thumb_path = os.path.join(thumb_path, "small") - large_thumb_path = os.path.join(thumb_path, "large") + sm_thumb_path = settings.Paths.get_sm_thumb_path() + lg_thumb_path = settings.Paths.get_lg_thumb_path() + md_thumb_path = settings.Paths.get_md_thumb_path() + xsm_thumb_path = settings.Paths.get_xsm_thumb_path() - artist_img_path = os.path.join("images", "artists") - small_artist_img_path = os.path.join(artist_img_path, "small") - large_artist_img_path = os.path.join(artist_img_path, "large") + small_artist_img_path = settings.Paths.get_sm_artist_img_path() + md_artist_img_path = settings.Paths.get_md_artist_img_path() + large_artist_img_path = settings.Paths.get_lg_artist_img_path() playlist_img_path = os.path.join("images", "playlists") dirs = [ "", # creates the config folder - "images", - "plugins", + sm_thumb_path, + lg_thumb_path, + md_thumb_path, + xsm_thumb_path, "plugins/lyrics", - thumb_path, - small_thumb_path, - large_thumb_path, - artist_img_path, + playlist_img_path, + md_artist_img_path, small_artist_img_path, large_artist_img_path, - playlist_img_path, ] for _dir in dirs: @@ -80,7 +80,9 @@ def create_config_dir() -> None: exists = os.path.exists(path) if not exists: - os.makedirs(path) + # exist_ok=True to create parent directories if they don't exist + os.makedirs(path, exist_ok=True) os.chmod(path, 0o755) + # copy assets to the app directory CopyFiles() diff --git a/app/setup/sqlite.py b/app/setup/sqlite.py index 0506fb82..667d2e4b 100644 --- a/app/setup/sqlite.py +++ b/app/setup/sqlite.py @@ -4,6 +4,7 @@ Applies migrations. """ from app.db.sqlite import create_connection, create_tables, queries +from app.db.sqlite.auth import SQLiteAuthMethods as authdb from app.migrations import apply_migrations from app.settings import Db @@ -29,5 +30,8 @@ def setup_sqlite(): create_tables(user_db_conn, queries.CREATE_USERDATA_TABLES) create_tables(app_db_conn, queries.CREATE_MIGRATIONS_TABLE) + if not authdb.get_all_users(): + authdb.insert_default_user() + app_db_conn.close() user_db_conn.close() diff --git a/app/start_info_logger.py b/app/start_info_logger.py index 94cba209..5b497341 100644 --- a/app/start_info_logger.py +++ b/app/start_info_logger.py @@ -1,16 +1,16 @@ import os -from app.settings import FLASKVARS, TCOLOR, Keys, Paths +from app.settings import FLASKVARS, TCOLOR, Info, Paths from app.utils.network import get_ip def log_startup_info(): lines = "------------------------------" # clears terminal 👇 - os.system("cls" if os.name == "nt" else "echo -e \\\\033c") + # os.system("cls" if os.name == "nt" else "echo -e \\\\033c") print(lines) - print(f"{TCOLOR.HEADER}SwingMusic {Keys.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}") + print(f"{TCOLOR.HEADER}Swing Music v{Info.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}") adresses = [FLASKVARS.get_flask_host()] diff --git a/app/store/tracks.py b/app/store/tracks.py index 6b56a829..ae699e01 100644 --- a/app/store/tracks.py +++ b/app/store/tracks.py @@ -1,7 +1,7 @@ # from tqdm import tqdm from app.db.sqlite.favorite import SQLiteFavoriteMethods as favdb -from app.db.sqlite.tracks import SQLiteTrackMethods as tdb +from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb from app.models import Track from app.utils.bisection import use_bisection from app.utils.customlist import CustomList @@ -23,7 +23,7 @@ class TrackStore: global TRACKS_LOAD_KEY TRACKS_LOAD_KEY = instance_key - cls.tracks = CustomList(tdb.get_all_tracks()) + cls.tracks = CustomList(trackdb.get_all_tracks()) fav_hashes = favdb.get_fav_tracks() fav_hashes = " ".join([t[1] for t in fav_hashes]) @@ -84,17 +84,6 @@ class TrackStore: if track.filepath in filepaths: cls.remove_track_obj(track) - @classmethod - def remove_tracks_by_dir_except(cls, dirs: list[str]): - """Removes all tracks not in the root directories.""" - to_remove = set() - - for track in cls.tracks: - if not track.folder.startswith(tuple(dirs)): - to_remove.add(track.folder) - - tdb.remove_tracks_by_folders(to_remove) - @classmethod def count_tracks_by_trackhash(cls, trackhash: str) -> int: """ diff --git a/app/utils/auth.py b/app/utils/auth.py new file mode 100644 index 00000000..87f20c75 --- /dev/null +++ b/app/utils/auth.py @@ -0,0 +1,31 @@ +import hmac +import hashlib + +from app.config import UserConfig + + +def hash_password(password: str) -> str: + """ + Hashes the given password using sha256 algorithm and the user id as salt. + + :param password: The password to hash. + + :return: The hashed password. + """ + + return hashlib.pbkdf2_hmac( + "sha256", password.encode("utf-8"), UserConfig().userId.encode("utf-8"), 100000 + ).hex() + + +def check_password(password: str, hashed: str) -> bool: + """ + This function checks if the given password matches the hashed password. + + :param password: The password to check. + :param hashed: The hashed password. + + :return: Whether the password matches. + """ + + return hmac.compare_digest(hash_password(password), hashed) diff --git a/app/utils/files.py b/app/utils/files.py new file mode 100644 index 00000000..382ac423 --- /dev/null +++ b/app/utils/files.py @@ -0,0 +1,21 @@ +import mimetypes + + +def get_mime_from_ext(filename: str): + """ + Constructs a mime type from a file extension. + """ + ext = filename.rsplit(".", maxsplit=1)[-1] + return f"audio/{ext}" + + +def guess_mime_type(filename: str): + """ + Guess the mime type of a file. + """ + type = mimetypes.guess_type(filename)[0] + + if type is None: + return get_mime_from_ext(filename) + + return type diff --git a/app/utils/paths.py b/app/utils/paths.py index fe6bcfb4..d817010d 100644 --- a/app/utils/paths.py +++ b/app/utils/paths.py @@ -1,5 +1,8 @@ +import os import sys +from app.utils.filesystem import get_home_res_path + def getFlaskOpenApiPath(): """ @@ -10,3 +13,19 @@ def getFlaskOpenApiPath(): site_packages_path = [p for p in sys.path if "site-packages" in p][0] return f"{site_packages_path}/flask_openapi3" + + +def getClientFilesExtensions(): + """ + Get all the file extensions for the client files + """ + + client_path = get_home_res_path("client") + + extensions = set() + for root, dirs, files in os.walk(client_path): + for file in files: + ext = file.split(".")[-1] + extensions.add("." + ext) + + return extensions diff --git a/manage.py b/manage.py index 56698253..f27c26bd 100644 --- a/manage.py +++ b/manage.py @@ -2,8 +2,16 @@ This file is used to run the application. """ +from datetime import datetime, timezone import os import logging +from flask_jwt_extended import ( + create_access_token, + get_jwt, + get_jwt_identity, + set_access_cookies, + verify_jwt_in_request, +) import psutil import mimetypes from flask import Response, request @@ -12,14 +20,15 @@ import waitress import setproctitle from app.api import create_api -from app.arg_handler import HandleArgs +from app.arg_handler import ProcessArgs from app.lib.watchdogg import Watcher as WatchDog from app.periodic_scan import run_periodic_scans from app.plugins.register import register_plugins -from app.settings import FLASKVARS, TCOLOR, Keys -from app.setup import run_setup +from app.settings import FLASKVARS, TCOLOR, Info +from app.setup import load_into_mem, run_setup from app.start_info_logger import log_startup_info from app.utils.filesystem import get_home_res_path +from app.utils.paths import getClientFilesExtensions from app.utils.threading import background mimetypes.add_type("text/css", ".css") @@ -38,9 +47,84 @@ mimetypes.add_type("application/manifest+json", ".webmanifest") werkzeug = logging.getLogger("werkzeug") werkzeug.setLevel(logging.ERROR) +# Background tasks +@background +def bg_run_setup(): + run_periodic_scans() + + +@background +def start_watchdog(): + WatchDog().run() + + +@background +def run_swingmusic(): + log_startup_info() + bg_run_setup() + register_plugins() + + start_watchdog() + + setproctitle.setproctitle(f"swingmusic ::{FLASKVARS.get_flask_port()}") + + +# Setup function calls +Info.load() +ProcessArgs() +run_setup() +load_into_mem() +run_swingmusic() + + +# Create the Flask app + app = create_api() app.static_folder = get_home_res_path("client") +# INFO: Routes that don't need authentication +whitelisted_routes = {"/auth/login", "/auth/users", "/auth/logout", "/docs"} +blacklist_extensions = {".webp"}.union(getClientFilesExtensions()) + + +@app.before_request +def verify_auth(): + """ + Verifies the JWT token before each request. + """ + if request.path == "/" or any( + request.path.endswith(ext) for ext in blacklist_extensions + ): + return + + # if request path starts with any of the blacklisted routes, don't verify jwt + if any(request.path.startswith(route) for route in whitelisted_routes): + # print( + # "Found whitelisted route: ", request.path, "... Skipping jwt verification" + # ) + return + + verify_jwt_in_request() + + +@app.after_request +def refresh_expiring_jwt(response: Response): + """ + Refreshes the JWT token after each request. + """ + try: + exp_timestamp = get_jwt()["exp"] + now = datetime.now(timezone.utc) + target_timestamp = datetime.timestamp(now) + 60 * 60 * 24 * 7 # 7 days + + if target_timestamp > exp_timestamp: + access_token = create_access_token(identity=get_jwt_identity()) + set_access_cookies(response, access_token) + + return response + except (RuntimeError, KeyError): + return response + @app.route("/") def serve_client_files(path: str): @@ -106,33 +190,7 @@ def print_memory_usage(response: Response): return response -@background -def bg_run_setup() -> None: - run_periodic_scans() - - -@background -def start_watchdog(): - WatchDog().run() - - -@background -def run_swingmusic(): - log_startup_info() - run_setup() - bg_run_setup() - register_plugins() - - start_watchdog() - - setproctitle.setproctitle(f"swingmusic ::{FLASKVARS.get_flask_port()}") - - if __name__ == "__main__": - Keys.load() - HandleArgs() - run_swingmusic() - host = FLASKVARS.get_flask_host() port = FLASKVARS.get_flask_port() diff --git a/poetry.lock b/poetry.lock index 0b9fa9b7..481d868a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -557,6 +557,25 @@ files = [ Flask = ">=0.9" Six = "*" +[[package]] +name = "flask-jwt-extended" +version = "4.6.0" +description = "Extended JWT integration with Flask" +optional = false +python-versions = ">=3.7,<4" +files = [ + {file = "Flask-JWT-Extended-4.6.0.tar.gz", hash = "sha256:9215d05a9413d3855764bcd67035e75819d23af2fafb6b55197eb5a3313fdfb2"}, + {file = "Flask_JWT_Extended-4.6.0-py2.py3-none-any.whl", hash = "sha256:63a28fc9731bcc6c4b8815b6f954b5904caa534fc2ae9b93b1d3ef12930dca95"}, +] + +[package.dependencies] +Flask = ">=2.0,<4.0" +PyJWT = ">=2.0,<3.0" +Werkzeug = ">=0.14" + +[package.extras] +asymmetric-crypto = ["cryptography (>=3.3.1)"] + [[package]] name = "flask-openapi3" version = "3.0.2" @@ -1599,6 +1618,23 @@ files = [ {file = "pyinstaller_hooks_contrib-2023.9-py2.py3-none-any.whl", hash = "sha256:f34f4c6807210025c8073ebe665f422a3aa2ac5f4c7ebf4c2a26cc77bebf63b5"}, ] +[[package]] +name = "pyjwt" +version = "2.8.0" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"}, + {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + [[package]] name = "pylint" version = "2.17.7" @@ -2468,4 +2504,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "feb13f92b7b3a909fcb851860a405b96579feac0e2dde7681ed0e9c381c4f6cd" +content-hash = "317c4094a6f768467219db94be02a6e2e9f1ed5ca18d929437094bf60be1594d" diff --git a/pyproject.toml b/pyproject.toml index 005eaed5..5ff4920b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ waitress = "^2.1.2" watchdog = "^4.0.0" pendulum = "^3.0.0" flask-openapi3 = "^3.0.2" +flask-jwt-extended = "^4.6.0" [tool.poetry.dev-dependencies] pylint = "^2.15.5" diff --git a/requirements.txt b/requirements.txt index 51b16ecf..1089e561 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ Flask==2.3.3 Flask-BasicAuth==0.2.0 Flask-Compress==1.14 Flask-Cors==3.0.10 +Flask-JWT-Extended==4.6.0 flask-openapi3==3.0.2 gevent==23.9.1 geventhttpclient==2.0.11 @@ -42,6 +43,7 @@ pydantic==2.6.3 pydantic_core==2.16.3 pyinstaller==5.13.2 pyinstaller-hooks-contrib==2023.9 +PyJWT==2.8.0 pylint==2.17.7 pytest==7.4.2 python-dateutil==2.8.2