64 lines
1.6 KiB
Python
64 lines
1.6 KiB
Python
|
||
from collections.abc import AsyncGenerator
|
||
from contextlib import asynccontextmanager
|
||
from sqlalchemy import Engine
|
||
from sqlalchemy.ext.asyncio import (
|
||
AsyncConnection,
|
||
AsyncEngine,
|
||
AsyncSession,
|
||
async_sessionmaker,
|
||
create_async_engine,
|
||
)
|
||
from sqlmodel import create_engine, SQLModel
|
||
|
||
|
||
from ..settings import get_settings
|
||
|
||
settings = get_settings()
|
||
|
||
|
||
async def _get_async_engine() -> AsyncEngine:
|
||
return create_async_engine(
|
||
f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}",
|
||
future=True,
|
||
)
|
||
|
||
|
||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||
"""
|
||
Initialize a database session and return it.
|
||
|
||
Use this when interacting via the ORM.
|
||
"""
|
||
engine = await _get_async_engine()
|
||
async_session = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||
async with async_session() as session:
|
||
yield session
|
||
|
||
|
||
@asynccontextmanager
|
||
async def get_connection() -> AsyncGenerator[AsyncConnection, None]:
|
||
"""
|
||
Initialize a database connection and return it.
|
||
|
||
Only use this when you need to execute raw SQL queries.
|
||
"""
|
||
engine = await _get_async_engine()
|
||
async with engine.connect() as connection:
|
||
yield connection
|
||
|
||
await engine.dispose()
|
||
|
||
|
||
def _get_engine() -> Engine:
|
||
return create_engine(
|
||
f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}"
|
||
)
|
||
|
||
|
||
def create_db_and_tables():
|
||
engine = _get_engine()
|
||
SQLModel.metadata.create_all(engine)
|
||
|
||
# TODO: add seeding
|