diff --git a/the_app/database.py b/the_app/database.py index 0652a5b..8a715e0 100644 --- a/the_app/database.py +++ b/the_app/database.py @@ -1,5 +1,6 @@ from typing import AsyncGenerator +from fastapi import HTTPException from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker @@ -22,13 +23,15 @@ async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession # Dependency async def get_db() -> AsyncGenerator: - session = async_session() - try: - yield session - await session.commit() - except SQLAlchemyError as ex: - await session.rollback() - raise ex - finally: - await session.close() - + async with async_session() as session: + try: + yield session + await session.commit() + except SQLAlchemyError as sql_ex: + await session.rollback() + raise sql_ex + except HTTPException as http_ex: + session.rollback() + raise http_ex + finally: + await session.close()