diff --git a/.env b/.env index 3dbfa84..842a23b 100644 --- a/.env +++ b/.env @@ -7,7 +7,6 @@ POSTGRES_PORT=5432 POSTGRES_DB=devdb POSTGRES_USER=devdb POSTGRES_TEST_DB=testdb -POSTGRES_TEST_USER=testdb POSTGRES_PASSWORD=secret # Redis diff --git a/app/config.py b/app/config.py index c79d0d6..dbe4877 100644 --- a/app/config.py +++ b/app/config.py @@ -33,7 +33,6 @@ class Settings(BaseSettings): POSTGRES_PASSWORD: str POSTGRES_HOST: str POSTGRES_DB: str - POSTGRES_TEST_USER: str POSTGRES_TEST_DB: str @computed_field diff --git a/app/database.py b/app/database.py index 7f34e70..e1003d7 100644 --- a/app/database.py +++ b/app/database.py @@ -18,7 +18,7 @@ engine = create_async_engine( test_engine = create_async_engine( global_settings.test_asyncpg_url.unicode_string(), future=True, - echo=True, + echo=False, ) # expire_on_commit=False will prevent attributes from being expired diff --git a/tests/api/test_stuff.py b/tests/api/test_stuff.py index e420b07..985c892 100644 --- a/tests/api/test_stuff.py +++ b/tests/api/test_stuff.py @@ -4,8 +4,10 @@ from fastapi import status from httpx import AsyncClient from inline_snapshot import snapshot from polyfactory.factories.pydantic_factory import ModelFactory +from sqlalchemy.ext.asyncio import AsyncSession from app.schemas.stuff import StuffSchema +from app.models import Stuff pytestmark = pytest.mark.anyio @@ -14,7 +16,7 @@ class StuffFactory(ModelFactory[StuffSchema]): __model__ = StuffSchema -async def test_add_stuff(client: AsyncClient): +async def test_add_stuff(client: AsyncClient, db_session: AsyncSession): stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json") response = await client.post("/stuff", json=stuff) assert response.status_code == status.HTTP_201_CREATED @@ -32,22 +34,27 @@ async def test_add_stuff(client: AsyncClient): ) -async def test_get_stuff(client: AsyncClient): +async def test_get_stuff(client: AsyncClient, db_session: AsyncSession): response = await client.get("/stuff/nonexistent") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == snapshot( {"no_response": "The requested resource was not found"} ) stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json") - await client.post("/stuff", json=stuff) - name = stuff["name"] + # await client.post("/stuff", json=stuff) + # name = stuff["name"] + stuff = Stuff(**stuff) + name = stuff.name + db_session.add(stuff) + await db_session.commit() + response = await client.get(f"/stuff/{name}") assert response.status_code == status.HTTP_200_OK assert response.json() == snapshot( { "id": IsUUID(4), - "name": stuff["name"], - "description": stuff["description"], + "name": stuff.name, + "description": stuff.description, } ) diff --git a/tests/conftest.py b/tests/conftest.py index 107d3bc..9ad5d46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,11 +2,12 @@ from collections.abc import AsyncGenerator from typing import Any import pytest +from fastapi.exceptions import ResponseValidationError from httpx import ASGITransport, AsyncClient from sqlalchemy import text -from sqlalchemy.exc import ProgrammingError +from sqlalchemy.exc import ProgrammingError, SQLAlchemyError -from app.database import engine, get_db, get_test_db, test_engine +from app.database import engine, get_db, test_engine, TestAsyncSessionFactory from app.main import app from app.models.base import Base from app.redis import get_redis @@ -43,7 +44,7 @@ def _create_db_schema(conn) -> None: pass -@pytest.fixture(scope="session") +@pytest.fixture(scope="session", autouse=True) async def start_db(): # The `engine` is configured for the default 'postgres' database. # We connect to it and create the test database. @@ -63,16 +64,37 @@ async def start_db(): await test_engine.dispose() -@pytest.fixture(scope="session") -async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 +@pytest.fixture() +async def db_session(): + connection = await test_engine.connect() + transaction = await connection.begin() + session = TestAsyncSessionFactory(bind=connection) + + try: + yield session + finally: + # Rollback the overall transaction, restoring the state before the test ran. + await session.close() + if transaction.is_active: + await transaction.rollback() + await connection.close() + + +@pytest.fixture(scope="function") +async def client(db_session) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 transport = ASGITransport( app=app, ) + + async def override_get_db(): + yield db_session + await db_session.commit() + async with AsyncClient( base_url="http://testserver/v1", headers={"Content-Type": "application/json"}, transport=transport, ) as test_client: - app.dependency_overrides[get_db] = get_test_db + app.dependency_overrides[get_db] = override_get_db app.redis = await get_redis() yield test_client