from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from sqlalchemy import Engine from sqlalchemy.ext.asyncio import ( AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from sqlmodel import create_engine, SQLModel from sqlalchemy.orm import Session from ..models.user import User from ..settings import get_settings settings = get_settings() async def _get_async_engine() -> AsyncEngine: return create_async_engine( f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@{settings.db.host}:{settings.db.port}/{settings.db.db_name}", future=True, ) async def get_session() -> AsyncGenerator[AsyncSession, None]: """ Initialize a database session and return it. Use this when interacting via the ORM. """ engine = await _get_async_engine() async_session = async_sessionmaker(bind=engine, expire_on_commit=False) async with async_session() as session: yield session @asynccontextmanager async def get_connection() -> AsyncGenerator[AsyncConnection, None]: """ Initialize a database connection and return it. Only use this when you need to execute raw SQL queries. """ engine = await _get_async_engine() async with engine.connect() as connection: yield connection await engine.dispose() def _get_engine() -> Engine: return create_engine( f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@{settings.db.host}:{settings.db.port}/{settings.db.db_name}" ) def _seed_db(): if settings.app.environment != "development": return session = Session(_get_engine()) existing_user = ( session.query(User).filter(User.email == settings.app.mock_user_email).first() ) if not existing_user: mock_user = User( email=settings.app.mock_user_email, password_hash="test", ) session.add(mock_user) session.commit() session.close() def create_db_and_tables(): engine = _get_engine() SQLModel.metadata.create_all(engine) _seed_db()