add json config and its manager class

+ rewrite logic to prevent removing last admin role
+ handle showing users on login and enabling guest
This commit is contained in:
mungai-njoroge
2024-04-29 16:31:30 +03:00
parent 0ff5661765
commit cfeff7ff51
6 changed files with 234 additions and 23 deletions
+95 -19
View File
@@ -1,14 +1,21 @@
import json import json
from dataclasses import asdict from dataclasses import asdict
from functools import wraps from functools import wraps
import sqlite3
from flask import jsonify from flask import jsonify
from flask_jwt_extended import create_access_token, current_user, set_access_cookies from flask_jwt_extended import (
create_access_token,
current_user,
jwt_required,
set_access_cookies,
)
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from flask_openapi3 import Tag from flask_openapi3 import Tag
from flask_openapi3 import APIBlueprint from flask_openapi3 import APIBlueprint
from app.db.sqlite.auth import SQLiteAuthMethods as authdb from app.db.sqlite.auth import SQLiteAuthMethods as authdb
from app.utils.auth import check_password, encode_password from app.utils.auth import check_password, encode_password
from app.config import UserConfig
bp_tag = Tag(name="Auth", description="Authentication stuff") bp_tag = Tag(name="Auth", description="Authentication stuff")
api = APIBlueprint("auth", __name__, url_prefix="/auth", abp_tags=[bp_tag]) api = APIBlueprint("auth", __name__, url_prefix="/auth", abp_tags=[bp_tag])
@@ -51,7 +58,7 @@ def login(body: LoginBody):
password_ok = check_password(body.password, user.password) password_ok = check_password(body.password, user.password)
if not password_ok: if not password_ok:
return {"msg": "Invalid password"}, 401 return {"msg": "Hehe! invalid password"}, 401
access_token = create_access_token(identity=user.todict()) access_token = create_access_token(identity=user.todict())
set_access_cookies(res, access_token) set_access_cookies(res, access_token)
@@ -59,38 +66,55 @@ def login(body: LoginBody):
class UpdateProfileBody(BaseModel): class UpdateProfileBody(BaseModel):
id: int = Field(0, description="The user id")
email: str = Field("", description="The email") email: str = Field("", description="The email")
username: str = Field("", description="The username", example="user0") username: str = Field("", description="The username", example="user0")
password: str = Field("", description="The password", example="password0") password: str = Field("", description="The password", example="password0")
roles: list[str] = Field([], description="The roles") roles: list[str] = Field(None, description="The roles")
@api.put("/profile/update") @api.put("/profile/update")
def update_profile(body: UpdateProfileBody): def update_profile(body: UpdateProfileBody):
user = { user = {
"id": current_user["id"], "id": body.id,
"email": body.email, "email": body.email,
"username": body.username, "username": body.username,
"password": body.password, "password": body.password,
"roles": body.roles, "roles": body.roles,
} }
# if not id, update self
if not user["id"]:
user["id"] = current_user["id"]
print("current_user: ", current_user)
# only admins can update roles # only admins can update roles
if body.roles:
if "admin" in current_user["roles"]: if body.roles is not None:
# prevent admin from locking themselves out if "admin" not in current_user["roles"]:
roles = set(body.roles) return {"msg": "Only admins can update roles"}, 403
roles.add("admin")
user["roles"] = json.dumps(list(roles)) if "admin" not in body.roles:
else: # check if we're removing the last admin
user.pop("roles") users = authdb.get_all_users()
admins = [user for user in users if "admin" in user.roles]
if len(admins) == 1 and admins[0].id == user["id"]:
return {"msg": "Cannot remove the only admin"}, 400
user["roles"] = json.dumps(body.roles)
if user["password"]: if user["password"]:
user["password"] = encode_password(user["password"]) user["password"] = encode_password(user["password"])
# remove empty values # remove empty values
clean_user = {k: v for k, v in user.items() if v} clean_user = {k: v for k, v in user.items() if v}
return authdb.update_user(clean_user)
try:
return authdb.update_user(clean_user)
except sqlite3.IntegrityError:
return {"msg": "Username already exists"}, 400
@api.post("/profile/create") @api.post("/profile/create")
@@ -150,7 +174,7 @@ def delete_user(body: DeleteUseBody):
Delete a user by username Delete a user by username
""" """
# prevent admin from deleting themselves # prevent admin from deleting themselves
if body.username == current_user["usrname"]: if body.username == current_user["username"]:
return {"msg": "Sorry! you cannot delete yourselfu"}, 400 return {"msg": "Sorry! you cannot delete yourselfu"}, 400
# prevent deleting the only admin # prevent deleting the only admin
@@ -160,6 +184,11 @@ def delete_user(body: DeleteUseBody):
return {"msg": "Cannot delete the only admin"}, 400 return {"msg": "Cannot delete the only admin"}, 400
authdb.delete_user_by_username(body.username) authdb.delete_user_by_username(body.username)
# if user is guest, update config
if body.username == "guest":
UserConfig().enableGuest = False
return {"msg": f"User {body.username} deleted"} return {"msg": f"User {body.username} deleted"}
@@ -180,14 +209,51 @@ class GetAllUsersQuery(BaseModel):
@api.get("/users") @api.get("/users")
@jwt_required(optional=True)
def get_all_users(query: GetAllUsersQuery): def get_all_users(query: GetAllUsersQuery):
""" """
Get all users Get all users (if you're an admin, you will also receive accounts settings)
""" """
config = UserConfig()
# config.enableGuest = True
# config.usersOnLogin = True
settings = {
"enableGuest": config.enableGuest,
"usersOnLogin": config.usersOnLogin,
}
res = {
"settings": {},
"users": [],
}
# if user is admin, also return settings
if current_user and "admin" in current_user["roles"]:
res = {
"settings": settings,
}
# if is normal user, return empty response
elif current_user:
return res
# if not logged in and showing users on login is disabled, return empty response
elif (
not current_user
and not settings["usersOnLogin"]
and not settings["enableGuest"]
):
return res
users = authdb.get_all_users() users = authdb.get_all_users()
# remove guest user # remove guest user
users = [user for user in users if user.username != "guest"] print("settings: ", settings["enableGuest"])
if not settings["enableGuest"]:
users = [user for user in users if user.username != "guest"]
if not settings["usersOnLogin"]:
users = [user for user in users if user.username == "guest"]
# reverse list to show latest users first # reverse list to show latest users first
users = list(reversed(users)) users = list(reversed(users))
@@ -195,10 +261,20 @@ def get_all_users(query: GetAllUsersQuery):
# bring admins to the front # bring admins to the front
users = sorted(users, key=lambda x: "admin" in x.roles, reverse=True) users = sorted(users, key=lambda x: "admin" in x.roles, reverse=True)
if query.simplified: # bring current user to index 0
return [user.todict_simplified() for user in users] if current_user:
users = sorted(
users,
key=lambda x: x.username == current_user["username"],
reverse=True,
)
return [user.todict() for user in users] if query.simplified:
res["users"] = [user.todict_simplified() for user in users]
res["users"] = [user.todict() for user in users]
return res
@api.route("/user") @api.route("/user")
+28 -2
View File
@@ -3,6 +3,7 @@ from flask import request
from flask_openapi3 import Tag from flask_openapi3 import Tag
from flask_openapi3 import APIBlueprint from flask_openapi3 import APIBlueprint
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.api.auth import admin_required
from app.db.sqlite.plugins import PluginsMethods as pdb from app.db.sqlite.plugins import PluginsMethods as pdb
from app.db.sqlite.settings import SettingsSQLMethods as sdb from app.db.sqlite.settings import SettingsSQLMethods as sdb
@@ -15,6 +16,7 @@ from app.store.artists import ArtistStore
from app.store.tracks import TrackStore from app.store.tracks import TrackStore
from app.utils.generators import get_random_str from app.utils.generators import get_random_str
from app.utils.threading import background from app.utils.threading import background
from app.config import UserConfig
bp_tag = Tag(name="Settings", description="Customize stuff") bp_tag = Tag(name="Settings", description="Customize stuff")
api = APIBlueprint("settings", __name__, url_prefix="/notsettings", abp_tags=[bp_tag]) api = APIBlueprint("settings", __name__, url_prefix="/notsettings", abp_tags=[bp_tag])
@@ -69,7 +71,7 @@ def rebuild_store(db_dirs: list[str]):
log.info("Rebuilding library... ✅") log.info("Rebuilding library... ✅")
# I freaking don't know what this function does anymore # I freaking don't know what this function does anymore
def finalize(new_: list[str], removed_: list[str], db_dirs_: list[str]): def finalize(new_: list[str], removed_: list[str], db_dirs_: list[str]):
""" """
Params: Params:
@@ -162,7 +164,6 @@ mapp = {
@api.get("") @api.get("")
def get_all_settings(): def get_all_settings():
""" """
Get all settings Get all settings
@@ -265,3 +266,28 @@ def trigger_scan():
run_populate() run_populate()
return {"msg": "Scan triggered!"} return {"msg": "Scan triggered!"}
class UpdateConfigBody(BaseModel):
key: str = Field(
description="The setting key",
example="usersOnLogin",
)
value: Any = Field(
description="The setting value",
example=False,
)
@api.put("/update")
@admin_required()
def update_config(body: UpdateConfigBody):
"""
Update the config file
"""
config = UserConfig()
setattr(config, body.key, body.value)
return {
"msg": "Config updated!",
}
+97
View File
@@ -0,0 +1,97 @@
from dataclasses import dataclass, asdict, field
import json
import os
import time
from typing import Any
from .settings import Paths
# TODO: Publish this on PyPi
@dataclass
class UserConfig:
_config_path: str = ""
# NOTE: only auth stuff are used (the others are still reading/writing to db)
# TODO: Move the rest of the settings to the config file
# auth stuff
usersOnLogin: bool = True
enableGuest: bool = False
# lists
rootDirs: list[str] = field(default_factory=list)
excludeDirs: list[str] = field(default_factory=list)
artistSeparators: set[str] = field(default_factory=list)
# tracks
extractFeaturedArtists: bool = True
removeProdBy: bool = True
removeRemasterInfo: bool = True
# albums
mergeAlbums: bool = False
cleanAlbumTitle: bool = True
showAlbumsAsSingles: bool = False
def __post_init__(self):
"""
Loads the config file and sets the values to this instance
"""
# set config path locally to avoid writing to file
config_path = Paths.get_config_file_path()
try:
config = self.load_config(config_path)
except FileNotFoundError:
self._config_path = config_path
return
# loop through the config file and set the values
for key, value in config.items():
setattr(self, key, value)
# finally set the config path
self._config_path = config_path
def setup_config_file(self) -> None:
"""
Creates the config file with the default settings
if it doesn't exist
"""
print("config path: ", self._config_path)
# if not exists, create the config file
if not os.path.exists(self._config_path):
self.write_to_file(asdict(self))
def load_config(self, path: str) -> dict[str, Any]:
"""
Reads the settings from the config file.
Returns a dictget_root_dirs
"""
with open(path, "r") as f:
settings = json.load(f)
return settings
def write_to_file(self, settings: dict[str, Any]):
"""
Writes the settings to the config file
"""
# remove internal attributes
settings = {k: v for k, v in settings.items() if not k.startswith("_")}
with open(self._config_path, "w") as f:
json.dump(settings, f, indent=4)
def __setattr__(self, key: str, value: Any) -> None:
"""
Writes to the config file whenever a value is set
"""
super().__setattr__(key, value)
# if is internal attribute, don't write to file
if key.startswith("_") or not self._config_path:
return
print(f"writing to file: {key}={value}")
self.write_to_file(asdict(self))
+7 -1
View File
@@ -85,6 +85,10 @@ class Paths:
def get_lyrics_plugins_path(cls): def get_lyrics_plugins_path(cls):
return join(Paths.get_plugins_path(), "lyrics") return join(Paths.get_plugins_path(), "lyrics")
@classmethod
def get_config_file_path(cls):
return join(cls.get_app_dir(), "settings.json")
# defaults # defaults
class Defaults: class Defaults:
@@ -268,7 +272,9 @@ class Keys:
SWINGMUSIC_APP_VERSION = os.environ.get("SWINGMUSIC_APP_VERSION") SWINGMUSIC_APP_VERSION = os.environ.get("SWINGMUSIC_APP_VERSION")
GIT_LATEST_COMMIT_HASH = "<unset>" GIT_LATEST_COMMIT_HASH = "<unset>"
GIT_CURRENT_BRANCH = "<unset>" GIT_CURRENT_BRANCH = "<unset>"
JWT_SECRET_KEY = "swingmusic_secret_key" # REVIEW: This should be set in the environment JWT_SECRET_KEY = (
"swingmusic_secret_key" # REVIEW: This should be set in the environment
)
@classmethod @classmethod
def load(cls): def load(cls):
+6
View File
@@ -2,6 +2,7 @@
Prepares the server for use. Prepares the server for use.
""" """
from dataclasses import asdict
from app.db.sqlite.settings import load_settings from app.db.sqlite.settings import load_settings
from app.setup.files import create_config_dir from app.setup.files import create_config_dir
from app.setup.sqlite import run_migrations, setup_sqlite from app.setup.sqlite import run_migrations, setup_sqlite
@@ -9,6 +10,7 @@ from app.store.albums import AlbumStore
from app.store.artists import ArtistStore from app.store.artists import ArtistStore
from app.store.tracks import TrackStore from app.store.tracks import TrackStore
from app.utils.generators import get_random_str from app.utils.generators import get_random_str
from app.config import UserConfig
def run_setup(): def run_setup():
@@ -22,6 +24,10 @@ def run_setup():
# settings table is empty # settings table is empty
pass pass
# setup config file
config = UserConfig()
config.setup_config_file()
instance_key = get_random_str() instance_key = get_random_str()
# INFO: Load all tracks, albums, and artists into memory # INFO: Load all tracks, albums, and artists into memory
+1 -1
View File
@@ -44,7 +44,7 @@ app = create_api()
app.static_folder = get_home_res_path("client") app.static_folder = get_home_res_path("client")
# INFO: Routes that don't need authentication # INFO: Routes that don't need authentication
blacklist_routes = {"/auth/login", "/auth/users"} blacklist_routes = {"/auth/login", "/auth/users", "/auth/logout"}
blacklist_extensions = {".webp"}.union(getClientFilesExtensions()) blacklist_extensions = {".webp"}.union(getClientFilesExtensions())