set up auth

This commit is contained in:
mungai-njoroge
2024-04-25 18:18:52 +03:00
parent b1de2c7321
commit 04957dd5a9
15 changed files with 350 additions and 6 deletions
+28 -4
View File
@@ -8,6 +8,7 @@ from flask_compress import Compress
from flask_openapi3 import Info from flask_openapi3 import Info
from flask_openapi3 import OpenAPI from flask_openapi3 import OpenAPI
from flask_jwt_extended import JWTManager
from app.settings import Keys from app.settings import Keys
from .plugins import lyrics as lyrics_plugin from .plugins import lyrics as lyrics_plugin
@@ -27,6 +28,7 @@ from app.api import (
logger, logger,
home, home,
getall, getall,
auth,
) )
# TODO: Move this description to a separate file # TODO: Move this description to a separate file
@@ -60,13 +62,34 @@ def create_api():
app = OpenAPI(__name__, info=api_info, doc_prefix="/docs") app = OpenAPI(__name__, info=api_info, doc_prefix="/docs")
CORS(app, origins="*") # JWT CONFIGS
Compress(app) app.config["JWT_SECRET_KEY"] = Keys.JWT_SECRET_KEY
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
app.config["JWT_COOKIE_CSRF_PROTECT"] = False
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=1)
# CORS
CORS(app, origins="*", supports_credentials=True)
# RESPONSE COMPRESSION
Compress(app)
app.config["COMPRESS_MIMETYPES"] = [ app.config["COMPRESS_MIMETYPES"] = [
"application/json", "application/json",
] ]
# JWT
jwt = JWTManager(app)
@jwt.user_identity_loader
def user_identity_lookup(user):
return user
@jwt.user_lookup_loader
def user_lookup_callback(_jwt_header, jwt_data):
identity = jwt_data["sub"]
return identity
# Register all the API blueprints
with app.app_context(): with app.app_context():
app.register_api(album.api) app.register_api(album.api)
app.register_api(artist.api) app.register_api(artist.api)
@@ -89,8 +112,9 @@ def create_api():
# Home # Home
app.register_api(home.api) app.register_api(home.api)
# Flask Restful
app.register_api(getall.api) app.register_api(getall.api)
# Auth
app.register_api(auth.api)
return app return app
+67
View File
@@ -0,0 +1,67 @@
from dataclasses import asdict
from flask import jsonify
from flask_jwt_extended import create_access_token, current_user, set_access_cookies
from pydantic import BaseModel, Field
from flask_openapi3 import Tag
from flask_openapi3 import APIBlueprint
from app.db.sqlite.auth import SQLiteAuthMethods as authdb
from app.utils.auth import check_password
bp_tag = Tag(name="Auth", description="Authentication")
api = APIBlueprint("auth", __name__, url_prefix="/auth", abp_tags=[bp_tag])
class LoginBody(BaseModel):
username: str = Field(description="The username", example="user0")
password: str = Field(description="The password", example="password0")
@api.post("/login")
def login(body: LoginBody):
"""
Authenticate using username and password
"""
res = jsonify({"msg": f"Logged in as {body.username}"})
user = authdb.get_user_by_username(body.username)
if user is None:
return {"msg": "User not found"}, 404
password_ok = check_password(body.password, user.password)
if not password_ok:
return {"msg": "Invalid password"}, 401
access_token = create_access_token(identity=user.todict())
set_access_cookies(res, access_token)
return res
@api.get("/logout")
def logout():
"""
Log out
"""
res = jsonify({"msg": "Logged out"})
res.delete_cookie("access_token_cookie")
return res
@api.get("/users")
def get_all_users():
"""
Get all users
"""
users = authdb.get_all_users()
return [user.todict_simplified() for user in users]
@api.route("/user")
def get_logged_in_user():
"""
Get logged in user
"""
print("current_user", current_user)
return dict(current_user)
+1
View File
@@ -162,6 +162,7 @@ mapp = {
@api.get("") @api.get("")
def get_all_settings(): def get_all_settings():
""" """
Get all settings Get all settings
+93
View File
@@ -0,0 +1,93 @@
import json
from app.models.user import User
from app.utils.auth import encode_password
from app.db.sqlite.utils import SQLiteManager
class SQLiteAuthMethods:
"""
Methods for authenticating users.
"""
@staticmethod
def insert_default_user():
"""
Inserts the default admin user.
"""
user = {
"username": "admin",
"password": encode_password("admin"),
"roles": json.dumps(["admin"]),
}
user_tuple = tuple(user.values())
sql = """INSERT INTO users(
username,
password,
roles
) VALUES(:username, :password, :roles)
"""
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql, user_tuple)
cur.close()
@staticmethod
def insert_guest_user():
"""
Inserts the default guest user.
"""
user = {
"username": "guest",
"password": encode_password("guest"),
"firstname": "Guest",
"lastname": "User",
"roles": json.dumps(["guest"]),
}
user_tuple = tuple(user.values())
sql = """INSERT INTO users(
username,
password,
firstname,
lastname,
roles
) VALUES(:username, :password, :firstname, :lastname, :roles)
"""
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql, user_tuple)
cur.close()
@staticmethod
def get_all_users():
"""
Check if there are any users in the database.
"""
sql = "SELECT * FROM users"
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql)
data = cur.fetchall()
cur.close()
return [User(*user) for user in data]
@staticmethod
def get_user_by_username(username: str):
"""
Get a user by username.
"""
sql = "SELECT * FROM users WHERE username = ?"
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql, (username,))
data = cur.fetchone()
cur.close()
if data is not None:
return User(*data)
return None
+11
View File
@@ -54,6 +54,17 @@ CREATE TABLE IF NOT EXISTS track_logger (
timestamp integer NOT NULL, timestamp integer NOT NULL,
source text, source text,
userid integer NOT NULL DEFAULT 0 userid integer NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS users (
id integer PRIMARY KEY,
username text NOT NULL UNIQUE,
firstname text,
lastname text,
password text NOT NULL,
email text,
image text,
roles text NOT NULL DEFAULT '["user"]'
) )
""" """
+32
View File
@@ -0,0 +1,32 @@
from dataclasses import asdict, field, dataclass
import json
@dataclass(slots=True)
class User:
id: int
username: str
firstname: str
lastname: str
password: str
email: str
image: str
# NOTE: roles: ['admin', 'user', 'curator']
roles: list[str] = field(default_factory=lambda: ["user"])
def __post_init__(self):
self.roles = json.loads(self.roles)
def todict(self):
this_dict = asdict(self)
del this_dict["password"]
return this_dict
def todict_simplified(self):
return {
"id": self.id,
"username": self.username,
"firstname": self.firstname,
}
+1
View File
@@ -268,6 +268,7 @@ class Keys:
SWINGMUSIC_APP_VERSION = os.environ.get("SWINGMUSIC_APP_VERSION") SWINGMUSIC_APP_VERSION = os.environ.get("SWINGMUSIC_APP_VERSION")
GIT_LATEST_COMMIT_HASH = "<unset>" GIT_LATEST_COMMIT_HASH = "<unset>"
GIT_CURRENT_BRANCH = "<unset>" GIT_CURRENT_BRANCH = "<unset>"
JWT_SECRET_KEY = "swingmusic_secret_key" # REVIEW: This should be set in the environment
@classmethod @classmethod
def load(cls): def load(cls):
+4
View File
@@ -4,6 +4,7 @@ Applies migrations.
""" """
from app.db.sqlite import create_connection, create_tables, queries from app.db.sqlite import create_connection, create_tables, queries
from app.db.sqlite.auth import SQLiteAuthMethods as authdb
from app.migrations import apply_migrations from app.migrations import apply_migrations
from app.settings import Db from app.settings import Db
@@ -29,5 +30,8 @@ def setup_sqlite():
create_tables(user_db_conn, queries.CREATE_USERDATA_TABLES) create_tables(user_db_conn, queries.CREATE_USERDATA_TABLES)
create_tables(app_db_conn, queries.CREATE_MIGRATIONS_TABLE) create_tables(app_db_conn, queries.CREATE_MIGRATIONS_TABLE)
if not authdb.get_all_users():
authdb.insert_default_user()
app_db_conn.close() app_db_conn.close()
user_db_conn.close() user_db_conn.close()
+1 -1
View File
@@ -7,7 +7,7 @@ from app.utils.network import get_ip
def log_startup_info(): def log_startup_info():
lines = "------------------------------" lines = "------------------------------"
# clears terminal 👇 # clears terminal 👇
os.system("cls" if os.name == "nt" else "echo -e \\\\033c") # os.system("cls" if os.name == "nt" else "echo -e \\\\033c")
print(lines) print(lines)
print(f"{TCOLOR.HEADER}SwingMusic {Keys.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}") print(f"{TCOLOR.HEADER}SwingMusic {Keys.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}")
+25
View File
@@ -0,0 +1,25 @@
import hashlib
def encode_password(password: str) -> str:
"""
This function encodes the given password.
:param password: The password to encode.
:return: The encoded password.
"""
return hashlib.sha256(password.encode("utf-8")).hexdigest()
def check_password(password: str, encoded: str) -> bool:
"""
This function checks if the given password matches the encoded password.
:param password: The password to check.
:param encoded: The encoded password.
:return: Whether the password matches.
"""
return encode_password(password) == encoded
+19
View File
@@ -1,5 +1,8 @@
import os
import sys import sys
from app.utils.filesystem import get_home_res_path
def getFlaskOpenApiPath(): def getFlaskOpenApiPath():
""" """
@@ -10,3 +13,19 @@ def getFlaskOpenApiPath():
site_packages_path = [p for p in sys.path if "site-packages" in p][0] site_packages_path = [p for p in sys.path if "site-packages" in p][0]
return f"{site_packages_path}/flask_openapi3" return f"{site_packages_path}/flask_openapi3"
def getClientFilesExtensions():
"""
Get all the file extensions for the client files
"""
client_path = get_home_res_path("client")
extensions = set()
for root, dirs, files in os.walk(client_path):
for file in files:
ext = file.split(".")[-1]
extensions.add("." + ext)
return extensions
+28
View File
@@ -4,6 +4,7 @@ This file is used to run the application.
import os import os
import logging import logging
from flask_jwt_extended import verify_jwt_in_request
import psutil import psutil
import mimetypes import mimetypes
from flask import Response, request from flask import Response, request
@@ -20,6 +21,7 @@ from app.settings import FLASKVARS, TCOLOR, Keys
from app.setup import run_setup from app.setup import run_setup
from app.start_info_logger import log_startup_info from app.start_info_logger import log_startup_info
from app.utils.filesystem import get_home_res_path from app.utils.filesystem import get_home_res_path
from app.utils.paths import getClientFilesExtensions
from app.utils.threading import background from app.utils.threading import background
mimetypes.add_type("text/css", ".css") mimetypes.add_type("text/css", ".css")
@@ -41,6 +43,32 @@ werkzeug.setLevel(logging.ERROR)
app = create_api() app = create_api()
app.static_folder = get_home_res_path("client") app.static_folder = get_home_res_path("client")
# INFO: Routes that don't need authentication
blacklist_routes = {"/auth/login", "/auth/users"}
blacklist_extensions = {".webp"}.union(getClientFilesExtensions())
@app.before_request
def verify_auth():
"""
Verifies the JWT token before each request.
"""
print(request.path)
if request.path == "/" or any(
request.path.endswith(ext) for ext in blacklist_extensions
):
return
# if request path starts with any of the blacklisted routes, don't verify jwt
if any(request.path.startswith(route) for route in blacklist_routes):
print(
"Found blacklisted route: ", request.path, "... Skipping jwt verification"
)
return
data = verify_jwt_in_request()
print(data)
@app.route("/<path:path>") @app.route("/<path:path>")
def serve_client_files(path: str): def serve_client_files(path: str):
Generated
+37 -1
View File
@@ -557,6 +557,25 @@ files = [
Flask = ">=0.9" Flask = ">=0.9"
Six = "*" Six = "*"
[[package]]
name = "flask-jwt-extended"
version = "4.6.0"
description = "Extended JWT integration with Flask"
optional = false
python-versions = ">=3.7,<4"
files = [
{file = "Flask-JWT-Extended-4.6.0.tar.gz", hash = "sha256:9215d05a9413d3855764bcd67035e75819d23af2fafb6b55197eb5a3313fdfb2"},
{file = "Flask_JWT_Extended-4.6.0-py2.py3-none-any.whl", hash = "sha256:63a28fc9731bcc6c4b8815b6f954b5904caa534fc2ae9b93b1d3ef12930dca95"},
]
[package.dependencies]
Flask = ">=2.0,<4.0"
PyJWT = ">=2.0,<3.0"
Werkzeug = ">=0.14"
[package.extras]
asymmetric-crypto = ["cryptography (>=3.3.1)"]
[[package]] [[package]]
name = "flask-openapi3" name = "flask-openapi3"
version = "3.0.2" version = "3.0.2"
@@ -1599,6 +1618,23 @@ files = [
{file = "pyinstaller_hooks_contrib-2023.9-py2.py3-none-any.whl", hash = "sha256:f34f4c6807210025c8073ebe665f422a3aa2ac5f4c7ebf4c2a26cc77bebf63b5"}, {file = "pyinstaller_hooks_contrib-2023.9-py2.py3-none-any.whl", hash = "sha256:f34f4c6807210025c8073ebe665f422a3aa2ac5f4c7ebf4c2a26cc77bebf63b5"},
] ]
[[package]]
name = "pyjwt"
version = "2.8.0"
description = "JSON Web Token implementation in Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"},
{file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"},
]
[package.extras]
crypto = ["cryptography (>=3.4.0)"]
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
[[package]] [[package]]
name = "pylint" name = "pylint"
version = "2.17.7" version = "2.17.7"
@@ -2468,4 +2504,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.12" python-versions = ">=3.10,<3.12"
content-hash = "feb13f92b7b3a909fcb851860a405b96579feac0e2dde7681ed0e9c381c4f6cd" content-hash = "317c4094a6f768467219db94be02a6e2e9f1ed5ca18d929437094bf60be1594d"
+1
View File
@@ -25,6 +25,7 @@ waitress = "^2.1.2"
watchdog = "^4.0.0" watchdog = "^4.0.0"
pendulum = "^3.0.0" pendulum = "^3.0.0"
flask-openapi3 = "^3.0.2" flask-openapi3 = "^3.0.2"
flask-jwt-extended = "^4.6.0"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pylint = "^2.15.5" pylint = "^2.15.5"
+2
View File
@@ -15,6 +15,7 @@ Flask==2.3.3
Flask-BasicAuth==0.2.0 Flask-BasicAuth==0.2.0
Flask-Compress==1.14 Flask-Compress==1.14
Flask-Cors==3.0.10 Flask-Cors==3.0.10
Flask-JWT-Extended==4.6.0
flask-openapi3==3.0.2 flask-openapi3==3.0.2
gevent==23.9.1 gevent==23.9.1
geventhttpclient==2.0.11 geventhttpclient==2.0.11
@@ -42,6 +43,7 @@ pydantic==2.6.3
pydantic_core==2.16.3 pydantic_core==2.16.3
pyinstaller==5.13.2 pyinstaller==5.13.2
pyinstaller-hooks-contrib==2023.9 pyinstaller-hooks-contrib==2023.9
PyJWT==2.8.0
pylint==2.17.7 pylint==2.17.7
pytest==7.4.2 pytest==7.4.2
python-dateutil==2.8.2 python-dateutil==2.8.2