2025-02-19 12:49:42 +01:00

64 lines
1.6 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlmodel import create_engine, SQLModel
from ..settings import get_settings
settings = get_settings()
async def _get_async_engine() -> AsyncEngine:
return create_async_engine(
f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}",
future=True,
)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""
Initialize a database session and return it.
Use this when interacting via the ORM.
"""
engine = await _get_async_engine()
async_session = async_sessionmaker(bind=engine, expire_on_commit=False)
async with async_session() as session:
yield session
@asynccontextmanager
async def get_connection() -> AsyncGenerator[AsyncConnection, None]:
"""
Initialize a database connection and return it.
Only use this when you need to execute raw SQL queries.
"""
engine = await _get_async_engine()
async with engine.connect() as connection:
yield connection
await engine.dispose()
def _get_engine() -> Engine:
return create_engine(
f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}"
)
def create_db_and_tables():
engine = _get_engine()
SQLModel.metadata.create_all(engine)
# TODO: add seeding