From 04957dd5a9c91956e930b0a4138ca862596ded87 Mon Sep 17 00:00:00 2001 From: mungai-njoroge Date: Thu, 25 Apr 2024 18:18:52 +0300 Subject: [PATCH 01/12] set up auth --- app/api/__init__.py | 32 ++++++++++++-- app/api/auth.py | 67 +++++++++++++++++++++++++++++ app/api/settings.py | 1 + app/db/sqlite/auth.py | 93 ++++++++++++++++++++++++++++++++++++++++ app/db/sqlite/queries.py | 11 +++++ app/models/user.py | 32 ++++++++++++++ app/settings.py | 1 + app/setup/sqlite.py | 4 ++ app/start_info_logger.py | 2 +- app/utils/auth.py | 25 +++++++++++ app/utils/paths.py | 19 ++++++++ manage.py | 28 ++++++++++++ poetry.lock | 38 +++++++++++++++- pyproject.toml | 1 + requirements.txt | 2 + 15 files changed, 350 insertions(+), 6 deletions(-) create mode 100644 app/api/auth.py create mode 100644 app/db/sqlite/auth.py create mode 100644 app/models/user.py create mode 100644 app/utils/auth.py diff --git a/app/api/__init__.py b/app/api/__init__.py index a74624bc..3f6aa250 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -8,6 +8,7 @@ from flask_compress import Compress from flask_openapi3 import Info from flask_openapi3 import OpenAPI +from flask_jwt_extended import JWTManager from app.settings import Keys from .plugins import lyrics as lyrics_plugin @@ -27,6 +28,7 @@ from app.api import ( logger, home, getall, + auth, ) # TODO: Move this description to a separate file @@ -60,13 +62,34 @@ def create_api(): app = OpenAPI(__name__, info=api_info, doc_prefix="/docs") - CORS(app, origins="*") - Compress(app) + # JWT CONFIGS + app.config["JWT_SECRET_KEY"] = Keys.JWT_SECRET_KEY + app.config["JWT_TOKEN_LOCATION"] = ["cookies"] + app.config["JWT_COOKIE_CSRF_PROTECT"] = False + app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=1) + # CORS + CORS(app, origins="*", supports_credentials=True) + + # RESPONSE COMPRESSION + Compress(app) app.config["COMPRESS_MIMETYPES"] = [ "application/json", ] + # JWT + jwt = JWTManager(app) + + @jwt.user_identity_loader + def user_identity_lookup(user): + return user + + @jwt.user_lookup_loader + def user_lookup_callback(_jwt_header, jwt_data): + identity = jwt_data["sub"] + return identity + + # Register all the API blueprints with app.app_context(): app.register_api(album.api) app.register_api(artist.api) @@ -89,8 +112,9 @@ def create_api(): # Home app.register_api(home.api) - - # Flask Restful app.register_api(getall.api) + # Auth + app.register_api(auth.api) + return app diff --git a/app/api/auth.py b/app/api/auth.py new file mode 100644 index 00000000..b903f1e6 --- /dev/null +++ b/app/api/auth.py @@ -0,0 +1,67 @@ +from dataclasses import asdict +from flask import jsonify +from flask_jwt_extended import create_access_token, current_user, set_access_cookies +from pydantic import BaseModel, Field +from flask_openapi3 import Tag +from flask_openapi3 import APIBlueprint + +from app.db.sqlite.auth import SQLiteAuthMethods as authdb +from app.utils.auth import check_password + +bp_tag = Tag(name="Auth", description="Authentication") +api = APIBlueprint("auth", __name__, url_prefix="/auth", abp_tags=[bp_tag]) + + +class LoginBody(BaseModel): + username: str = Field(description="The username", example="user0") + password: str = Field(description="The password", example="password0") + + +@api.post("/login") +def login(body: LoginBody): + """ + Authenticate using username and password + """ + res = jsonify({"msg": f"Logged in as {body.username}"}) + + user = authdb.get_user_by_username(body.username) + + if user is None: + return {"msg": "User not found"}, 404 + + password_ok = check_password(body.password, user.password) + + if not password_ok: + return {"msg": "Invalid password"}, 401 + + access_token = create_access_token(identity=user.todict()) + set_access_cookies(res, access_token) + return res + + +@api.get("/logout") +def logout(): + """ + Log out + """ + res = jsonify({"msg": "Logged out"}) + res.delete_cookie("access_token_cookie") + return res + + +@api.get("/users") +def get_all_users(): + """ + Get all users + """ + users = authdb.get_all_users() + return [user.todict_simplified() for user in users] + + +@api.route("/user") +def get_logged_in_user(): + """ + Get logged in user + """ + print("current_user", current_user) + return dict(current_user) diff --git a/app/api/settings.py b/app/api/settings.py index e5d3306b..900f031d 100644 --- a/app/api/settings.py +++ b/app/api/settings.py @@ -162,6 +162,7 @@ mapp = { @api.get("") + def get_all_settings(): """ Get all settings diff --git a/app/db/sqlite/auth.py b/app/db/sqlite/auth.py new file mode 100644 index 00000000..c637de54 --- /dev/null +++ b/app/db/sqlite/auth.py @@ -0,0 +1,93 @@ +import json +from app.models.user import User +from app.utils.auth import encode_password +from app.db.sqlite.utils import SQLiteManager + + +class SQLiteAuthMethods: + """ + Methods for authenticating users. + """ + + @staticmethod + def insert_default_user(): + """ + Inserts the default admin user. + """ + user = { + "username": "admin", + "password": encode_password("admin"), + "roles": json.dumps(["admin"]), + } + user_tuple = tuple(user.values()) + + sql = """INSERT INTO users( + username, + password, + roles + ) VALUES(:username, :password, :roles) + """ + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql, user_tuple) + cur.close() + + @staticmethod + def insert_guest_user(): + """ + Inserts the default guest user. + """ + user = { + "username": "guest", + "password": encode_password("guest"), + "firstname": "Guest", + "lastname": "User", + "roles": json.dumps(["guest"]), + } + user_tuple = tuple(user.values()) + + sql = """INSERT INTO users( + username, + password, + firstname, + lastname, + roles + ) VALUES(:username, :password, :firstname, :lastname, :roles) + """ + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql, user_tuple) + cur.close() + + @staticmethod + def get_all_users(): + """ + Check if there are any users in the database. + """ + sql = "SELECT * FROM users" + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql) + + data = cur.fetchall() + cur.close() + + return [User(*user) for user in data] + + @staticmethod + def get_user_by_username(username: str): + """ + Get a user by username. + """ + sql = "SELECT * FROM users WHERE username = ?" + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql, (username,)) + + data = cur.fetchone() + cur.close() + + if data is not None: + return User(*data) + + return None \ No newline at end of file diff --git a/app/db/sqlite/queries.py b/app/db/sqlite/queries.py index 96b8b5b1..cf084d6a 100644 --- a/app/db/sqlite/queries.py +++ b/app/db/sqlite/queries.py @@ -54,6 +54,17 @@ CREATE TABLE IF NOT EXISTS track_logger ( timestamp integer NOT NULL, source text, userid integer NOT NULL DEFAULT 0 +); + +CREATE TABLE IF NOT EXISTS users ( + id integer PRIMARY KEY, + username text NOT NULL UNIQUE, + firstname text, + lastname text, + password text NOT NULL, + email text, + image text, + roles text NOT NULL DEFAULT '["user"]' ) """ diff --git a/app/models/user.py b/app/models/user.py new file mode 100644 index 00000000..b8d05d01 --- /dev/null +++ b/app/models/user.py @@ -0,0 +1,32 @@ +from dataclasses import asdict, field, dataclass +import json + + +@dataclass(slots=True) +class User: + id: int + username: str + firstname: str + lastname: str + password: str + email: str + image: str + + # NOTE: roles: ['admin', 'user', 'curator'] + roles: list[str] = field(default_factory=lambda: ["user"]) + + def __post_init__(self): + self.roles = json.loads(self.roles) + + def todict(self): + this_dict = asdict(self) + del this_dict["password"] + + return this_dict + + def todict_simplified(self): + return { + "id": self.id, + "username": self.username, + "firstname": self.firstname, + } diff --git a/app/settings.py b/app/settings.py index 1dcd7c0f..7df25e5f 100644 --- a/app/settings.py +++ b/app/settings.py @@ -268,6 +268,7 @@ class Keys: SWINGMUSIC_APP_VERSION = os.environ.get("SWINGMUSIC_APP_VERSION") GIT_LATEST_COMMIT_HASH = "" GIT_CURRENT_BRANCH = "" + JWT_SECRET_KEY = "swingmusic_secret_key" # REVIEW: This should be set in the environment @classmethod def load(cls): diff --git a/app/setup/sqlite.py b/app/setup/sqlite.py index 0506fb82..667d2e4b 100644 --- a/app/setup/sqlite.py +++ b/app/setup/sqlite.py @@ -4,6 +4,7 @@ Applies migrations. """ from app.db.sqlite import create_connection, create_tables, queries +from app.db.sqlite.auth import SQLiteAuthMethods as authdb from app.migrations import apply_migrations from app.settings import Db @@ -29,5 +30,8 @@ def setup_sqlite(): create_tables(user_db_conn, queries.CREATE_USERDATA_TABLES) create_tables(app_db_conn, queries.CREATE_MIGRATIONS_TABLE) + if not authdb.get_all_users(): + authdb.insert_default_user() + app_db_conn.close() user_db_conn.close() diff --git a/app/start_info_logger.py b/app/start_info_logger.py index 94cba209..45674b94 100644 --- a/app/start_info_logger.py +++ b/app/start_info_logger.py @@ -7,7 +7,7 @@ from app.utils.network import get_ip def log_startup_info(): lines = "------------------------------" # clears terminal 👇 - os.system("cls" if os.name == "nt" else "echo -e \\\\033c") + # os.system("cls" if os.name == "nt" else "echo -e \\\\033c") print(lines) print(f"{TCOLOR.HEADER}SwingMusic {Keys.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}") diff --git a/app/utils/auth.py b/app/utils/auth.py new file mode 100644 index 00000000..fac68e22 --- /dev/null +++ b/app/utils/auth.py @@ -0,0 +1,25 @@ +import hashlib + + +def encode_password(password: str) -> str: + """ + This function encodes the given password. + + :param password: The password to encode. + + :return: The encoded password. + """ + + return hashlib.sha256(password.encode("utf-8")).hexdigest() + +def check_password(password: str, encoded: str) -> bool: + """ + This function checks if the given password matches the encoded password. + + :param password: The password to check. + :param encoded: The encoded password. + + :return: Whether the password matches. + """ + + return encode_password(password) == encoded \ No newline at end of file diff --git a/app/utils/paths.py b/app/utils/paths.py index fe6bcfb4..d817010d 100644 --- a/app/utils/paths.py +++ b/app/utils/paths.py @@ -1,5 +1,8 @@ +import os import sys +from app.utils.filesystem import get_home_res_path + def getFlaskOpenApiPath(): """ @@ -10,3 +13,19 @@ def getFlaskOpenApiPath(): site_packages_path = [p for p in sys.path if "site-packages" in p][0] return f"{site_packages_path}/flask_openapi3" + + +def getClientFilesExtensions(): + """ + Get all the file extensions for the client files + """ + + client_path = get_home_res_path("client") + + extensions = set() + for root, dirs, files in os.walk(client_path): + for file in files: + ext = file.split(".")[-1] + extensions.add("." + ext) + + return extensions diff --git a/manage.py b/manage.py index 56698253..dc852b7c 100644 --- a/manage.py +++ b/manage.py @@ -4,6 +4,7 @@ This file is used to run the application. import os import logging +from flask_jwt_extended import verify_jwt_in_request import psutil import mimetypes from flask import Response, request @@ -20,6 +21,7 @@ from app.settings import FLASKVARS, TCOLOR, Keys from app.setup import run_setup from app.start_info_logger import log_startup_info from app.utils.filesystem import get_home_res_path +from app.utils.paths import getClientFilesExtensions from app.utils.threading import background mimetypes.add_type("text/css", ".css") @@ -41,6 +43,32 @@ werkzeug.setLevel(logging.ERROR) app = create_api() app.static_folder = get_home_res_path("client") +# INFO: Routes that don't need authentication +blacklist_routes = {"/auth/login", "/auth/users"} +blacklist_extensions = {".webp"}.union(getClientFilesExtensions()) + + +@app.before_request +def verify_auth(): + """ + Verifies the JWT token before each request. + """ + print(request.path) + if request.path == "/" or any( + request.path.endswith(ext) for ext in blacklist_extensions + ): + return + + # if request path starts with any of the blacklisted routes, don't verify jwt + if any(request.path.startswith(route) for route in blacklist_routes): + print( + "Found blacklisted route: ", request.path, "... Skipping jwt verification" + ) + return + + data = verify_jwt_in_request() + print(data) + @app.route("/") def serve_client_files(path: str): diff --git a/poetry.lock b/poetry.lock index 0b9fa9b7..481d868a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -557,6 +557,25 @@ files = [ Flask = ">=0.9" Six = "*" +[[package]] +name = "flask-jwt-extended" +version = "4.6.0" +description = "Extended JWT integration with Flask" +optional = false +python-versions = ">=3.7,<4" +files = [ + {file = "Flask-JWT-Extended-4.6.0.tar.gz", hash = "sha256:9215d05a9413d3855764bcd67035e75819d23af2fafb6b55197eb5a3313fdfb2"}, + {file = "Flask_JWT_Extended-4.6.0-py2.py3-none-any.whl", hash = "sha256:63a28fc9731bcc6c4b8815b6f954b5904caa534fc2ae9b93b1d3ef12930dca95"}, +] + +[package.dependencies] +Flask = ">=2.0,<4.0" +PyJWT = ">=2.0,<3.0" +Werkzeug = ">=0.14" + +[package.extras] +asymmetric-crypto = ["cryptography (>=3.3.1)"] + [[package]] name = "flask-openapi3" version = "3.0.2" @@ -1599,6 +1618,23 @@ files = [ {file = "pyinstaller_hooks_contrib-2023.9-py2.py3-none-any.whl", hash = "sha256:f34f4c6807210025c8073ebe665f422a3aa2ac5f4c7ebf4c2a26cc77bebf63b5"}, ] +[[package]] +name = "pyjwt" +version = "2.8.0" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"}, + {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + [[package]] name = "pylint" version = "2.17.7" @@ -2468,4 +2504,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "feb13f92b7b3a909fcb851860a405b96579feac0e2dde7681ed0e9c381c4f6cd" +content-hash = "317c4094a6f768467219db94be02a6e2e9f1ed5ca18d929437094bf60be1594d" diff --git a/pyproject.toml b/pyproject.toml index 005eaed5..5ff4920b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ waitress = "^2.1.2" watchdog = "^4.0.0" pendulum = "^3.0.0" flask-openapi3 = "^3.0.2" +flask-jwt-extended = "^4.6.0" [tool.poetry.dev-dependencies] pylint = "^2.15.5" diff --git a/requirements.txt b/requirements.txt index 51b16ecf..1089e561 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ Flask==2.3.3 Flask-BasicAuth==0.2.0 Flask-Compress==1.14 Flask-Cors==3.0.10 +Flask-JWT-Extended==4.6.0 flask-openapi3==3.0.2 gevent==23.9.1 geventhttpclient==2.0.11 @@ -42,6 +43,7 @@ pydantic==2.6.3 pydantic_core==2.16.3 pyinstaller==5.13.2 pyinstaller-hooks-contrib==2023.9 +PyJWT==2.8.0 pylint==2.17.7 pytest==7.4.2 python-dateutil==2.8.2 From 1eeab2d49e5e677e27ef12ea66bd33deba75b2c2 Mon Sep 17 00:00:00 2001 From: mungai-njoroge Date: Thu, 25 Apr 2024 20:05:02 +0300 Subject: [PATCH 02/12] add update profile logic --- app/api/__init__.py | 11 ++++-- app/api/auth.py | 34 +++++++++++++++++ app/db/sqlite/auth.py | 88 +++++++++++++++++++++++++++++++------------ manage.py | 4 +- 4 files changed, 106 insertions(+), 31 deletions(-) diff --git a/app/api/__init__.py b/app/api/__init__.py index 3f6aa250..5c3e7f64 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -12,6 +12,7 @@ from flask_jwt_extended import JWTManager from app.settings import Keys from .plugins import lyrics as lyrics_plugin +from app.db.sqlite.auth import SQLiteAuthMethods as authdb from app.api import ( album, artist, @@ -80,14 +81,16 @@ def create_api(): # JWT jwt = JWTManager(app) - @jwt.user_identity_loader - def user_identity_lookup(user): - return user + # @jwt.user_identity_loader + # def user_identity_lookup(user): + # return user @jwt.user_lookup_loader def user_lookup_callback(_jwt_header, jwt_data): identity = jwt_data["sub"] - return identity + userid = identity["id"] + user = authdb.get_user_by_id(userid) + return user.todict() # Register all the API blueprints with app.app_context(): diff --git a/app/api/auth.py b/app/api/auth.py index b903f1e6..a22884fc 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -1,4 +1,5 @@ from dataclasses import asdict +import json from flask import jsonify from flask_jwt_extended import create_access_token, current_user, set_access_cookies from pydantic import BaseModel, Field @@ -39,6 +40,39 @@ def login(body: LoginBody): return res +class UpdateProfileBody(BaseModel): + email: str = Field("", description="The email") + username: str = Field("", description="The username", example="user0") + password: str = Field("", description="The password", example="password0") + roles: list[str] = Field([], description="The roles") + + +@api.put("/profile/update") +def update_profile(body: UpdateProfileBody): + + user = { + "id": current_user["id"], + "email": body.email, + "username": body.username, + "password": body.password, + "roles": body.roles, + } + + # only admins can update roles + if body.roles: + if "admin" in current_user["roles"]: + # prevent admin from locking themselves out + roles = set(body.roles) + roles.add("admin") + user["roles"] = json.dumps(list(roles)) + else: + user.pop("roles") + + # remove empty values + clean_user = {k: v for k, v in user.items() if v} + return authdb.update_user(clean_user) + + @api.get("/logout") def logout(): """ diff --git a/app/db/sqlite/auth.py b/app/db/sqlite/auth.py index c637de54..83821500 100644 --- a/app/db/sqlite/auth.py +++ b/app/db/sqlite/auth.py @@ -9,6 +9,33 @@ class SQLiteAuthMethods: Methods for authenticating users. """ + @staticmethod + def insert_user(user: dict[str, str]): + """ + Insert a user into the database. + + :param user: A dict with the username, password and roles. + """ + sql = """INSERT INTO users( + username, + password, + roles + ) VALUES(:username, :password, :roles) + """ + + user_tuple = tuple(user.values()) + + with SQLiteManager(userdata_db=True) as cur: + cur = cur.execute(sql, user_tuple) + userid = cur.lastrowid + + if userid: + user = SQLiteAuthMethods.get_user_by_id(userid).todict_simplified() + cur.close() + return user + + raise Exception(f"Failed to insert user: {user}") + @staticmethod def insert_default_user(): """ @@ -19,18 +46,7 @@ class SQLiteAuthMethods: "password": encode_password("admin"), "roles": json.dumps(["admin"]), } - user_tuple = tuple(user.values()) - - sql = """INSERT INTO users( - username, - password, - roles - ) VALUES(:username, :password, :roles) - """ - - with SQLiteManager(userdata_db=True) as cur: - cur.execute(sql, user_tuple) - cur.close() + return SQLiteAuthMethods.insert_user(user) @staticmethod def insert_guest_user(): @@ -40,25 +56,31 @@ class SQLiteAuthMethods: user = { "username": "guest", "password": encode_password("guest"), - "firstname": "Guest", - "lastname": "User", "roles": json.dumps(["guest"]), } - user_tuple = tuple(user.values()) - sql = """INSERT INTO users( - username, - password, - firstname, - lastname, - roles - ) VALUES(:username, :password, :firstname, :lastname, :roles) + return SQLiteAuthMethods.insert_user(user) + + @staticmethod + def update_user(user: dict[str, str]): + """ + Update a user in the database. + + :param user: A dict with the username, password and roles. + """ + # get all user dict keys + keys = list(user.keys()) + sql = f"""UPDATE users SET + {', '.join([f"{key} = :{key}" for key in keys if key != 'id'])} + WHERE id = :id """ with SQLiteManager(userdata_db=True) as cur: - cur.execute(sql, user_tuple) + cur.execute(sql, user) cur.close() + return SQLiteAuthMethods.get_user_by_id(user["id"]).todict() + @staticmethod def get_all_users(): """ @@ -90,4 +112,22 @@ class SQLiteAuthMethods: if data is not None: return User(*data) - return None \ No newline at end of file + return None + + @staticmethod + def get_user_by_id(userid: int): + """ + Get a user by id. + """ + sql = "SELECT * FROM users WHERE id = ?" + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql, (userid,)) + + data = cur.fetchone() + cur.close() + + if data is not None: + return User(*data) + + return None diff --git a/manage.py b/manage.py index dc852b7c..de855e35 100644 --- a/manage.py +++ b/manage.py @@ -53,7 +53,6 @@ def verify_auth(): """ Verifies the JWT token before each request. """ - print(request.path) if request.path == "/" or any( request.path.endswith(ext) for ext in blacklist_extensions ): @@ -66,8 +65,7 @@ def verify_auth(): ) return - data = verify_jwt_in_request() - print(data) + verify_jwt_in_request() @app.route("/") From 0ff566176596b82657936f6f031b105efba907b2 Mon Sep 17 00:00:00 2001 From: mungai-njoroge Date: Sat, 27 Apr 2024 10:05:15 +0300 Subject: [PATCH 03/12] add routes to create user + route to delete user + add admin_required decorator --- .github/changelog.md | 7 ++- app/api/auth.py | 122 +++++++++++++++++++++++++++++++++++++++--- app/db/sqlite/auth.py | 22 ++++++-- manage.py | 6 +-- 4 files changed, 138 insertions(+), 19 deletions(-) diff --git a/.github/changelog.md b/.github/changelog.md index 1d00e7f1..4c1511df 100644 --- a/.github/changelog.md +++ b/.github/changelog.md @@ -1,7 +1,6 @@ # What's New? -- Hovering on recent favorite item will show how long ago it was ♥ed -- Recently added playlist returns a max of 100 tracks, but without a cutoff period + +- Auth -# Development -- API documentation on /openapi \ No newline at end of file +## Development diff --git a/app/api/auth.py b/app/api/auth.py index a22884fc..07a6da39 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -1,5 +1,6 @@ -from dataclasses import asdict import json +from dataclasses import asdict +from functools import wraps from flask import jsonify from flask_jwt_extended import create_access_token, current_user, set_access_cookies from pydantic import BaseModel, Field @@ -7,12 +8,29 @@ from flask_openapi3 import Tag from flask_openapi3 import APIBlueprint from app.db.sqlite.auth import SQLiteAuthMethods as authdb -from app.utils.auth import check_password +from app.utils.auth import check_password, encode_password -bp_tag = Tag(name="Auth", description="Authentication") +bp_tag = Tag(name="Auth", description="Authentication stuff") api = APIBlueprint("auth", __name__, url_prefix="/auth", abp_tags=[bp_tag]) +def admin_required(): + """ + Decorator to require admin role + """ + + def wrapper(fn): + @wraps(fn) + def decorator(*args, **kwargs): + if "admin" not in current_user["roles"]: + return {"msg": "Only admins can do that!"}, 403 + return fn(*args, **kwargs) + + return decorator + + return wrapper + + class LoginBody(BaseModel): username: str = Field(description="The username", example="user0") password: str = Field(description="The password", example="password0") @@ -49,7 +67,6 @@ class UpdateProfileBody(BaseModel): @api.put("/profile/update") def update_profile(body: UpdateProfileBody): - user = { "id": current_user["id"], "email": body.email, @@ -68,11 +85,84 @@ def update_profile(body: UpdateProfileBody): else: user.pop("roles") + if user["password"]: + user["password"] = encode_password(user["password"]) + # remove empty values clean_user = {k: v for k, v in user.items() if v} return authdb.update_user(clean_user) +@api.post("/profile/create") +@admin_required() +def create_user(body: UpdateProfileBody): + if not body.username or not body.password: + return {"msg": "Username and password are required"}, 400 + + user = { + "username": body.username, + "password": encode_password(body.password), + "roles": json.dumps([]), + } + + # check if user already exists + if authdb.get_user_by_username(user["username"]): + return {"msg": "Username already exists"}, 400 + + userid = authdb.insert_user(user) + return authdb.get_user_by_id(userid).todict() + + +@api.post("/profile/guest/create") +@admin_required() +def create_guest_user(): + """ + Create a guest user + """ + # check if guest user already exists + guest_user = authdb.get_user_by_username("guest") + + if guest_user: + return { + "msg": "Guest user already exists", + }, 400 + + userid = authdb.insert_guest_user() + + if userid: + return { + "msg": "Guest user created", + } + + return { + "msg": "Failed to create guest user", + }, 500 + + +class DeleteUseBody(BaseModel): + username: str = Field("", description="The username") + + +@api.delete("/profile/delete") +@admin_required() +def delete_user(body: DeleteUseBody): + """ + Delete a user by username + """ + # prevent admin from deleting themselves + if body.username == current_user["usrname"]: + return {"msg": "Sorry! you cannot delete yourselfu"}, 400 + + # prevent deleting the only admin + users = authdb.get_all_users() + admins = [user for user in users if "admin" in user.roles] + if len(admins) == 1 and admins[0].username == body.username: + return {"msg": "Cannot delete the only admin"}, 400 + + authdb.delete_user_by_username(body.username) + return {"msg": f"User {body.username} deleted"} + + @api.get("/logout") def logout(): """ @@ -83,13 +173,32 @@ def logout(): return res +class GetAllUsersQuery(BaseModel): + simplified: bool = Field( + False, description="Whether to return simplified user data" + ) + + @api.get("/users") -def get_all_users(): +def get_all_users(query: GetAllUsersQuery): """ Get all users """ users = authdb.get_all_users() - return [user.todict_simplified() for user in users] + + # remove guest user + users = [user for user in users if user.username != "guest"] + + # reverse list to show latest users first + users = list(reversed(users)) + + # bring admins to the front + users = sorted(users, key=lambda x: "admin" in x.roles, reverse=True) + + if query.simplified: + return [user.todict_simplified() for user in users] + + return [user.todict() for user in users] @api.route("/user") @@ -97,5 +206,4 @@ def get_logged_in_user(): """ Get logged in user """ - print("current_user", current_user) return dict(current_user) diff --git a/app/db/sqlite/auth.py b/app/db/sqlite/auth.py index 83821500..2c36d183 100644 --- a/app/db/sqlite/auth.py +++ b/app/db/sqlite/auth.py @@ -28,11 +28,12 @@ class SQLiteAuthMethods: with SQLiteManager(userdata_db=True) as cur: cur = cur.execute(sql, user_tuple) userid = cur.lastrowid - - if userid: - user = SQLiteAuthMethods.get_user_by_id(userid).todict_simplified() - cur.close() - return user + return userid + # if userid: + # # sleep + # user = SQLiteAuthMethods.get_user_by_id(userid).todict_simplified() + # cur.close() + # return user raise Exception(f"Failed to insert user: {user}") @@ -131,3 +132,14 @@ class SQLiteAuthMethods: return User(*data) return None + + @staticmethod + def delete_user_by_username(username: str): + """ + Delete a user by username. + """ + sql = "DELETE FROM users WHERE username = ?" + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql, (username,)) + cur.close() diff --git a/manage.py b/manage.py index de855e35..a3e106c9 100644 --- a/manage.py +++ b/manage.py @@ -60,9 +60,9 @@ def verify_auth(): # if request path starts with any of the blacklisted routes, don't verify jwt if any(request.path.startswith(route) for route in blacklist_routes): - print( - "Found blacklisted route: ", request.path, "... Skipping jwt verification" - ) + # print( + # "Found blacklisted route: ", request.path, "... Skipping jwt verification" + # ) return verify_jwt_in_request() From cfeff7ff5161910d8b5d5fd04344ec75f234b1f8 Mon Sep 17 00:00:00 2001 From: mungai-njoroge Date: Mon, 29 Apr 2024 16:31:30 +0300 Subject: [PATCH 04/12] add json config and its manager class + rewrite logic to prevent removing last admin role + handle showing users on login and enabling guest --- app/api/auth.py | 114 +++++++++++++++++++++++++++++++++++------- app/api/settings.py | 30 ++++++++++- app/config.py | 97 +++++++++++++++++++++++++++++++++++ app/settings.py | 8 ++- app/setup/__init__.py | 6 +++ manage.py | 2 +- 6 files changed, 234 insertions(+), 23 deletions(-) create mode 100644 app/config.py diff --git a/app/api/auth.py b/app/api/auth.py index 07a6da39..fbba9a6b 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -1,14 +1,21 @@ import json from dataclasses import asdict from functools import wraps +import sqlite3 from flask import jsonify -from flask_jwt_extended import create_access_token, current_user, set_access_cookies +from flask_jwt_extended import ( + create_access_token, + current_user, + jwt_required, + set_access_cookies, +) from pydantic import BaseModel, Field from flask_openapi3 import Tag from flask_openapi3 import APIBlueprint from app.db.sqlite.auth import SQLiteAuthMethods as authdb from app.utils.auth import check_password, encode_password +from app.config import UserConfig bp_tag = Tag(name="Auth", description="Authentication stuff") api = APIBlueprint("auth", __name__, url_prefix="/auth", abp_tags=[bp_tag]) @@ -51,7 +58,7 @@ def login(body: LoginBody): password_ok = check_password(body.password, user.password) if not password_ok: - return {"msg": "Invalid password"}, 401 + return {"msg": "Hehe! invalid password"}, 401 access_token = create_access_token(identity=user.todict()) set_access_cookies(res, access_token) @@ -59,38 +66,55 @@ def login(body: LoginBody): class UpdateProfileBody(BaseModel): + id: int = Field(0, description="The user id") email: str = Field("", description="The email") username: str = Field("", description="The username", example="user0") password: str = Field("", description="The password", example="password0") - roles: list[str] = Field([], description="The roles") + roles: list[str] = Field(None, description="The roles") @api.put("/profile/update") def update_profile(body: UpdateProfileBody): user = { - "id": current_user["id"], + "id": body.id, "email": body.email, "username": body.username, "password": body.password, "roles": body.roles, } + # if not id, update self + if not user["id"]: + user["id"] = current_user["id"] + + print("current_user: ", current_user) + # only admins can update roles - if body.roles: - if "admin" in current_user["roles"]: - # prevent admin from locking themselves out - roles = set(body.roles) - roles.add("admin") - user["roles"] = json.dumps(list(roles)) - else: - user.pop("roles") + + if body.roles is not None: + if "admin" not in current_user["roles"]: + return {"msg": "Only admins can update roles"}, 403 + + if "admin" not in body.roles: + # check if we're removing the last admin + users = authdb.get_all_users() + admins = [user for user in users if "admin" in user.roles] + + if len(admins) == 1 and admins[0].id == user["id"]: + return {"msg": "Cannot remove the only admin"}, 400 + + user["roles"] = json.dumps(body.roles) if user["password"]: user["password"] = encode_password(user["password"]) # remove empty values clean_user = {k: v for k, v in user.items() if v} - return authdb.update_user(clean_user) + + try: + return authdb.update_user(clean_user) + except sqlite3.IntegrityError: + return {"msg": "Username already exists"}, 400 @api.post("/profile/create") @@ -150,7 +174,7 @@ def delete_user(body: DeleteUseBody): Delete a user by username """ # prevent admin from deleting themselves - if body.username == current_user["usrname"]: + if body.username == current_user["username"]: return {"msg": "Sorry! you cannot delete yourselfu"}, 400 # prevent deleting the only admin @@ -160,6 +184,11 @@ def delete_user(body: DeleteUseBody): return {"msg": "Cannot delete the only admin"}, 400 authdb.delete_user_by_username(body.username) + + # if user is guest, update config + if body.username == "guest": + UserConfig().enableGuest = False + return {"msg": f"User {body.username} deleted"} @@ -180,14 +209,51 @@ class GetAllUsersQuery(BaseModel): @api.get("/users") +@jwt_required(optional=True) def get_all_users(query: GetAllUsersQuery): """ - Get all users + Get all users (if you're an admin, you will also receive accounts settings) """ + config = UserConfig() + # config.enableGuest = True + # config.usersOnLogin = True + settings = { + "enableGuest": config.enableGuest, + "usersOnLogin": config.usersOnLogin, + } + + res = { + "settings": {}, + "users": [], + } + + # if user is admin, also return settings + if current_user and "admin" in current_user["roles"]: + res = { + "settings": settings, + } + + # if is normal user, return empty response + elif current_user: + return res + + # if not logged in and showing users on login is disabled, return empty response + elif ( + not current_user + and not settings["usersOnLogin"] + and not settings["enableGuest"] + ): + return res + users = authdb.get_all_users() # remove guest user - users = [user for user in users if user.username != "guest"] + print("settings: ", settings["enableGuest"]) + if not settings["enableGuest"]: + users = [user for user in users if user.username != "guest"] + + if not settings["usersOnLogin"]: + users = [user for user in users if user.username == "guest"] # reverse list to show latest users first users = list(reversed(users)) @@ -195,10 +261,20 @@ def get_all_users(query: GetAllUsersQuery): # bring admins to the front users = sorted(users, key=lambda x: "admin" in x.roles, reverse=True) - if query.simplified: - return [user.todict_simplified() for user in users] + # bring current user to index 0 + if current_user: + users = sorted( + users, + key=lambda x: x.username == current_user["username"], + reverse=True, + ) - return [user.todict() for user in users] + if query.simplified: + res["users"] = [user.todict_simplified() for user in users] + + res["users"] = [user.todict() for user in users] + + return res @api.route("/user") diff --git a/app/api/settings.py b/app/api/settings.py index 900f031d..7eca6d5c 100644 --- a/app/api/settings.py +++ b/app/api/settings.py @@ -3,6 +3,7 @@ from flask import request from flask_openapi3 import Tag from flask_openapi3 import APIBlueprint from pydantic import BaseModel, Field +from app.api.auth import admin_required from app.db.sqlite.plugins import PluginsMethods as pdb from app.db.sqlite.settings import SettingsSQLMethods as sdb @@ -15,6 +16,7 @@ from app.store.artists import ArtistStore from app.store.tracks import TrackStore from app.utils.generators import get_random_str from app.utils.threading import background +from app.config import UserConfig bp_tag = Tag(name="Settings", description="Customize stuff") api = APIBlueprint("settings", __name__, url_prefix="/notsettings", abp_tags=[bp_tag]) @@ -69,7 +71,7 @@ def rebuild_store(db_dirs: list[str]): log.info("Rebuilding library... ✅") -# I freaking don't know what this function does anymore +# I freaking don't know what this function does anymore def finalize(new_: list[str], removed_: list[str], db_dirs_: list[str]): """ Params: @@ -162,7 +164,6 @@ mapp = { @api.get("") - def get_all_settings(): """ Get all settings @@ -265,3 +266,28 @@ def trigger_scan(): run_populate() return {"msg": "Scan triggered!"} + + +class UpdateConfigBody(BaseModel): + key: str = Field( + description="The setting key", + example="usersOnLogin", + ) + value: Any = Field( + description="The setting value", + example=False, + ) + + +@api.put("/update") +@admin_required() +def update_config(body: UpdateConfigBody): + """ + Update the config file + """ + config = UserConfig() + setattr(config, body.key, body.value) + + return { + "msg": "Config updated!", + } \ No newline at end of file diff --git a/app/config.py b/app/config.py new file mode 100644 index 00000000..6f7e773c --- /dev/null +++ b/app/config.py @@ -0,0 +1,97 @@ +from dataclasses import dataclass, asdict, field +import json +import os +import time +from typing import Any +from .settings import Paths + +# TODO: Publish this on PyPi + +@dataclass +class UserConfig: + _config_path: str = "" + # NOTE: only auth stuff are used (the others are still reading/writing to db) + # TODO: Move the rest of the settings to the config file + + # auth stuff + usersOnLogin: bool = True + enableGuest: bool = False + + # lists + rootDirs: list[str] = field(default_factory=list) + excludeDirs: list[str] = field(default_factory=list) + artistSeparators: set[str] = field(default_factory=list) + + # tracks + extractFeaturedArtists: bool = True + removeProdBy: bool = True + removeRemasterInfo: bool = True + + # albums + mergeAlbums: bool = False + cleanAlbumTitle: bool = True + showAlbumsAsSingles: bool = False + + def __post_init__(self): + """ + Loads the config file and sets the values to this instance + """ + # set config path locally to avoid writing to file + config_path = Paths.get_config_file_path() + + try: + config = self.load_config(config_path) + except FileNotFoundError: + self._config_path = config_path + return + + # loop through the config file and set the values + for key, value in config.items(): + setattr(self, key, value) + + # finally set the config path + self._config_path = config_path + + def setup_config_file(self) -> None: + """ + Creates the config file with the default settings + if it doesn't exist + """ + print("config path: ", self._config_path) + + # if not exists, create the config file + if not os.path.exists(self._config_path): + self.write_to_file(asdict(self)) + + def load_config(self, path: str) -> dict[str, Any]: + """ + Reads the settings from the config file. + Returns a dictget_root_dirs + """ + with open(path, "r") as f: + settings = json.load(f) + + return settings + + def write_to_file(self, settings: dict[str, Any]): + """ + Writes the settings to the config file + """ + # remove internal attributes + settings = {k: v for k, v in settings.items() if not k.startswith("_")} + + with open(self._config_path, "w") as f: + json.dump(settings, f, indent=4) + + def __setattr__(self, key: str, value: Any) -> None: + """ + Writes to the config file whenever a value is set + """ + super().__setattr__(key, value) + + # if is internal attribute, don't write to file + if key.startswith("_") or not self._config_path: + return + + print(f"writing to file: {key}={value}") + self.write_to_file(asdict(self)) diff --git a/app/settings.py b/app/settings.py index 7df25e5f..d8a05ea4 100644 --- a/app/settings.py +++ b/app/settings.py @@ -85,6 +85,10 @@ class Paths: def get_lyrics_plugins_path(cls): return join(Paths.get_plugins_path(), "lyrics") + @classmethod + def get_config_file_path(cls): + return join(cls.get_app_dir(), "settings.json") + # defaults class Defaults: @@ -268,7 +272,9 @@ class Keys: SWINGMUSIC_APP_VERSION = os.environ.get("SWINGMUSIC_APP_VERSION") GIT_LATEST_COMMIT_HASH = "" GIT_CURRENT_BRANCH = "" - JWT_SECRET_KEY = "swingmusic_secret_key" # REVIEW: This should be set in the environment + JWT_SECRET_KEY = ( + "swingmusic_secret_key" # REVIEW: This should be set in the environment + ) @classmethod def load(cls): diff --git a/app/setup/__init__.py b/app/setup/__init__.py index 4e840c44..5b741782 100644 --- a/app/setup/__init__.py +++ b/app/setup/__init__.py @@ -2,6 +2,7 @@ Prepares the server for use. """ +from dataclasses import asdict from app.db.sqlite.settings import load_settings from app.setup.files import create_config_dir from app.setup.sqlite import run_migrations, setup_sqlite @@ -9,6 +10,7 @@ from app.store.albums import AlbumStore from app.store.artists import ArtistStore from app.store.tracks import TrackStore from app.utils.generators import get_random_str +from app.config import UserConfig def run_setup(): @@ -22,6 +24,10 @@ def run_setup(): # settings table is empty pass + # setup config file + config = UserConfig() + config.setup_config_file() + instance_key = get_random_str() # INFO: Load all tracks, albums, and artists into memory diff --git a/manage.py b/manage.py index a3e106c9..3d20f402 100644 --- a/manage.py +++ b/manage.py @@ -44,7 +44,7 @@ app = create_api() app.static_folder = get_home_res_path("client") # INFO: Routes that don't need authentication -blacklist_routes = {"/auth/login", "/auth/users"} +blacklist_routes = {"/auth/login", "/auth/users", "/auth/logout"} blacklist_extensions = {".webp"}.union(getClientFilesExtensions()) From 5d947f3ad93fe13c45fb753c07bfa9c309318e09 Mon Sep 17 00:00:00 2001 From: mungai-njoroge Date: Wed, 1 May 2024 23:44:38 +0300 Subject: [PATCH 05/12] protect settings write routes + prevent updating guest user + add docs to whitelisted auth routes + fix: sort in get all route + fix: folders not having trailing slash in recentlyplayed --- app/api/auth.py | 30 +++++++++++++++++++----------- app/api/getall/__init__.py | 13 +++++-------- app/api/plugins/__init__.py | 3 +++ app/api/settings.py | 2 ++ app/lib/home/recentlyplayed.py | 6 +++++- manage.py | 6 +++--- 6 files changed, 37 insertions(+), 23 deletions(-) diff --git a/app/api/auth.py b/app/api/auth.py index fbba9a6b..0edc2424 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -83,26 +83,33 @@ def update_profile(body: UpdateProfileBody): "roles": body.roles, } + # prevent updating guest + if current_user["username"] == "guest" or user["username"] == "guest": + return {"msg": "Cannot update guest user"}, 400 + # if not id, update self if not user["id"]: user["id"] = current_user["id"] - print("current_user: ", current_user) - - # only admins can update roles - if body.roles is not None: + # only admins can update roles if "admin" not in current_user["roles"]: return {"msg": "Only admins can update roles"}, 403 + all_users = authdb.get_all_users() if "admin" not in body.roles: # check if we're removing the last admin - users = authdb.get_all_users() - admins = [user for user in users if "admin" in user.roles] + admins = [user for user in all_users if "admin" in user.roles] if len(admins) == 1 and admins[0].id == user["id"]: return {"msg": "Cannot remove the only admin"}, 400 + # guest roles cannot be updated + _user = [u for u in all_users if u.id == user["id"]][0] + if "guest" in _user.roles: + return {"msg": "Cannot update guest user"}, 400 + + # finally, convert roles to json string user["roles"] = json.dumps(body.roles) if user["password"]: @@ -227,8 +234,10 @@ def get_all_users(query: GetAllUsersQuery): "users": [], } + is_admin = current_user and "admin" in current_user["roles"] + # if user is admin, also return settings - if current_user and "admin" in current_user["roles"]: + if is_admin: res = { "settings": settings, } @@ -248,11 +257,10 @@ def get_all_users(query: GetAllUsersQuery): users = authdb.get_all_users() # remove guest user - print("settings: ", settings["enableGuest"]) - if not settings["enableGuest"]: - users = [user for user in users if user.username != "guest"] + # if not settings["enableGuest"]: + # users = [user for user in users if user.username != "guest"] - if not settings["usersOnLogin"]: + if not is_admin or not settings["usersOnLogin"]: users = [user for user in users if user.username == "guest"] # reverse list to show latest users first diff --git a/app/api/getall/__init__.py b/app/api/getall/__init__.py index 8a6c7385..b3199ca9 100644 --- a/app/api/getall/__init__.py +++ b/app/api/getall/__init__.py @@ -22,7 +22,7 @@ bp_tag = Tag(name="Get all", description="List all items") api = APIBlueprint("getall", __name__, url_prefix="/getall", abp_tags=[bp_tag]) -class GetAllItemsBody(GenericLimitSchema): +class GetAllItemsQuery(GenericLimitSchema): start: int = Field( description="The start index of the items to return", example=0, @@ -34,10 +34,10 @@ class GetAllItemsBody(GenericLimitSchema): default="created_date", ) - reverse: int = Field( + reverse: str = Field( description="Reverse the sort", example=1, - default=1, + default="1", ) @@ -50,7 +50,7 @@ class GetAllItemsPath(BaseModel): @api.get("/") -def get_all_items(path: GetAllItemsPath, query: GetAllItemsBody): +def get_all_items(path: GetAllItemsPath, query: GetAllItemsQuery): """ Get all items @@ -67,10 +67,7 @@ def get_all_items(path: GetAllItemsPath, query: GetAllItemsBody): start = query.start limit = query.limit sort = query.sortby - reverse = query.reverse == 1 - - # if sort == "": - # sort = "created_date" + reverse = query.reverse == "1" sort_is_count = sort == "count" sort_is_duration = sort == "duration" diff --git a/app/api/plugins/__init__.py b/app/api/plugins/__init__.py index 6fffe4ab..930bf603 100644 --- a/app/api/plugins/__init__.py +++ b/app/api/plugins/__init__.py @@ -3,6 +3,7 @@ from flask import Blueprint, request from flask_openapi3 import Tag from flask_openapi3 import APIBlueprint from pydantic import BaseModel, Field +from app.api.auth import admin_required from app.db.sqlite.plugins import PluginsMethods bp_tag = Tag(name="Plugins", description="Manage plugins") @@ -30,6 +31,7 @@ class PluginActivateBody(PluginBody): @api.post("/setactive") +@admin_required() def activate_deactivate_plugin(body: PluginActivateBody): """ Activate/Deactivate plugin @@ -49,6 +51,7 @@ class PluginSettingsBody(PluginBody): @api.post("/settings") +@admin_required() def update_plugin_settings(body: PluginSettingsBody): """ Update plugin settings diff --git a/app/api/settings.py b/app/api/settings.py index 7eca6d5c..65ff0a0b 100644 --- a/app/api/settings.py +++ b/app/api/settings.py @@ -97,6 +97,7 @@ class AddRootDirsBody(BaseModel): @api.post("/add-root-dirs") +@admin_required() def add_root_dirs(body: AddRootDirsBody): """ Add custom root directories to the database. @@ -216,6 +217,7 @@ class SetSettingBody(BaseModel): @api.post("/set") +@admin_required() def set_setting(body: SetSettingBody): """ Set a setting. diff --git a/app/lib/home/recentlyplayed.py b/app/lib/home/recentlyplayed.py index 6d41c649..99fd1268 100644 --- a/app/lib/home/recentlyplayed.py +++ b/app/lib/home/recentlyplayed.py @@ -82,9 +82,13 @@ def get_recently_played(limit=7): if entry.type == "folder": folder = entry.type_src + if not folder: continue + if not folder.endswith("/"): + folder += "/" + is_home_dir = entry.type_src == "$home" if is_home_dir: @@ -98,7 +102,7 @@ def get_recently_played(limit=7): { "type": "folder", "item": { - "path": entry.type_src, + "path": folder, "count": count, "help_text": "folder", "time": timestamp_to_time_passed(entry.timestamp), diff --git a/manage.py b/manage.py index 3d20f402..af47a766 100644 --- a/manage.py +++ b/manage.py @@ -44,7 +44,7 @@ app = create_api() app.static_folder = get_home_res_path("client") # INFO: Routes that don't need authentication -blacklist_routes = {"/auth/login", "/auth/users", "/auth/logout"} +whitelisted_routes = {"/auth/login", "/auth/users", "/auth/logout", "/docs"} blacklist_extensions = {".webp"}.union(getClientFilesExtensions()) @@ -59,9 +59,9 @@ def verify_auth(): return # if request path starts with any of the blacklisted routes, don't verify jwt - if any(request.path.startswith(route) for route in blacklist_routes): + if any(request.path.startswith(route) for route in whitelisted_routes): # print( - # "Found blacklisted route: ", request.path, "... Skipping jwt verification" + # "Found whitelisted route: ", request.path, "... Skipping jwt verification" # ) return From fdf3186be6fa5ccfc27cd1814971a0c50d5bf24a Mon Sep 17 00:00:00 2001 From: mungai-njoroge Date: Fri, 3 May 2024 23:22:09 +0300 Subject: [PATCH 06/12] salt passwords using userid --- app/api/__init__.py | 9 +++++---- app/api/settings.py | 6 +++--- app/arg_handler.py | 12 ++++++------ app/config.py | 5 ++--- app/plugins/lyrics.py | 2 +- app/settings.py | 11 +++++++---- app/setup/__init__.py | 6 ++++-- app/start_info_logger.py | 4 ++-- app/utils/auth.py | 10 ++++++++-- manage.py | 7 ++++--- 10 files changed, 42 insertions(+), 30 deletions(-) diff --git a/app/api/__init__.py b/app/api/__init__.py index 5c3e7f64..3a7e1afb 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -9,8 +9,9 @@ from flask_compress import Compress from flask_openapi3 import Info from flask_openapi3 import OpenAPI from flask_jwt_extended import JWTManager +from app.config import UserConfig -from app.settings import Keys +from app.settings import Info as AppInfo from .plugins import lyrics as lyrics_plugin from app.db.sqlite.auth import SQLiteAuthMethods as authdb from app.api import ( @@ -57,14 +58,14 @@ def create_api(): """ api_info = Info( title=f"Swing Music", - version=f"v{Keys.SWINGMUSIC_APP_VERSION}", + version=f"v{AppInfo.SWINGMUSIC_APP_VERSION}", description=open_api_description, ) app = OpenAPI(__name__, info=api_info, doc_prefix="/docs") - + print("userid", UserConfig().userId) # JWT CONFIGS - app.config["JWT_SECRET_KEY"] = Keys.JWT_SECRET_KEY + app.config["JWT_SECRET_KEY"] = UserConfig().userId app.config["JWT_TOKEN_LOCATION"] = ["cookies"] app.config["JWT_COOKIE_CSRF_PROTECT"] = False app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=1) diff --git a/app/api/settings.py b/app/api/settings.py index 65ff0a0b..78ecad82 100644 --- a/app/api/settings.py +++ b/app/api/settings.py @@ -10,7 +10,7 @@ from app.db.sqlite.settings import SettingsSQLMethods as sdb from app.lib import populate from app.lib.watchdogg import Watcher as WatchDog from app.logger import log -from app.settings import Keys, Paths, SessionVarKeys, set_flag +from app.settings import Info, Paths, SessionVarKeys, set_flag from app.store.albums import AlbumStore from app.store.artists import ArtistStore from app.store.tracks import TrackStore @@ -193,7 +193,7 @@ def get_all_settings(): root_dirs = sdb.get_root_dirs() s["root_dirs"] = root_dirs s["plugins"] = plugins - s["version"] = Keys.SWINGMUSIC_APP_VERSION + s["version"] = Info.SWINGMUSIC_APP_VERSION return { "settings": s, @@ -292,4 +292,4 @@ def update_config(body: UpdateConfigBody): return { "msg": "Config updated!", - } \ No newline at end of file + } diff --git a/app/arg_handler.py b/app/arg_handler.py index dec203ac..02b06919 100644 --- a/app/arg_handler.py +++ b/app/arg_handler.py @@ -45,7 +45,7 @@ class HandleArgs: print("https://www.youtube.com/watch?v=wZv62ShoStY") sys.exit(0) - config_keys = [ + info_keys = [ "SWINGMUSIC_APP_VERSION", "GIT_LATEST_COMMIT_HASH", "GIT_CURRENT_BRANCH", @@ -53,8 +53,8 @@ class HandleArgs: lines = [] - for key in config_keys: - value = settings.Keys.get(key) + for key in info_keys: + value = settings.Info.get(key) if not value: log.error(f"WARNING: {key} not set in environment") @@ -88,7 +88,7 @@ class HandleArgs: finally: # revert and remove the api keys for dev mode with open("./app/configs.py", "w", encoding="utf-8") as file: - lines = [f'{key} = ""\n' for key in config_keys] + lines = [f'{key} = ""\n' for key in info_keys] file.writelines(lines) sys.exit(0) @@ -184,8 +184,8 @@ class HandleArgs: @staticmethod def handle_version(): if any((a in ARGS for a in ALLARGS.version)): - print(f"VERSION: v{settings.Keys.SWINGMUSIC_APP_VERSION}") + print(f"VERSION: v{settings.Info.SWINGMUSIC_APP_VERSION}") print( - f"COMMIT#: {settings.Keys.GIT_CURRENT_BRANCH}/{settings.Keys.GIT_LATEST_COMMIT_HASH}" + f"COMMIT#: {settings.Info.GIT_CURRENT_BRANCH}/{settings.Info.GIT_LATEST_COMMIT_HASH}" ) sys.exit(0) diff --git a/app/config.py b/app/config.py index 6f7e773c..4d4d997b 100644 --- a/app/config.py +++ b/app/config.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, asdict, field import json import os -import time from typing import Any from .settings import Paths @@ -14,6 +13,8 @@ class UserConfig: # TODO: Move the rest of the settings to the config file # auth stuff + # NOTE: Don't expose the userId via the API + userId: str = "" usersOnLogin: bool = True enableGuest: bool = False @@ -57,8 +58,6 @@ class UserConfig: Creates the config file with the default settings if it doesn't exist """ - print("config path: ", self._config_path) - # if not exists, create the config file if not os.path.exists(self._config_path): self.write_to_file(asdict(self)) diff --git a/app/plugins/lyrics.py b/app/plugins/lyrics.py index 330693ec..81c73b59 100644 --- a/app/plugins/lyrics.py +++ b/app/plugins/lyrics.py @@ -8,7 +8,7 @@ import requests from app.db.sqlite.plugins import PluginsMethods from app.plugins import Plugin, plugin_method -from app.settings import Keys, Paths +from app.settings import Paths class LRCProvider: diff --git a/app/settings.py b/app/settings.py index d8a05ea4..cbaf0b35 100644 --- a/app/settings.py +++ b/app/settings.py @@ -268,13 +268,16 @@ def getCurrentBranch(): return "" -class Keys: +class Info: + """ + Contains information about the app + + NOTE: This class initially written to load keys when running in build mode. + TODO: Remove this class entirely, and implement functionality where needed. + """ SWINGMUSIC_APP_VERSION = os.environ.get("SWINGMUSIC_APP_VERSION") GIT_LATEST_COMMIT_HASH = "" GIT_CURRENT_BRANCH = "" - JWT_SECRET_KEY = ( - "swingmusic_secret_key" # REVIEW: This should be set in the environment - ) @classmethod def load(cls): diff --git a/app/setup/__init__.py b/app/setup/__init__.py index 5b741782..7a8ed41e 100644 --- a/app/setup/__init__.py +++ b/app/setup/__init__.py @@ -1,8 +1,7 @@ """ Prepares the server for use. """ - -from dataclasses import asdict +import uuid from app.db.sqlite.settings import load_settings from app.setup.files import create_config_dir from app.setup.sqlite import run_migrations, setup_sqlite @@ -28,6 +27,9 @@ def run_setup(): config = UserConfig() config.setup_config_file() + if not config.userId: + config.userId = str(uuid.uuid4()) + instance_key = get_random_str() # INFO: Load all tracks, albums, and artists into memory diff --git a/app/start_info_logger.py b/app/start_info_logger.py index 45674b94..e27a361e 100644 --- a/app/start_info_logger.py +++ b/app/start_info_logger.py @@ -1,6 +1,6 @@ import os -from app.settings import FLASKVARS, TCOLOR, Keys, Paths +from app.settings import FLASKVARS, TCOLOR, Info, Paths from app.utils.network import get_ip @@ -10,7 +10,7 @@ def log_startup_info(): # os.system("cls" if os.name == "nt" else "echo -e \\\\033c") print(lines) - print(f"{TCOLOR.HEADER}SwingMusic {Keys.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}") + print(f"{TCOLOR.HEADER}SwingMusic {Info.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}") adresses = [FLASKVARS.get_flask_host()] diff --git a/app/utils/auth.py b/app/utils/auth.py index fac68e22..e8b909af 100644 --- a/app/utils/auth.py +++ b/app/utils/auth.py @@ -1,5 +1,8 @@ +import hmac import hashlib +from app.config import UserConfig + def encode_password(password: str) -> str: """ @@ -10,7 +13,10 @@ def encode_password(password: str) -> str: :return: The encoded password. """ - return hashlib.sha256(password.encode("utf-8")).hexdigest() + return hashlib.pbkdf2_hmac( + "sha256", password.encode("utf-8"), UserConfig().userId.encode("utf-8"), 100000 + ).hex() + def check_password(password: str, encoded: str) -> bool: """ @@ -22,4 +28,4 @@ def check_password(password: str, encoded: str) -> bool: :return: Whether the password matches. """ - return encode_password(password) == encoded \ No newline at end of file + return hmac.compare_digest(encode_password(password), encoded) diff --git a/manage.py b/manage.py index af47a766..b7068dfd 100644 --- a/manage.py +++ b/manage.py @@ -17,7 +17,7 @@ from app.arg_handler import HandleArgs from app.lib.watchdogg import Watcher as WatchDog from app.periodic_scan import run_periodic_scans from app.plugins.register import register_plugins -from app.settings import FLASKVARS, TCOLOR, Keys +from app.settings import FLASKVARS, TCOLOR, Info from app.setup import run_setup from app.start_info_logger import log_startup_info from app.utils.filesystem import get_home_res_path @@ -40,6 +40,8 @@ mimetypes.add_type("application/manifest+json", ".webmanifest") werkzeug = logging.getLogger("werkzeug") werkzeug.setLevel(logging.ERROR) +HandleArgs() + app = create_api() app.static_folder = get_home_res_path("client") @@ -155,8 +157,7 @@ def run_swingmusic(): if __name__ == "__main__": - Keys.load() - HandleArgs() + Info.load() run_swingmusic() host = FLASKVARS.get_flask_host() From 36600ab782832e9c0c2d11f7edce741bf67f81e3 Mon Sep 17 00:00:00 2001 From: mungai-njoroge Date: Sun, 5 May 2024 23:55:25 +0300 Subject: [PATCH 07/12] fix: chunked audio stream desc: faulty content range headers + fix: tracks not being removed from db on root dirs change + implement implicit jwt refreshing + remove enableGuest from configs + set jwt validity to 30 days --- app/api/__init__.py | 3 +-- app/api/auth.py | 14 +++++--------- app/api/send_file.py | 35 ++++++++++++++++++++++++++--------- app/api/settings.py | 17 +++++++++-------- app/config.py | 1 - app/db/sqlite/tracks.py | 9 +++++---- app/lib/trackslib.py | 5 +++-- app/store/tracks.py | 15 ++------------- manage.py | 28 +++++++++++++++++++++++++++- 9 files changed, 78 insertions(+), 49 deletions(-) diff --git a/app/api/__init__.py b/app/api/__init__.py index 3a7e1afb..2fb0e1c5 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -63,12 +63,11 @@ def create_api(): ) app = OpenAPI(__name__, info=api_info, doc_prefix="/docs") - print("userid", UserConfig().userId) # JWT CONFIGS app.config["JWT_SECRET_KEY"] = UserConfig().userId app.config["JWT_TOKEN_LOCATION"] = ["cookies"] app.config["JWT_COOKIE_CSRF_PROTECT"] = False - app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=1) + app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=30) # CORS CORS(app, origins="*", supports_credentials=True) diff --git a/app/api/auth.py b/app/api/auth.py index 0edc2424..83dfda38 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -191,11 +191,6 @@ def delete_user(body: DeleteUseBody): return {"msg": "Cannot delete the only admin"}, 400 authdb.delete_user_by_username(body.username) - - # if user is guest, update config - if body.username == "guest": - UserConfig().enableGuest = False - return {"msg": f"User {body.username} deleted"} @@ -225,7 +220,7 @@ def get_all_users(query: GetAllUsersQuery): # config.enableGuest = True # config.usersOnLogin = True settings = { - "enableGuest": config.enableGuest, + "enableGuest": False, "usersOnLogin": config.usersOnLogin, } @@ -234,7 +229,10 @@ def get_all_users(query: GetAllUsersQuery): "users": [], } + users = authdb.get_all_users() + is_admin = current_user and "admin" in current_user["roles"] + settings['enableGuest'] = [user for user in users if user.username == "guest"].__len__() > 0 # if user is admin, also return settings if is_admin: @@ -254,13 +252,12 @@ def get_all_users(query: GetAllUsersQuery): ): return res - users = authdb.get_all_users() # remove guest user # if not settings["enableGuest"]: # users = [user for user in users if user.username != "guest"] - if not is_admin or not settings["usersOnLogin"]: + if not settings["usersOnLogin"]: users = [user for user in users if user.username == "guest"] # reverse list to show latest users first @@ -268,7 +265,6 @@ def get_all_users(query: GetAllUsersQuery): # bring admins to the front users = sorted(users, key=lambda x: "admin" in x.roles, reverse=True) - # bring current user to index 0 if current_user: users = sorted( diff --git a/app/api/send_file.py b/app/api/send_file.py index 72fe8b37..5c804079 100644 --- a/app/api/send_file.py +++ b/app/api/send_file.py @@ -68,15 +68,31 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery): def send_file_as_chunks(filepath: str, audio_type: str) -> Response: + """ + Returns a Response object that streams the file in chunks. + """ + # NOTE: +1 makes sure the last byte is included in the range. + # NOTE: -1 is used to convert the end index to a 0-based index. + chunk_size = 1024 * 512 + + # Get file size file_size = os.path.getsize(filepath) start = 0 - end = file_size - 1 + end = chunk_size + # Read range header range_header = request.headers.get("Range") if range_header: - start, end = parse_range_header(range_header, file_size) + start = get_start_range(range_header) - chunk_size = 1024 * 1024 # 1MB chunk size (adjust as needed) + # If start + chunk_size is greater than file_size, + # set end to file_size - 1 + _end = start + chunk_size - 1 + + if _end > file_size: + end = file_size - 1 + else: + end = _end def generate_chunks(): with open(filepath, "rb") as file: @@ -84,8 +100,11 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response: remaining_bytes = end - start + 1 while remaining_bytes > 0: + # Read the chunk size or all the remaining bytes chunk = file.read(min(chunk_size, remaining_bytes)) yield chunk + + # Update the remaining bytes remaining_bytes -= len(chunk) response = Response( @@ -102,15 +121,13 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response: return response -def parse_range_header(range_header: str, file_size: int) -> tuple[int, int]: +def get_start_range(range_header: str): try: range_start, range_end = range_header.strip().split("=")[1].split("-") - start = int(range_start) - end = min(int(range_end), file_size - 1) - except ValueError: - return 0, file_size - 1 + return int(range_start) - return start, end + except ValueError: + return 0 class GetAudioSilenceBody(BaseModel): diff --git a/app/api/settings.py b/app/api/settings.py index 78ecad82..4931bc52 100644 --- a/app/api/settings.py +++ b/app/api/settings.py @@ -7,6 +7,7 @@ from app.api.auth import admin_required from app.db.sqlite.plugins import PluginsMethods as pdb from app.db.sqlite.settings import SettingsSQLMethods as sdb +from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb from app.lib import populate from app.lib.watchdogg import Watcher as WatchDog from app.logger import log @@ -51,12 +52,12 @@ def reload_everything(instance_key: str): @background def rebuild_store(db_dirs: list[str]): """ - Restarts the watchdog and rebuilds the music library. + Restarts watchdog and rebuilds the music library. """ instance_key = get_random_str() log.info("Rebuilding library...") - TrackStore.remove_tracks_by_dir_except(db_dirs) + trackdb.remove_tracks_not_in_folders(db_dirs) reload_everything(instance_key) try: @@ -106,10 +107,10 @@ def add_root_dirs(body: AddRootDirsBody): removed_dirs = body.removed db_dirs = sdb.get_root_dirs() - _h = "$home" + home = "$home" - db_home = any([d == _h for d in db_dirs]) # if $home is in db - incoming_home = any([d == _h for d in new_dirs]) # if $home is in incoming + db_home = any([d == home for d in db_dirs]) # if $home is in db + incoming_home = any([d == home for d in new_dirs]) # if $home is in incoming # handle $home case if db_home and incoming_home: @@ -119,8 +120,8 @@ def add_root_dirs(body: AddRootDirsBody): sdb.remove_root_dirs(db_dirs) if incoming_home: - finalize([_h], [], [Paths.USER_HOME_DIR]) - return {"root_dirs": [_h]} + finalize([home], [], [Paths.USER_HOME_DIR]) + return {"root_dirs": [home]} # --- @@ -135,7 +136,7 @@ def add_root_dirs(body: AddRootDirsBody): pass db_dirs.extend(new_dirs) - db_dirs = [dir_ for dir_ in db_dirs if dir_ != _h] + db_dirs = [dir_ for dir_ in db_dirs if dir_ != home] finalize(new_dirs, removed_dirs, db_dirs) diff --git a/app/config.py b/app/config.py index 4d4d997b..00e6cde5 100644 --- a/app/config.py +++ b/app/config.py @@ -16,7 +16,6 @@ class UserConfig: # NOTE: Don't expose the userId via the API userId: str = "" usersOnLogin: bool = True - enableGuest: bool = False # lists rootDirs: list[str] = field(default_factory=list) diff --git a/app/db/sqlite/tracks.py b/app/db/sqlite/tracks.py index ab45f888..c8bed28c 100644 --- a/app/db/sqlite/tracks.py +++ b/app/db/sqlite/tracks.py @@ -112,9 +112,10 @@ class SQLiteTrackMethods: cur.execute("DELETE FROM tracks WHERE filepath=?", (filepath,)) @staticmethod - def remove_tracks_by_folders(folders: set[str]): - sql = "DELETE FROM tracks WHERE folder = ?" + def remove_tracks_not_in_folders(folders: set[str]): + sql = "DELETE FROM tracks WHERE folder NOT IN ({})".format( + ",".join("?" * len(folders)) + ) with SQLiteManager() as cur: - for folder in folders: - cur.execute(sql, (folder,)) + cur.execute(sql, tuple(folders)) diff --git a/app/lib/trackslib.py b/app/lib/trackslib.py index a685c0bd..085dda0c 100644 --- a/app/lib/trackslib.py +++ b/app/lib/trackslib.py @@ -1,12 +1,13 @@ """ This library contains all the functions related to tracks. """ + import os from app.lib.pydub.pydub import AudioSegment from app.lib.pydub.pydub.silence import detect_leading_silence, detect_silence -from app.db.sqlite.tracks import SQLiteTrackMethods as tdb +from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb from app.store.tracks import TrackStore from app.utils.progressbar import tqdm from app.utils.threading import ThreadWithReturnValue @@ -19,7 +20,7 @@ def validate_tracks() -> None: for track in tqdm(TrackStore.tracks, desc="Validating tracks"): if not os.path.exists(track.filepath): TrackStore.remove_track_obj(track) - tdb.remove_tracks_by_filepaths(track.filepath) + trackdb.remove_tracks_by_filepaths(track.filepath) def get_leading_silence_end(filepath: str): diff --git a/app/store/tracks.py b/app/store/tracks.py index 6b56a829..ae699e01 100644 --- a/app/store/tracks.py +++ b/app/store/tracks.py @@ -1,7 +1,7 @@ # from tqdm import tqdm from app.db.sqlite.favorite import SQLiteFavoriteMethods as favdb -from app.db.sqlite.tracks import SQLiteTrackMethods as tdb +from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb from app.models import Track from app.utils.bisection import use_bisection from app.utils.customlist import CustomList @@ -23,7 +23,7 @@ class TrackStore: global TRACKS_LOAD_KEY TRACKS_LOAD_KEY = instance_key - cls.tracks = CustomList(tdb.get_all_tracks()) + cls.tracks = CustomList(trackdb.get_all_tracks()) fav_hashes = favdb.get_fav_tracks() fav_hashes = " ".join([t[1] for t in fav_hashes]) @@ -84,17 +84,6 @@ class TrackStore: if track.filepath in filepaths: cls.remove_track_obj(track) - @classmethod - def remove_tracks_by_dir_except(cls, dirs: list[str]): - """Removes all tracks not in the root directories.""" - to_remove = set() - - for track in cls.tracks: - if not track.folder.startswith(tuple(dirs)): - to_remove.add(track.folder) - - tdb.remove_tracks_by_folders(to_remove) - @classmethod def count_tracks_by_trackhash(cls, trackhash: str) -> int: """ diff --git a/manage.py b/manage.py index b7068dfd..cd83f9c8 100644 --- a/manage.py +++ b/manage.py @@ -2,9 +2,16 @@ This file is used to run the application. """ +from datetime import datetime, timezone import os import logging -from flask_jwt_extended import verify_jwt_in_request +from flask_jwt_extended import ( + create_access_token, + get_jwt, + get_jwt_identity, + set_access_cookies, + verify_jwt_in_request, +) import psutil import mimetypes from flask import Response, request @@ -70,6 +77,25 @@ def verify_auth(): verify_jwt_in_request() +@app.after_request +def refresh_expiring_jwt(response: Response): + """ + Refreshes the JWT token after each request. + """ + try: + exp_timestamp = get_jwt()["exp"] + now = datetime.now(timezone.utc) + target_timestamp = datetime.timestamp(now) + 60 * 60 * 24 * 7 # 7 days + + if target_timestamp > exp_timestamp: + access_token = create_access_token(identity=get_jwt_identity()) + set_access_cookies(response, access_token) + + return response + except (RuntimeError, KeyError): + return response + + @app.route("/") def serve_client_files(path: str): """ From 10b613513c7ee5acd6c579ecb192ed82365c4515 Mon Sep 17 00:00:00 2001 From: cwilvx Date: Tue, 7 May 2024 23:00:53 +0300 Subject: [PATCH 08/12] fix: default user inserted before userId is created moved application setup function calls before flask app creation --- app/setup/__init__.py | 16 ++++++++------ manage.py | 51 ++++++++++++++++++++++++------------------- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/app/setup/__init__.py b/app/setup/__init__.py index 7a8ed41e..4f784bea 100644 --- a/app/setup/__init__.py +++ b/app/setup/__init__.py @@ -1,6 +1,7 @@ """ Prepares the server for use. """ + import uuid from app.db.sqlite.settings import load_settings from app.setup.files import create_config_dir @@ -14,6 +15,14 @@ from app.config import UserConfig def run_setup(): create_config_dir() + + # setup config file + config = UserConfig() + config.setup_config_file() + + if not config.userId: + config.userId = str(uuid.uuid4()) + setup_sqlite() run_migrations() @@ -23,13 +32,6 @@ def run_setup(): # settings table is empty pass - # setup config file - config = UserConfig() - config.setup_config_file() - - if not config.userId: - config.userId = str(uuid.uuid4()) - instance_key = get_random_str() # INFO: Load all tracks, albums, and artists into memory diff --git a/manage.py b/manage.py index cd83f9c8..7ad50916 100644 --- a/manage.py +++ b/manage.py @@ -47,8 +47,37 @@ mimetypes.add_type("application/manifest+json", ".webmanifest") werkzeug = logging.getLogger("werkzeug") werkzeug.setLevel(logging.ERROR) +# Set up the application HandleArgs() +@background +def bg_run_setup() -> None: + run_periodic_scans() + + +@background +def start_watchdog(): + WatchDog().run() + + +@background +def run_swingmusic(): + log_startup_info() + run_setup() + bg_run_setup() + register_plugins() + + start_watchdog() + + setproctitle.setproctitle(f"swingmusic ::{FLASKVARS.get_flask_port()}") + + +Info.load() +run_swingmusic() + + +# Create the Flask app + app = create_api() app.static_folder = get_home_res_path("client") @@ -160,31 +189,9 @@ def print_memory_usage(response: Response): return response -@background -def bg_run_setup() -> None: - run_periodic_scans() - - -@background -def start_watchdog(): - WatchDog().run() - - -@background -def run_swingmusic(): - log_startup_info() - run_setup() - bg_run_setup() - register_plugins() - - start_watchdog() - - setproctitle.setproctitle(f"swingmusic ::{FLASKVARS.get_flask_port()}") if __name__ == "__main__": - Info.load() - run_swingmusic() host = FLASKVARS.get_flask_host() port = FLASKVARS.get_flask_port() From 6692c78110fee0591e9c4af6edba811a062e0c79 Mon Sep 17 00:00:00 2001 From: cwilvx Date: Tue, 7 May 2024 23:16:56 +0300 Subject: [PATCH 09/12] fix: setup beginning before folders are created --- app/arg_handler.py | 2 +- manage.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/app/arg_handler.py b/app/arg_handler.py index 02b06919..a1aa5037 100644 --- a/app/arg_handler.py +++ b/app/arg_handler.py @@ -18,7 +18,7 @@ ALLARGS = settings.ALLARGS ARGS = sys.argv[1:] -class HandleArgs: +class ProcessArgs: def __init__(self) -> None: self.handle_build() self.handle_host() diff --git a/manage.py b/manage.py index 7ad50916..4210463e 100644 --- a/manage.py +++ b/manage.py @@ -20,7 +20,7 @@ import waitress import setproctitle from app.api import create_api -from app.arg_handler import HandleArgs +from app.arg_handler import ProcessArgs from app.lib.watchdogg import Watcher as WatchDog from app.periodic_scan import run_periodic_scans from app.plugins.register import register_plugins @@ -47,9 +47,8 @@ mimetypes.add_type("application/manifest+json", ".webmanifest") werkzeug = logging.getLogger("werkzeug") werkzeug.setLevel(logging.ERROR) -# Set up the application -HandleArgs() +# Background tasks @background def bg_run_setup() -> None: run_periodic_scans() @@ -63,7 +62,6 @@ def start_watchdog(): @background def run_swingmusic(): log_startup_info() - run_setup() bg_run_setup() register_plugins() @@ -72,6 +70,9 @@ def run_swingmusic(): setproctitle.setproctitle(f"swingmusic ::{FLASKVARS.get_flask_port()}") +# Setup function calls +ProcessArgs() +run_setup() Info.load() run_swingmusic() @@ -189,10 +190,7 @@ def print_memory_usage(response: Response): return response - - if __name__ == "__main__": - host = FLASKVARS.get_flask_host() port = FLASKVARS.get_flask_port() From 999b802f7d7b21b67991396404430a3c5abc9ffb Mon Sep 17 00:00:00 2001 From: cwilvx Date: Wed, 8 May 2024 11:15:02 +0300 Subject: [PATCH 10/12] guess audio mimetypes --- app/api/send_file.py | 9 +++------ app/utils/files.py | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 app/utils/files.py diff --git a/app/api/send_file.py b/app/api/send_file.py index 5c804079..74d1b88d 100644 --- a/app/api/send_file.py +++ b/app/api/send_file.py @@ -11,6 +11,7 @@ from app.api.apischemas import TrackHashSchema from app.lib.trackslib import get_silence_paddings from app.store.tracks import TrackStore +from app.utils.files import guess_mime_type bp_tag = Tag(name="File", description="Audio files") api = APIBlueprint("track", __name__, url_prefix="/file", abp_tags=[bp_tag]) @@ -33,10 +34,6 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery): filepath = query.filepath msg = {"msg": "File Not Found"} - def get_mime(filename: str) -> str: - ext = filename.rsplit(".", maxsplit=1)[-1] - return f"audio/{ext}" - # If filepath is provided, try to send that if filepath is not None: try: @@ -47,7 +44,7 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery): track_exists = track is not None and os.path.exists(track.filepath) if track_exists: - audio_type = get_mime(filepath) + audio_type = guess_mime_type(filepath) return send_file_as_chunks(track.filepath, audio_type) # Else, find file by trackhash @@ -57,7 +54,7 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery): if track is None: return msg, 404 - audio_type = get_mime(track.filepath) + audio_type = guess_mime_type(track.filepath) try: return send_file_as_chunks(track.filepath, audio_type) diff --git a/app/utils/files.py b/app/utils/files.py new file mode 100644 index 00000000..382ac423 --- /dev/null +++ b/app/utils/files.py @@ -0,0 +1,21 @@ +import mimetypes + + +def get_mime_from_ext(filename: str): + """ + Constructs a mime type from a file extension. + """ + ext = filename.rsplit(".", maxsplit=1)[-1] + return f"audio/{ext}" + + +def guess_mime_type(filename: str): + """ + Guess the mime type of a file. + """ + type = mimetypes.guess_type(filename)[0] + + if type is None: + return get_mime_from_ext(filename) + + return type From 1e857c1e89f952456a22d3d7b3224a31dea2f8a5 Mon Sep 17 00:00:00 2001 From: cwilvx Date: Fri, 10 May 2024 12:08:57 +0300 Subject: [PATCH 11/12] rename send_file.py -> stream.py --- app/api/__init__.py | 4 ++-- app/api/{send_file.py => stream.py} | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) rename app/api/{send_file.py => stream.py} (97%) diff --git a/app/api/__init__.py b/app/api/__init__.py index 2fb0e1c5..4319499d 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -23,7 +23,6 @@ from app.api import ( imgserver, playlist, search, - send_file, settings, lyrics, plugins, @@ -31,6 +30,7 @@ from app.api import ( home, getall, auth, + stream, ) # TODO: Move this description to a separate file @@ -96,7 +96,7 @@ def create_api(): with app.app_context(): app.register_api(album.api) app.register_api(artist.api) - app.register_api(send_file.api) + app.register_api(stream.api) app.register_api(search.api) app.register_api(folder.api) app.register_api(playlist.api) diff --git a/app/api/send_file.py b/app/api/stream.py similarity index 97% rename from app/api/send_file.py rename to app/api/stream.py index 74d1b88d..714c3599 100644 --- a/app/api/send_file.py +++ b/app/api/stream.py @@ -3,11 +3,13 @@ Contains all the track routes. """ import os +import time from flask import Blueprint, send_file, request, Response from flask_openapi3 import APIBlueprint, Tag from pydantic import BaseModel, Field from app.api.apischemas import TrackHashSchema +from app.lib.pydub.pydub.audio_segment import AudioSegment from app.lib.trackslib import get_silence_paddings from app.store.tracks import TrackStore @@ -70,7 +72,7 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response: """ # NOTE: +1 makes sure the last byte is included in the range. # NOTE: -1 is used to convert the end index to a 0-based index. - chunk_size = 1024 * 512 + chunk_size = 1024 * 360 # 360 KB # Get file size file_size = os.path.getsize(filepath) From b40f05cc7cb51dea2b6e7fd92eb20ed7a8462a9e Mon Sep 17 00:00:00 2001 From: cwilvx Date: Sat, 11 May 2024 21:26:03 +0300 Subject: [PATCH 12/12] implement CLI password recovery (hacky :omg:) + rewrite migrations logic + rename encode_password to hash_password + update image sizes (add medium size) + rename image endpoints --- .gitignore | 2 +- TODO.md | 2 + app/api/auth.py | 11 ++-- app/api/imgserver.py | 98 +++++++++++++++++++------------ app/api/playlist.py | 3 +- app/arg_handler.py | 63 +++++++++++++++++--- app/db/sqlite/auth.py | 9 +-- app/db/sqlite/migrations.py | 8 +-- app/lib/artistlib.py | 38 ++++++++---- app/lib/colorlib.py | 2 +- app/lib/taglib.py | 18 +++--- app/migrations/__init__.py | 47 ++++++++------- app/migrations/v1_4_9/__init__.py | 16 +++++ app/periodic_scan.py | 6 ++ app/print_help.py | 13 ++-- app/settings.py | 58 ++++++++++++------ app/setup/__init__.py | 8 +++ app/setup/files.py | 30 +++++----- app/start_info_logger.py | 2 +- app/utils/auth.py | 16 ++--- manage.py | 8 +-- 21 files changed, 306 insertions(+), 152 deletions(-) create mode 100644 TODO.md diff --git a/.gitignore b/.gitignore index 7cbee9ee..b2412028 100644 --- a/.gitignore +++ b/.gitignore @@ -27,7 +27,7 @@ client logs.txt *.spec -TODO.md +# TODO.md testdata.py test.py nohup.out diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..c78efee4 --- /dev/null +++ b/TODO.md @@ -0,0 +1,2 @@ +- Fix migrations! + - Use total length instead of release version length \ No newline at end of file diff --git a/app/api/auth.py b/app/api/auth.py index 83dfda38..b8af863a 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -14,7 +14,7 @@ from flask_openapi3 import Tag from flask_openapi3 import APIBlueprint from app.db.sqlite.auth import SQLiteAuthMethods as authdb -from app.utils.auth import check_password, encode_password +from app.utils.auth import check_password, hash_password from app.config import UserConfig bp_tag = Tag(name="Auth", description="Authentication stuff") @@ -113,7 +113,7 @@ def update_profile(body: UpdateProfileBody): user["roles"] = json.dumps(body.roles) if user["password"]: - user["password"] = encode_password(user["password"]) + user["password"] = hash_password(user["password"]) # remove empty values clean_user = {k: v for k, v in user.items() if v} @@ -132,7 +132,7 @@ def create_user(body: UpdateProfileBody): user = { "username": body.username, - "password": encode_password(body.password), + "password": hash_password(body.password), "roles": json.dumps([]), } @@ -232,7 +232,9 @@ def get_all_users(query: GetAllUsersQuery): users = authdb.get_all_users() is_admin = current_user and "admin" in current_user["roles"] - settings['enableGuest'] = [user for user in users if user.username == "guest"].__len__() > 0 + settings["enableGuest"] = [ + user for user in users if user.username == "guest" + ].__len__() > 0 # if user is admin, also return settings if is_admin: @@ -252,7 +254,6 @@ def get_all_users(query: GetAllUsersQuery): ): return res - # remove guest user # if not settings["enableGuest"]: # users = [user for user in users if user.username != "guest"] diff --git a/app/api/imgserver.py b/app/api/imgserver.py index 680abba1..c5e2db49 100644 --- a/app/api/imgserver.py +++ b/app/api/imgserver.py @@ -13,6 +13,9 @@ api = APIBlueprint("imgserver", __name__, url_prefix="/img", abp_tags=[bp_tag]) def send_fallback_img(filename: str = "default.webp"): + """ + Returns the fallback image from the assets folder. + """ folder = Paths.get_assets_path() img = Path(folder) / filename @@ -22,6 +25,18 @@ def send_fallback_img(filename: str = "default.webp"): return send_from_directory(folder, filename) +def send_file_or_fallback(folder: str, filename: str, fallback: str = "default.webp"): + """ + Returns the file from the folder or the fallback image. + """ + fpath = Path(folder) / filename + + if fpath.exists(): + return send_from_directory(folder, filename) + + return send_fallback_img(fallback) + + class ImagePath(BaseModel): imgpath: str = Field( description="The image filename", @@ -43,62 +58,72 @@ class ImagePath(BaseModel): # return send_fallback_img() -@api.get("/t/") +# TRACK THUMBNAILS +@api.get("/thumbnail/") def send_lg_thumbnail(path: ImagePath): """ Get large thumbnail (500 x 500) """ folder = Paths.get_lg_thumb_path() - fpath = Path(folder) / path.imgpath - - if fpath.exists(): - return send_from_directory(folder, path.imgpath) - - return send_fallback_img() + return send_file_or_fallback(folder, path.imgpath) -@api.get("/t/s/") +@api.get("/thumbnail/xsmall/") +def send_xsm_thumbnail(path: ImagePath): + """ + Get extra small thumbnail (64px) + """ + folder = Paths.get_xsm_thumb_path() + return send_file_or_fallback(folder, path.imgpath) + + +@api.get("/thumbnail/small/") def send_sm_thumbnail(path: ImagePath): """ - Get small thumbnail (64 x 64) + Get small thumbnail (96px) """ folder = Paths.get_sm_thumb_path() - fpath = Path(folder) / path.imgpath - - if fpath.exists(): - return send_from_directory(folder, path.imgpath) - - return send_fallback_img() + return send_file_or_fallback(folder, path.imgpath) -@api.get("/a/") +@api.get("/thumbnail/medium/") +def send_md_thumbnail(path: ImagePath): + """ + Get medium thumbnail (256px) + """ + folder = Paths.get_md_thumb_path() + return send_file_or_fallback(folder, path.imgpath) + + +# ARTISTS +@api.get("/artist/") def send_lg_artist_image(path: ImagePath): """ Get large artist image (500 x 500) """ - folder = Paths.get_artist_img_lg_path() - fpath = Path(folder) / path.imgpath - - if fpath.exists(): - return send_from_directory(folder, path.imgpath) - - return send_fallback_img("artist.webp") + folder = Paths.get_lg_artist_img_path() + return send_file_or_fallback(folder, path.imgpath, "artist.webp") -@api.get("/a/s/") +@api.get("/artist/small/") def send_sm_artist_image(path: ImagePath): """ - Get small artist image (64 x 64) + Get small artist image (128) """ - folder = Paths.get_artist_img_sm_path() - fpath = Path(folder) / path.imgpath - - if fpath.exists(): - return send_from_directory(folder, path.imgpath) - - return send_fallback_img("artist.webp") + folder = Paths.get_sm_artist_img_path() + return send_file_or_fallback(folder, path.imgpath, "artist.webp") +@api.get("/artist/medium/") +def send_md_artist_image(path: ImagePath): + """ + Get medium artist image (256px) + """ + folder = Paths.get_md_artist_img_path() + return send_file_or_fallback(folder, path.imgpath, "artist.webp") + + +# PLAYLISTS class PlaylistImagePath(BaseModel): imgpath: str = Field( description="The image path", @@ -106,7 +131,7 @@ class PlaylistImagePath(BaseModel): ) -@api.get("/p/") +@api.get("/playlist/") def send_playlist_image(path: PlaylistImagePath): """ Get playlist image @@ -114,9 +139,4 @@ def send_playlist_image(path: PlaylistImagePath): Images are constructed as '{playlist_id}.webp' """ folder = Paths.get_playlist_img_path() - fpath = Path(folder) / path.imgpath - - if fpath.exists(): - return send_from_directory(folder, path.imgpath) - - return send_fallback_img("playlist.svg") + return send_file_or_fallback(folder, path.imgpath, "playlist.svg") diff --git a/app/api/playlist.py b/app/api/playlist.py index 91461fdf..8ccd1582 100644 --- a/app/api/playlist.py +++ b/app/api/playlist.py @@ -26,6 +26,7 @@ api = APIBlueprint("playlists", __name__, url_prefix="/playlists", abp_tags=[tag PL = SQLitePlaylistMethods + class SendAllPlaylistsQuery(BaseModel): no_images: bool = Field(False, description="Whether to include images") @@ -410,7 +411,7 @@ def save_item_as_playlist(body: SavePlaylistAsItemBody): filename = itemhash + ".webp" base_path = ( - Paths.get_artist_img_lg_path() + Paths.get_lg_artist_img_path() if itemtype == "artist" else Paths.get_lg_thumb_path() ) diff --git a/app/arg_handler.py b/app/arg_handler.py index a1aa5037..b74acfe4 100644 --- a/app/arg_handler.py +++ b/app/arg_handler.py @@ -2,6 +2,7 @@ Handles arguments passed to the program. """ +from getpass import getpass import os.path import sys @@ -10,27 +11,37 @@ import PyInstaller.__main__ as bundler from app import settings from app.logger import log from app.print_help import HELP_MESSAGE +from app.utils.auth import hash_password from app.utils.paths import getFlaskOpenApiPath from app.utils.xdg_utils import get_xdg_config_dir from app.utils.wintools import is_windows +from app.db.sqlite.auth import SQLiteAuthMethods as authdb ALLARGS = settings.ALLARGS ARGS = sys.argv[1:] class ProcessArgs: + """ + Processes the arguments passed to the program. + """ + def __init__(self) -> None: + # resolve config path + self.handle_config_path() # 1 + + # handles that exit + self.handle_password_recovery() self.handle_build() - self.handle_host() - self.handle_port() - self.handle_config_path() - - self.handle_periodic_scan() - self.handle_periodic_scan_interval() - self.handle_help() self.handle_version() + # non-exiting handles + self.handle_host() + self.handle_port() + self.handle_periodic_scan() + self.handle_periodic_scan_interval() + @staticmethod def handle_build(): """ @@ -57,12 +68,13 @@ class ProcessArgs: value = settings.Info.get(key) if not value: - log.error(f"WARNING: {key} not set in environment") + log.error(f"WARNING: {key} not resolved. Exiting ...") sys.exit(0) lines.append(f'{key} = "{value}"\n') try: + # write the info to the config file with open("./app/configs.py", "w", encoding="utf-8") as file: # copy the api keys to the config file file.writelines(lines) @@ -189,3 +201,38 @@ class ProcessArgs: f"COMMIT#: {settings.Info.GIT_CURRENT_BRANCH}/{settings.Info.GIT_LATEST_COMMIT_HASH}" ) sys.exit(0) + + @staticmethod + def handle_password_recovery(): + if ALLARGS.pswd in ARGS: + print("SWING MUSIC v2.0.0 ") + print("PASSWORD RECOVERY \n") + + username: str = "" + password: str = "" + + # collect username + try: + username = input("Enter username: ") + except KeyboardInterrupt: + print("\nOperation cancelled! Exiting ...") + sys.exit(0) + + username = username.strip() + user = authdb.get_user_by_username(username) + + if not user: + print(f"User {username} not found") + sys.exit(0) + + # collect password + try: + password = getpass("Enter new password: ") + except KeyboardInterrupt: + print("\nOperation cancelled! Exiting ...") + sys.exit(0) + + password = hash_password(password) + user = authdb.update_user({"id": user.id, "password": password}) + + sys.exit(0) diff --git a/app/db/sqlite/auth.py b/app/db/sqlite/auth.py index 2c36d183..3d9da234 100644 --- a/app/db/sqlite/auth.py +++ b/app/db/sqlite/auth.py @@ -1,6 +1,6 @@ import json from app.models.user import User -from app.utils.auth import encode_password +from app.utils.auth import hash_password from app.db.sqlite.utils import SQLiteManager @@ -44,7 +44,7 @@ class SQLiteAuthMethods: """ user = { "username": "admin", - "password": encode_password("admin"), + "password": hash_password("admin"), "roles": json.dumps(["admin"]), } return SQLiteAuthMethods.insert_user(user) @@ -56,7 +56,7 @@ class SQLiteAuthMethods: """ user = { "username": "guest", - "password": encode_password("guest"), + "password": hash_password("guest"), "roles": json.dumps(["guest"]), } @@ -67,7 +67,7 @@ class SQLiteAuthMethods: """ Update a user in the database. - :param user: A dict with the username, password and roles. + :param user: A dict with the user id and the fields to update. Ommited fields will not be updated. """ # get all user dict keys keys = list(user.keys()) @@ -75,6 +75,7 @@ class SQLiteAuthMethods: {', '.join([f"{key} = :{key}" for key in keys if key != 'id'])} WHERE id = :id """ + print(sql, user) with SQLiteManager(userdata_db=True) as cur: cur.execute(sql, user) diff --git a/app/db/sqlite/migrations.py b/app/db/sqlite/migrations.py index 86d36899..550673a0 100644 --- a/app/db/sqlite/migrations.py +++ b/app/db/sqlite/migrations.py @@ -7,9 +7,9 @@ from app.db.sqlite.utils import SQLiteManager class MigrationManager: @staticmethod - def get_version() -> int: + def get_index() -> int: """ - Returns the latest userdata database version. + Returns the latest databases migrations index. """ sql = "SELECT * FROM dbmigrations" with SQLiteManager() as cur: @@ -21,9 +21,9 @@ class MigrationManager: # 👇 Setters 👇 @staticmethod - def set_version(version: int): + def set_index(version: int): """ - Sets the userdata pre-init database version. + Updates the databases migrations index. """ sql = "UPDATE dbmigrations SET version = ? WHERE id = 1" with SQLiteManager() as cur: diff --git a/app/lib/artistlib.py b/app/lib/artistlib.py index f49fcf0c..b40cb56a 100644 --- a/app/lib/artistlib.py +++ b/app/lib/artistlib.py @@ -55,13 +55,22 @@ def get_artist_image_link(artist: str): # TODO: Move network calls to utils/network.py class DownloadImage: def __init__(self, url: str, name: str) -> None: - sm_path = Path(settings.Paths.get_artist_img_sm_path()) / name - lg_path = Path(settings.Paths.get_artist_img_lg_path()) / name - img = self.download(url) - if img is not None: - self.save_img(img, sm_path, lg_path) + if img is None: + return + + sm_path = Path(settings.Paths.get_sm_artist_img_path()) / name + lg_path = Path(settings.Paths.get_lg_artist_img_path()) / name + md_path = Path(settings.Paths.get_md_artist_img_path()) / name + + entries = [ + (lg_path, None), # save in the original size + (sm_path, settings.Defaults.SM_ARTIST_IMG_SIZE), + (md_path, settings.Defaults.MD_ARTIST_IMG_SIZE), + ] + + self.save_img(img, entries) @staticmethod def download(url: str) -> Image.Image | None: @@ -74,14 +83,21 @@ class DownloadImage: return None @staticmethod - def save_img(img: Image.Image, sm_path: Path, lg_path: Path): + def save_img(img: Image.Image, entries: list[tuple[Path, int | None]]): """ Saves the image to the destinations. """ - img.save(lg_path, format="webp") + ratio = img.width / img.height + for entry in entries: + path, size = entry - sm_size = settings.Defaults.SM_ARTIST_IMG_SIZE - img.resize((sm_size, sm_size), Image.ANTIALIAS).save(sm_path, format="webp") + if size is None: + img.save(path, format="webp") + continue + + img.resize((size, int(size / ratio)), Image.ANTIALIAS).save( + path, format="webp" + ) class CheckArtistImages: @@ -90,7 +106,7 @@ class CheckArtistImages: CHECK_ARTIST_IMAGES_KEY = instance_key # read all files in the artist image folder - path = settings.Paths.get_artist_img_sm_path() + path = settings.Paths.get_sm_artist_img_path() processed = "".join(os.listdir(path)).replace("webp", "") # filter out artists that already have an image @@ -126,7 +142,7 @@ class CheckArtistImages: return img_path = ( - Path(settings.Paths.get_artist_img_sm_path()) / f"{artist.artisthash}.webp" + Path(settings.Paths.get_sm_artist_img_path()) / f"{artist.artisthash}.webp" ) if img_path.exists(): diff --git a/app/lib/colorlib.py b/app/lib/colorlib.py index 17dd5c4c..086fcb2b 100644 --- a/app/lib/colorlib.py +++ b/app/lib/colorlib.py @@ -42,7 +42,7 @@ def process_color(item_hash: str, is_album=True): path = ( settings.Paths.get_sm_thumb_path() if is_album - else settings.Paths.get_artist_img_sm_path() + else settings.Paths.get_sm_artist_img_path() ) path = Path(path) / (item_hash + ".webp") diff --git a/app/lib/taglib.py b/app/lib/taglib.py index 2a523144..3d7e9f6d 100644 --- a/app/lib/taglib.py +++ b/app/lib/taglib.py @@ -33,20 +33,22 @@ def extract_thumb(filepath: str, webp_path: str, overwrite=False) -> bool: """ lg_img_path = os.path.join(Paths.get_lg_thumb_path(), webp_path) sm_img_path = os.path.join(Paths.get_sm_thumb_path(), webp_path) + xms_img_path = os.path.join(Paths.get_xsm_thumb_path(), webp_path) + md_img_path = os.path.join(Paths.get_md_thumb_path(), webp_path) - tsize = Defaults.THUMB_SIZE - sm_tsize = Defaults.SM_THUMB_SIZE + images = [ + (lg_img_path, Defaults.LG_THUMB_SIZE), + (sm_img_path, Defaults.SM_THUMB_SIZE), + (xms_img_path, Defaults.XSM_THUMB_SIZE), + (md_img_path, Defaults.MD_THUMB_SIZE), + ] def save_image(img: Image.Image): width, height = img.size ratio = width / height - img.resize((tsize, int(tsize / ratio)), Image.ANTIALIAS).save( - lg_img_path, "webp" - ) - img.resize((sm_tsize, int(sm_tsize / ratio)), Image.ANTIALIAS).save( - sm_img_path, "webp" - ) + for path, size in images: + img.resize((size, int(size / ratio)), Image.ANTIALIAS).save(path, "webp") if not overwrite and os.path.exists(sm_img_path): img_size = os.path.getsize(sm_img_path) diff --git a/app/migrations/__init__.py b/app/migrations/__init__.py index 58efc9f0..e6d5c431 100644 --- a/app/migrations/__init__.py +++ b/app/migrations/__init__.py @@ -2,12 +2,6 @@ Migrations module. Reads and applies the latest database migrations. - -PLEASE NOTE: OLDER MIGRATIONS CAN NEVER BE DELETED. -ONLY MODIFY OLD MIGRATIONS FOR BUG FIXES OR ENHANCEMENTS ONLY -[TRY NOT TO MODIFY BEHAVIOR, UNLESS YOU KNOW WHAT YOU'RE DOING]. - -PS: Fuck that! Do what you want. """ from app.db.sqlite.migrations import MigrationManager @@ -33,21 +27,32 @@ migrations: list[list[Migration]] = [ def apply_migrations(): """ Applies the latest database migrations. + + The length of all the migrations is stored in the database + and used to check for new migrations. When the length of the + migrations list is larger than the number stored in the db, + migrations past that index are applied and the new length + is stored as the new migration index. """ - version = MigrationManager.get_version() + index = MigrationManager.get_index() + all_migrations = [migration for sublist in migrations for migration in sublist] - if version != len(migrations): - # INFO: Apply new migrations - for migration in migrations[version:]: - for m in migration: - try: - m.migrate() - log.info("Applied migration: %s", m.__name__) - except: - log.error("Failed to run migration: %s", m.__name__) - - print("Migrations applied successfully.") - print("Current migration version: ", len(migrations)) - # bump migration version - MigrationManager.set_version(len(migrations)) + to_apply: list[Migration] = [] + + # if index is from old release, + # get migrations from the "migrations" list + if index < 3: + _migrations = migrations[index:] + to_apply = [migration for sublist in _migrations for migration in sublist] + else: + to_apply = all_migrations[index:] + + for migration in to_apply: + try: + migration.migrate() + log.info("Applied migration: %s", migration.__name__) + except: + log.error("Failed to run migration: %s", migration.__name__) + + MigrationManager.set_index(len(all_migrations)) diff --git a/app/migrations/v1_4_9/__init__.py b/app/migrations/v1_4_9/__init__.py index 0b37ee98..6f0de717 100644 --- a/app/migrations/v1_4_9/__init__.py +++ b/app/migrations/v1_4_9/__init__.py @@ -58,3 +58,19 @@ class DeleteOriginalThumbnails(Migration): if os.path.exists(og_imgpath): shutil.rmtree(og_imgpath) + +class DeleteOriginalThumbnailsa(Migration): + """ + Original thumbnails are too large and are not needed. + """ + + # TODO: Implement this migration + + @staticmethod + def migrate(): + imgpath = Paths.get_thumbs_path() + og_imgpath = os.path.join(imgpath, "original") + + if os.path.exists(og_imgpath): + shutil.rmtree(og_imgpath) + diff --git a/app/periodic_scan.py b/app/periodic_scan.py index bf7d8c79..7a0af7e3 100644 --- a/app/periodic_scan.py +++ b/app/periodic_scan.py @@ -13,6 +13,12 @@ from app.logger import log def run_periodic_scans(): """ Runs periodic scans. + + Periodic scans are checks that run every few minutes + in the background to do stuff like: + - checking for new music + - delete deleted entries + - downloading artist images, and other data. """ # ValidateAlbumThumbs() # ValidatePlaylistThumbs() diff --git a/app/print_help.py b/app/print_help.py index 78393340..2ce7825c 100644 --- a/app/print_help.py +++ b/app/print_help.py @@ -1,4 +1,4 @@ -from app.settings import ALLARGS +from app.settings import ALLARGS, Info from tabulate import tabulate args = ALLARGS @@ -10,6 +10,7 @@ help_args_list = [ ["--port", "", "Set the port"], ["--config", "", "Set the config path"], ["--no-periodic-scan", "-nps", "Disable periodic scan"], + ["--pswd", "", "Recover a password"], [ "--scan-interval", "-psi", @@ -23,10 +24,12 @@ help_args_list = [ ] HELP_MESSAGE = f""" -Swing Music is a beautiful, self-hosted music player for your -local audio files. Like a cooler Spotify ... but bring your own music. +Swing Music v{Info.SWINGMUSIC_APP_VERSION} -Usage: swingmusic [options] [args] +A beautiful, self-hosted music player for your local audio files. +Like Spotify ... but bring your own music. -{tabulate(help_args_list, headers=["Option", "Short", "Description"], tablefmt="simple_grid", maxcolwidths=[None, None, 40])} +Usage: ./swingmusic [options] [args] + +{tabulate(help_args_list, headers=["Option", "Alias", "Description"], tablefmt="psql", maxcolwidths=[None, None, 40])} """ diff --git a/app/settings.py b/app/settings.py index cbaf0b35..c3c308fd 100644 --- a/app/settings.py +++ b/app/settings.py @@ -45,22 +45,24 @@ class Paths: def get_img_path(cls): return join(cls.get_app_dir(), "images") + # ARTISTS @classmethod def get_artist_img_path(cls): return join(cls.get_img_path(), "artists") @classmethod - def get_artist_img_sm_path(cls): + def get_sm_artist_img_path(cls): return join(cls.get_artist_img_path(), "small") @classmethod - def get_artist_img_lg_path(cls): - return join(cls.get_artist_img_path(), "large") + def get_md_artist_img_path(cls): + return join(cls.get_artist_img_path(), "medium") @classmethod - def get_playlist_img_path(cls): - return join(cls.get_img_path(), "playlists") + def get_lg_artist_img_path(cls): + return join(cls.get_artist_img_path(), "large") + # TRACK THUMBNAILS @classmethod def get_thumbs_path(cls): return join(cls.get_img_path(), "thumbnails") @@ -69,10 +71,23 @@ class Paths: def get_sm_thumb_path(cls): return join(cls.get_thumbs_path(), "small") + @classmethod + def get_xsm_thumb_path(cls): + return join(cls.get_thumbs_path(), "xsmall") + + @classmethod + def get_md_thumb_path(cls): + return join(cls.get_thumbs_path(), "medium") + @classmethod def get_lg_thumb_path(cls): return join(cls.get_thumbs_path(), "large") + # OTHERS + @classmethod + def get_playlist_img_path(cls): + return join(cls.get_img_path(), "playlists") + @classmethod def get_assets_path(cls): return join(Paths.get_app_dir(), "assets") @@ -92,12 +107,25 @@ class Paths: # defaults class Defaults: - THUMB_SIZE = 512 - SM_THUMB_SIZE = 128 + """ + Contains default values for various settings. + + XSM_THUMB_SIZE: extra small thumbnail size for web client tracklist + SM_THUMB_SIZE: small thumbnail size for android client tracklist + MD_THUMB_SIZE: medium thumbnail size for web client album cards + LG_THUMB_SIZE: large thumbnail size for web client now playing album art + + NOTE: LG_ARTIST_IMG_SIZE is not defined as the images are saved in the original size (500px) + """ + + XSM_THUMB_SIZE = 64 + SM_THUMB_SIZE = 96 + MD_THUMB_SIZE = 256 + LG_THUMB_SIZE = 512 + SM_ARTIST_IMG_SIZE = 128 - """ - The size of extracted images in pixels - """ + MD_ARTIST_IMG_SIZE = 256 + HASH_LENGTH = 10 API_ALBUMHASH = "bfe300e966" API_ARTISTHASH = "cae59f1fc5" @@ -105,7 +133,6 @@ class Defaults: API_ALBUMNAME = "The Goat" API_ARTISTNAME = "Polo G" API_TRACKNAME = "Martin & Gina" - API_CARD_LIMIT = 6 @@ -162,6 +189,8 @@ class ALLARGS: host = "--host" config = "--config" + pswd = "--pswd" + show_feat = ("--show-feat", "-sf") show_prod = ("--show-prod", "-sp") dont_clean_albums = ("--no-clean-albums", "-nca") @@ -275,6 +304,7 @@ class Info: NOTE: This class initially written to load keys when running in build mode. TODO: Remove this class entirely, and implement functionality where needed. """ + SWINGMUSIC_APP_VERSION = os.environ.get("SWINGMUSIC_APP_VERSION") GIT_LATEST_COMMIT_HASH = "" GIT_CURRENT_BRANCH = "" @@ -289,12 +319,6 @@ class Info: cls.GIT_LATEST_COMMIT_HASH = getLatestCommitHash() cls.GIT_CURRENT_BRANCH = getCurrentBranch() - cls.verify_keys() - - @classmethod - def verify_keys(cls): - pass - @classmethod def get(cls, key: str): return getattr(cls, key, None) diff --git a/app/setup/__init__.py b/app/setup/__init__.py index 4f784bea..36e7e356 100644 --- a/app/setup/__init__.py +++ b/app/setup/__init__.py @@ -14,6 +14,9 @@ from app.config import UserConfig def run_setup(): + """ + Creates the config directory, runs migrations, and loads settings. + """ create_config_dir() # setup config file @@ -32,6 +35,11 @@ def run_setup(): # settings table is empty pass + +def load_into_mem(): + """ + Load all tracks, albums, and artists into memory. + """ instance_key = get_random_str() # INFO: Load all tracks, albums, and artists into memory diff --git a/app/setup/files.py b/app/setup/files.py index 46f6a001..4e875750 100644 --- a/app/setup/files.py +++ b/app/setup/files.py @@ -51,28 +51,28 @@ def create_config_dir() -> None: """ Creates the config directory if it doesn't exist. """ - thumb_path = os.path.join("images", "thumbnails") - small_thumb_path = os.path.join(thumb_path, "small") - large_thumb_path = os.path.join(thumb_path, "large") + sm_thumb_path = settings.Paths.get_sm_thumb_path() + lg_thumb_path = settings.Paths.get_lg_thumb_path() + md_thumb_path = settings.Paths.get_md_thumb_path() + xsm_thumb_path = settings.Paths.get_xsm_thumb_path() - artist_img_path = os.path.join("images", "artists") - small_artist_img_path = os.path.join(artist_img_path, "small") - large_artist_img_path = os.path.join(artist_img_path, "large") + small_artist_img_path = settings.Paths.get_sm_artist_img_path() + md_artist_img_path = settings.Paths.get_md_artist_img_path() + large_artist_img_path = settings.Paths.get_lg_artist_img_path() playlist_img_path = os.path.join("images", "playlists") dirs = [ "", # creates the config folder - "images", - "plugins", + sm_thumb_path, + lg_thumb_path, + md_thumb_path, + xsm_thumb_path, "plugins/lyrics", - thumb_path, - small_thumb_path, - large_thumb_path, - artist_img_path, + playlist_img_path, + md_artist_img_path, small_artist_img_path, large_artist_img_path, - playlist_img_path, ] for _dir in dirs: @@ -80,7 +80,9 @@ def create_config_dir() -> None: exists = os.path.exists(path) if not exists: - os.makedirs(path) + # exist_ok=True to create parent directories if they don't exist + os.makedirs(path, exist_ok=True) os.chmod(path, 0o755) + # copy assets to the app directory CopyFiles() diff --git a/app/start_info_logger.py b/app/start_info_logger.py index e27a361e..5b497341 100644 --- a/app/start_info_logger.py +++ b/app/start_info_logger.py @@ -10,7 +10,7 @@ def log_startup_info(): # os.system("cls" if os.name == "nt" else "echo -e \\\\033c") print(lines) - print(f"{TCOLOR.HEADER}SwingMusic {Info.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}") + print(f"{TCOLOR.HEADER}Swing Music v{Info.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}") adresses = [FLASKVARS.get_flask_host()] diff --git a/app/utils/auth.py b/app/utils/auth.py index e8b909af..87f20c75 100644 --- a/app/utils/auth.py +++ b/app/utils/auth.py @@ -4,13 +4,13 @@ import hashlib from app.config import UserConfig -def encode_password(password: str) -> str: +def hash_password(password: str) -> str: """ - This function encodes the given password. + Hashes the given password using sha256 algorithm and the user id as salt. - :param password: The password to encode. + :param password: The password to hash. - :return: The encoded password. + :return: The hashed password. """ return hashlib.pbkdf2_hmac( @@ -18,14 +18,14 @@ def encode_password(password: str) -> str: ).hex() -def check_password(password: str, encoded: str) -> bool: +def check_password(password: str, hashed: str) -> bool: """ - This function checks if the given password matches the encoded password. + This function checks if the given password matches the hashed password. :param password: The password to check. - :param encoded: The encoded password. + :param hashed: The hashed password. :return: Whether the password matches. """ - return hmac.compare_digest(encode_password(password), encoded) + return hmac.compare_digest(hash_password(password), hashed) diff --git a/manage.py b/manage.py index 4210463e..f27c26bd 100644 --- a/manage.py +++ b/manage.py @@ -25,7 +25,7 @@ from app.lib.watchdogg import Watcher as WatchDog from app.periodic_scan import run_periodic_scans from app.plugins.register import register_plugins from app.settings import FLASKVARS, TCOLOR, Info -from app.setup import run_setup +from app.setup import load_into_mem, run_setup from app.start_info_logger import log_startup_info from app.utils.filesystem import get_home_res_path from app.utils.paths import getClientFilesExtensions @@ -47,10 +47,9 @@ mimetypes.add_type("application/manifest+json", ".webmanifest") werkzeug = logging.getLogger("werkzeug") werkzeug.setLevel(logging.ERROR) - # Background tasks @background -def bg_run_setup() -> None: +def bg_run_setup(): run_periodic_scans() @@ -71,9 +70,10 @@ def run_swingmusic(): # Setup function calls +Info.load() ProcessArgs() run_setup() -Info.load() +load_into_mem() run_swingmusic()