import os import subprocess from collections.abc import AsyncGenerator from contextlib import asynccontextmanager import asyncpg import pg8000 import pg8000.dbapi from google.cloud.sql.connector import Connector, IPTypes, create_async_connector from sqlalchemy import Engine from sqlalchemy.ext.asyncio import ( AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from sqlmodel import create_engine from ..config import get_settings from ..db.seed import seed from ..models.base_db_model import BaseDBModel async def _get_async_engine() -> AsyncEngine: settings = get_settings() if settings.app.environment == "development": engine = create_async_engine( f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}", future=True, ) else: connector = await create_async_connector() async def getconn() -> asyncpg.Connection: return await connector.connect_async( settings.db.connection_name, "asyncpg", user=settings.db.username, password=settings.db.password, db=settings.db.db_name, ip_type=IPTypes.PUBLIC, ) engine = create_async_engine("postgresql+asyncpg://", async_creator=getconn, future=True) return engine async def get_session() -> AsyncGenerator[AsyncSession, None]: """ Initialize a database session and return it. Use this when interacting via the ORM. """ engine = await _get_async_engine() async_session = async_sessionmaker(bind=engine, expire_on_commit=False) async with async_session() as session: yield session @asynccontextmanager async def get_connection() -> AsyncGenerator[AsyncConnection, None]: """ Initialize a database connection and return it. Only use this when you need to execute raw SQL queries. """ engine = await _get_async_engine() async with engine.connect() as connection: yield connection await engine.dispose() def _get_engine() -> Engine: settings = get_settings() if settings.app.environment == "development": engine = create_engine( f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}" ) else: connector = Connector() def getconn() -> pg8000.dbapi.Connection: conn: pg8000.dbapi.Connection = connector.connect( settings.db.connection_name, "pg8000", user=settings.db.username, password=settings.db.password, db=settings.db.db_name, ip_type=IPTypes.PUBLIC, ) return conn engine = create_engine("postgresql+pg8000://", creator=getconn) return engine def create_db_and_tables(): # TODO Move this to use asyncpg engine = _get_engine() BaseDBModel.metadata.create_all(engine) if get_settings().app.environment == "development": seed(engine) def startup_migrations(): """Run Alembic migrations""" print("Running Alembic migrations...") api_path = os.path.dirname(os.path.abspath(__file__)) + "/../.." try: subprocess.run(["alembic", "upgrade", "head"], check=True, cwd=api_path) print("Migrations applied successfully!") except subprocess.CalledProcessError as e: print(f"Error applying migrations: {e}")