diff --git a/tests/api/test_stuff.py b/tests/api/test_stuff.py index 0e703dd..ba4bab4 100644 --- a/tests/api/test_stuff.py +++ b/tests/api/test_stuff.py @@ -1,3 +1,4 @@ +from uuid import UUID import pytest from fastapi import status from httpx import AsyncClient @@ -6,5 +7,92 @@ from httpx import AsyncClient pytestmark = pytest.mark.asyncio -async def test_add_stuff(client: AsyncClient): - assert True +@pytest.mark.parametrize( + "payload, status_code", + ( + ( + { + "name": "motorhead", + "description": "we play rock and roll" + }, + status.HTTP_201_CREATED, + ), + ), +) +async def test_add_configuration(client: AsyncClient, payload: dict, status_code: int): + response = await client.post("/v1/", json=payload) + assert response.status_code == status_code + assert payload["name"] == response.json()["name"] + + +@pytest.mark.parametrize( + "payload, status_code", + ( + ( + { + "name": "motorhead", + "description": "we play rock and roll" + }, + status.HTTP_200_OK, + ), + ), +) +async def test_get_configuration(client: AsyncClient, payload: dict, status_code: int): + await client.post("/v1/", json=payload) + name = payload["name"] + response = await client.get(f"/v1/?name={name}") + assert response.status_code == status_code + assert payload["name"] == response.json()["name"] + assert UUID(response.json()["id"]) + + +@pytest.mark.parametrize( + "payload, status_code", + ( + ( + { + "name": "motorhead", + "description": "we play rock and roll" + }, + status.HTTP_200_OK, + ), + ), +) +async def test_delete_configuration(client: AsyncClient, payload: dict, status_code: int): + response = await client.post("/v1/", json=payload) + stuff_id = response.json()["id"] + response = await client.delete(f"/v1/?stuff_id={stuff_id}") + assert response.status_code == status_code + + +@pytest.mark.parametrize( + "payload, status_code", + ( + ( + { + "name": "motorhead", + "description": "we play rock and roll" + }, + status.HTTP_200_OK, + ), + ), +) +@pytest.mark.parametrize( + "patch_payload, patch_status_code", + ( + ( + { + "name": "motorhead", + "description": "we play loud" + }, + status.HTTP_200_OK, + ), + ), +) +async def test_update_configuration(client: AsyncClient, payload: dict, status_code: int, patch_payload: dict, patch_status_code: int): + await client.post("/v1/", json=payload) + name = payload["name"] + response = await client.patch(f"/v1/?name={name}", json=patch_payload) + assert response.status_code == patch_status_code + response = await client.get(f"/v1/?name={name}") + assert patch_payload["description"] == response.json()["description"] diff --git a/tests/conftest.py b/tests/conftest.py index 8682e7d..2d48f1c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,12 +8,14 @@ from sqlalchemy.exc import SQLAlchemyError from the_app.main import app from the_app import config -from the_app.database import get_db, engine +from the_app.database import get_db from the_app.models.base import Base +from sqlalchemy.pool import NullPool + global_settings = config.get_settings() url = global_settings.asyncpg_test_url -engine = create_async_engine(url) +engine = create_async_engine(url, poolclass=NullPool) async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) diff --git a/the_app/api/stuff.py b/the_app/api/stuff.py index 05df710..e8284b6 100644 --- a/the_app/api/stuff.py +++ b/the_app/api/stuff.py @@ -1,6 +1,6 @@ from uuid import UUID -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, status from sqlalchemy.ext.asyncio import AsyncSession from the_app.database import get_db @@ -10,7 +10,7 @@ from the_app.schemas.stuff import StuffResponse, StuffSchema router = APIRouter() -@router.post("/", response_model=StuffResponse) +@router.post("/", status_code=status.HTTP_201_CREATED) async def create_stuff(stuff: StuffSchema, db_session: AsyncSession = Depends(get_db)): stuff_id = await Stuff.create(db_session, stuff) return {**stuff.dict(), "id": stuff_id} diff --git a/the_app/models/base.py b/the_app/models/base.py index bf4d2ba..09b6e26 100644 --- a/the_app/models/base.py +++ b/the_app/models/base.py @@ -3,6 +3,8 @@ from typing import Any from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.declarative import as_declarative, declared_attr +from fastapi import HTTPException, status +from icecream import ic @as_declarative() @@ -19,5 +21,6 @@ class Base: db_session.add(self) return await db_session.commit() except SQLAlchemyError as ex: - print(f"Have to rollback, save failed: {ex}") - raise + ic("Have to rollback, save failed:") + ic(ex) + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ex.__str__()) diff --git a/the_app/models/stuff.py b/the_app/models/stuff.py index 73171eb..5b3c72b 100644 --- a/the_app/models/stuff.py +++ b/the_app/models/stuff.py @@ -12,7 +12,7 @@ class Stuff(Base): __tablename__ = "stuff" id = Column(UUID(as_uuid=True), unique=True, default=uuid.uuid4, autoincrement=True) name = Column(String, nullable=False, primary_key=True, unique=True) - description = Column(String, nullable=False, unique=True) + description = Column(String, nullable=False) def __init__(self, name: str, description: str): self.name = name