Add db provider

This commit is contained in:
Jacob Windsor 2025-02-19 12:49:42 +01:00
parent 267a0e1512
commit 1e137a4f3a

View File

@ -1,13 +1,6 @@
import os
import subprocess
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncpg
import pg8000
import pg8000.dbapi
from google.cloud.sql.connector import Connector, IPTypes, create_async_connector
from sqlalchemy import Engine from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import ( from sqlalchemy.ext.asyncio import (
AsyncConnection, AsyncConnection,
@ -16,35 +9,19 @@ from sqlalchemy.ext.asyncio import (
async_sessionmaker, async_sessionmaker,
create_async_engine, create_async_engine,
) )
from sqlmodel import create_engine from sqlmodel import create_engine, SQLModel
from ..config import get_settings
from ..db.seed import seed from ..settings import get_settings
from ..models.base_db_model import BaseDBModel
settings = get_settings()
async def _get_async_engine() -> AsyncEngine: async def _get_async_engine() -> AsyncEngine:
settings = get_settings() return create_async_engine(
if settings.app.environment == "development":
engine = create_async_engine(
f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}", f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}",
future=True, future=True,
) )
else:
connector = await create_async_connector()
async def getconn() -> asyncpg.Connection:
return await connector.connect_async(
settings.db.connection_name,
"asyncpg",
user=settings.db.username,
password=settings.db.password,
db=settings.db.db_name,
ip_type=IPTypes.PUBLIC,
)
engine = create_async_engine("postgresql+asyncpg://", async_creator=getconn, future=True)
return engine
async def get_session() -> AsyncGenerator[AsyncSession, None]: async def get_session() -> AsyncGenerator[AsyncSession, None]:
@ -74,45 +51,13 @@ async def get_connection() -> AsyncGenerator[AsyncConnection, None]:
def _get_engine() -> Engine: def _get_engine() -> Engine:
settings = get_settings() return create_engine(
if settings.app.environment == "development": f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}"
engine = create_engine( )
f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}"
)
else:
connector = Connector()
def getconn() -> pg8000.dbapi.Connection:
conn: pg8000.dbapi.Connection = connector.connect(
settings.db.connection_name,
"pg8000",
user=settings.db.username,
password=settings.db.password,
db=settings.db.db_name,
ip_type=IPTypes.PUBLIC,
)
return conn
engine = create_engine("postgresql+pg8000://", creator=getconn)
return engine
def create_db_and_tables(): def create_db_and_tables():
# TODO Move this to use asyncpg
engine = _get_engine() engine = _get_engine()
BaseDBModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)
if get_settings().app.environment == "development": # TODO: add seeding
seed(engine)
def startup_migrations():
"""Run Alembic migrations"""
print("Running Alembic migrations...")
api_path = os.path.dirname(os.path.abspath(__file__)) + "/../.."
try:
subprocess.run(["alembic", "upgrade", "head"], check=True, cwd=api_path)
print("Migrations applied successfully!")
except subprocess.CalledProcessError as e:
print(f"Error applying migrations: {e}")