58 lines
1.5 KiB
Python
58 lines
1.5 KiB
Python
from contextlib import contextmanager
|
|
import os
|
|
|
|
from flask import current_app, has_app_context
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
|
|
_engine = None
|
|
_engine_url = None
|
|
_Session = None
|
|
|
|
|
|
def get_database_url() -> str:
|
|
if has_app_context():
|
|
return str(current_app.config["DATABASE_URL"])
|
|
return os.getenv("DATABASE_URL", "sqlite:///l4d2web.db")
|
|
|
|
|
|
def get_engine():
|
|
global _engine
|
|
global _engine_url
|
|
global _Session
|
|
|
|
db_url = get_database_url()
|
|
if _engine is None or _engine_url != db_url:
|
|
connect_args = {"check_same_thread": False} if db_url.startswith("sqlite") else {}
|
|
_engine = create_engine(db_url, connect_args=connect_args)
|
|
if db_url.startswith("sqlite"):
|
|
with _engine.connect() as conn:
|
|
conn.exec_driver_sql("PRAGMA journal_mode=WAL;")
|
|
conn.exec_driver_sql("PRAGMA busy_timeout=5000;")
|
|
_engine_url = db_url
|
|
_Session = sessionmaker(bind=_engine, expire_on_commit=False)
|
|
return _engine
|
|
|
|
|
|
def init_db() -> None:
|
|
from l4d2web.models import Base
|
|
|
|
Base.metadata.create_all(bind=get_engine())
|
|
|
|
|
|
@contextmanager
|
|
def session_scope() -> Session:
|
|
global _Session
|
|
if _Session is None:
|
|
get_engine()
|
|
assert _Session is not None
|
|
session = _Session()
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|