document + rename stuff

This commit is contained in:
cwilvx
2024-07-07 16:07:27 +03:00
parent 32a2684ea2
commit 2ba5d6c1d7
11 changed files with 72 additions and 72 deletions
+20 -12
View File
@@ -9,14 +9,12 @@ from sqlalchemy import (
from sqlalchemy.engine import Engine
from sqlalchemy import event
from sqlalchemy.orm import (
DeclarativeBase,
MappedAsDataclass,
Session
)
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass, Session
from app.db.engine import DbEngine
# Enable foreign key constraints for SQLite
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
@@ -25,9 +23,12 @@ def set_sqlite_pragma(dbapi_connection, connection_record):
class DbManager:
""" """
def __init__(self, commit: bool = False):
self.commit = commit
self.conn = DbEngine.engine.connect()
with Session(DbEngine.engine) as session:
session.connection
@@ -42,9 +43,15 @@ class DbManager:
class Base(MappedAsDataclass, DeclarativeBase):
"""
Base class for all database models.
It has methods common to all tables. eg. `insert_one`, `insert_many`, `remove_all`, `remove_one`, `all`, `count`.
"""
@classmethod
def execute(cls, stmt: Any, commit: bool = False):
with DbManager(commit=commit) as conn:
with DbEngine.manager(commit=commit) as conn:
return conn.execute(stmt)
@classmethod
@@ -52,8 +59,7 @@ class Base(MappedAsDataclass, DeclarativeBase):
"""
Inserts multiple items into the database.
"""
with DbManager(commit=True) as conn:
return conn.execute(insert(cls).values(items))
return cls.execute(insert(cls).values(items), commit=True)
@classmethod
def insert_one(cls, item: dict[str, Any]):
@@ -64,12 +70,11 @@ class Base(MappedAsDataclass, DeclarativeBase):
@classmethod
def remove_all(cls):
with DbManager(commit=True) as conn:
conn.execute(delete(cls))
return cls.execute(delete(cls), commit=True)
@classmethod
def remove_one(cls, id: int):
cls.execute(delete(cls).where(cls.id == id), commit=True)
return cls.execute(delete(cls).where(cls.id == id), commit=True)
@classmethod
def all(cls):
@@ -80,5 +85,8 @@ class Base(MappedAsDataclass, DeclarativeBase):
return cls.execute(select(func.count()).select_from(cls)).scalar()
def create_all():
def create_all_tables():
"""
Creates all the tables that build on the Base class.
"""
Base().metadata.create_all(DbEngine.engine)