diff --git a/app/models/base.py b/app/models/base.py index 7dc7e3e..8b66dc2 100644 --- a/app/models/base.py +++ b/app/models/base.py @@ -1,7 +1,7 @@ from typing import Any from fastapi import HTTPException, status -from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.exc import SQLAlchemyError, IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.declarative import as_declarative, declared_attr @@ -67,3 +67,18 @@ class Base: setattr(self, k, v) await self.save(db_session) + async def save_or_update(self, db: AsyncSession): + # TODO: this will be successor of update meth + _success = False + try: + db.add(self) + _success = await db.commit() + return _success + except IntegrityError as exception: + if not _success: + # TODO: check if exception is instance of class 'asyncpg.exceptions.UniqueViolationError' + return await db.merge(self) + else: + raise exception + finally: + await db.close() \ No newline at end of file