Add db provider
This commit is contained in:
parent
267a0e1512
commit
1e137a4f3a
@ -1,13 +1,6 @@
|
|||||||
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
import asyncpg
|
|
||||||
import pg8000
|
|
||||||
import pg8000.dbapi
|
|
||||||
from google.cloud.sql.connector import Connector, IPTypes, create_async_connector
|
|
||||||
from sqlalchemy import Engine
|
from sqlalchemy import Engine
|
||||||
from sqlalchemy.ext.asyncio import (
|
from sqlalchemy.ext.asyncio import (
|
||||||
AsyncConnection,
|
AsyncConnection,
|
||||||
@ -16,35 +9,19 @@ from sqlalchemy.ext.asyncio import (
|
|||||||
async_sessionmaker,
|
async_sessionmaker,
|
||||||
create_async_engine,
|
create_async_engine,
|
||||||
)
|
)
|
||||||
from sqlmodel import create_engine
|
from sqlmodel import create_engine, SQLModel
|
||||||
|
|
||||||
from ..config import get_settings
|
|
||||||
from ..db.seed import seed
|
from ..settings import get_settings
|
||||||
from ..models.base_db_model import BaseDBModel
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
|
||||||
async def _get_async_engine() -> AsyncEngine:
|
async def _get_async_engine() -> AsyncEngine:
|
||||||
settings = get_settings()
|
return create_async_engine(
|
||||||
if settings.app.environment == "development":
|
|
||||||
engine = create_async_engine(
|
|
||||||
f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}",
|
f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}",
|
||||||
future=True,
|
future=True,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
connector = await create_async_connector()
|
|
||||||
|
|
||||||
async def getconn() -> asyncpg.Connection:
|
|
||||||
return await connector.connect_async(
|
|
||||||
settings.db.connection_name,
|
|
||||||
"asyncpg",
|
|
||||||
user=settings.db.username,
|
|
||||||
password=settings.db.password,
|
|
||||||
db=settings.db.db_name,
|
|
||||||
ip_type=IPTypes.PUBLIC,
|
|
||||||
)
|
|
||||||
|
|
||||||
engine = create_async_engine("postgresql+asyncpg://", async_creator=getconn, future=True)
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
@ -74,45 +51,13 @@ async def get_connection() -> AsyncGenerator[AsyncConnection, None]:
|
|||||||
|
|
||||||
|
|
||||||
def _get_engine() -> Engine:
|
def _get_engine() -> Engine:
|
||||||
settings = get_settings()
|
return create_engine(
|
||||||
if settings.app.environment == "development":
|
f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}"
|
||||||
engine = create_engine(
|
)
|
||||||
f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
connector = Connector()
|
|
||||||
|
|
||||||
def getconn() -> pg8000.dbapi.Connection:
|
|
||||||
conn: pg8000.dbapi.Connection = connector.connect(
|
|
||||||
settings.db.connection_name,
|
|
||||||
"pg8000",
|
|
||||||
user=settings.db.username,
|
|
||||||
password=settings.db.password,
|
|
||||||
db=settings.db.db_name,
|
|
||||||
ip_type=IPTypes.PUBLIC,
|
|
||||||
)
|
|
||||||
return conn
|
|
||||||
|
|
||||||
engine = create_engine("postgresql+pg8000://", creator=getconn)
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
def create_db_and_tables():
|
def create_db_and_tables():
|
||||||
# TODO Move this to use asyncpg
|
|
||||||
engine = _get_engine()
|
engine = _get_engine()
|
||||||
BaseDBModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
|
|
||||||
if get_settings().app.environment == "development":
|
# TODO: add seeding
|
||||||
seed(engine)
|
|
||||||
|
|
||||||
|
|
||||||
def startup_migrations():
|
|
||||||
"""Run Alembic migrations"""
|
|
||||||
print("Running Alembic migrations...")
|
|
||||||
|
|
||||||
api_path = os.path.dirname(os.path.abspath(__file__)) + "/../.."
|
|
||||||
try:
|
|
||||||
subprocess.run(["alembic", "upgrade", "head"], check=True, cwd=api_path)
|
|
||||||
print("Migrations applied successfully!")
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print(f"Error applying migrations: {e}")
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user