add set of unit tests

This commit is contained in:
grillazz 2021-03-28 17:12:35 +02:00
parent 786bb23eab
commit c3fc4e3146
5 changed files with 102 additions and 9 deletions

View File

@ -1,3 +1,4 @@
from uuid import UUID
import pytest import pytest
from fastapi import status from fastapi import status
from httpx import AsyncClient from httpx import AsyncClient
@ -6,5 +7,92 @@ from httpx import AsyncClient
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
async def test_add_stuff(client: AsyncClient): @pytest.mark.parametrize(
assert True "payload, status_code",
(
(
{
"name": "motorhead",
"description": "we play rock and roll"
},
status.HTTP_201_CREATED,
),
),
)
async def test_add_configuration(client: AsyncClient, payload: dict, status_code: int):
response = await client.post("/v1/", json=payload)
assert response.status_code == status_code
assert payload["name"] == response.json()["name"]
@pytest.mark.parametrize(
"payload, status_code",
(
(
{
"name": "motorhead",
"description": "we play rock and roll"
},
status.HTTP_200_OK,
),
),
)
async def test_get_configuration(client: AsyncClient, payload: dict, status_code: int):
await client.post("/v1/", json=payload)
name = payload["name"]
response = await client.get(f"/v1/?name={name}")
assert response.status_code == status_code
assert payload["name"] == response.json()["name"]
assert UUID(response.json()["id"])
@pytest.mark.parametrize(
"payload, status_code",
(
(
{
"name": "motorhead",
"description": "we play rock and roll"
},
status.HTTP_200_OK,
),
),
)
async def test_delete_configuration(client: AsyncClient, payload: dict, status_code: int):
response = await client.post("/v1/", json=payload)
stuff_id = response.json()["id"]
response = await client.delete(f"/v1/?stuff_id={stuff_id}")
assert response.status_code == status_code
@pytest.mark.parametrize(
"payload, status_code",
(
(
{
"name": "motorhead",
"description": "we play rock and roll"
},
status.HTTP_200_OK,
),
),
)
@pytest.mark.parametrize(
"patch_payload, patch_status_code",
(
(
{
"name": "motorhead",
"description": "we play loud"
},
status.HTTP_200_OK,
),
),
)
async def test_update_configuration(client: AsyncClient, payload: dict, status_code: int, patch_payload: dict, patch_status_code: int):
await client.post("/v1/", json=payload)
name = payload["name"]
response = await client.patch(f"/v1/?name={name}", json=patch_payload)
assert response.status_code == patch_status_code
response = await client.get(f"/v1/?name={name}")
assert patch_payload["description"] == response.json()["description"]

View File

@ -8,12 +8,14 @@ from sqlalchemy.exc import SQLAlchemyError
from the_app.main import app from the_app.main import app
from the_app import config from the_app import config
from the_app.database import get_db, engine from the_app.database import get_db
from the_app.models.base import Base from the_app.models.base import Base
from sqlalchemy.pool import NullPool
global_settings = config.get_settings() global_settings = config.get_settings()
url = global_settings.asyncpg_test_url url = global_settings.asyncpg_test_url
engine = create_async_engine(url) engine = create_async_engine(url, poolclass=NullPool)
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)

View File

@ -1,6 +1,6 @@
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, status
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from the_app.database import get_db from the_app.database import get_db
@ -10,7 +10,7 @@ from the_app.schemas.stuff import StuffResponse, StuffSchema
router = APIRouter() router = APIRouter()
@router.post("/", response_model=StuffResponse) @router.post("/", status_code=status.HTTP_201_CREATED)
async def create_stuff(stuff: StuffSchema, db_session: AsyncSession = Depends(get_db)): async def create_stuff(stuff: StuffSchema, db_session: AsyncSession = Depends(get_db)):
stuff_id = await Stuff.create(db_session, stuff) stuff_id = await Stuff.create(db_session, stuff)
return {**stuff.dict(), "id": stuff_id} return {**stuff.dict(), "id": stuff_id}

View File

@ -3,6 +3,8 @@ from typing import Any
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.declarative import as_declarative, declared_attr from sqlalchemy.ext.declarative import as_declarative, declared_attr
from fastapi import HTTPException, status
from icecream import ic
@as_declarative() @as_declarative()
@ -19,5 +21,6 @@ class Base:
db_session.add(self) db_session.add(self)
return await db_session.commit() return await db_session.commit()
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
print(f"Have to rollback, save failed: {ex}") ic("Have to rollback, save failed:")
raise ic(ex)
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ex.__str__())

View File

@ -12,7 +12,7 @@ class Stuff(Base):
__tablename__ = "stuff" __tablename__ = "stuff"
id = Column(UUID(as_uuid=True), unique=True, default=uuid.uuid4, autoincrement=True) id = Column(UUID(as_uuid=True), unique=True, default=uuid.uuid4, autoincrement=True)
name = Column(String, nullable=False, primary_key=True, unique=True) name = Column(String, nullable=False, primary_key=True, unique=True)
description = Column(String, nullable=False, unique=True) description = Column(String, nullable=False)
def __init__(self, name: str, description: str): def __init__(self, name: str, description: str):
self.name = name self.name = name