Add auth stuff

yeah! lets fffff goooo!
This commit is contained in:
Mungai Njoroge
2024-05-11 14:27:53 -04:00
committed by GitHub
41 changed files with 1223 additions and 238 deletions
+3 -4
View File
@@ -1,7 +1,6 @@
# What's New?
- Hovering on recent favorite item will show how long ago it was ♥ed
- Recently added playlist returns a max of 100 tracks, but without a cutoff period
<!-- TODO: ELABORATE -->
- Auth
# Development
- API documentation on /openapi
## Development
+1 -1
View File
@@ -27,7 +27,7 @@ client
logs.txt
*.spec
TODO.md
# TODO.md
testdata.py
test.py
nohup.out
+2
View File
@@ -0,0 +1,2 @@
- Fix migrations!
- Use total length instead of release version length
+35 -8
View File
@@ -8,9 +8,12 @@ from flask_compress import Compress
from flask_openapi3 import Info
from flask_openapi3 import OpenAPI
from flask_jwt_extended import JWTManager
from app.config import UserConfig
from app.settings import Keys
from app.settings import Info as AppInfo
from .plugins import lyrics as lyrics_plugin
from app.db.sqlite.auth import SQLiteAuthMethods as authdb
from app.api import (
album,
artist,
@@ -20,13 +23,14 @@ from app.api import (
imgserver,
playlist,
search,
send_file,
settings,
lyrics,
plugins,
logger,
home,
getall,
auth,
stream,
)
# TODO: Move this description to a separate file
@@ -54,23 +58,45 @@ def create_api():
"""
api_info = Info(
title=f"Swing Music",
version=f"v{Keys.SWINGMUSIC_APP_VERSION}",
version=f"v{AppInfo.SWINGMUSIC_APP_VERSION}",
description=open_api_description,
)
app = OpenAPI(__name__, info=api_info, doc_prefix="/docs")
# JWT CONFIGS
app.config["JWT_SECRET_KEY"] = UserConfig().userId
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
app.config["JWT_COOKIE_CSRF_PROTECT"] = False
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = datetime.timedelta(days=30)
CORS(app, origins="*")
# CORS
CORS(app, origins="*", supports_credentials=True)
# RESPONSE COMPRESSION
Compress(app)
app.config["COMPRESS_MIMETYPES"] = [
"application/json",
]
# JWT
jwt = JWTManager(app)
# @jwt.user_identity_loader
# def user_identity_lookup(user):
# return user
@jwt.user_lookup_loader
def user_lookup_callback(_jwt_header, jwt_data):
identity = jwt_data["sub"]
userid = identity["id"]
user = authdb.get_user_by_id(userid)
return user.todict()
# Register all the API blueprints
with app.app_context():
app.register_api(album.api)
app.register_api(artist.api)
app.register_api(send_file.api)
app.register_api(stream.api)
app.register_api(search.api)
app.register_api(folder.api)
app.register_api(playlist.api)
@@ -89,8 +115,9 @@ def create_api():
# Home
app.register_api(home.api)
# Flask Restful
app.register_api(getall.api)
# Auth
app.register_api(auth.api)
return app
+290
View File
@@ -0,0 +1,290 @@
import json
from dataclasses import asdict
from functools import wraps
import sqlite3
from flask import jsonify
from flask_jwt_extended import (
create_access_token,
current_user,
jwt_required,
set_access_cookies,
)
from pydantic import BaseModel, Field
from flask_openapi3 import Tag
from flask_openapi3 import APIBlueprint
from app.db.sqlite.auth import SQLiteAuthMethods as authdb
from app.utils.auth import check_password, hash_password
from app.config import UserConfig
bp_tag = Tag(name="Auth", description="Authentication stuff")
api = APIBlueprint("auth", __name__, url_prefix="/auth", abp_tags=[bp_tag])
def admin_required():
"""
Decorator to require admin role
"""
def wrapper(fn):
@wraps(fn)
def decorator(*args, **kwargs):
if "admin" not in current_user["roles"]:
return {"msg": "Only admins can do that!"}, 403
return fn(*args, **kwargs)
return decorator
return wrapper
class LoginBody(BaseModel):
username: str = Field(description="The username", example="user0")
password: str = Field(description="The password", example="password0")
@api.post("/login")
def login(body: LoginBody):
"""
Authenticate using username and password
"""
res = jsonify({"msg": f"Logged in as {body.username}"})
user = authdb.get_user_by_username(body.username)
if user is None:
return {"msg": "User not found"}, 404
password_ok = check_password(body.password, user.password)
if not password_ok:
return {"msg": "Hehe! invalid password"}, 401
access_token = create_access_token(identity=user.todict())
set_access_cookies(res, access_token)
return res
class UpdateProfileBody(BaseModel):
id: int = Field(0, description="The user id")
email: str = Field("", description="The email")
username: str = Field("", description="The username", example="user0")
password: str = Field("", description="The password", example="password0")
roles: list[str] = Field(None, description="The roles")
@api.put("/profile/update")
def update_profile(body: UpdateProfileBody):
user = {
"id": body.id,
"email": body.email,
"username": body.username,
"password": body.password,
"roles": body.roles,
}
# prevent updating guest
if current_user["username"] == "guest" or user["username"] == "guest":
return {"msg": "Cannot update guest user"}, 400
# if not id, update self
if not user["id"]:
user["id"] = current_user["id"]
if body.roles is not None:
# only admins can update roles
if "admin" not in current_user["roles"]:
return {"msg": "Only admins can update roles"}, 403
all_users = authdb.get_all_users()
if "admin" not in body.roles:
# check if we're removing the last admin
admins = [user for user in all_users if "admin" in user.roles]
if len(admins) == 1 and admins[0].id == user["id"]:
return {"msg": "Cannot remove the only admin"}, 400
# guest roles cannot be updated
_user = [u for u in all_users if u.id == user["id"]][0]
if "guest" in _user.roles:
return {"msg": "Cannot update guest user"}, 400
# finally, convert roles to json string
user["roles"] = json.dumps(body.roles)
if user["password"]:
user["password"] = hash_password(user["password"])
# remove empty values
clean_user = {k: v for k, v in user.items() if v}
try:
return authdb.update_user(clean_user)
except sqlite3.IntegrityError:
return {"msg": "Username already exists"}, 400
@api.post("/profile/create")
@admin_required()
def create_user(body: UpdateProfileBody):
if not body.username or not body.password:
return {"msg": "Username and password are required"}, 400
user = {
"username": body.username,
"password": hash_password(body.password),
"roles": json.dumps([]),
}
# check if user already exists
if authdb.get_user_by_username(user["username"]):
return {"msg": "Username already exists"}, 400
userid = authdb.insert_user(user)
return authdb.get_user_by_id(userid).todict()
@api.post("/profile/guest/create")
@admin_required()
def create_guest_user():
"""
Create a guest user
"""
# check if guest user already exists
guest_user = authdb.get_user_by_username("guest")
if guest_user:
return {
"msg": "Guest user already exists",
}, 400
userid = authdb.insert_guest_user()
if userid:
return {
"msg": "Guest user created",
}
return {
"msg": "Failed to create guest user",
}, 500
class DeleteUseBody(BaseModel):
username: str = Field("", description="The username")
@api.delete("/profile/delete")
@admin_required()
def delete_user(body: DeleteUseBody):
"""
Delete a user by username
"""
# prevent admin from deleting themselves
if body.username == current_user["username"]:
return {"msg": "Sorry! you cannot delete yourselfu"}, 400
# prevent deleting the only admin
users = authdb.get_all_users()
admins = [user for user in users if "admin" in user.roles]
if len(admins) == 1 and admins[0].username == body.username:
return {"msg": "Cannot delete the only admin"}, 400
authdb.delete_user_by_username(body.username)
return {"msg": f"User {body.username} deleted"}
@api.get("/logout")
def logout():
"""
Log out
"""
res = jsonify({"msg": "Logged out"})
res.delete_cookie("access_token_cookie")
return res
class GetAllUsersQuery(BaseModel):
simplified: bool = Field(
False, description="Whether to return simplified user data"
)
@api.get("/users")
@jwt_required(optional=True)
def get_all_users(query: GetAllUsersQuery):
"""
Get all users (if you're an admin, you will also receive accounts settings)
"""
config = UserConfig()
# config.enableGuest = True
# config.usersOnLogin = True
settings = {
"enableGuest": False,
"usersOnLogin": config.usersOnLogin,
}
res = {
"settings": {},
"users": [],
}
users = authdb.get_all_users()
is_admin = current_user and "admin" in current_user["roles"]
settings["enableGuest"] = [
user for user in users if user.username == "guest"
].__len__() > 0
# if user is admin, also return settings
if is_admin:
res = {
"settings": settings,
}
# if is normal user, return empty response
elif current_user:
return res
# if not logged in and showing users on login is disabled, return empty response
elif (
not current_user
and not settings["usersOnLogin"]
and not settings["enableGuest"]
):
return res
# remove guest user
# if not settings["enableGuest"]:
# users = [user for user in users if user.username != "guest"]
if not settings["usersOnLogin"]:
users = [user for user in users if user.username == "guest"]
# reverse list to show latest users first
users = list(reversed(users))
# bring admins to the front
users = sorted(users, key=lambda x: "admin" in x.roles, reverse=True)
# bring current user to index 0
if current_user:
users = sorted(
users,
key=lambda x: x.username == current_user["username"],
reverse=True,
)
if query.simplified:
res["users"] = [user.todict_simplified() for user in users]
res["users"] = [user.todict() for user in users]
return res
@api.route("/user")
def get_logged_in_user():
"""
Get logged in user
"""
return dict(current_user)
+5 -8
View File
@@ -22,7 +22,7 @@ bp_tag = Tag(name="Get all", description="List all items")
api = APIBlueprint("getall", __name__, url_prefix="/getall", abp_tags=[bp_tag])
class GetAllItemsBody(GenericLimitSchema):
class GetAllItemsQuery(GenericLimitSchema):
start: int = Field(
description="The start index of the items to return",
example=0,
@@ -34,10 +34,10 @@ class GetAllItemsBody(GenericLimitSchema):
default="created_date",
)
reverse: int = Field(
reverse: str = Field(
description="Reverse the sort",
example=1,
default=1,
default="1",
)
@@ -50,7 +50,7 @@ class GetAllItemsPath(BaseModel):
@api.get("/<itemtype>")
def get_all_items(path: GetAllItemsPath, query: GetAllItemsBody):
def get_all_items(path: GetAllItemsPath, query: GetAllItemsQuery):
"""
Get all items
@@ -67,10 +67,7 @@ def get_all_items(path: GetAllItemsPath, query: GetAllItemsBody):
start = query.start
limit = query.limit
sort = query.sortby
reverse = query.reverse == 1
# if sort == "":
# sort = "created_date"
reverse = query.reverse == "1"
sort_is_count = sort == "count"
sort_is_duration = sort == "duration"
+59 -39
View File
@@ -13,6 +13,9 @@ api = APIBlueprint("imgserver", __name__, url_prefix="/img", abp_tags=[bp_tag])
def send_fallback_img(filename: str = "default.webp"):
"""
Returns the fallback image from the assets folder.
"""
folder = Paths.get_assets_path()
img = Path(folder) / filename
@@ -22,6 +25,18 @@ def send_fallback_img(filename: str = "default.webp"):
return send_from_directory(folder, filename)
def send_file_or_fallback(folder: str, filename: str, fallback: str = "default.webp"):
"""
Returns the file from the folder or the fallback image.
"""
fpath = Path(folder) / filename
if fpath.exists():
return send_from_directory(folder, filename)
return send_fallback_img(fallback)
class ImagePath(BaseModel):
imgpath: str = Field(
description="The image filename",
@@ -43,62 +58,72 @@ class ImagePath(BaseModel):
# return send_fallback_img()
@api.get("/t/<imgpath>")
# TRACK THUMBNAILS
@api.get("/thumbnail/<imgpath>")
def send_lg_thumbnail(path: ImagePath):
"""
Get large thumbnail (500 x 500)
"""
folder = Paths.get_lg_thumb_path()
fpath = Path(folder) / path.imgpath
if fpath.exists():
return send_from_directory(folder, path.imgpath)
return send_fallback_img()
return send_file_or_fallback(folder, path.imgpath)
@api.get("/t/s/<imgpath>")
@api.get("/thumbnail/xsmall/<imgpath>")
def send_xsm_thumbnail(path: ImagePath):
"""
Get extra small thumbnail (64px)
"""
folder = Paths.get_xsm_thumb_path()
return send_file_or_fallback(folder, path.imgpath)
@api.get("/thumbnail/small/<imgpath>")
def send_sm_thumbnail(path: ImagePath):
"""
Get small thumbnail (64 x 64)
Get small thumbnail (96px)
"""
folder = Paths.get_sm_thumb_path()
fpath = Path(folder) / path.imgpath
if fpath.exists():
return send_from_directory(folder, path.imgpath)
return send_fallback_img()
return send_file_or_fallback(folder, path.imgpath)
@api.get("/a/<imgpath>")
@api.get("/thumbnail/medium/<imgpath>")
def send_md_thumbnail(path: ImagePath):
"""
Get medium thumbnail (256px)
"""
folder = Paths.get_md_thumb_path()
return send_file_or_fallback(folder, path.imgpath)
# ARTISTS
@api.get("/artist/<imgpath>")
def send_lg_artist_image(path: ImagePath):
"""
Get large artist image (500 x 500)
"""
folder = Paths.get_artist_img_lg_path()
fpath = Path(folder) / path.imgpath
if fpath.exists():
return send_from_directory(folder, path.imgpath)
return send_fallback_img("artist.webp")
folder = Paths.get_lg_artist_img_path()
return send_file_or_fallback(folder, path.imgpath, "artist.webp")
@api.get("/a/s/<imgpath>")
@api.get("/artist/small/<imgpath>")
def send_sm_artist_image(path: ImagePath):
"""
Get small artist image (64 x 64)
Get small artist image (128)
"""
folder = Paths.get_artist_img_sm_path()
fpath = Path(folder) / path.imgpath
if fpath.exists():
return send_from_directory(folder, path.imgpath)
return send_fallback_img("artist.webp")
folder = Paths.get_sm_artist_img_path()
return send_file_or_fallback(folder, path.imgpath, "artist.webp")
@api.get("/artist/medium/<imgpath>")
def send_md_artist_image(path: ImagePath):
"""
Get medium artist image (256px)
"""
folder = Paths.get_md_artist_img_path()
return send_file_or_fallback(folder, path.imgpath, "artist.webp")
# PLAYLISTS
class PlaylistImagePath(BaseModel):
imgpath: str = Field(
description="The image path",
@@ -106,7 +131,7 @@ class PlaylistImagePath(BaseModel):
)
@api.get("/p/<imgpath>")
@api.get("/playlist/<imgpath>")
def send_playlist_image(path: PlaylistImagePath):
"""
Get playlist image
@@ -114,9 +139,4 @@ def send_playlist_image(path: PlaylistImagePath):
Images are constructed as '{playlist_id}.webp'
"""
folder = Paths.get_playlist_img_path()
fpath = Path(folder) / path.imgpath
if fpath.exists():
return send_from_directory(folder, path.imgpath)
return send_fallback_img("playlist.svg")
return send_file_or_fallback(folder, path.imgpath, "playlist.svg")
+2 -1
View File
@@ -26,6 +26,7 @@ api = APIBlueprint("playlists", __name__, url_prefix="/playlists", abp_tags=[tag
PL = SQLitePlaylistMethods
class SendAllPlaylistsQuery(BaseModel):
no_images: bool = Field(False, description="Whether to include images")
@@ -410,7 +411,7 @@ def save_item_as_playlist(body: SavePlaylistAsItemBody):
filename = itemhash + ".webp"
base_path = (
Paths.get_artist_img_lg_path()
Paths.get_lg_artist_img_path()
if itemtype == "artist"
else Paths.get_lg_thumb_path()
)
+3
View File
@@ -3,6 +3,7 @@ from flask import Blueprint, request
from flask_openapi3 import Tag
from flask_openapi3 import APIBlueprint
from pydantic import BaseModel, Field
from app.api.auth import admin_required
from app.db.sqlite.plugins import PluginsMethods
bp_tag = Tag(name="Plugins", description="Manage plugins")
@@ -30,6 +31,7 @@ class PluginActivateBody(PluginBody):
@api.post("/setactive")
@admin_required()
def activate_deactivate_plugin(body: PluginActivateBody):
"""
Activate/Deactivate plugin
@@ -49,6 +51,7 @@ class PluginSettingsBody(PluginBody):
@api.post("/settings")
@admin_required()
def update_plugin_settings(body: PluginSettingsBody):
"""
Update plugin settings
+40 -10
View File
@@ -3,18 +3,21 @@ from flask import request
from flask_openapi3 import Tag
from flask_openapi3 import APIBlueprint
from pydantic import BaseModel, Field
from app.api.auth import admin_required
from app.db.sqlite.plugins import PluginsMethods as pdb
from app.db.sqlite.settings import SettingsSQLMethods as sdb
from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb
from app.lib import populate
from app.lib.watchdogg import Watcher as WatchDog
from app.logger import log
from app.settings import Keys, Paths, SessionVarKeys, set_flag
from app.settings import Info, Paths, SessionVarKeys, set_flag
from app.store.albums import AlbumStore
from app.store.artists import ArtistStore
from app.store.tracks import TrackStore
from app.utils.generators import get_random_str
from app.utils.threading import background
from app.config import UserConfig
bp_tag = Tag(name="Settings", description="Customize stuff")
api = APIBlueprint("settings", __name__, url_prefix="/notsettings", abp_tags=[bp_tag])
@@ -49,12 +52,12 @@ def reload_everything(instance_key: str):
@background
def rebuild_store(db_dirs: list[str]):
"""
Restarts the watchdog and rebuilds the music library.
Restarts watchdog and rebuilds the music library.
"""
instance_key = get_random_str()
log.info("Rebuilding library...")
TrackStore.remove_tracks_by_dir_except(db_dirs)
trackdb.remove_tracks_not_in_folders(db_dirs)
reload_everything(instance_key)
try:
@@ -95,6 +98,7 @@ class AddRootDirsBody(BaseModel):
@api.post("/add-root-dirs")
@admin_required()
def add_root_dirs(body: AddRootDirsBody):
"""
Add custom root directories to the database.
@@ -103,10 +107,10 @@ def add_root_dirs(body: AddRootDirsBody):
removed_dirs = body.removed
db_dirs = sdb.get_root_dirs()
_h = "$home"
home = "$home"
db_home = any([d == _h for d in db_dirs]) # if $home is in db
incoming_home = any([d == _h for d in new_dirs]) # if $home is in incoming
db_home = any([d == home for d in db_dirs]) # if $home is in db
incoming_home = any([d == home for d in new_dirs]) # if $home is in incoming
# handle $home case
if db_home and incoming_home:
@@ -116,8 +120,8 @@ def add_root_dirs(body: AddRootDirsBody):
sdb.remove_root_dirs(db_dirs)
if incoming_home:
finalize([_h], [], [Paths.USER_HOME_DIR])
return {"root_dirs": [_h]}
finalize([home], [], [Paths.USER_HOME_DIR])
return {"root_dirs": [home]}
# ---
@@ -132,7 +136,7 @@ def add_root_dirs(body: AddRootDirsBody):
pass
db_dirs.extend(new_dirs)
db_dirs = [dir_ for dir_ in db_dirs if dir_ != _h]
db_dirs = [dir_ for dir_ in db_dirs if dir_ != home]
finalize(new_dirs, removed_dirs, db_dirs)
@@ -190,7 +194,7 @@ def get_all_settings():
root_dirs = sdb.get_root_dirs()
s["root_dirs"] = root_dirs
s["plugins"] = plugins
s["version"] = Keys.SWINGMUSIC_APP_VERSION
s["version"] = Info.SWINGMUSIC_APP_VERSION
return {
"settings": s,
@@ -214,6 +218,7 @@ class SetSettingBody(BaseModel):
@api.post("/set")
@admin_required()
def set_setting(body: SetSettingBody):
"""
Set a setting.
@@ -264,3 +269,28 @@ def trigger_scan():
run_populate()
return {"msg": "Scan triggered!"}
class UpdateConfigBody(BaseModel):
key: str = Field(
description="The setting key",
example="usersOnLogin",
)
value: Any = Field(
description="The setting value",
example=False,
)
@api.put("/update")
@admin_required()
def update_config(body: UpdateConfigBody):
"""
Update the config file
"""
config = UserConfig()
setattr(config, body.key, body.value)
return {
"msg": "Config updated!",
}
+31 -15
View File
@@ -3,14 +3,17 @@ Contains all the track routes.
"""
import os
import time
from flask import Blueprint, send_file, request, Response
from flask_openapi3 import APIBlueprint, Tag
from pydantic import BaseModel, Field
from app.api.apischemas import TrackHashSchema
from app.lib.pydub.pydub.audio_segment import AudioSegment
from app.lib.trackslib import get_silence_paddings
from app.store.tracks import TrackStore
from app.utils.files import guess_mime_type
bp_tag = Tag(name="File", description="Audio files")
api = APIBlueprint("track", __name__, url_prefix="/file", abp_tags=[bp_tag])
@@ -33,10 +36,6 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery):
filepath = query.filepath
msg = {"msg": "File Not Found"}
def get_mime(filename: str) -> str:
ext = filename.rsplit(".", maxsplit=1)[-1]
return f"audio/{ext}"
# If filepath is provided, try to send that
if filepath is not None:
try:
@@ -47,7 +46,7 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery):
track_exists = track is not None and os.path.exists(track.filepath)
if track_exists:
audio_type = get_mime(filepath)
audio_type = guess_mime_type(filepath)
return send_file_as_chunks(track.filepath, audio_type)
# Else, find file by trackhash
@@ -57,7 +56,7 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery):
if track is None:
return msg, 404
audio_type = get_mime(track.filepath)
audio_type = guess_mime_type(track.filepath)
try:
return send_file_as_chunks(track.filepath, audio_type)
@@ -68,15 +67,31 @@ def send_track_file(path: TrackHashSchema, query: SendTrackFileQuery):
def send_file_as_chunks(filepath: str, audio_type: str) -> Response:
"""
Returns a Response object that streams the file in chunks.
"""
# NOTE: +1 makes sure the last byte is included in the range.
# NOTE: -1 is used to convert the end index to a 0-based index.
chunk_size = 1024 * 360 # 360 KB
# Get file size
file_size = os.path.getsize(filepath)
start = 0
end = file_size - 1
end = chunk_size
# Read range header
range_header = request.headers.get("Range")
if range_header:
start, end = parse_range_header(range_header, file_size)
start = get_start_range(range_header)
chunk_size = 1024 * 1024 # 1MB chunk size (adjust as needed)
# If start + chunk_size is greater than file_size,
# set end to file_size - 1
_end = start + chunk_size - 1
if _end > file_size:
end = file_size - 1
else:
end = _end
def generate_chunks():
with open(filepath, "rb") as file:
@@ -84,8 +99,11 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response:
remaining_bytes = end - start + 1
while remaining_bytes > 0:
# Read the chunk size or all the remaining bytes
chunk = file.read(min(chunk_size, remaining_bytes))
yield chunk
# Update the remaining bytes
remaining_bytes -= len(chunk)
response = Response(
@@ -102,15 +120,13 @@ def send_file_as_chunks(filepath: str, audio_type: str) -> Response:
return response
def parse_range_header(range_header: str, file_size: int) -> tuple[int, int]:
def get_start_range(range_header: str):
try:
range_start, range_end = range_header.strip().split("=")[1].split("-")
start = int(range_start)
end = min(int(range_end), file_size - 1)
except ValueError:
return 0, file_size - 1
return int(range_start)
return start, end
except ValueError:
return 0
class GetAudioSilenceBody(BaseModel):
+62 -15
View File
@@ -2,6 +2,7 @@
Handles arguments passed to the program.
"""
from getpass import getpass
import os.path
import sys
@@ -10,27 +11,37 @@ import PyInstaller.__main__ as bundler
from app import settings
from app.logger import log
from app.print_help import HELP_MESSAGE
from app.utils.auth import hash_password
from app.utils.paths import getFlaskOpenApiPath
from app.utils.xdg_utils import get_xdg_config_dir
from app.utils.wintools import is_windows
from app.db.sqlite.auth import SQLiteAuthMethods as authdb
ALLARGS = settings.ALLARGS
ARGS = sys.argv[1:]
class HandleArgs:
class ProcessArgs:
"""
Processes the arguments passed to the program.
"""
def __init__(self) -> None:
# resolve config path
self.handle_config_path() # 1
# handles that exit
self.handle_password_recovery()
self.handle_build()
self.handle_host()
self.handle_port()
self.handle_config_path()
self.handle_periodic_scan()
self.handle_periodic_scan_interval()
self.handle_help()
self.handle_version()
# non-exiting handles
self.handle_host()
self.handle_port()
self.handle_periodic_scan()
self.handle_periodic_scan_interval()
@staticmethod
def handle_build():
"""
@@ -45,7 +56,7 @@ class HandleArgs:
print("https://www.youtube.com/watch?v=wZv62ShoStY")
sys.exit(0)
config_keys = [
info_keys = [
"SWINGMUSIC_APP_VERSION",
"GIT_LATEST_COMMIT_HASH",
"GIT_CURRENT_BRANCH",
@@ -53,16 +64,17 @@ class HandleArgs:
lines = []
for key in config_keys:
value = settings.Keys.get(key)
for key in info_keys:
value = settings.Info.get(key)
if not value:
log.error(f"WARNING: {key} not set in environment")
log.error(f"WARNING: {key} not resolved. Exiting ...")
sys.exit(0)
lines.append(f'{key} = "{value}"\n')
try:
# write the info to the config file
with open("./app/configs.py", "w", encoding="utf-8") as file:
# copy the api keys to the config file
file.writelines(lines)
@@ -88,7 +100,7 @@ class HandleArgs:
finally:
# revert and remove the api keys for dev mode
with open("./app/configs.py", "w", encoding="utf-8") as file:
lines = [f'{key} = ""\n' for key in config_keys]
lines = [f'{key} = ""\n' for key in info_keys]
file.writelines(lines)
sys.exit(0)
@@ -184,8 +196,43 @@ class HandleArgs:
@staticmethod
def handle_version():
if any((a in ARGS for a in ALLARGS.version)):
print(f"VERSION: v{settings.Keys.SWINGMUSIC_APP_VERSION}")
print(f"VERSION: v{settings.Info.SWINGMUSIC_APP_VERSION}")
print(
f"COMMIT#: {settings.Keys.GIT_CURRENT_BRANCH}/{settings.Keys.GIT_LATEST_COMMIT_HASH}"
f"COMMIT#: {settings.Info.GIT_CURRENT_BRANCH}/{settings.Info.GIT_LATEST_COMMIT_HASH}"
)
sys.exit(0)
@staticmethod
def handle_password_recovery():
if ALLARGS.pswd in ARGS:
print("SWING MUSIC v2.0.0 ")
print("PASSWORD RECOVERY \n")
username: str = ""
password: str = ""
# collect username
try:
username = input("Enter username: ")
except KeyboardInterrupt:
print("\nOperation cancelled! Exiting ...")
sys.exit(0)
username = username.strip()
user = authdb.get_user_by_username(username)
if not user:
print(f"User {username} not found")
sys.exit(0)
# collect password
try:
password = getpass("Enter new password: ")
except KeyboardInterrupt:
print("\nOperation cancelled! Exiting ...")
sys.exit(0)
password = hash_password(password)
user = authdb.update_user({"id": user.id, "password": password})
sys.exit(0)
+95
View File
@@ -0,0 +1,95 @@
from dataclasses import dataclass, asdict, field
import json
import os
from typing import Any
from .settings import Paths
# TODO: Publish this on PyPi
@dataclass
class UserConfig:
_config_path: str = ""
# NOTE: only auth stuff are used (the others are still reading/writing to db)
# TODO: Move the rest of the settings to the config file
# auth stuff
# NOTE: Don't expose the userId via the API
userId: str = ""
usersOnLogin: bool = True
# lists
rootDirs: list[str] = field(default_factory=list)
excludeDirs: list[str] = field(default_factory=list)
artistSeparators: set[str] = field(default_factory=list)
# tracks
extractFeaturedArtists: bool = True
removeProdBy: bool = True
removeRemasterInfo: bool = True
# albums
mergeAlbums: bool = False
cleanAlbumTitle: bool = True
showAlbumsAsSingles: bool = False
def __post_init__(self):
"""
Loads the config file and sets the values to this instance
"""
# set config path locally to avoid writing to file
config_path = Paths.get_config_file_path()
try:
config = self.load_config(config_path)
except FileNotFoundError:
self._config_path = config_path
return
# loop through the config file and set the values
for key, value in config.items():
setattr(self, key, value)
# finally set the config path
self._config_path = config_path
def setup_config_file(self) -> None:
"""
Creates the config file with the default settings
if it doesn't exist
"""
# if not exists, create the config file
if not os.path.exists(self._config_path):
self.write_to_file(asdict(self))
def load_config(self, path: str) -> dict[str, Any]:
"""
Reads the settings from the config file.
Returns a dictget_root_dirs
"""
with open(path, "r") as f:
settings = json.load(f)
return settings
def write_to_file(self, settings: dict[str, Any]):
"""
Writes the settings to the config file
"""
# remove internal attributes
settings = {k: v for k, v in settings.items() if not k.startswith("_")}
with open(self._config_path, "w") as f:
json.dump(settings, f, indent=4)
def __setattr__(self, key: str, value: Any) -> None:
"""
Writes to the config file whenever a value is set
"""
super().__setattr__(key, value)
# if is internal attribute, don't write to file
if key.startswith("_") or not self._config_path:
return
print(f"writing to file: {key}={value}")
self.write_to_file(asdict(self))
+146
View File
@@ -0,0 +1,146 @@
import json
from app.models.user import User
from app.utils.auth import hash_password
from app.db.sqlite.utils import SQLiteManager
class SQLiteAuthMethods:
"""
Methods for authenticating users.
"""
@staticmethod
def insert_user(user: dict[str, str]):
"""
Insert a user into the database.
:param user: A dict with the username, password and roles.
"""
sql = """INSERT INTO users(
username,
password,
roles
) VALUES(:username, :password, :roles)
"""
user_tuple = tuple(user.values())
with SQLiteManager(userdata_db=True) as cur:
cur = cur.execute(sql, user_tuple)
userid = cur.lastrowid
return userid
# if userid:
# # sleep
# user = SQLiteAuthMethods.get_user_by_id(userid).todict_simplified()
# cur.close()
# return user
raise Exception(f"Failed to insert user: {user}")
@staticmethod
def insert_default_user():
"""
Inserts the default admin user.
"""
user = {
"username": "admin",
"password": hash_password("admin"),
"roles": json.dumps(["admin"]),
}
return SQLiteAuthMethods.insert_user(user)
@staticmethod
def insert_guest_user():
"""
Inserts the default guest user.
"""
user = {
"username": "guest",
"password": hash_password("guest"),
"roles": json.dumps(["guest"]),
}
return SQLiteAuthMethods.insert_user(user)
@staticmethod
def update_user(user: dict[str, str]):
"""
Update a user in the database.
:param user: A dict with the user id and the fields to update. Ommited fields will not be updated.
"""
# get all user dict keys
keys = list(user.keys())
sql = f"""UPDATE users SET
{', '.join([f"{key} = :{key}" for key in keys if key != 'id'])}
WHERE id = :id
"""
print(sql, user)
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql, user)
cur.close()
return SQLiteAuthMethods.get_user_by_id(user["id"]).todict()
@staticmethod
def get_all_users():
"""
Check if there are any users in the database.
"""
sql = "SELECT * FROM users"
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql)
data = cur.fetchall()
cur.close()
return [User(*user) for user in data]
@staticmethod
def get_user_by_username(username: str):
"""
Get a user by username.
"""
sql = "SELECT * FROM users WHERE username = ?"
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql, (username,))
data = cur.fetchone()
cur.close()
if data is not None:
return User(*data)
return None
@staticmethod
def get_user_by_id(userid: int):
"""
Get a user by id.
"""
sql = "SELECT * FROM users WHERE id = ?"
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql, (userid,))
data = cur.fetchone()
cur.close()
if data is not None:
return User(*data)
return None
@staticmethod
def delete_user_by_username(username: str):
"""
Delete a user by username.
"""
sql = "DELETE FROM users WHERE username = ?"
with SQLiteManager(userdata_db=True) as cur:
cur.execute(sql, (username,))
cur.close()
+4 -4
View File
@@ -7,9 +7,9 @@ from app.db.sqlite.utils import SQLiteManager
class MigrationManager:
@staticmethod
def get_version() -> int:
def get_index() -> int:
"""
Returns the latest userdata database version.
Returns the latest databases migrations index.
"""
sql = "SELECT * FROM dbmigrations"
with SQLiteManager() as cur:
@@ -21,9 +21,9 @@ class MigrationManager:
# 👇 Setters 👇
@staticmethod
def set_version(version: int):
def set_index(version: int):
"""
Sets the userdata pre-init database version.
Updates the databases migrations index.
"""
sql = "UPDATE dbmigrations SET version = ? WHERE id = 1"
with SQLiteManager() as cur:
+11
View File
@@ -54,6 +54,17 @@ CREATE TABLE IF NOT EXISTS track_logger (
timestamp integer NOT NULL,
source text,
userid integer NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS users (
id integer PRIMARY KEY,
username text NOT NULL UNIQUE,
firstname text,
lastname text,
password text NOT NULL,
email text,
image text,
roles text NOT NULL DEFAULT '["user"]'
)
"""
+5 -4
View File
@@ -112,9 +112,10 @@ class SQLiteTrackMethods:
cur.execute("DELETE FROM tracks WHERE filepath=?", (filepath,))
@staticmethod
def remove_tracks_by_folders(folders: set[str]):
sql = "DELETE FROM tracks WHERE folder = ?"
def remove_tracks_not_in_folders(folders: set[str]):
sql = "DELETE FROM tracks WHERE folder NOT IN ({})".format(
",".join("?" * len(folders))
)
with SQLiteManager() as cur:
for folder in folders:
cur.execute(sql, (folder,))
cur.execute(sql, tuple(folders))
+27 -11
View File
@@ -55,13 +55,22 @@ def get_artist_image_link(artist: str):
# TODO: Move network calls to utils/network.py
class DownloadImage:
def __init__(self, url: str, name: str) -> None:
sm_path = Path(settings.Paths.get_artist_img_sm_path()) / name
lg_path = Path(settings.Paths.get_artist_img_lg_path()) / name
img = self.download(url)
if img is not None:
self.save_img(img, sm_path, lg_path)
if img is None:
return
sm_path = Path(settings.Paths.get_sm_artist_img_path()) / name
lg_path = Path(settings.Paths.get_lg_artist_img_path()) / name
md_path = Path(settings.Paths.get_md_artist_img_path()) / name
entries = [
(lg_path, None), # save in the original size
(sm_path, settings.Defaults.SM_ARTIST_IMG_SIZE),
(md_path, settings.Defaults.MD_ARTIST_IMG_SIZE),
]
self.save_img(img, entries)
@staticmethod
def download(url: str) -> Image.Image | None:
@@ -74,14 +83,21 @@ class DownloadImage:
return None
@staticmethod
def save_img(img: Image.Image, sm_path: Path, lg_path: Path):
def save_img(img: Image.Image, entries: list[tuple[Path, int | None]]):
"""
Saves the image to the destinations.
"""
img.save(lg_path, format="webp")
ratio = img.width / img.height
for entry in entries:
path, size = entry
sm_size = settings.Defaults.SM_ARTIST_IMG_SIZE
img.resize((sm_size, sm_size), Image.ANTIALIAS).save(sm_path, format="webp")
if size is None:
img.save(path, format="webp")
continue
img.resize((size, int(size / ratio)), Image.ANTIALIAS).save(
path, format="webp"
)
class CheckArtistImages:
@@ -90,7 +106,7 @@ class CheckArtistImages:
CHECK_ARTIST_IMAGES_KEY = instance_key
# read all files in the artist image folder
path = settings.Paths.get_artist_img_sm_path()
path = settings.Paths.get_sm_artist_img_path()
processed = "".join(os.listdir(path)).replace("webp", "")
# filter out artists that already have an image
@@ -126,7 +142,7 @@ class CheckArtistImages:
return
img_path = (
Path(settings.Paths.get_artist_img_sm_path()) / f"{artist.artisthash}.webp"
Path(settings.Paths.get_sm_artist_img_path()) / f"{artist.artisthash}.webp"
)
if img_path.exists():
+1 -1
View File
@@ -42,7 +42,7 @@ def process_color(item_hash: str, is_album=True):
path = (
settings.Paths.get_sm_thumb_path()
if is_album
else settings.Paths.get_artist_img_sm_path()
else settings.Paths.get_sm_artist_img_path()
)
path = Path(path) / (item_hash + ".webp")
+5 -1
View File
@@ -82,9 +82,13 @@ def get_recently_played(limit=7):
if entry.type == "folder":
folder = entry.type_src
if not folder:
continue
if not folder.endswith("/"):
folder += "/"
is_home_dir = entry.type_src == "$home"
if is_home_dir:
@@ -98,7 +102,7 @@ def get_recently_played(limit=7):
{
"type": "folder",
"item": {
"path": entry.type_src,
"path": folder,
"count": count,
"help_text": "folder",
"time": timestamp_to_time_passed(entry.timestamp),
+10 -8
View File
@@ -33,20 +33,22 @@ def extract_thumb(filepath: str, webp_path: str, overwrite=False) -> bool:
"""
lg_img_path = os.path.join(Paths.get_lg_thumb_path(), webp_path)
sm_img_path = os.path.join(Paths.get_sm_thumb_path(), webp_path)
xms_img_path = os.path.join(Paths.get_xsm_thumb_path(), webp_path)
md_img_path = os.path.join(Paths.get_md_thumb_path(), webp_path)
tsize = Defaults.THUMB_SIZE
sm_tsize = Defaults.SM_THUMB_SIZE
images = [
(lg_img_path, Defaults.LG_THUMB_SIZE),
(sm_img_path, Defaults.SM_THUMB_SIZE),
(xms_img_path, Defaults.XSM_THUMB_SIZE),
(md_img_path, Defaults.MD_THUMB_SIZE),
]
def save_image(img: Image.Image):
width, height = img.size
ratio = width / height
img.resize((tsize, int(tsize / ratio)), Image.ANTIALIAS).save(
lg_img_path, "webp"
)
img.resize((sm_tsize, int(sm_tsize / ratio)), Image.ANTIALIAS).save(
sm_img_path, "webp"
)
for path, size in images:
img.resize((size, int(size / ratio)), Image.ANTIALIAS).save(path, "webp")
if not overwrite and os.path.exists(sm_img_path):
img_size = os.path.getsize(sm_img_path)
+3 -2
View File
@@ -1,12 +1,13 @@
"""
This library contains all the functions related to tracks.
"""
import os
from app.lib.pydub.pydub import AudioSegment
from app.lib.pydub.pydub.silence import detect_leading_silence, detect_silence
from app.db.sqlite.tracks import SQLiteTrackMethods as tdb
from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb
from app.store.tracks import TrackStore
from app.utils.progressbar import tqdm
from app.utils.threading import ThreadWithReturnValue
@@ -19,7 +20,7 @@ def validate_tracks() -> None:
for track in tqdm(TrackStore.tracks, desc="Validating tracks"):
if not os.path.exists(track.filepath):
TrackStore.remove_track_obj(track)
tdb.remove_tracks_by_filepaths(track.filepath)
trackdb.remove_tracks_by_filepaths(track.filepath)
def get_leading_silence_end(filepath: str):
+25 -20
View File
@@ -2,12 +2,6 @@
Migrations module.
Reads and applies the latest database migrations.
PLEASE NOTE: OLDER MIGRATIONS CAN NEVER BE DELETED.
ONLY MODIFY OLD MIGRATIONS FOR BUG FIXES OR ENHANCEMENTS ONLY
[TRY NOT TO MODIFY BEHAVIOR, UNLESS YOU KNOW WHAT YOU'RE DOING].
PS: Fuck that! Do what you want.
"""
from app.db.sqlite.migrations import MigrationManager
@@ -33,21 +27,32 @@ migrations: list[list[Migration]] = [
def apply_migrations():
"""
Applies the latest database migrations.
The length of all the migrations is stored in the database
and used to check for new migrations. When the length of the
migrations list is larger than the number stored in the db,
migrations past that index are applied and the new length
is stored as the new migration index.
"""
version = MigrationManager.get_version()
index = MigrationManager.get_index()
all_migrations = [migration for sublist in migrations for migration in sublist]
if version != len(migrations):
# INFO: Apply new migrations
for migration in migrations[version:]:
for m in migration:
try:
m.migrate()
log.info("Applied migration: %s", m.__name__)
except:
log.error("Failed to run migration: %s", m.__name__)
to_apply: list[Migration] = []
print("Migrations applied successfully.")
print("Current migration version: ", len(migrations))
# bump migration version
MigrationManager.set_version(len(migrations))
# if index is from old release,
# get migrations from the "migrations" list
if index < 3:
_migrations = migrations[index:]
to_apply = [migration for sublist in _migrations for migration in sublist]
else:
to_apply = all_migrations[index:]
for migration in to_apply:
try:
migration.migrate()
log.info("Applied migration: %s", migration.__name__)
except:
log.error("Failed to run migration: %s", migration.__name__)
MigrationManager.set_index(len(all_migrations))
+16
View File
@@ -58,3 +58,19 @@ class DeleteOriginalThumbnails(Migration):
if os.path.exists(og_imgpath):
shutil.rmtree(og_imgpath)
class DeleteOriginalThumbnailsa(Migration):
"""
Original thumbnails are too large and are not needed.
"""
# TODO: Implement this migration
@staticmethod
def migrate():
imgpath = Paths.get_thumbs_path()
og_imgpath = os.path.join(imgpath, "original")
if os.path.exists(og_imgpath):
shutil.rmtree(og_imgpath)
+32
View File
@@ -0,0 +1,32 @@
from dataclasses import asdict, field, dataclass
import json
@dataclass(slots=True)
class User:
id: int
username: str
firstname: str
lastname: str
password: str
email: str
image: str
# NOTE: roles: ['admin', 'user', 'curator']
roles: list[str] = field(default_factory=lambda: ["user"])
def __post_init__(self):
self.roles = json.loads(self.roles)
def todict(self):
this_dict = asdict(self)
del this_dict["password"]
return this_dict
def todict_simplified(self):
return {
"id": self.id,
"username": self.username,
"firstname": self.firstname,
}
+6
View File
@@ -13,6 +13,12 @@ from app.logger import log
def run_periodic_scans():
"""
Runs periodic scans.
Periodic scans are checks that run every few minutes
in the background to do stuff like:
- checking for new music
- delete deleted entries
- downloading artist images, and other data.
"""
# ValidateAlbumThumbs()
# ValidatePlaylistThumbs()
+1 -1
View File
@@ -8,7 +8,7 @@ import requests
from app.db.sqlite.plugins import PluginsMethods
from app.plugins import Plugin, plugin_method
from app.settings import Keys, Paths
from app.settings import Paths
class LRCProvider:
+8 -5
View File
@@ -1,4 +1,4 @@
from app.settings import ALLARGS
from app.settings import ALLARGS, Info
from tabulate import tabulate
args = ALLARGS
@@ -10,6 +10,7 @@ help_args_list = [
["--port", "", "Set the port"],
["--config", "", "Set the config path"],
["--no-periodic-scan", "-nps", "Disable periodic scan"],
["--pswd", "", "Recover a password"],
[
"--scan-interval",
"-psi",
@@ -23,10 +24,12 @@ help_args_list = [
]
HELP_MESSAGE = f"""
Swing Music is a beautiful, self-hosted music player for your
local audio files. Like a cooler Spotify ... but bring your own music.
Swing Music v{Info.SWINGMUSIC_APP_VERSION}
Usage: swingmusic [options] [args]
A beautiful, self-hosted music player for your local audio files.
Like Spotify ... but bring your own music.
{tabulate(help_args_list, headers=["Option", "Short", "Description"], tablefmt="simple_grid", maxcolwidths=[None, None, 40])}
Usage: ./swingmusic [options] [args]
{tabulate(help_args_list, headers=["Option", "Alias", "Description"], tablefmt="psql", maxcolwidths=[None, None, 40])}
"""
+52 -18
View File
@@ -45,22 +45,24 @@ class Paths:
def get_img_path(cls):
return join(cls.get_app_dir(), "images")
# ARTISTS
@classmethod
def get_artist_img_path(cls):
return join(cls.get_img_path(), "artists")
@classmethod
def get_artist_img_sm_path(cls):
def get_sm_artist_img_path(cls):
return join(cls.get_artist_img_path(), "small")
@classmethod
def get_artist_img_lg_path(cls):
return join(cls.get_artist_img_path(), "large")
def get_md_artist_img_path(cls):
return join(cls.get_artist_img_path(), "medium")
@classmethod
def get_playlist_img_path(cls):
return join(cls.get_img_path(), "playlists")
def get_lg_artist_img_path(cls):
return join(cls.get_artist_img_path(), "large")
# TRACK THUMBNAILS
@classmethod
def get_thumbs_path(cls):
return join(cls.get_img_path(), "thumbnails")
@@ -69,10 +71,23 @@ class Paths:
def get_sm_thumb_path(cls):
return join(cls.get_thumbs_path(), "small")
@classmethod
def get_xsm_thumb_path(cls):
return join(cls.get_thumbs_path(), "xsmall")
@classmethod
def get_md_thumb_path(cls):
return join(cls.get_thumbs_path(), "medium")
@classmethod
def get_lg_thumb_path(cls):
return join(cls.get_thumbs_path(), "large")
# OTHERS
@classmethod
def get_playlist_img_path(cls):
return join(cls.get_img_path(), "playlists")
@classmethod
def get_assets_path(cls):
return join(Paths.get_app_dir(), "assets")
@@ -85,15 +100,32 @@ class Paths:
def get_lyrics_plugins_path(cls):
return join(Paths.get_plugins_path(), "lyrics")
@classmethod
def get_config_file_path(cls):
return join(cls.get_app_dir(), "settings.json")
# defaults
class Defaults:
THUMB_SIZE = 512
SM_THUMB_SIZE = 128
"""
Contains default values for various settings.
XSM_THUMB_SIZE: extra small thumbnail size for web client tracklist
SM_THUMB_SIZE: small thumbnail size for android client tracklist
MD_THUMB_SIZE: medium thumbnail size for web client album cards
LG_THUMB_SIZE: large thumbnail size for web client now playing album art
NOTE: LG_ARTIST_IMG_SIZE is not defined as the images are saved in the original size (500px)
"""
XSM_THUMB_SIZE = 64
SM_THUMB_SIZE = 96
MD_THUMB_SIZE = 256
LG_THUMB_SIZE = 512
SM_ARTIST_IMG_SIZE = 128
"""
The size of extracted images in pixels
"""
MD_ARTIST_IMG_SIZE = 256
HASH_LENGTH = 10
API_ALBUMHASH = "bfe300e966"
API_ARTISTHASH = "cae59f1fc5"
@@ -101,7 +133,6 @@ class Defaults:
API_ALBUMNAME = "The Goat"
API_ARTISTNAME = "Polo G"
API_TRACKNAME = "Martin & Gina"
API_CARD_LIMIT = 6
@@ -158,6 +189,8 @@ class ALLARGS:
host = "--host"
config = "--config"
pswd = "--pswd"
show_feat = ("--show-feat", "-sf")
show_prod = ("--show-prod", "-sp")
dont_clean_albums = ("--no-clean-albums", "-nca")
@@ -264,7 +297,14 @@ def getCurrentBranch():
return ""
class Keys:
class Info:
"""
Contains information about the app
NOTE: This class initially written to load keys when running in build mode.
TODO: Remove this class entirely, and implement functionality where needed.
"""
SWINGMUSIC_APP_VERSION = os.environ.get("SWINGMUSIC_APP_VERSION")
GIT_LATEST_COMMIT_HASH = "<unset>"
GIT_CURRENT_BRANCH = "<unset>"
@@ -279,12 +319,6 @@ class Keys:
cls.GIT_LATEST_COMMIT_HASH = getLatestCommitHash()
cls.GIT_CURRENT_BRANCH = getCurrentBranch()
cls.verify_keys()
@classmethod
def verify_keys(cls):
pass
@classmethod
def get(cls, key: str):
return getattr(cls, key, None)
+18
View File
@@ -2,6 +2,7 @@
Prepares the server for use.
"""
import uuid
from app.db.sqlite.settings import load_settings
from app.setup.files import create_config_dir
from app.setup.sqlite import run_migrations, setup_sqlite
@@ -9,10 +10,22 @@ from app.store.albums import AlbumStore
from app.store.artists import ArtistStore
from app.store.tracks import TrackStore
from app.utils.generators import get_random_str
from app.config import UserConfig
def run_setup():
"""
Creates the config directory, runs migrations, and loads settings.
"""
create_config_dir()
# setup config file
config = UserConfig()
config.setup_config_file()
if not config.userId:
config.userId = str(uuid.uuid4())
setup_sqlite()
run_migrations()
@@ -22,6 +35,11 @@ def run_setup():
# settings table is empty
pass
def load_into_mem():
"""
Load all tracks, albums, and artists into memory.
"""
instance_key = get_random_str()
# INFO: Load all tracks, albums, and artists into memory
+16 -14
View File
@@ -51,28 +51,28 @@ def create_config_dir() -> None:
"""
Creates the config directory if it doesn't exist.
"""
thumb_path = os.path.join("images", "thumbnails")
small_thumb_path = os.path.join(thumb_path, "small")
large_thumb_path = os.path.join(thumb_path, "large")
sm_thumb_path = settings.Paths.get_sm_thumb_path()
lg_thumb_path = settings.Paths.get_lg_thumb_path()
md_thumb_path = settings.Paths.get_md_thumb_path()
xsm_thumb_path = settings.Paths.get_xsm_thumb_path()
artist_img_path = os.path.join("images", "artists")
small_artist_img_path = os.path.join(artist_img_path, "small")
large_artist_img_path = os.path.join(artist_img_path, "large")
small_artist_img_path = settings.Paths.get_sm_artist_img_path()
md_artist_img_path = settings.Paths.get_md_artist_img_path()
large_artist_img_path = settings.Paths.get_lg_artist_img_path()
playlist_img_path = os.path.join("images", "playlists")
dirs = [
"", # creates the config folder
"images",
"plugins",
sm_thumb_path,
lg_thumb_path,
md_thumb_path,
xsm_thumb_path,
"plugins/lyrics",
thumb_path,
small_thumb_path,
large_thumb_path,
artist_img_path,
playlist_img_path,
md_artist_img_path,
small_artist_img_path,
large_artist_img_path,
playlist_img_path,
]
for _dir in dirs:
@@ -80,7 +80,9 @@ def create_config_dir() -> None:
exists = os.path.exists(path)
if not exists:
os.makedirs(path)
# exist_ok=True to create parent directories if they don't exist
os.makedirs(path, exist_ok=True)
os.chmod(path, 0o755)
# copy assets to the app directory
CopyFiles()
+4
View File
@@ -4,6 +4,7 @@ Applies migrations.
"""
from app.db.sqlite import create_connection, create_tables, queries
from app.db.sqlite.auth import SQLiteAuthMethods as authdb
from app.migrations import apply_migrations
from app.settings import Db
@@ -29,5 +30,8 @@ def setup_sqlite():
create_tables(user_db_conn, queries.CREATE_USERDATA_TABLES)
create_tables(app_db_conn, queries.CREATE_MIGRATIONS_TABLE)
if not authdb.get_all_users():
authdb.insert_default_user()
app_db_conn.close()
user_db_conn.close()
+3 -3
View File
@@ -1,16 +1,16 @@
import os
from app.settings import FLASKVARS, TCOLOR, Keys, Paths
from app.settings import FLASKVARS, TCOLOR, Info, Paths
from app.utils.network import get_ip
def log_startup_info():
lines = "------------------------------"
# clears terminal 👇
os.system("cls" if os.name == "nt" else "echo -e \\\\033c")
# os.system("cls" if os.name == "nt" else "echo -e \\\\033c")
print(lines)
print(f"{TCOLOR.HEADER}SwingMusic {Keys.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}")
print(f"{TCOLOR.HEADER}Swing Music v{Info.SWINGMUSIC_APP_VERSION} {TCOLOR.ENDC}")
adresses = [FLASKVARS.get_flask_host()]
+2 -13
View File
@@ -1,7 +1,7 @@
# from tqdm import tqdm
from app.db.sqlite.favorite import SQLiteFavoriteMethods as favdb
from app.db.sqlite.tracks import SQLiteTrackMethods as tdb
from app.db.sqlite.tracks import SQLiteTrackMethods as trackdb
from app.models import Track
from app.utils.bisection import use_bisection
from app.utils.customlist import CustomList
@@ -23,7 +23,7 @@ class TrackStore:
global TRACKS_LOAD_KEY
TRACKS_LOAD_KEY = instance_key
cls.tracks = CustomList(tdb.get_all_tracks())
cls.tracks = CustomList(trackdb.get_all_tracks())
fav_hashes = favdb.get_fav_tracks()
fav_hashes = " ".join([t[1] for t in fav_hashes])
@@ -84,17 +84,6 @@ class TrackStore:
if track.filepath in filepaths:
cls.remove_track_obj(track)
@classmethod
def remove_tracks_by_dir_except(cls, dirs: list[str]):
"""Removes all tracks not in the root directories."""
to_remove = set()
for track in cls.tracks:
if not track.folder.startswith(tuple(dirs)):
to_remove.add(track.folder)
tdb.remove_tracks_by_folders(to_remove)
@classmethod
def count_tracks_by_trackhash(cls, trackhash: str) -> int:
"""
+31
View File
@@ -0,0 +1,31 @@
import hmac
import hashlib
from app.config import UserConfig
def hash_password(password: str) -> str:
"""
Hashes the given password using sha256 algorithm and the user id as salt.
:param password: The password to hash.
:return: The hashed password.
"""
return hashlib.pbkdf2_hmac(
"sha256", password.encode("utf-8"), UserConfig().userId.encode("utf-8"), 100000
).hex()
def check_password(password: str, hashed: str) -> bool:
"""
This function checks if the given password matches the hashed password.
:param password: The password to check.
:param hashed: The hashed password.
:return: Whether the password matches.
"""
return hmac.compare_digest(hash_password(password), hashed)
+21
View File
@@ -0,0 +1,21 @@
import mimetypes
def get_mime_from_ext(filename: str):
"""
Constructs a mime type from a file extension.
"""
ext = filename.rsplit(".", maxsplit=1)[-1]
return f"audio/{ext}"
def guess_mime_type(filename: str):
"""
Guess the mime type of a file.
"""
type = mimetypes.guess_type(filename)[0]
if type is None:
return get_mime_from_ext(filename)
return type
+19
View File
@@ -1,5 +1,8 @@
import os
import sys
from app.utils.filesystem import get_home_res_path
def getFlaskOpenApiPath():
"""
@@ -10,3 +13,19 @@ def getFlaskOpenApiPath():
site_packages_path = [p for p in sys.path if "site-packages" in p][0]
return f"{site_packages_path}/flask_openapi3"
def getClientFilesExtensions():
"""
Get all the file extensions for the client files
"""
client_path = get_home_res_path("client")
extensions = set()
for root, dirs, files in os.walk(client_path):
for file in files:
ext = file.split(".")[-1]
extensions.add("." + ext)
return extensions
+87 -29
View File
@@ -2,8 +2,16 @@
This file is used to run the application.
"""
from datetime import datetime, timezone
import os
import logging
from flask_jwt_extended import (
create_access_token,
get_jwt,
get_jwt_identity,
set_access_cookies,
verify_jwt_in_request,
)
import psutil
import mimetypes
from flask import Response, request
@@ -12,14 +20,15 @@ import waitress
import setproctitle
from app.api import create_api
from app.arg_handler import HandleArgs
from app.arg_handler import ProcessArgs
from app.lib.watchdogg import Watcher as WatchDog
from app.periodic_scan import run_periodic_scans
from app.plugins.register import register_plugins
from app.settings import FLASKVARS, TCOLOR, Keys
from app.setup import run_setup
from app.settings import FLASKVARS, TCOLOR, Info
from app.setup import load_into_mem, run_setup
from app.start_info_logger import log_startup_info
from app.utils.filesystem import get_home_res_path
from app.utils.paths import getClientFilesExtensions
from app.utils.threading import background
mimetypes.add_type("text/css", ".css")
@@ -38,9 +47,84 @@ mimetypes.add_type("application/manifest+json", ".webmanifest")
werkzeug = logging.getLogger("werkzeug")
werkzeug.setLevel(logging.ERROR)
# Background tasks
@background
def bg_run_setup():
run_periodic_scans()
@background
def start_watchdog():
WatchDog().run()
@background
def run_swingmusic():
log_startup_info()
bg_run_setup()
register_plugins()
start_watchdog()
setproctitle.setproctitle(f"swingmusic ::{FLASKVARS.get_flask_port()}")
# Setup function calls
Info.load()
ProcessArgs()
run_setup()
load_into_mem()
run_swingmusic()
# Create the Flask app
app = create_api()
app.static_folder = get_home_res_path("client")
# INFO: Routes that don't need authentication
whitelisted_routes = {"/auth/login", "/auth/users", "/auth/logout", "/docs"}
blacklist_extensions = {".webp"}.union(getClientFilesExtensions())
@app.before_request
def verify_auth():
"""
Verifies the JWT token before each request.
"""
if request.path == "/" or any(
request.path.endswith(ext) for ext in blacklist_extensions
):
return
# if request path starts with any of the blacklisted routes, don't verify jwt
if any(request.path.startswith(route) for route in whitelisted_routes):
# print(
# "Found whitelisted route: ", request.path, "... Skipping jwt verification"
# )
return
verify_jwt_in_request()
@app.after_request
def refresh_expiring_jwt(response: Response):
"""
Refreshes the JWT token after each request.
"""
try:
exp_timestamp = get_jwt()["exp"]
now = datetime.now(timezone.utc)
target_timestamp = datetime.timestamp(now) + 60 * 60 * 24 * 7 # 7 days
if target_timestamp > exp_timestamp:
access_token = create_access_token(identity=get_jwt_identity())
set_access_cookies(response, access_token)
return response
except (RuntimeError, KeyError):
return response
@app.route("/<path:path>")
def serve_client_files(path: str):
@@ -106,33 +190,7 @@ def print_memory_usage(response: Response):
return response
@background
def bg_run_setup() -> None:
run_periodic_scans()
@background
def start_watchdog():
WatchDog().run()
@background
def run_swingmusic():
log_startup_info()
run_setup()
bg_run_setup()
register_plugins()
start_watchdog()
setproctitle.setproctitle(f"swingmusic ::{FLASKVARS.get_flask_port()}")
if __name__ == "__main__":
Keys.load()
HandleArgs()
run_swingmusic()
host = FLASKVARS.get_flask_host()
port = FLASKVARS.get_flask_port()
Generated
+37 -1
View File
@@ -557,6 +557,25 @@ files = [
Flask = ">=0.9"
Six = "*"
[[package]]
name = "flask-jwt-extended"
version = "4.6.0"
description = "Extended JWT integration with Flask"
optional = false
python-versions = ">=3.7,<4"
files = [
{file = "Flask-JWT-Extended-4.6.0.tar.gz", hash = "sha256:9215d05a9413d3855764bcd67035e75819d23af2fafb6b55197eb5a3313fdfb2"},
{file = "Flask_JWT_Extended-4.6.0-py2.py3-none-any.whl", hash = "sha256:63a28fc9731bcc6c4b8815b6f954b5904caa534fc2ae9b93b1d3ef12930dca95"},
]
[package.dependencies]
Flask = ">=2.0,<4.0"
PyJWT = ">=2.0,<3.0"
Werkzeug = ">=0.14"
[package.extras]
asymmetric-crypto = ["cryptography (>=3.3.1)"]
[[package]]
name = "flask-openapi3"
version = "3.0.2"
@@ -1599,6 +1618,23 @@ files = [
{file = "pyinstaller_hooks_contrib-2023.9-py2.py3-none-any.whl", hash = "sha256:f34f4c6807210025c8073ebe665f422a3aa2ac5f4c7ebf4c2a26cc77bebf63b5"},
]
[[package]]
name = "pyjwt"
version = "2.8.0"
description = "JSON Web Token implementation in Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"},
{file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"},
]
[package.extras]
crypto = ["cryptography (>=3.4.0)"]
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
[[package]]
name = "pylint"
version = "2.17.7"
@@ -2468,4 +2504,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.12"
content-hash = "feb13f92b7b3a909fcb851860a405b96579feac0e2dde7681ed0e9c381c4f6cd"
content-hash = "317c4094a6f768467219db94be02a6e2e9f1ed5ca18d929437094bf60be1594d"
+1
View File
@@ -25,6 +25,7 @@ waitress = "^2.1.2"
watchdog = "^4.0.0"
pendulum = "^3.0.0"
flask-openapi3 = "^3.0.2"
flask-jwt-extended = "^4.6.0"
[tool.poetry.dev-dependencies]
pylint = "^2.15.5"
+2
View File
@@ -15,6 +15,7 @@ Flask==2.3.3
Flask-BasicAuth==0.2.0
Flask-Compress==1.14
Flask-Cors==3.0.10
Flask-JWT-Extended==4.6.0
flask-openapi3==3.0.2
gevent==23.9.1
geventhttpclient==2.0.11
@@ -42,6 +43,7 @@ pydantic==2.6.3
pydantic_core==2.16.3
pyinstaller==5.13.2
pyinstaller-hooks-contrib==2023.9
PyJWT==2.8.0
pylint==2.17.7
pytest==7.4.2
python-dateutil==2.8.2