diff --git a/app/database.py b/app/database.py index bb5ef08..e400ed6 100644 --- a/app/database.py +++ b/app/database.py @@ -2,7 +2,7 @@ from collections.abc import AsyncGenerator from rotoger import AppStructLogger from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine - +from sqlalchemy.exc import SQLAlchemyError from app.config import settings as global_settings logger = AppStructLogger().get_logger() @@ -29,6 +29,11 @@ async def get_db() -> AsyncGenerator: try: yield session await session.commit() - except Exception as e: - await logger.aerror(f"Error getting database session: {e}") - raise + except Exception as ex: + if isinstance(ex, SQLAlchemyError): + # Re-raise SQLAlchemyError directly without handling + raise + else: + # Handle other exceptions + await logger.aerror(f"NonSQLAlchemyError: {repr(ex)}") + raise # Re-raise after logging diff --git a/app/exception_handlers.py b/app/exception_handlers.py new file mode 100644 index 0000000..c67ca13 --- /dev/null +++ b/app/exception_handlers.py @@ -0,0 +1,32 @@ +from fastapi import Request +from fastapi.responses import JSONResponse +from sqlalchemy.exc import SQLAlchemyError +import orjson +from fastapi import FastAPI +from rotoger import AppStructLogger + +logger = AppStructLogger().get_logger() + +async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError) -> JSONResponse: + request_path = request.url.path + try: + raw_body = await request.body() + request_body = orjson.loads(raw_body) if raw_body else None + except orjson.JSONDecodeError: + request_body = None + + await logger.aerror( + "Database error occurred", + sql_error=repr(exc), + request_url=request_path, + request_body=request_body, + ) + + return JSONResponse( + status_code=500, + content={"message": "A database error occurred. Please try again later."}, + ) + +def register_exception_handlers(app: FastAPI) -> None: + """Register all exception handlers with the FastAPI app.""" + app.add_exception_handler(SQLAlchemyError, sqlalchemy_exception_handler) diff --git a/app/exceptions.py b/app/exceptions.py deleted file mode 100644 index b6d7a2e..0000000 --- a/app/exceptions.py +++ /dev/null @@ -1,59 +0,0 @@ -from fastapi import HTTPException, status - - -class BadRequestHTTPException(HTTPException): - def __init__(self, msg: str): - super().__init__( - status_code=status.HTTP_400_BAD_REQUEST, - detail=msg or "Bad request", - ) - - -class AuthFailedHTTPException(HTTPException): - def __init__(self): - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"}, - ) - - -class AuthTokenExpiredHTTPException(HTTPException): - def __init__(self): - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Expired token", - headers={"WWW-Authenticate": "Bearer"}, - ) - - -class ForbiddenHTTPException(HTTPException): - def __init__(self, msg: str): - super().__init__( - status_code=status.HTTP_403_FORBIDDEN, - detail=msg or "Requested resource is forbidden", - ) - - -class NotFoundHTTPException(HTTPException): - def __init__(self, msg: str): - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=msg or "Requested resource is not found", - ) - - -class ConflictHTTPException(HTTPException): - def __init__(self, msg: str): - super().__init__( - status_code=status.HTTP_409_CONFLICT, - detail=msg or "Conflicting resource request", - ) - - -class ServiceNotAvailableHTTPException(HTTPException): - def __init__(self, msg: str): - super().__init__( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=msg or "Service not available", - ) diff --git a/app/main.py b/app/main.py index 8a63090..1bf1fee 100644 --- a/app/main.py +++ b/app/main.py @@ -3,10 +3,10 @@ from pathlib import Path import asyncpg from fastapi import Depends, FastAPI, Request -from fastapi.responses import HTMLResponse, JSONResponse +from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from rotoger import AppStructLogger -from sqlalchemy.exc import SQLAlchemyError + from app.api.health import router as health_router from app.api.ml import router as ml_router @@ -17,6 +17,7 @@ from app.api.user import router as user_router from app.config import settings as global_settings from app.redis import get_redis from app.services.auth import AuthBearer +from app.exception_handlers import register_exception_handlers logger = AppStructLogger().get_logger() templates = Jinja2Templates(directory=Path(__file__).parent.parent / "templates") @@ -62,18 +63,8 @@ def create_app() -> FastAPI: dependencies=[Depends(AuthBearer())], ) - @app.exception_handler(SQLAlchemyError) - async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError): - await logger.aerror( - "A database error occurred", - sql_error=repr(exc), - request_url=request.url.path, - request_body=request.body, - ) - return JSONResponse( - status_code=500, - content={"message": "A database error occurred. Please try again later."}, - ) + # Register exception handlers + register_exception_handlers(app) @app.get("/index", response_class=HTMLResponse) def get_index(request: Request): diff --git a/app/models/base.py b/app/models/base.py index f04cef9..1347c6e 100644 --- a/app/models/base.py +++ b/app/models/base.py @@ -20,14 +20,11 @@ class Base(DeclarativeBase): return self.__name__.lower() async def save(self, db_session: AsyncSession): - try: - db_session.add(self) - await db_session.flush() - await db_session.refresh(self) - return self - except SQLAlchemyError as ex: - await logger.aerror(f"Error inserting instance of {self}: {repr(ex)}") - raise # This will make the exception handler catch it + db_session.add(self) + await db_session.flush() + await db_session.refresh(self) + return self + async def delete(self, db_session: AsyncSession): try: @@ -61,5 +58,4 @@ class Base(DeclarativeBase): status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=repr(exception), ) from exception - finally: - await db_session.close() +