2025-02-19 15:36:28 +01:00

83 lines
2.1 KiB
Python

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlmodel import create_engine, SQLModel
from sqlalchemy.orm import Session
from ..models.user import User
from ..settings import get_settings
settings = get_settings()
async def _get_async_engine() -> AsyncEngine:
return create_async_engine(
f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@{settings.db.host}:{settings.db.port}/{settings.db.db_name}",
future=True,
)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""
Initialize a database session and return it.
Use this when interacting via the ORM.
"""
engine = await _get_async_engine()
async_session = async_sessionmaker(bind=engine, expire_on_commit=False)
async with async_session() as session:
yield session
@asynccontextmanager
async def get_connection() -> AsyncGenerator[AsyncConnection, None]:
"""
Initialize a database connection and return it.
Only use this when you need to execute raw SQL queries.
"""
engine = await _get_async_engine()
async with engine.connect() as connection:
yield connection
await engine.dispose()
def _get_engine() -> Engine:
return create_engine(
f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@{settings.db.host}:{settings.db.port}/{settings.db.db_name}"
)
def _seed_db():
if settings.app.environment != "development":
return
session = Session(_get_engine())
existing_user = session.query(User).filter(User.email == settings.app.mock_user_email).first()
if not existing_user:
mock_user = User(
email=settings.app.mock_user_email,
password_hash="test",
)
session.add(mock_user)
session.commit()
session.close()
def create_db_and_tables():
engine = _get_engine()
SQLModel.metadata.create_all(engine)
_seed_db()