fix: chunked audio stream

desc: faulty content range headers

+ fix: tracks not being removed from db on root dirs change
+ implement implicit jwt refreshing
+ remove enableGuest from configs
+ set jwt validity to 30 days
This commit is contained in:
mungai-njoroge
2024-05-05 23:55:25 +03:00
parent fdf3186be6
commit 36600ab782
9 changed files with 78 additions and 49 deletions
+1 -2
View File
@@ -63,12 +63,11 @@ def create_api():
) )
app = OpenAPI(__name__, info=api_info, doc_prefix="/docs") app = OpenAPI(__name__, info=api_info, doc_prefix="/docs")
print("userid", UserConfig().userId)
# JWT CONFIGS # JWT CONFIGS
app.config["JWT_SECRET_KEY"] = UserConfig().userId app.config["JWT_SECRET_KEY"] = UserConfig().userId
app.config["JWT_TOKEN_LOCATION"] = ["cookies"] app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
app.config["JWT_COOKIE_CSRF_PROTECT"] = False app.config["JWT_COOKIE_CSRF_PROTECT"] = False
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=1) app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=30)
# CORS # CORS
CORS(app, origins="*", supports_credentials=True) CORS(app, origins="*", supports_credentials=True)
+5 -9
View File
@@ -191,11 +191,6 @@ def delete_user(body: DeleteUseBody):
return {"msg": "Cannot delete the only admin"}, 400 return {"msg": "Cannot delete the only admin"}, 400
authdb.delete_user_by_username(body.username) authdb.delete_user_by_username(body.username)
# if user is guest, update config
if body.username == "guest":
UserConfig().enableGuest = False
return {"msg": f"User {body.username} deleted"} return {"msg": f"User {body.username} deleted"}
@@ -225,7 +220,7 @@ def get_all_users(query: GetAllUsersQuery):
# config.enableGuest = True # config.enableGuest = True
# config.usersOnLogin = True # config.usersOnLogin = True
settings = { settings = {
"enableGuest": config.enableGuest, "enableGuest": False,
"usersOnLogin": config.usersOnLogin, "usersOnLogin": config.usersOnLogin,
} }
@@ -234,7 +229,10 @@ def get_all_users(query: GetAllUsersQuery):
"users": [], "users": [],
} }
users = authdb.get_all_users()
is_admin = current_user and "admin" in current_user["roles"] is_admin = current_user and "admin" in current_user["roles"]
settings['enableGuest'] = [user for user in users if user.username == "guest"].__len__() > 0
# if user is admin, also return settings # if user is admin, also return settings
if is_admin: if is_admin:
@@ -254,13 +252,12 @@ def get_all_users(query: GetAllUsersQuery):
): ):
return res return res
users = authdb.get_all_users()
# remove guest user # remove guest user
# if not settings["enableGuest"]: # if not settings["enableGuest"]:
# users = [user for user in users if user.username != "guest"] # users = [user for user in users if user.username != "guest"]
if not is_admin or not settings["usersOnLogin"]: if not settings["usersOnLogin"]:
users = [user for user in users if user.username == "guest"] users = [user for user in users if user.username == "guest"]
# reverse list to show latest users first # reverse list to show latest users first
@@ -268,7 +265,6 @@ def get_all_users(query: GetAllUsersQuery):
# bring admins to the front # bring admins to the front
users = sorted(users, key=lambda x: "admin" in x.roles, reverse=True) users = sorted(users, key=lambda x: "admin" in x.roles, reverse=True)
# bring current user to index 0 # bring current user to index 0
if current_user: if current_user:
users = sorted( users = sorted(
+26 -9
View File
@@ -68,15 +68,31 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery):
def send_file_as_chunks(filepath: str, audio_type: str) -> Response: def send_file_as_chunks(filepath: str, audio_type: str) -> Response:
"""
Returns a Response object that streams the file in chunks.
"""
# NOTE: +1 makes sure the last byte is included in the range.
# NOTE: -1 is used to convert the end index to a 0-based index.
chunk_size = 1024 * 512
# Get file size
file_size = os.path.getsize(filepath) file_size = os.path.getsize(filepath)
start = 0 start = 0
end = file_size - 1 end = chunk_size
# Read range header
range_header = request.headers.get("Range") range_header = request.headers.get("Range")
if range_header: if range_header:
start, end = parse_range_header(range_header, file_size) start = get_start_range(range_header)
chunk_size = 1024 * 1024 # 1MB chunk size (adjust as needed) # If start + chunk_size is greater than file_size,
# set end to file_size - 1
_end = start + chunk_size - 1
if _end > file_size:
end = file_size - 1
else:
end = _end
def generate_chunks(): def generate_chunks():
with open(filepath, "rb") as file: with open(filepath, "rb") as file:
@@ -84,8 +100,11 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response:
remaining_bytes = end - start + 1 remaining_bytes = end - start + 1
while remaining_bytes > 0: while remaining_bytes > 0:
# Read the chunk size or all the remaining bytes
chunk = file.read(min(chunk_size, remaining_bytes)) chunk = file.read(min(chunk_size, remaining_bytes))
yield chunk yield chunk
# Update the remaining bytes
remaining_bytes -= len(chunk) remaining_bytes -= len(chunk)
response = Response( response = Response(
@@ -102,15 +121,13 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response:
return response return response
def parse_range_header(range_header: str, file_size: int) -> tuple[int, int]: def get_start_range(range_header: str):
try: try:
range_start, range_end = range_header.strip().split("=")[1].split("-") range_start, range_end = range_header.strip().split("=")[1].split("-")
start = int(range_start) return int(range_start)
end = min(int(range_end), file_size - 1)
except ValueError:
return 0, file_size - 1
return start, end except ValueError:
return 0
class GetAudioSilenceBody(BaseModel): class GetAudioSilenceBody(BaseModel):
+9 -8
View File
@@ -7,6 +7,7 @@ from app.api.auth import admin_required
from app.db.sqlite.plugins import PluginsMethods as pdb from app.db.sqlite.plugins import PluginsMethods as pdb
from app.db.sqlite.settings import SettingsSQLMethods as sdb from app.db.sqlite.settings import SettingsSQLMethods as sdb
from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb
from app.lib import populate from app.lib import populate
from app.lib.watchdogg import Watcher as WatchDog from app.lib.watchdogg import Watcher as WatchDog
from app.logger import log from app.logger import log
@@ -51,12 +52,12 @@ def reload_everything(instance_key: str):
@background @background
def rebuild_store(db_dirs: list[str]): def rebuild_store(db_dirs: list[str]):
""" """
Restarts the watchdog and rebuilds the music library. Restarts watchdog and rebuilds the music library.
""" """
instance_key = get_random_str() instance_key = get_random_str()
log.info("Rebuilding library...") log.info("Rebuilding library...")
TrackStore.remove_tracks_by_dir_except(db_dirs) trackdb.remove_tracks_not_in_folders(db_dirs)
reload_everything(instance_key) reload_everything(instance_key)
try: try:
@@ -106,10 +107,10 @@ def add_root_dirs(body: AddRootDirsBody):
removed_dirs = body.removed removed_dirs = body.removed
db_dirs = sdb.get_root_dirs() db_dirs = sdb.get_root_dirs()
_h = "$home" home = "$home"
db_home = any([d == _h for d in db_dirs]) # if $home is in db db_home = any([d == home for d in db_dirs]) # if $home is in db
incoming_home = any([d == _h for d in new_dirs]) # if $home is in incoming incoming_home = any([d == home for d in new_dirs]) # if $home is in incoming
# handle $home case # handle $home case
if db_home and incoming_home: if db_home and incoming_home:
@@ -119,8 +120,8 @@ def add_root_dirs(body: AddRootDirsBody):
sdb.remove_root_dirs(db_dirs) sdb.remove_root_dirs(db_dirs)
if incoming_home: if incoming_home:
finalize([_h], [], [Paths.USER_HOME_DIR]) finalize([home], [], [Paths.USER_HOME_DIR])
return {"root_dirs": [_h]} return {"root_dirs": [home]}
# --- # ---
@@ -135,7 +136,7 @@ def add_root_dirs(body: AddRootDirsBody):
pass pass
db_dirs.extend(new_dirs) db_dirs.extend(new_dirs)
db_dirs = [dir_ for dir_ in db_dirs if dir_ != _h] db_dirs = [dir_ for dir_ in db_dirs if dir_ != home]
finalize(new_dirs, removed_dirs, db_dirs) finalize(new_dirs, removed_dirs, db_dirs)
-1
View File
@@ -16,7 +16,6 @@ class UserConfig:
# NOTE: Don't expose the userId via the API # NOTE: Don't expose the userId via the API
userId: str = "" userId: str = ""
usersOnLogin: bool = True usersOnLogin: bool = True
enableGuest: bool = False
# lists # lists
rootDirs: list[str] = field(default_factory=list) rootDirs: list[str] = field(default_factory=list)
+5 -4
View File
@@ -112,9 +112,10 @@ class SQLiteTrackMethods:
cur.execute("DELETE FROM tracks WHERE filepath=?", (filepath,)) cur.execute("DELETE FROM tracks WHERE filepath=?", (filepath,))
@staticmethod @staticmethod
def remove_tracks_by_folders(folders: set[str]): def remove_tracks_not_in_folders(folders: set[str]):
sql = "DELETE FROM tracks WHERE folder = ?" sql = "DELETE FROM tracks WHERE folder NOT IN ({})".format(
",".join("?" * len(folders))
)
with SQLiteManager() as cur: with SQLiteManager() as cur:
for folder in folders: cur.execute(sql, tuple(folders))
cur.execute(sql, (folder,))
+3 -2
View File
@@ -1,12 +1,13 @@
""" """
This library contains all the functions related to tracks. This library contains all the functions related to tracks.
""" """
import os import os
from app.lib.pydub.pydub import AudioSegment from app.lib.pydub.pydub import AudioSegment
from app.lib.pydub.pydub.silence import detect_leading_silence, detect_silence from app.lib.pydub.pydub.silence import detect_leading_silence, detect_silence
from app.db.sqlite.tracks import SQLiteTrackMethods as tdb from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb
from app.store.tracks import TrackStore from app.store.tracks import TrackStore
from app.utils.progressbar import tqdm from app.utils.progressbar import tqdm
from app.utils.threading import ThreadWithReturnValue from app.utils.threading import ThreadWithReturnValue
@@ -19,7 +20,7 @@ def validate_tracks() -> None:
for track in tqdm(TrackStore.tracks, desc="Validating tracks"): for track in tqdm(TrackStore.tracks, desc="Validating tracks"):
if not os.path.exists(track.filepath): if not os.path.exists(track.filepath):
TrackStore.remove_track_obj(track) TrackStore.remove_track_obj(track)
tdb.remove_tracks_by_filepaths(track.filepath) trackdb.remove_tracks_by_filepaths(track.filepath)
def get_leading_silence_end(filepath: str): def get_leading_silence_end(filepath: str):
+2 -13
View File
@@ -1,7 +1,7 @@
# from tqdm import tqdm # from tqdm import tqdm
from app.db.sqlite.favorite import SQLiteFavoriteMethods as favdb from app.db.sqlite.favorite import SQLiteFavoriteMethods as favdb
from app.db.sqlite.tracks import SQLiteTrackMethods as tdb from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb
from app.models import Track from app.models import Track
from app.utils.bisection import use_bisection from app.utils.bisection import use_bisection
from app.utils.customlist import CustomList from app.utils.customlist import CustomList
@@ -23,7 +23,7 @@ class TrackStore:
global TRACKS_LOAD_KEY global TRACKS_LOAD_KEY
TRACKS_LOAD_KEY = instance_key TRACKS_LOAD_KEY = instance_key
cls.tracks = CustomList(tdb.get_all_tracks()) cls.tracks = CustomList(trackdb.get_all_tracks())
fav_hashes = favdb.get_fav_tracks() fav_hashes = favdb.get_fav_tracks()
fav_hashes = " ".join([t[1] for t in fav_hashes]) fav_hashes = " ".join([t[1] for t in fav_hashes])
@@ -84,17 +84,6 @@ class TrackStore:
if track.filepath in filepaths: if track.filepath in filepaths:
cls.remove_track_obj(track) cls.remove_track_obj(track)
@classmethod
def remove_tracks_by_dir_except(cls, dirs: list[str]):
"""Removes all tracks not in the root directories."""
to_remove = set()
for track in cls.tracks:
if not track.folder.startswith(tuple(dirs)):
to_remove.add(track.folder)
tdb.remove_tracks_by_folders(to_remove)
@classmethod @classmethod
def count_tracks_by_trackhash(cls, trackhash: str) -> int: def count_tracks_by_trackhash(cls, trackhash: str) -> int:
""" """
+27 -1
View File
@@ -2,9 +2,16 @@
This file is used to run the application. This file is used to run the application.
""" """
from datetime import datetime, timezone
import os import os
import logging import logging
from flask_jwt_extended import verify_jwt_in_request from flask_jwt_extended import (
create_access_token,
get_jwt,
get_jwt_identity,
set_access_cookies,
verify_jwt_in_request,
)
import psutil import psutil
import mimetypes import mimetypes
from flask import Response, request from flask import Response, request
@@ -70,6 +77,25 @@ def verify_auth():
verify_jwt_in_request() verify_jwt_in_request()
@app.after_request
def refresh_expiring_jwt(response: Response):
"""
Refreshes the JWT token after each request.
"""
try:
exp_timestamp = get_jwt()["exp"]
now = datetime.now(timezone.utc)
target_timestamp = datetime.timestamp(now) + 60 * 60 * 24 * 7 # 7 days
if target_timestamp > exp_timestamp:
access_token = create_access_token(identity=get_jwt_identity())
set_access_cookies(response, access_token)
return response
except (RuntimeError, KeyError):
return response
@app.route("/<path:path>") @app.route("/<path:path>")
def serve_client_files(path: str): def serve_client_files(path: str):
""" """