diff --git a/app/api/__init__.py b/app/api/__init__.py index 3f6aa250..5c3e7f64 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -12,6 +12,7 @@ from flask_jwt_extended import JWTManager from app.settings import Keys from .plugins import lyrics as lyrics_plugin +from app.db.sqlite.auth import SQLiteAuthMethods as authdb from app.api import ( album, artist, @@ -80,14 +81,16 @@ def create_api(): # JWT jwt = JWTManager(app) - @jwt.user_identity_loader - def user_identity_lookup(user): - return user + # @jwt.user_identity_loader + # def user_identity_lookup(user): + # return user @jwt.user_lookup_loader def user_lookup_callback(_jwt_header, jwt_data): identity = jwt_data["sub"] - return identity + userid = identity["id"] + user = authdb.get_user_by_id(userid) + return user.todict() # Register all the API blueprints with app.app_context(): diff --git a/app/api/auth.py b/app/api/auth.py index b903f1e6..a22884fc 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -1,4 +1,5 @@ from dataclasses import asdict +import json from flask import jsonify from flask_jwt_extended import create_access_token, current_user, set_access_cookies from pydantic import BaseModel, Field @@ -39,6 +40,39 @@ def login(body: LoginBody): return res +class UpdateProfileBody(BaseModel): + email: str = Field("", description="The email") + username: str = Field("", description="The username", example="user0") + password: str = Field("", description="The password", example="password0") + roles: list[str] = Field([], description="The roles") + + +@api.put("/profile/update") +def update_profile(body: UpdateProfileBody): + + user = { + "id": current_user["id"], + "email": body.email, + "username": body.username, + "password": body.password, + "roles": body.roles, + } + + # only admins can update roles + if body.roles: + if "admin" in current_user["roles"]: + # prevent admin from locking themselves out + roles = set(body.roles) + roles.add("admin") + user["roles"] = json.dumps(list(roles)) + else: + user.pop("roles") + + # remove empty values + clean_user = {k: v for k, v in user.items() if v} + return authdb.update_user(clean_user) + + @api.get("/logout") def logout(): """ diff --git a/app/db/sqlite/auth.py b/app/db/sqlite/auth.py index c637de54..83821500 100644 --- a/app/db/sqlite/auth.py +++ b/app/db/sqlite/auth.py @@ -9,6 +9,33 @@ class SQLiteAuthMethods: Methods for authenticating users. """ + @staticmethod + def insert_user(user: dict[str, str]): + """ + Insert a user into the database. + + :param user: A dict with the username, password and roles. + """ + sql = """INSERT INTO users( + username, + password, + roles + ) VALUES(:username, :password, :roles) + """ + + user_tuple = tuple(user.values()) + + with SQLiteManager(userdata_db=True) as cur: + cur = cur.execute(sql, user_tuple) + userid = cur.lastrowid + + if userid: + user = SQLiteAuthMethods.get_user_by_id(userid).todict_simplified() + cur.close() + return user + + raise Exception(f"Failed to insert user: {user}") + @staticmethod def insert_default_user(): """ @@ -19,18 +46,7 @@ class SQLiteAuthMethods: "password": encode_password("admin"), "roles": json.dumps(["admin"]), } - user_tuple = tuple(user.values()) - - sql = """INSERT INTO users( - username, - password, - roles - ) VALUES(:username, :password, :roles) - """ - - with SQLiteManager(userdata_db=True) as cur: - cur.execute(sql, user_tuple) - cur.close() + return SQLiteAuthMethods.insert_user(user) @staticmethod def insert_guest_user(): @@ -40,25 +56,31 @@ class SQLiteAuthMethods: user = { "username": "guest", "password": encode_password("guest"), - "firstname": "Guest", - "lastname": "User", "roles": json.dumps(["guest"]), } - user_tuple = tuple(user.values()) - sql = """INSERT INTO users( - username, - password, - firstname, - lastname, - roles - ) VALUES(:username, :password, :firstname, :lastname, :roles) + return SQLiteAuthMethods.insert_user(user) + + @staticmethod + def update_user(user: dict[str, str]): + """ + Update a user in the database. + + :param user: A dict with the username, password and roles. + """ + # get all user dict keys + keys = list(user.keys()) + sql = f"""UPDATE users SET + {', '.join([f"{key} = :{key}" for key in keys if key != 'id'])} + WHERE id = :id """ with SQLiteManager(userdata_db=True) as cur: - cur.execute(sql, user_tuple) + cur.execute(sql, user) cur.close() + return SQLiteAuthMethods.get_user_by_id(user["id"]).todict() + @staticmethod def get_all_users(): """ @@ -90,4 +112,22 @@ class SQLiteAuthMethods: if data is not None: return User(*data) - return None \ No newline at end of file + return None + + @staticmethod + def get_user_by_id(userid: int): + """ + Get a user by id. + """ + sql = "SELECT * FROM users WHERE id = ?" + + with SQLiteManager(userdata_db=True) as cur: + cur.execute(sql, (userid,)) + + data = cur.fetchone() + cur.close() + + if data is not None: + return User(*data) + + return None diff --git a/manage.py b/manage.py index dc852b7c..de855e35 100644 --- a/manage.py +++ b/manage.py @@ -53,7 +53,6 @@ def verify_auth(): """ Verifies the JWT token before each request. """ - print(request.path) if request.path == "/" or any( request.path.endswith(ext) for ext in blacklist_extensions ): @@ -66,8 +65,7 @@ def verify_auth(): ) return - data = verify_jwt_in_request() - print(data) + verify_jwt_in_request() @app.route("/")