diff --git a/app/api/__init__.py b/app/api/__init__.py index 3a7e1afb..2fb0e1c5 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -63,12 +63,11 @@ def create_api(): ) app = OpenAPI(__name__, info=api_info, doc_prefix="/docs") - print("userid", UserConfig().userId) # JWT CONFIGS app.config["JWT_SECRET_KEY"] = UserConfig().userId app.config["JWT_TOKEN_LOCATION"] = ["cookies"] app.config["JWT_COOKIE_CSRF_PROTECT"] = False - app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=1) + app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=30) # CORS CORS(app, origins="*", supports_credentials=True) diff --git a/app/api/auth.py b/app/api/auth.py index 0edc2424..83dfda38 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -191,11 +191,6 @@ def delete_user(body: DeleteUseBody): return {"msg": "Cannot delete the only admin"}, 400 authdb.delete_user_by_username(body.username) - - # if user is guest, update config - if body.username == "guest": - UserConfig().enableGuest = False - return {"msg": f"User {body.username} deleted"} @@ -225,7 +220,7 @@ def get_all_users(query: GetAllUsersQuery): # config.enableGuest = True # config.usersOnLogin = True settings = { - "enableGuest": config.enableGuest, + "enableGuest": False, "usersOnLogin": config.usersOnLogin, } @@ -234,7 +229,10 @@ def get_all_users(query: GetAllUsersQuery): "users": [], } + users = authdb.get_all_users() + is_admin = current_user and "admin" in current_user["roles"] + settings['enableGuest'] = [user for user in users if user.username == "guest"].__len__() > 0 # if user is admin, also return settings if is_admin: @@ -254,13 +252,12 @@ def get_all_users(query: GetAllUsersQuery): ): return res - users = authdb.get_all_users() # remove guest user # if not settings["enableGuest"]: # users = [user for user in users if user.username != "guest"] - if not is_admin or not settings["usersOnLogin"]: + if not settings["usersOnLogin"]: users = [user for user in users if user.username == "guest"] # reverse list to show latest users first @@ -268,7 +265,6 @@ def get_all_users(query: GetAllUsersQuery): # bring admins to the front users = sorted(users, key=lambda x: "admin" in x.roles, reverse=True) - # bring current user to index 0 if current_user: users = sorted( diff --git a/app/api/send_file.py b/app/api/send_file.py index 72fe8b37..5c804079 100644 --- a/app/api/send_file.py +++ b/app/api/send_file.py @@ -68,15 +68,31 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery): def send_file_as_chunks(filepath: str, audio_type: str) -> Response: + """ + Returns a Response object that streams the file in chunks. + """ + # NOTE: +1 makes sure the last byte is included in the range. + # NOTE: -1 is used to convert the end index to a 0-based index. + chunk_size = 1024 * 512 + + # Get file size file_size = os.path.getsize(filepath) start = 0 - end = file_size - 1 + end = chunk_size + # Read range header range_header = request.headers.get("Range") if range_header: - start, end = parse_range_header(range_header, file_size) + start = get_start_range(range_header) - chunk_size = 1024 * 1024 # 1MB chunk size (adjust as needed) + # If start + chunk_size is greater than file_size, + # set end to file_size - 1 + _end = start + chunk_size - 1 + + if _end > file_size: + end = file_size - 1 + else: + end = _end def generate_chunks(): with open(filepath, "rb") as file: @@ -84,8 +100,11 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response: remaining_bytes = end - start + 1 while remaining_bytes > 0: + # Read the chunk size or all the remaining bytes chunk = file.read(min(chunk_size, remaining_bytes)) yield chunk + + # Update the remaining bytes remaining_bytes -= len(chunk) response = Response( @@ -102,15 +121,13 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response: return response -def parse_range_header(range_header: str, file_size: int) -> tuple[int, int]: +def get_start_range(range_header: str): try: range_start, range_end = range_header.strip().split("=")[1].split("-") - start = int(range_start) - end = min(int(range_end), file_size - 1) - except ValueError: - return 0, file_size - 1 + return int(range_start) - return start, end + except ValueError: + return 0 class GetAudioSilenceBody(BaseModel): diff --git a/app/api/settings.py b/app/api/settings.py index 78ecad82..4931bc52 100644 --- a/app/api/settings.py +++ b/app/api/settings.py @@ -7,6 +7,7 @@ from app.api.auth import admin_required from app.db.sqlite.plugins import PluginsMethods as pdb from app.db.sqlite.settings import SettingsSQLMethods as sdb +from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb from app.lib import populate from app.lib.watchdogg import Watcher as WatchDog from app.logger import log @@ -51,12 +52,12 @@ def reload_everything(instance_key: str): @background def rebuild_store(db_dirs: list[str]): """ - Restarts the watchdog and rebuilds the music library. + Restarts watchdog and rebuilds the music library. """ instance_key = get_random_str() log.info("Rebuilding library...") - TrackStore.remove_tracks_by_dir_except(db_dirs) + trackdb.remove_tracks_not_in_folders(db_dirs) reload_everything(instance_key) try: @@ -106,10 +107,10 @@ def add_root_dirs(body: AddRootDirsBody): removed_dirs = body.removed db_dirs = sdb.get_root_dirs() - _h = "$home" + home = "$home" - db_home = any([d == _h for d in db_dirs]) # if $home is in db - incoming_home = any([d == _h for d in new_dirs]) # if $home is in incoming + db_home = any([d == home for d in db_dirs]) # if $home is in db + incoming_home = any([d == home for d in new_dirs]) # if $home is in incoming # handle $home case if db_home and incoming_home: @@ -119,8 +120,8 @@ def add_root_dirs(body: AddRootDirsBody): sdb.remove_root_dirs(db_dirs) if incoming_home: - finalize([_h], [], [Paths.USER_HOME_DIR]) - return {"root_dirs": [_h]} + finalize([home], [], [Paths.USER_HOME_DIR]) + return {"root_dirs": [home]} # --- @@ -135,7 +136,7 @@ def add_root_dirs(body: AddRootDirsBody): pass db_dirs.extend(new_dirs) - db_dirs = [dir_ for dir_ in db_dirs if dir_ != _h] + db_dirs = [dir_ for dir_ in db_dirs if dir_ != home] finalize(new_dirs, removed_dirs, db_dirs) diff --git a/app/config.py b/app/config.py index 4d4d997b..00e6cde5 100644 --- a/app/config.py +++ b/app/config.py @@ -16,7 +16,6 @@ class UserConfig: # NOTE: Don't expose the userId via the API userId: str = "" usersOnLogin: bool = True - enableGuest: bool = False # lists rootDirs: list[str] = field(default_factory=list) diff --git a/app/db/sqlite/tracks.py b/app/db/sqlite/tracks.py index ab45f888..c8bed28c 100644 --- a/app/db/sqlite/tracks.py +++ b/app/db/sqlite/tracks.py @@ -112,9 +112,10 @@ class SQLiteTrackMethods: cur.execute("DELETE FROM tracks WHERE filepath=?", (filepath,)) @staticmethod - def remove_tracks_by_folders(folders: set[str]): - sql = "DELETE FROM tracks WHERE folder = ?" + def remove_tracks_not_in_folders(folders: set[str]): + sql = "DELETE FROM tracks WHERE folder NOT IN ({})".format( + ",".join("?" * len(folders)) + ) with SQLiteManager() as cur: - for folder in folders: - cur.execute(sql, (folder,)) + cur.execute(sql, tuple(folders)) diff --git a/app/lib/trackslib.py b/app/lib/trackslib.py index a685c0bd..085dda0c 100644 --- a/app/lib/trackslib.py +++ b/app/lib/trackslib.py @@ -1,12 +1,13 @@ """ This library contains all the functions related to tracks. """ + import os from app.lib.pydub.pydub import AudioSegment from app.lib.pydub.pydub.silence import detect_leading_silence, detect_silence -from app.db.sqlite.tracks import SQLiteTrackMethods as tdb +from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb from app.store.tracks import TrackStore from app.utils.progressbar import tqdm from app.utils.threading import ThreadWithReturnValue @@ -19,7 +20,7 @@ def validate_tracks() -> None: for track in tqdm(TrackStore.tracks, desc="Validating tracks"): if not os.path.exists(track.filepath): TrackStore.remove_track_obj(track) - tdb.remove_tracks_by_filepaths(track.filepath) + trackdb.remove_tracks_by_filepaths(track.filepath) def get_leading_silence_end(filepath: str): diff --git a/app/store/tracks.py b/app/store/tracks.py index 6b56a829..ae699e01 100644 --- a/app/store/tracks.py +++ b/app/store/tracks.py @@ -1,7 +1,7 @@ # from tqdm import tqdm from app.db.sqlite.favorite import SQLiteFavoriteMethods as favdb -from app.db.sqlite.tracks import SQLiteTrackMethods as tdb +from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb from app.models import Track from app.utils.bisection import use_bisection from app.utils.customlist import CustomList @@ -23,7 +23,7 @@ class TrackStore: global TRACKS_LOAD_KEY TRACKS_LOAD_KEY = instance_key - cls.tracks = CustomList(tdb.get_all_tracks()) + cls.tracks = CustomList(trackdb.get_all_tracks()) fav_hashes = favdb.get_fav_tracks() fav_hashes = " ".join([t[1] for t in fav_hashes]) @@ -84,17 +84,6 @@ class TrackStore: if track.filepath in filepaths: cls.remove_track_obj(track) - @classmethod - def remove_tracks_by_dir_except(cls, dirs: list[str]): - """Removes all tracks not in the root directories.""" - to_remove = set() - - for track in cls.tracks: - if not track.folder.startswith(tuple(dirs)): - to_remove.add(track.folder) - - tdb.remove_tracks_by_folders(to_remove) - @classmethod def count_tracks_by_trackhash(cls, trackhash: str) -> int: """ diff --git a/manage.py b/manage.py index b7068dfd..cd83f9c8 100644 --- a/manage.py +++ b/manage.py @@ -2,9 +2,16 @@ This file is used to run the application. """ +from datetime import datetime, timezone import os import logging -from flask_jwt_extended import verify_jwt_in_request +from flask_jwt_extended import ( + create_access_token, + get_jwt, + get_jwt_identity, + set_access_cookies, + verify_jwt_in_request, +) import psutil import mimetypes from flask import Response, request @@ -70,6 +77,25 @@ def verify_auth(): verify_jwt_in_request() +@app.after_request +def refresh_expiring_jwt(response: Response): + """ + Refreshes the JWT token after each request. + """ + try: + exp_timestamp = get_jwt()["exp"] + now = datetime.now(timezone.utc) + target_timestamp = datetime.timestamp(now) + 60 * 60 * 24 * 7 # 7 days + + if target_timestamp > exp_timestamp: + access_token = create_access_token(identity=get_jwt_identity()) + set_access_cookies(response, access_token) + + return response + except (RuntimeError, KeyError): + return response + + @app.route("/") def serve_client_files(path: str): """