mirror of
https://github.com/Dvorinka/swingmusic-extended.git
synced 2026-06-03 20:13:02 +00:00
93 lines
2.3 KiB
Python
93 lines
2.3 KiB
Python
from typing import Any
|
|
|
|
from sqlalchemy import (
|
|
delete,
|
|
func,
|
|
insert,
|
|
select,
|
|
)
|
|
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy import event
|
|
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass, Session
|
|
|
|
from app.db.engine import DbEngine
|
|
|
|
|
|
# Enable foreign key constraints for SQLite
|
|
@event.listens_for(Engine, "connect")
|
|
def set_sqlite_pragma(dbapi_connection, connection_record):
|
|
cursor = dbapi_connection.cursor()
|
|
cursor.execute("PRAGMA foreign_keys=ON")
|
|
cursor.close()
|
|
|
|
|
|
class DbManager:
|
|
""" """
|
|
|
|
def __init__(self, commit: bool = False):
|
|
self.commit = commit
|
|
self.conn = DbEngine.engine.connect()
|
|
|
|
with Session(DbEngine.engine) as session:
|
|
session.connection
|
|
|
|
def __enter__(self):
|
|
return self.conn.execution_options(preserve_rowcount=True)
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self.commit:
|
|
self.conn.commit()
|
|
|
|
self.conn.close()
|
|
|
|
|
|
class Base(MappedAsDataclass, DeclarativeBase):
|
|
"""
|
|
Base class for all database models.
|
|
|
|
It has methods common to all tables. eg. `insert_one`, `insert_many`, `remove_all`, `remove_one`, `all`, `count`.
|
|
"""
|
|
|
|
@classmethod
|
|
def execute(cls, stmt: Any, commit: bool = False):
|
|
with DbEngine.manager(commit=commit) as conn:
|
|
return conn.execute(stmt)
|
|
|
|
@classmethod
|
|
def insert_many(cls, items: list[dict[str, Any]]):
|
|
"""
|
|
Inserts multiple items into the database.
|
|
"""
|
|
return cls.execute(insert(cls).values(items), commit=True)
|
|
|
|
@classmethod
|
|
def insert_one(cls, item: dict[str, Any]):
|
|
"""
|
|
Inserts a single item into the database.
|
|
"""
|
|
return cls.insert_many([item])
|
|
|
|
@classmethod
|
|
def remove_all(cls):
|
|
return cls.execute(delete(cls), commit=True)
|
|
|
|
@classmethod
|
|
def remove_one(cls, id: int):
|
|
return cls.execute(delete(cls).where(cls.id == id), commit=True)
|
|
|
|
@classmethod
|
|
def all(cls):
|
|
return cls.execute(select(cls))
|
|
|
|
@classmethod
|
|
def count(cls):
|
|
return cls.execute(select(func.count()).select_from(cls)).scalar()
|
|
|
|
|
|
def create_all_tables():
|
|
"""
|
|
Creates all the tables that build on the Base class.
|
|
"""
|
|
Base().metadata.create_all(DbEngine.engine)
|