move delete method to base class, minor api updates and code linting

This commit is contained in:
grillazz 2021-04-05 17:29:09 +02:00
parent c3fc4e3146
commit 6a83b80340
6 changed files with 108 additions and 56 deletions

View File

@ -1,4 +1,5 @@
from uuid import UUID from uuid import UUID
import pytest import pytest
from fastapi import status from fastapi import status
from httpx import AsyncClient from httpx import AsyncClient
@ -11,10 +12,7 @@ pytestmark = pytest.mark.asyncio
"payload, status_code", "payload, status_code",
( (
( (
{ {"name": "motorhead", "description": "we play rock and roll"},
"name": "motorhead",
"description": "we play rock and roll"
},
status.HTTP_201_CREATED, status.HTTP_201_CREATED,
), ),
), ),
@ -29,10 +27,7 @@ async def test_add_configuration(client: AsyncClient, payload: dict, status_code
"payload, status_code", "payload, status_code",
( (
( (
{ {"name": "motorhead", "description": "we play rock and roll"},
"name": "motorhead",
"description": "we play rock and roll"
},
status.HTTP_200_OK, status.HTTP_200_OK,
), ),
), ),
@ -50,18 +45,15 @@ async def test_get_configuration(client: AsyncClient, payload: dict, status_code
"payload, status_code", "payload, status_code",
( (
( (
{ {"name": "motorhead", "description": "we play rock and roll"},
"name": "motorhead",
"description": "we play rock and roll"
},
status.HTTP_200_OK, status.HTTP_200_OK,
), ),
), ),
) )
async def test_delete_configuration(client: AsyncClient, payload: dict, status_code: int): async def test_delete_configuration(client: AsyncClient, payload: dict, status_code: int):
response = await client.post("/v1/", json=payload) response = await client.post("/v1/", json=payload)
stuff_id = response.json()["id"] name = response.json()["name"]
response = await client.delete(f"/v1/?stuff_id={stuff_id}") response = await client.delete(f"/v1/?name={name}")
assert response.status_code == status_code assert response.status_code == status_code
@ -69,10 +61,7 @@ async def test_delete_configuration(client: AsyncClient, payload: dict, status_c
"payload, status_code", "payload, status_code",
( (
( (
{ {"name": "motorhead", "description": "we play rock and roll"},
"name": "motorhead",
"description": "we play rock and roll"
},
status.HTTP_200_OK, status.HTTP_200_OK,
), ),
), ),
@ -81,15 +70,14 @@ async def test_delete_configuration(client: AsyncClient, payload: dict, status_c
"patch_payload, patch_status_code", "patch_payload, patch_status_code",
( (
( (
{ {"name": "motorhead", "description": "we play loud"},
"name": "motorhead",
"description": "we play loud"
},
status.HTTP_200_OK, status.HTTP_200_OK,
), ),
), ),
) )
async def test_update_configuration(client: AsyncClient, payload: dict, status_code: int, patch_payload: dict, patch_status_code: int): async def test_update_configuration(
client: AsyncClient, payload: dict, status_code: int, patch_payload: dict, patch_status_code: int
):
await client.post("/v1/", json=payload) await client.post("/v1/", json=payload)
name = payload["name"] name = payload["name"]
response = await client.patch(f"/v1/?name={name}", json=patch_payload) response = await client.patch(f"/v1/?name={name}", json=patch_payload)

View File

@ -1,18 +1,15 @@
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from the_app.main import app
from the_app import config from the_app import config
from the_app.database import get_db from the_app.database import get_db
from the_app.main import app
from the_app.models.base import Base from the_app.models.base import Base
from sqlalchemy.pool import NullPool
global_settings = config.get_settings() global_settings = config.get_settings()
url = global_settings.asyncpg_test_url url = global_settings.asyncpg_test_url
engine = create_async_engine(url, poolclass=NullPool) engine = create_async_engine(url, poolclass=NullPool)
@ -31,6 +28,7 @@ async def get_test_db():
finally: finally:
await session.close() await session.close()
app.dependency_overrides[get_db] = get_test_db app.dependency_overrides[get_db] = get_test_db

View File

@ -10,30 +10,33 @@ from the_app.schemas.stuff import StuffResponse, StuffSchema
router = APIRouter() router = APIRouter()
@router.post("/", status_code=status.HTTP_201_CREATED) @router.post("/", status_code=status.HTTP_201_CREATED, response_model=StuffResponse)
async def create_stuff(stuff: StuffSchema, db_session: AsyncSession = Depends(get_db)): async def create_stuff(stuff: StuffSchema, db_session: AsyncSession = Depends(get_db)):
stuff_id = await Stuff.create(db_session, stuff) stuff_instance = await Stuff.create(db_session, stuff)
return {**stuff.dict(), "id": stuff_id} return stuff_instance.__dict__
@router.delete("/") @router.get("/", response_model=StuffResponse)
async def delete_stuff(stuff_id: UUID, db_session: AsyncSession = Depends(get_db)):
return await Stuff.delete(db_session, stuff_id)
@router.get("/")
async def find_stuff( async def find_stuff(
name: str, name: str,
db_session: AsyncSession = Depends(get_db), db_session: AsyncSession = Depends(get_db),
): ):
return await Stuff.find(db_session, name) stuff_instance = await Stuff.find(db_session, name)
return stuff_instance.__dict__
@router.patch("/") @router.delete("/")
async def update_config( async def delete_stuff(name: str, db_session: AsyncSession = Depends(get_db)):
stuff_instance = await Stuff.find(db_session, name)
return await Stuff.delete(stuff_instance, db_session)
@router.patch("/", response_model=StuffResponse)
async def update_stuff(
stuff: StuffSchema, stuff: StuffSchema,
name: str, name: str,
db_session: AsyncSession = Depends(get_db), db_session: AsyncSession = Depends(get_db),
): ):
instance_of_the_stuff = await Stuff.find(db_session, name) stuff_instance = await Stuff.find(db_session, name)
return instance_of_the_stuff.update(db_session, stuff) stuff_instance = await stuff_instance.update(db_session, stuff)
return stuff_instance.__dict__

59
the_app/exceptions.py Normal file
View File

@ -0,0 +1,59 @@
from fastapi import HTTPException, status
class BadRequestHTTPException(HTTPException):
def __init__(self, msg: str):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=msg if msg else "Bad request",
)
class AuthFailedHTTPException(HTTPException):
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
class AuthTokenExpiredHTTPException(HTTPException):
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Expired token",
headers={"WWW-Authenticate": "Bearer"},
)
class ForbiddenHTTPException(HTTPException):
def __init__(self, msg: str):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=msg if msg else "Requested resource is forbidden",
)
class NotFoundHTTPException(HTTPException):
def __init__(self, msg: str):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=msg if msg else "Requested resource is not found",
)
class ConflictHTTPException(HTTPException):
def __init__(self, msg: str):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=msg if msg else "Conflicting resource request",
)
class ServiceNotAvailableHTTPException(HTTPException):
def __init__(self, msg: str):
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=msg if msg else "Service not available",
)

View File

@ -1,10 +1,10 @@
from typing import Any from typing import Any
from fastapi import HTTPException, status
from icecream import ic
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.declarative import as_declarative, declared_attr from sqlalchemy.ext.declarative import as_declarative, declared_attr
from fastapi import HTTPException, status
from icecream import ic
@as_declarative() @as_declarative()
@ -24,3 +24,13 @@ class Base:
ic("Have to rollback, save failed:") ic("Have to rollback, save failed:")
ic(ex) ic(ex)
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ex.__str__()) raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ex.__str__())
async def delete(self, db_session: AsyncSession):
try:
await db_session.delete(self)
await db_session.commit()
return True
except SQLAlchemyError as ex:
ic("Have to rollback, save failed:")
ic(ex)
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=ex.__str__())

View File

@ -1,6 +1,6 @@
import uuid import uuid
from sqlalchemy import Column, String, delete, select from sqlalchemy import Column, String, select
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -25,22 +25,16 @@ class Stuff(Base):
description=schema.description, description=schema.description,
) )
await stuff.save(db_session) await stuff.save(db_session)
return stuff.id return stuff
async def update(self, db_session: AsyncSession, schema: StuffSchema): async def update(self, db_session: AsyncSession, schema: StuffSchema):
self.name = schema.name self.name = schema.name
self.description = schema.description self.description = schema.description
return await self.save(db_session) await self.save(db_session)
return self
@classmethod @classmethod
async def find(cls, db_session: AsyncSession, name: str): async def find(cls, db_session: AsyncSession, name: str):
stmt = select(cls).where(cls.name == name) stmt = select(cls).where(cls.name == name)
result = await db_session.execute(stmt) result = await db_session.execute(stmt)
return result.scalars().first() return result.scalars().first()
@classmethod
async def delete(cls, db_session: AsyncSession, stuff_id: UUID):
stmt = delete(cls).where(cls.id == stuff_id)
await db_session.execute(stmt)
await db_session.commit()
return True