diff --git a/app/api/stuff.py b/app/api/stuff.py index 5cab5c3..bfb6c8c 100644 --- a/app/api/stuff.py +++ b/app/api/stuff.py @@ -45,7 +45,10 @@ async def create_stuff( async def find_stuff(name: str, db_session: AsyncSession = Depends(get_db)): result = await Stuff.find(db_session, name) if not result: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Stuff with name {name} not found.") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Stuff with name {name} not found.", + ) return result @@ -91,7 +94,6 @@ async def find_stuff_pool( return result - @router.delete("/{name}") async def delete_stuff(name: str, db_session: AsyncSession = Depends(get_db)): stuff = await Stuff.find(db_session, name) diff --git a/app/models/stuff.py b/app/models/stuff.py index 45a0851..fdd2dc7 100644 --- a/app/models/stuff.py +++ b/app/models/stuff.py @@ -9,6 +9,20 @@ from sqlalchemy.orm import mapped_column, Mapped, relationship, joinedload from app.models.base import Base from app.models.nonsense import Nonsense +from functools import wraps + + +def compile_sql_or_scalar(func): + @wraps(func) + async def wrapper(cls, db_session, name, compile_sql=False, *args, **kwargs): + stmt = await func(cls, db_session, name, *args, **kwargs) + if compile_sql: + return stmt.compile(compile_kwargs={"literal_binds": True}) + result = await db_session.execute(stmt) + return result.scalars().first() + + return wrapper + class Stuff(Base): __tablename__ = "stuff" @@ -24,12 +38,10 @@ class Stuff(Base): ) @classmethod - async def find(cls, db_session: AsyncSession, name: str, compile_sql: bool = False): + @compile_sql_or_scalar + async def find(cls, db_session: AsyncSession, name: str, compile_sql=False): stmt = select(cls).options(joinedload(cls.nonsense)).where(cls.name == name) - if compile_sql: - return stmt.compile(compile_kwargs={"literal_binds": True}) - result = await db_session.execute(stmt) - return result.scalars().first() + return stmt class StuffFullOfNonsense(Base):