mirror of
https://github.com/Dvorinka/swingmusic-extended.git
synced 2026-06-03 20:13:02 +00:00
set up auth
This commit is contained in:
+28
-4
@@ -8,6 +8,7 @@ from flask_compress import Compress
|
||||
|
||||
from flask_openapi3 import Info
|
||||
from flask_openapi3 import OpenAPI
|
||||
from flask_jwt_extended import JWTManager
|
||||
|
||||
from app.settings import Keys
|
||||
from .plugins import lyrics as lyrics_plugin
|
||||
@@ -27,6 +28,7 @@ from app.api import (
|
||||
logger,
|
||||
home,
|
||||
getall,
|
||||
auth,
|
||||
)
|
||||
|
||||
# TODO: Move this description to a separate file
|
||||
@@ -60,13 +62,34 @@ def create_api():
|
||||
|
||||
app = OpenAPI(__name__, info=api_info, doc_prefix="/docs")
|
||||
|
||||
CORS(app, origins="*")
|
||||
Compress(app)
|
||||
# JWT CONFIGS
|
||||
app.config["JWT_SECRET_KEY"] = Keys.JWT_SECRET_KEY
|
||||
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
|
||||
app.config["JWT_COOKIE_CSRF_PROTECT"] = False
|
||||
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=1)
|
||||
|
||||
# CORS
|
||||
CORS(app, origins="*", supports_credentials=True)
|
||||
|
||||
# RESPONSE COMPRESSION
|
||||
Compress(app)
|
||||
app.config["COMPRESS_MIMETYPES"] = [
|
||||
"application/json",
|
||||
]
|
||||
|
||||
# JWT
|
||||
jwt = JWTManager(app)
|
||||
|
||||
@jwt.user_identity_loader
|
||||
def user_identity_lookup(user):
|
||||
return user
|
||||
|
||||
@jwt.user_lookup_loader
|
||||
def user_lookup_callback(_jwt_header, jwt_data):
|
||||
identity = jwt_data["sub"]
|
||||
return identity
|
||||
|
||||
# Register all the API blueprints
|
||||
with app.app_context():
|
||||
app.register_api(album.api)
|
||||
app.register_api(artist.api)
|
||||
@@ -89,8 +112,9 @@ def create_api():
|
||||
|
||||
# Home
|
||||
app.register_api(home.api)
|
||||
|
||||
# Flask Restful
|
||||
app.register_api(getall.api)
|
||||
|
||||
# Auth
|
||||
app.register_api(auth.api)
|
||||
|
||||
return app
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
from dataclasses import asdict
|
||||
from flask import jsonify
|
||||
from flask_jwt_extended import create_access_token, current_user, set_access_cookies
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_openapi3 import Tag
|
||||
from flask_openapi3 import APIBlueprint
|
||||
|
||||
from app.db.sqlite.auth import SQLiteAuthMethods as authdb
|
||||
from app.utils.auth import check_password
|
||||
|
||||
bp_tag = Tag(name="Auth", description="Authentication")
|
||||
api = APIBlueprint("auth", __name__, url_prefix="/auth", abp_tags=[bp_tag])
|
||||
|
||||
|
||||
class LoginBody(BaseModel):
|
||||
username: str = Field(description="The username", example="user0")
|
||||
password: str = Field(description="The password", example="password0")
|
||||
|
||||
|
||||
@api.post("/login")
|
||||
def login(body: LoginBody):
|
||||
"""
|
||||
Authenticate using username and password
|
||||
"""
|
||||
res = jsonify({"msg": f"Logged in as {body.username}"})
|
||||
|
||||
user = authdb.get_user_by_username(body.username)
|
||||
|
||||
if user is None:
|
||||
return {"msg": "User not found"}, 404
|
||||
|
||||
password_ok = check_password(body.password, user.password)
|
||||
|
||||
if not password_ok:
|
||||
return {"msg": "Invalid password"}, 401
|
||||
|
||||
access_token = create_access_token(identity=user.todict())
|
||||
set_access_cookies(res, access_token)
|
||||
return res
|
||||
|
||||
|
||||
@api.get("/logout")
|
||||
def logout():
|
||||
"""
|
||||
Log out
|
||||
"""
|
||||
res = jsonify({"msg": "Logged out"})
|
||||
res.delete_cookie("access_token_cookie")
|
||||
return res
|
||||
|
||||
|
||||
@api.get("/users")
|
||||
def get_all_users():
|
||||
"""
|
||||
Get all users
|
||||
"""
|
||||
users = authdb.get_all_users()
|
||||
return [user.todict_simplified() for user in users]
|
||||
|
||||
|
||||
@api.route("/user")
|
||||
def get_logged_in_user():
|
||||
"""
|
||||
Get logged in user
|
||||
"""
|
||||
print("current_user", current_user)
|
||||
return dict(current_user)
|
||||
@@ -162,6 +162,7 @@ mapp = {
|
||||
|
||||
|
||||
@api.get("")
|
||||
|
||||
def get_all_settings():
|
||||
"""
|
||||
Get all settings
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
from app.models.user import User
|
||||
from app.utils.auth import encode_password
|
||||
from app.db.sqlite.utils import SQLiteManager
|
||||
|
||||
|
||||
class SQLiteAuthMethods:
|
||||
"""
|
||||
Methods for authenticating users.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def insert_default_user():
|
||||
"""
|
||||
Inserts the default admin user.
|
||||
"""
|
||||
user = {
|
||||
"username": "admin",
|
||||
"password": encode_password("admin"),
|
||||
"roles": json.dumps(["admin"]),
|
||||
}
|
||||
user_tuple = tuple(user.values())
|
||||
|
||||
sql = """INSERT INTO users(
|
||||
username,
|
||||
password,
|
||||
roles
|
||||
) VALUES(:username, :password, :roles)
|
||||
"""
|
||||
|
||||
with SQLiteManager(userdata_db=True) as cur:
|
||||
cur.execute(sql, user_tuple)
|
||||
cur.close()
|
||||
|
||||
@staticmethod
|
||||
def insert_guest_user():
|
||||
"""
|
||||
Inserts the default guest user.
|
||||
"""
|
||||
user = {
|
||||
"username": "guest",
|
||||
"password": encode_password("guest"),
|
||||
"firstname": "Guest",
|
||||
"lastname": "User",
|
||||
"roles": json.dumps(["guest"]),
|
||||
}
|
||||
user_tuple = tuple(user.values())
|
||||
|
||||
sql = """INSERT INTO users(
|
||||
username,
|
||||
password,
|
||||
firstname,
|
||||
lastname,
|
||||
roles
|
||||
) VALUES(:username, :password, :firstname, :lastname, :roles)
|
||||
"""
|
||||
|
||||
with SQLiteManager(userdata_db=True) as cur:
|
||||
cur.execute(sql, user_tuple)
|
||||
cur.close()
|
||||
|
||||
@staticmethod
|
||||
def get_all_users():
|
||||
"""
|
||||
Check if there are any users in the database.
|
||||
"""
|
||||
sql = "SELECT * FROM users"
|
||||
|
||||
with SQLiteManager(userdata_db=True) as cur:
|
||||
cur.execute(sql)
|
||||
|
||||
data = cur.fetchall()
|
||||
cur.close()
|
||||
|
||||
return [User(*user) for user in data]
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_username(username: str):
|
||||
"""
|
||||
Get a user by username.
|
||||
"""
|
||||
sql = "SELECT * FROM users WHERE username = ?"
|
||||
|
||||
with SQLiteManager(userdata_db=True) as cur:
|
||||
cur.execute(sql, (username,))
|
||||
|
||||
data = cur.fetchone()
|
||||
cur.close()
|
||||
|
||||
if data is not None:
|
||||
return User(*data)
|
||||
|
||||
return None
|
||||
@@ -54,6 +54,17 @@ CREATE TABLE IF NOT EXISTS track_logger (
|
||||
timestamp integer NOT NULL,
|
||||
source text,
|
||||
userid integer NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id integer PRIMARY KEY,
|
||||
username text NOT NULL UNIQUE,
|
||||
firstname text,
|
||||
lastname text,
|
||||
password text NOT NULL,
|
||||
email text,
|
||||
image text,
|
||||
roles text NOT NULL DEFAULT '["user"]'
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from dataclasses import asdict, field, dataclass
|
||||
import json
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class User:
|
||||
id: int
|
||||
username: str
|
||||
firstname: str
|
||||
lastname: str
|
||||
password: str
|
||||
email: str
|
||||
image: str
|
||||
|
||||
# NOTE: roles: ['admin', 'user', 'curator']
|
||||
roles: list[str] = field(default_factory=lambda: ["user"])
|
||||
|
||||
def __post_init__(self):
|
||||
self.roles = json.loads(self.roles)
|
||||
|
||||
def todict(self):
|
||||
this_dict = asdict(self)
|
||||
del this_dict["password"]
|
||||
|
||||
return this_dict
|
||||
|
||||
def todict_simplified(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"username": self.username,
|
||||
"firstname": self.firstname,
|
||||
}
|
||||
@@ -268,6 +268,7 @@ class Keys:
|
||||
SWINGMUSIC_APP_VERSION = os.environ.get("SWINGMUSIC_APP_VERSION")
|
||||
GIT_LATEST_COMMIT_HASH = "<unset>"
|
||||
GIT_CURRENT_BRANCH = "<unset>"
|
||||
JWT_SECRET_KEY = "swingmusic_secret_key" # REVIEW: This should be set in the environment
|
||||
|
||||
@classmethod
|
||||
def load(cls):
|
||||
|
||||
@@ -4,6 +4,7 @@ Applies migrations.
|
||||
"""
|
||||
|
||||
from app.db.sqlite import create_connection, create_tables, queries
|
||||
from app.db.sqlite.auth import SQLiteAuthMethods as authdb
|
||||
from app.migrations import apply_migrations
|
||||
from app.settings import Db
|
||||
|
||||
@@ -29,5 +30,8 @@ def setup_sqlite():
|
||||
create_tables(user_db_conn, queries.CREATE_USERDATA_TABLES)
|
||||
create_tables(app_db_conn, queries.CREATE_MIGRATIONS_TABLE)
|
||||
|
||||
if not authdb.get_all_users():
|
||||
authdb.insert_default_user()
|
||||
|
||||
app_db_conn.close()
|
||||
user_db_conn.close()
|
||||
|
||||
@@ -7,7 +7,7 @@ from app.utils.network import get_ip
|
||||
def log_startup_info():
|
||||
lines = "------------------------------"
|
||||
# clears terminal 👇
|
||||
os.system("cls" if os.name == "nt" else "echo -e \\\\033c")
|
||||
# os.system("cls" if os.name == "nt" else "echo -e \\\\033c")
|
||||
|
||||
print(lines)
|
||||
print(f"{TCOLOR.HEADER}SwingMusic {Keys.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}")
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import hashlib
|
||||
|
||||
|
||||
def encode_password(password: str) -> str:
|
||||
"""
|
||||
This function encodes the given password.
|
||||
|
||||
:param password: The password to encode.
|
||||
|
||||
:return: The encoded password.
|
||||
"""
|
||||
|
||||
return hashlib.sha256(password.encode("utf-8")).hexdigest()
|
||||
|
||||
def check_password(password: str, encoded: str) -> bool:
|
||||
"""
|
||||
This function checks if the given password matches the encoded password.
|
||||
|
||||
:param password: The password to check.
|
||||
:param encoded: The encoded password.
|
||||
|
||||
:return: Whether the password matches.
|
||||
"""
|
||||
|
||||
return encode_password(password) == encoded
|
||||
@@ -1,5 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from app.utils.filesystem import get_home_res_path
|
||||
|
||||
|
||||
def getFlaskOpenApiPath():
|
||||
"""
|
||||
@@ -10,3 +13,19 @@ def getFlaskOpenApiPath():
|
||||
site_packages_path = [p for p in sys.path if "site-packages" in p][0]
|
||||
|
||||
return f"{site_packages_path}/flask_openapi3"
|
||||
|
||||
|
||||
def getClientFilesExtensions():
|
||||
"""
|
||||
Get all the file extensions for the client files
|
||||
"""
|
||||
|
||||
client_path = get_home_res_path("client")
|
||||
|
||||
extensions = set()
|
||||
for root, dirs, files in os.walk(client_path):
|
||||
for file in files:
|
||||
ext = file.split(".")[-1]
|
||||
extensions.add("." + ext)
|
||||
|
||||
return extensions
|
||||
|
||||
Reference in New Issue
Block a user