diff --git a/backend/app/main.py b/backend/app/main.py index 0374406..1311bfb 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,9 +1,10 @@ -from fastapi import FastAPI +from fastapi import FastAPI, Depends from fastapi.middleware.cors import CORSMiddleware from .providers.db_provider import create_db_and_tables from .routers.houses import router as houses_router from .routers.owners import router as owners_router +from .middleware.authenticate import authenticate from contextlib import asynccontextmanager @@ -16,7 +17,8 @@ app = FastAPI( title="Fair Housing API", description="Provides access to core functionality for the fair housing platform.", version="1.0.0", - lifespan=lifespan + lifespan=lifespan, + dependencies=[Depends(authenticate)], ) app.add_middleware( diff --git a/backend/app/middleware/authenticate.py b/backend/app/middleware/authenticate.py new file mode 100644 index 0000000..8879d23 --- /dev/null +++ b/backend/app/middleware/authenticate.py @@ -0,0 +1,16 @@ +from fastapi import Request, Depends +from typing import Annotated +from ..settings import get_settings +from ..repositories.user_repository import UserRepository + +async def authenticate(request: Request, user_repository: Annotated[UserRepository, Depends()]) -> bool: + """ + Authenticate the current request. + """ + mocked_user = await user_repository.get_by_email(get_settings().app.mock_user_email) + + if not mocked_user: + raise Exception("Mock user not found.") + + request.state.user = mocked_user + return True \ No newline at end of file diff --git a/backend/app/models/house.py b/backend/app/models/house.py index bf8f57b..8f72217 100644 --- a/backend/app/models/house.py +++ b/backend/app/models/house.py @@ -8,4 +8,4 @@ class House(SQLModel, table=True): country: str = Field() price: float = Field() description: str = Field() - owner_id: UUID = Field(foreign_key="owner.id") + owner_id: UUID = Field(foreign_key="user.id") # TODO consider using owner.id diff --git a/backend/app/models/user.py b/backend/app/models/user.py index a281082..6c5c33b 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -2,7 +2,6 @@ from sqlmodel import SQLModel, Field from uuid import uuid4, UUID class User(SQLModel, table=True): - id: UUID = Field(default_factory=uuid4, primary_key=True) - username: str = Field(unique=True, nullable=False) + id: UUID = Field(default_factory=lambda: uuid4(), primary_key=True) email: str = Field(unique=True, nullable=False) password_hash: str = Field(nullable=False) diff --git a/backend/app/providers/auth_provider.py b/backend/app/providers/auth_provider.py index dec1ad7..3512541 100644 --- a/backend/app/providers/auth_provider.py +++ b/backend/app/providers/auth_provider.py @@ -2,6 +2,9 @@ from ..errors.not_authenticated import NotAuthenticatedError from ..models.user import User from ..settings import get_settings +from ..repositories.user_repository import UserRepository +from typing import Annotated +from fastapi import Depends, Request class AuthContext: """ @@ -10,23 +13,16 @@ class AuthContext: def __init__( self, + request: Request, ) -> None: if not get_settings().environment == "development": raise NotImplementedError("AuthProvider is only implemented for development environment.") - self._authenticated_user = self._get_mocked_user() - - - def _get_mocked_user(self): - return User( - email="test@test.com", - username="test", - password_hash="test", - ) + self._authenticated_user = request.state.user @property def is_authenticated(self) -> bool: - return self._authenticated_user is not None + return bool(self._authenticated_user) @property def user(self) -> User: diff --git a/backend/app/providers/db_provider.py b/backend/app/providers/db_provider.py index 6722302..a0a0f77 100644 --- a/backend/app/providers/db_provider.py +++ b/backend/app/providers/db_provider.py @@ -10,7 +10,9 @@ from sqlalchemy.ext.asyncio import ( create_async_engine, ) from sqlmodel import create_engine, SQLModel +from sqlalchemy.orm import Session +from ..models.user import User from ..settings import get_settings @@ -56,8 +58,25 @@ def _get_engine() -> Engine: ) +def _seed_db(): + if settings.app.environment != "development": + return + + session = Session(_get_engine()) + + existing_user = session.query(User).filter(User.email == settings.app.mock_user_email).first() + if not existing_user: + mock_user = User( + email=settings.app.mock_user_email, + password_hash="test", + ) + session.add(mock_user) + session.commit() + + session.close() + + def create_db_and_tables(): engine = _get_engine() SQLModel.metadata.create_all(engine) - - # TODO: add seeding + _seed_db() diff --git a/backend/app/repositories/user_repository.py b/backend/app/repositories/user_repository.py index 629be50..a01cc5a 100644 --- a/backend/app/repositories/user_repository.py +++ b/backend/app/repositories/user_repository.py @@ -1,7 +1,7 @@ from sqlalchemy.ext.asyncio.session import AsyncSession from typing import Annotated from fastapi import Depends -from app.providers.db_provider import get_session +from ..providers.db_provider import get_session from ..models.user import User from sqlmodel import select @@ -15,6 +15,11 @@ class UserRepository: statement = select(User).where(User.id == user_id) result = await self.session.execute(statement) return result.scalar_one_or_none() + + async def get_by_email(self, email: str): + statement = select(User).where(User.email == email) + result = await self.session.execute(statement) + return result.scalar_one_or_none() async def save(self, user: User) -> None: """ diff --git a/backend/app/routers/houses.py b/backend/app/routers/houses.py index d6b39d3..e0ecffa 100644 --- a/backend/app/routers/houses.py +++ b/backend/app/routers/houses.py @@ -24,7 +24,7 @@ async def create_house(body: HouseCreateRequest, auth: Annotated[AuthContext, De await house_repository.save(house) return HouseCreateResponse( - id=house.id + id=str(house.id) ) diff --git a/backend/app/settings.py b/backend/app/settings.py index 1426540..1ca91c7 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -13,7 +13,7 @@ class _BaseConfig(BaseSettings): environment: str = Field(default=os.getenv("ENVIRONMENT", "development")) class _AppSettings(_BaseConfig): - pass + mock_user_email: str = "test@test.com" class _DbSettings(_BaseConfig): username: str = Field(default=os.getenv("PG_USER"))