diff --git a/tests/api/test_stuff.py b/tests/api/test_stuff.py index ba4bab4..d2900e8 100644 --- a/tests/api/test_stuff.py +++ b/tests/api/test_stuff.py @@ -1,4 +1,5 @@ from uuid import UUID + import pytest from fastapi import status from httpx import AsyncClient @@ -11,10 +12,7 @@ pytestmark = pytest.mark.asyncio "payload, status_code", ( ( - { - "name": "motorhead", - "description": "we play rock and roll" - }, + {"name": "motorhead", "description": "we play rock and roll"}, status.HTTP_201_CREATED, ), ), @@ -29,10 +27,7 @@ async def test_add_configuration(client: AsyncClient, payload: dict, status_code "payload, status_code", ( ( - { - "name": "motorhead", - "description": "we play rock and roll" - }, + {"name": "motorhead", "description": "we play rock and roll"}, status.HTTP_200_OK, ), ), @@ -50,18 +45,15 @@ async def test_get_configuration(client: AsyncClient, payload: dict, status_code "payload, status_code", ( ( - { - "name": "motorhead", - "description": "we play rock and roll" - }, + {"name": "motorhead", "description": "we play rock and roll"}, status.HTTP_200_OK, ), ), ) async def test_delete_configuration(client: AsyncClient, payload: dict, status_code: int): response = await client.post("/v1/", json=payload) - stuff_id = response.json()["id"] - response = await client.delete(f"/v1/?stuff_id={stuff_id}") + name = response.json()["name"] + response = await client.delete(f"/v1/?name={name}") assert response.status_code == status_code @@ -69,10 +61,7 @@ async def test_delete_configuration(client: AsyncClient, payload: dict, status_c "payload, status_code", ( ( - { - "name": "motorhead", - "description": "we play rock and roll" - }, + {"name": "motorhead", "description": "we play rock and roll"}, status.HTTP_200_OK, ), ), @@ -81,15 +70,14 @@ async def test_delete_configuration(client: AsyncClient, payload: dict, status_c "patch_payload, patch_status_code", ( ( - { - "name": "motorhead", - "description": "we play loud" - }, + {"name": "motorhead", "description": "we play loud"}, status.HTTP_200_OK, ), ), ) -async def test_update_configuration(client: AsyncClient, payload: dict, status_code: int, patch_payload: dict, patch_status_code: int): +async def test_update_configuration( + client: AsyncClient, payload: dict, status_code: int, patch_payload: dict, patch_status_code: int +): await client.post("/v1/", json=payload) name = payload["name"] response = await client.patch(f"/v1/?name={name}", json=patch_payload) diff --git a/tests/conftest.py b/tests/conftest.py index 2d48f1c..e4c76be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,18 +1,15 @@ import pytest from httpx import AsyncClient - -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy.orm import sessionmaker from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool -from the_app.main import app from the_app import config from the_app.database import get_db +from the_app.main import app from the_app.models.base import Base -from sqlalchemy.pool import NullPool - global_settings = config.get_settings() url = global_settings.asyncpg_test_url engine = create_async_engine(url, poolclass=NullPool) @@ -31,6 +28,7 @@ async def get_test_db(): finally: await session.close() + app.dependency_overrides[get_db] = get_test_db diff --git a/the_app/api/stuff.py b/the_app/api/stuff.py index e8284b6..a7e9a42 100644 --- a/the_app/api/stuff.py +++ b/the_app/api/stuff.py @@ -10,30 +10,33 @@ from the_app.schemas.stuff import StuffResponse, StuffSchema router = APIRouter() -@router.post("/", status_code=status.HTTP_201_CREATED) +@router.post("/", status_code=status.HTTP_201_CREATED, response_model=StuffResponse) async def create_stuff(stuff: StuffSchema, db_session: AsyncSession = Depends(get_db)): - stuff_id = await Stuff.create(db_session, stuff) - return {**stuff.dict(), "id": stuff_id} + stuff_instance = await Stuff.create(db_session, stuff) + return stuff_instance.__dict__ -@router.delete("/") -async def delete_stuff(stuff_id: UUID, db_session: AsyncSession = Depends(get_db)): - return await Stuff.delete(db_session, stuff_id) - - -@router.get("/") +@router.get("/", response_model=StuffResponse) async def find_stuff( name: str, db_session: AsyncSession = Depends(get_db), ): - return await Stuff.find(db_session, name) + stuff_instance = await Stuff.find(db_session, name) + return stuff_instance.__dict__ -@router.patch("/") -async def update_config( +@router.delete("/") +async def delete_stuff(name: str, db_session: AsyncSession = Depends(get_db)): + stuff_instance = await Stuff.find(db_session, name) + return await Stuff.delete(stuff_instance, db_session) + + +@router.patch("/", response_model=StuffResponse) +async def update_stuff( stuff: StuffSchema, name: str, db_session: AsyncSession = Depends(get_db), ): - instance_of_the_stuff = await Stuff.find(db_session, name) - return instance_of_the_stuff.update(db_session, stuff) + stuff_instance = await Stuff.find(db_session, name) + stuff_instance = await stuff_instance.update(db_session, stuff) + return stuff_instance.__dict__ diff --git a/the_app/exceptions.py b/the_app/exceptions.py new file mode 100644 index 0000000..8f1806a --- /dev/null +++ b/the_app/exceptions.py @@ -0,0 +1,59 @@ +from fastapi import HTTPException, status + + +class BadRequestHTTPException(HTTPException): + def __init__(self, msg: str): + super().__init__( + status_code=status.HTTP_400_BAD_REQUEST, + detail=msg if msg else "Bad request", + ) + + +class AuthFailedHTTPException(HTTPException): + def __init__(self): + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +class AuthTokenExpiredHTTPException(HTTPException): + def __init__(self): + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +class ForbiddenHTTPException(HTTPException): + def __init__(self, msg: str): + super().__init__( + status_code=status.HTTP_403_FORBIDDEN, + detail=msg if msg else "Requested resource is forbidden", + ) + + +class NotFoundHTTPException(HTTPException): + def __init__(self, msg: str): + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=msg if msg else "Requested resource is not found", + ) + + +class ConflictHTTPException(HTTPException): + def __init__(self, msg: str): + super().__init__( + status_code=status.HTTP_409_CONFLICT, + detail=msg if msg else "Conflicting resource request", + ) + + +class ServiceNotAvailableHTTPException(HTTPException): + def __init__(self, msg: str): + super().__init__( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=msg if msg else "Service not available", + ) diff --git a/the_app/models/base.py b/the_app/models/base.py index 09b6e26..ec2a316 100644 --- a/the_app/models/base.py +++ b/the_app/models/base.py @@ -1,10 +1,10 @@ from typing import Any +from fastapi import HTTPException, status +from icecream import ic from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.declarative import as_declarative, declared_attr -from fastapi import HTTPException, status -from icecream import ic @as_declarative() @@ -24,3 +24,13 @@ class Base: ic("Have to rollback, save failed:") ic(ex) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ex.__str__()) + + async def delete(self, db_session: AsyncSession): + try: + await db_session.delete(self) + await db_session.commit() + return True + except SQLAlchemyError as ex: + ic("Have to rollback, save failed:") + ic(ex) + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ex.__str__()) diff --git a/the_app/models/stuff.py b/the_app/models/stuff.py index 5b3c72b..80aafe6 100644 --- a/the_app/models/stuff.py +++ b/the_app/models/stuff.py @@ -1,6 +1,6 @@ import uuid -from sqlalchemy import Column, String, delete, select +from sqlalchemy import Column, String, select from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.ext.asyncio import AsyncSession @@ -25,22 +25,16 @@ class Stuff(Base): description=schema.description, ) await stuff.save(db_session) - return stuff.id + return stuff async def update(self, db_session: AsyncSession, schema: StuffSchema): self.name = schema.name self.description = schema.description - return await self.save(db_session) + await self.save(db_session) + return self @classmethod async def find(cls, db_session: AsyncSession, name: str): stmt = select(cls).where(cls.name == name) result = await db_session.execute(stmt) return result.scalars().first() - - @classmethod - async def delete(cls, db_session: AsyncSession, stuff_id: UUID): - stmt = delete(cls).where(cls.id == stuff_id) - await db_session.execute(stmt) - await db_session.commit() - return True