fix: chunked audio stream

desc: faulty content range headers

+ fix: tracks not being removed from db on root dirs change
+ implement implicit jwt refreshing
+ remove enableGuest from configs
+ set jwt validity to 30 days
This commit is contained in:
mungai-njoroge
2024-05-05 23:55:25 +03:00
parent fdf3186be6
commit 36600ab782
9 changed files with 78 additions and 49 deletions
+1 -2
View File
@@ -63,12 +63,11 @@ def create_api():
)
app = OpenAPI(__name__, info=api_info, doc_prefix="/docs")
print("userid", UserConfig().userId)
# JWT CONFIGS
app.config["JWT_SECRET_KEY"] = UserConfig().userId
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
app.config["JWT_COOKIE_CSRF_PROTECT"] = False
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=1)
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=30)
# CORS
CORS(app, origins="*", supports_credentials=True)
+5 -9
View File
@@ -191,11 +191,6 @@ def delete_user(body: DeleteUseBody):
return {"msg": "Cannot delete the only admin"}, 400
authdb.delete_user_by_username(body.username)
# if user is guest, update config
if body.username == "guest":
UserConfig().enableGuest = False
return {"msg": f"User {body.username} deleted"}
@@ -225,7 +220,7 @@ def get_all_users(query: GetAllUsersQuery):
# config.enableGuest = True
# config.usersOnLogin = True
settings = {
"enableGuest": config.enableGuest,
"enableGuest": False,
"usersOnLogin": config.usersOnLogin,
}
@@ -234,7 +229,10 @@ def get_all_users(query: GetAllUsersQuery):
"users": [],
}
users = authdb.get_all_users()
is_admin = current_user and "admin" in current_user["roles"]
settings['enableGuest'] = [user for user in users if user.username == "guest"].__len__() > 0
# if user is admin, also return settings
if is_admin:
@@ -254,13 +252,12 @@ def get_all_users(query: GetAllUsersQuery):
):
return res
users = authdb.get_all_users()
# remove guest user
# if not settings["enableGuest"]:
# users = [user for user in users if user.username != "guest"]
if not is_admin or not settings["usersOnLogin"]:
if not settings["usersOnLogin"]:
users = [user for user in users if user.username == "guest"]
# reverse list to show latest users first
@@ -268,7 +265,6 @@ def get_all_users(query: GetAllUsersQuery):
# bring admins to the front
users = sorted(users, key=lambda x: "admin" in x.roles, reverse=True)
# bring current user to index 0
if current_user:
users = sorted(
+26 -9
View File
@@ -68,15 +68,31 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery):
def send_file_as_chunks(filepath: str, audio_type: str) -> Response:
"""
Returns a Response object that streams the file in chunks.
"""
# NOTE: +1 makes sure the last byte is included in the range.
# NOTE: -1 is used to convert the end index to a 0-based index.
chunk_size = 1024 * 512
# Get file size
file_size = os.path.getsize(filepath)
start = 0
end = file_size - 1
end = chunk_size
# Read range header
range_header = request.headers.get("Range")
if range_header:
start, end = parse_range_header(range_header, file_size)
start = get_start_range(range_header)
chunk_size = 1024 * 1024 # 1MB chunk size (adjust as needed)
# If start + chunk_size is greater than file_size,
# set end to file_size - 1
_end = start + chunk_size - 1
if _end > file_size:
end = file_size - 1
else:
end = _end
def generate_chunks():
with open(filepath, "rb") as file:
@@ -84,8 +100,11 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response:
remaining_bytes = end - start + 1
while remaining_bytes > 0:
# Read the chunk size or all the remaining bytes
chunk = file.read(min(chunk_size, remaining_bytes))
yield chunk
# Update the remaining bytes
remaining_bytes -= len(chunk)
response = Response(
@@ -102,15 +121,13 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response:
return response
def parse_range_header(range_header: str, file_size: int) -> tuple[int, int]:
def get_start_range(range_header: str):
try:
range_start, range_end = range_header.strip().split("=")[1].split("-")
start = int(range_start)
end = min(int(range_end), file_size - 1)
except ValueError:
return 0, file_size - 1
return int(range_start)
return start, end
except ValueError:
return 0
class GetAudioSilenceBody(BaseModel):
+9 -8
View File
@@ -7,6 +7,7 @@ from app.api.auth import admin_required
from app.db.sqlite.plugins import PluginsMethods as pdb
from app.db.sqlite.settings import SettingsSQLMethods as sdb
from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb
from app.lib import populate
from app.lib.watchdogg import Watcher as WatchDog
from app.logger import log
@@ -51,12 +52,12 @@ def reload_everything(instance_key: str):
@background
def rebuild_store(db_dirs: list[str]):
"""
Restarts the watchdog and rebuilds the music library.
Restarts watchdog and rebuilds the music library.
"""
instance_key = get_random_str()
log.info("Rebuilding library...")
TrackStore.remove_tracks_by_dir_except(db_dirs)
trackdb.remove_tracks_not_in_folders(db_dirs)
reload_everything(instance_key)
try:
@@ -106,10 +107,10 @@ def add_root_dirs(body: AddRootDirsBody):
removed_dirs = body.removed
db_dirs = sdb.get_root_dirs()
_h = "$home"
home = "$home"
db_home = any([d == _h for d in db_dirs]) # if $home is in db
incoming_home = any([d == _h for d in new_dirs]) # if $home is in incoming
db_home = any([d == home for d in db_dirs]) # if $home is in db
incoming_home = any([d == home for d in new_dirs]) # if $home is in incoming
# handle $home case
if db_home and incoming_home:
@@ -119,8 +120,8 @@ def add_root_dirs(body: AddRootDirsBody):
sdb.remove_root_dirs(db_dirs)
if incoming_home:
finalize([_h], [], [Paths.USER_HOME_DIR])
return {"root_dirs": [_h]}
finalize([home], [], [Paths.USER_HOME_DIR])
return {"root_dirs": [home]}
# ---
@@ -135,7 +136,7 @@ def add_root_dirs(body: AddRootDirsBody):
pass
db_dirs.extend(new_dirs)
db_dirs = [dir_ for dir_ in db_dirs if dir_ != _h]
db_dirs = [dir_ for dir_ in db_dirs if dir_ != home]
finalize(new_dirs, removed_dirs, db_dirs)
-1
View File
@@ -16,7 +16,6 @@ class UserConfig:
# NOTE: Don't expose the userId via the API
userId: str = ""
usersOnLogin: bool = True
enableGuest: bool = False
# lists
rootDirs: list[str] = field(default_factory=list)
+5 -4
View File
@@ -112,9 +112,10 @@ class SQLiteTrackMethods:
cur.execute("DELETE FROM tracks WHERE filepath=?", (filepath,))
@staticmethod
def remove_tracks_by_folders(folders: set[str]):
sql = "DELETE FROM tracks WHERE folder = ?"
def remove_tracks_not_in_folders(folders: set[str]):
sql = "DELETE FROM tracks WHERE folder NOT IN ({})".format(
",".join("?" * len(folders))
)
with SQLiteManager() as cur:
for folder in folders:
cur.execute(sql, (folder,))
cur.execute(sql, tuple(folders))
+3 -2
View File
@@ -1,12 +1,13 @@
"""
This library contains all the functions related to tracks.
"""
import os
from app.lib.pydub.pydub import AudioSegment
from app.lib.pydub.pydub.silence import detect_leading_silence, detect_silence
from app.db.sqlite.tracks import SQLiteTrackMethods as tdb
from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb
from app.store.tracks import TrackStore
from app.utils.progressbar import tqdm
from app.utils.threading import ThreadWithReturnValue
@@ -19,7 +20,7 @@ def validate_tracks() -> None:
for track in tqdm(TrackStore.tracks, desc="Validating tracks"):
if not os.path.exists(track.filepath):
TrackStore.remove_track_obj(track)
tdb.remove_tracks_by_filepaths(track.filepath)
trackdb.remove_tracks_by_filepaths(track.filepath)
def get_leading_silence_end(filepath: str):
+2 -13
View File
@@ -1,7 +1,7 @@
# from tqdm import tqdm
from app.db.sqlite.favorite import SQLiteFavoriteMethods as favdb
from app.db.sqlite.tracks import SQLiteTrackMethods as tdb
from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb
from app.models import Track
from app.utils.bisection import use_bisection
from app.utils.customlist import CustomList
@@ -23,7 +23,7 @@ class TrackStore:
global TRACKS_LOAD_KEY
TRACKS_LOAD_KEY = instance_key
cls.tracks = CustomList(tdb.get_all_tracks())
cls.tracks = CustomList(trackdb.get_all_tracks())
fav_hashes = favdb.get_fav_tracks()
fav_hashes = " ".join([t[1] for t in fav_hashes])
@@ -84,17 +84,6 @@ class TrackStore:
if track.filepath in filepaths:
cls.remove_track_obj(track)
@classmethod
def remove_tracks_by_dir_except(cls, dirs: list[str]):
"""Removes all tracks not in the root directories."""
to_remove = set()
for track in cls.tracks:
if not track.folder.startswith(tuple(dirs)):
to_remove.add(track.folder)
tdb.remove_tracks_by_folders(to_remove)
@classmethod
def count_tracks_by_trackhash(cls, trackhash: str) -> int:
"""
+27 -1
View File
@@ -2,9 +2,16 @@
This file is used to run the application.
"""
from datetime import datetime, timezone
import os
import logging
from flask_jwt_extended import verify_jwt_in_request
from flask_jwt_extended import (
create_access_token,
get_jwt,
get_jwt_identity,
set_access_cookies,
verify_jwt_in_request,
)
import psutil
import mimetypes
from flask import Response, request
@@ -70,6 +77,25 @@ def verify_auth():
verify_jwt_in_request()
@app.after_request
def refresh_expiring_jwt(response: Response):
"""
Refreshes the JWT token after each request.
"""
try:
exp_timestamp = get_jwt()["exp"]
now = datetime.now(timezone.utc)
target_timestamp = datetime.timestamp(now) + 60 * 60 * 24 * 7 # 7 days
if target_timestamp > exp_timestamp:
access_token = create_access_token(identity=get_jwt_identity())
set_access_cookies(response, access_token)
return response
except (RuntimeError, KeyError):
return response
@app.route("/<path:path>")
def serve_client_files(path: str):
"""