protect settings write routes

+ prevent updating guest user
+ add docs to whitelisted auth routes
+ fix: sort in get all route
+ fix: folders not having trailing slash in recentlyplayed
This commit is contained in:
mungai-njoroge
2024-05-01 23:44:38 +03:00
parent cfeff7ff51
commit 5d947f3ad9
6 changed files with 37 additions and 23 deletions
+19 -11
View File
@@ -83,26 +83,33 @@ def update_profile(body: UpdateProfileBody):
"roles": body.roles, "roles": body.roles,
} }
# prevent updating guest
if current_user["username"] == "guest" or user["username"] == "guest":
return {"msg": "Cannot update guest user"}, 400
# if not id, update self # if not id, update self
if not user["id"]: if not user["id"]:
user["id"] = current_user["id"] user["id"] = current_user["id"]
print("current_user: ", current_user)
# only admins can update roles
if body.roles is not None: if body.roles is not None:
# only admins can update roles
if "admin" not in current_user["roles"]: if "admin" not in current_user["roles"]:
return {"msg": "Only admins can update roles"}, 403 return {"msg": "Only admins can update roles"}, 403
all_users = authdb.get_all_users()
if "admin" not in body.roles: if "admin" not in body.roles:
# check if we're removing the last admin # check if we're removing the last admin
users = authdb.get_all_users() admins = [user for user in all_users if "admin" in user.roles]
admins = [user for user in users if "admin" in user.roles]
if len(admins) == 1 and admins[0].id == user["id"]: if len(admins) == 1 and admins[0].id == user["id"]:
return {"msg": "Cannot remove the only admin"}, 400 return {"msg": "Cannot remove the only admin"}, 400
# guest roles cannot be updated
_user = [u for u in all_users if u.id == user["id"]][0]
if "guest" in _user.roles:
return {"msg": "Cannot update guest user"}, 400
# finally, convert roles to json string
user["roles"] = json.dumps(body.roles) user["roles"] = json.dumps(body.roles)
if user["password"]: if user["password"]:
@@ -227,8 +234,10 @@ def get_all_users(query: GetAllUsersQuery):
"users": [], "users": [],
} }
is_admin = current_user and "admin" in current_user["roles"]
# if user is admin, also return settings # if user is admin, also return settings
if current_user and "admin" in current_user["roles"]: if is_admin:
res = { res = {
"settings": settings, "settings": settings,
} }
@@ -248,11 +257,10 @@ def get_all_users(query: GetAllUsersQuery):
users = authdb.get_all_users() users = authdb.get_all_users()
# remove guest user # remove guest user
print("settings: ", settings["enableGuest"]) # if not settings["enableGuest"]:
if not settings["enableGuest"]: # users = [user for user in users if user.username != "guest"]
users = [user for user in users if user.username != "guest"]
if not settings["usersOnLogin"]: if not is_admin or not settings["usersOnLogin"]:
users = [user for user in users if user.username == "guest"] users = [user for user in users if user.username == "guest"]
# reverse list to show latest users first # reverse list to show latest users first
+5 -8
View File
@@ -22,7 +22,7 @@ bp_tag = Tag(name="Get all", description="List all items")
api = APIBlueprint("getall", __name__, url_prefix="/getall", abp_tags=[bp_tag]) api = APIBlueprint("getall", __name__, url_prefix="/getall", abp_tags=[bp_tag])
class GetAllItemsBody(GenericLimitSchema): class GetAllItemsQuery(GenericLimitSchema):
start: int = Field( start: int = Field(
description="The start index of the items to return", description="The start index of the items to return",
example=0, example=0,
@@ -34,10 +34,10 @@ class GetAllItemsBody(GenericLimitSchema):
default="created_date", default="created_date",
) )
reverse: int = Field( reverse: str = Field(
description="Reverse the sort", description="Reverse the sort",
example=1, example=1,
default=1, default="1",
) )
@@ -50,7 +50,7 @@ class GetAllItemsPath(BaseModel):
@api.get("/<itemtype>") @api.get("/<itemtype>")
def get_all_items(path: GetAllItemsPath, query: GetAllItemsBody): def get_all_items(path: GetAllItemsPath, query: GetAllItemsQuery):
""" """
Get all items Get all items
@@ -67,10 +67,7 @@ def get_all_items(path: GetAllItemsPath, query: GetAllItemsBody):
start = query.start start = query.start
limit = query.limit limit = query.limit
sort = query.sortby sort = query.sortby
reverse = query.reverse == 1 reverse = query.reverse == "1"
# if sort == "":
# sort = "created_date"
sort_is_count = sort == "count" sort_is_count = sort == "count"
sort_is_duration = sort == "duration" sort_is_duration = sort == "duration"
+3
View File
@@ -3,6 +3,7 @@ from flask import Blueprint, request
from flask_openapi3 import Tag from flask_openapi3 import Tag
from flask_openapi3 import APIBlueprint from flask_openapi3 import APIBlueprint
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.api.auth import admin_required
from app.db.sqlite.plugins import PluginsMethods from app.db.sqlite.plugins import PluginsMethods
bp_tag = Tag(name="Plugins", description="Manage plugins") bp_tag = Tag(name="Plugins", description="Manage plugins")
@@ -30,6 +31,7 @@ class PluginActivateBody(PluginBody):
@api.post("/setactive") @api.post("/setactive")
@admin_required()
def activate_deactivate_plugin(body: PluginActivateBody): def activate_deactivate_plugin(body: PluginActivateBody):
""" """
Activate/Deactivate plugin Activate/Deactivate plugin
@@ -49,6 +51,7 @@ class PluginSettingsBody(PluginBody):
@api.post("/settings") @api.post("/settings")
@admin_required()
def update_plugin_settings(body: PluginSettingsBody): def update_plugin_settings(body: PluginSettingsBody):
""" """
Update plugin settings Update plugin settings
+2
View File
@@ -97,6 +97,7 @@ class AddRootDirsBody(BaseModel):
@api.post("/add-root-dirs") @api.post("/add-root-dirs")
@admin_required()
def add_root_dirs(body: AddRootDirsBody): def add_root_dirs(body: AddRootDirsBody):
""" """
Add custom root directories to the database. Add custom root directories to the database.
@@ -216,6 +217,7 @@ class SetSettingBody(BaseModel):
@api.post("/set") @api.post("/set")
@admin_required()
def set_setting(body: SetSettingBody): def set_setting(body: SetSettingBody):
""" """
Set a setting. Set a setting.
+5 -1
View File
@@ -82,9 +82,13 @@ def get_recently_played(limit=7):
if entry.type == "folder": if entry.type == "folder":
folder = entry.type_src folder = entry.type_src
if not folder: if not folder:
continue continue
if not folder.endswith("/"):
folder += "/"
is_home_dir = entry.type_src == "$home" is_home_dir = entry.type_src == "$home"
if is_home_dir: if is_home_dir:
@@ -98,7 +102,7 @@ def get_recently_played(limit=7):
{ {
"type": "folder", "type": "folder",
"item": { "item": {
"path": entry.type_src, "path": folder,
"count": count, "count": count,
"help_text": "folder", "help_text": "folder",
"time": timestamp_to_time_passed(entry.timestamp), "time": timestamp_to_time_passed(entry.timestamp),
+3 -3
View File
@@ -44,7 +44,7 @@ app = create_api()
app.static_folder = get_home_res_path("client") app.static_folder = get_home_res_path("client")
# INFO: Routes that don't need authentication # INFO: Routes that don't need authentication
blacklist_routes = {"/auth/login", "/auth/users", "/auth/logout"} whitelisted_routes = {"/auth/login", "/auth/users", "/auth/logout", "/docs"}
blacklist_extensions = {".webp"}.union(getClientFilesExtensions()) blacklist_extensions = {".webp"}.union(getClientFilesExtensions())
@@ -59,9 +59,9 @@ def verify_auth():
return return
# if request path starts with any of the blacklisted routes, don't verify jwt # if request path starts with any of the blacklisted routes, don't verify jwt
if any(request.path.startswith(route) for route in blacklist_routes): if any(request.path.startswith(route) for route in whitelisted_routes):
# print( # print(
# "Found blacklisted route: ", request.path, "... Skipping jwt verification" # "Found whitelisted route: ", request.path, "... Skipping jwt verification"
# ) # )
return return