from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from sqlalchemy import Engine from sqlalchemy.ext.asyncio import ( AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from sqlmodel import create_engine, SQLModel from ..settings import get_settings settings = get_settings() async def _get_async_engine() -> AsyncEngine: return create_async_engine( f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}", future=True, ) async def get_session() -> AsyncGenerator[AsyncSession, None]: """ Initialize a database session and return it. Use this when interacting via the ORM. """ engine = await _get_async_engine() async_session = async_sessionmaker(bind=engine, expire_on_commit=False) async with async_session() as session: yield session @asynccontextmanager async def get_connection() -> AsyncGenerator[AsyncConnection, None]: """ Initialize a database connection and return it. Only use this when you need to execute raw SQL queries. """ engine = await _get_async_engine() async with engine.connect() as connection: yield connection await engine.dispose() def _get_engine() -> Engine: return create_engine( f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}" ) def create_db_and_tables(): engine = _get_engine() SQLModel.metadata.create_all(engine) # TODO: add seeding