83 lines
2.1 KiB
Python
83 lines
2.1 KiB
Python
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
from sqlalchemy import Engine
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncConnection,
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
from sqlmodel import create_engine, SQLModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from ..models.user import User
|
|
|
|
from ..settings import get_settings
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
async def _get_async_engine() -> AsyncEngine:
|
|
return create_async_engine(
|
|
f"postgresql+asyncpg://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}",
|
|
future=True,
|
|
)
|
|
|
|
|
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""
|
|
Initialize a database session and return it.
|
|
|
|
Use this when interacting via the ORM.
|
|
"""
|
|
engine = await _get_async_engine()
|
|
async_session = async_sessionmaker(bind=engine, expire_on_commit=False)
|
|
async with async_session() as session:
|
|
yield session
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_connection() -> AsyncGenerator[AsyncConnection, None]:
|
|
"""
|
|
Initialize a database connection and return it.
|
|
|
|
Only use this when you need to execute raw SQL queries.
|
|
"""
|
|
engine = await _get_async_engine()
|
|
async with engine.connect() as connection:
|
|
yield connection
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
def _get_engine() -> Engine:
|
|
return create_engine(
|
|
f"postgresql+pg8000://{settings.db.username}:{settings.db.password}@localhost/{settings.db.db_name}"
|
|
)
|
|
|
|
|
|
def _seed_db():
|
|
if settings.app.environment != "development":
|
|
return
|
|
|
|
session = Session(_get_engine())
|
|
|
|
existing_user = session.query(User).filter(User.email == settings.app.mock_user_email).first()
|
|
if not existing_user:
|
|
mock_user = User(
|
|
email=settings.app.mock_user_email,
|
|
password_hash="test",
|
|
)
|
|
session.add(mock_user)
|
|
session.commit()
|
|
|
|
session.close()
|
|
|
|
|
|
def create_db_and_tables():
|
|
engine = _get_engine()
|
|
SQLModel.metadata.create_all(engine)
|
|
_seed_db()
|