refactor: update test fixtures and remove unused environment variables

This commit is contained in:
grillazz
2026-01-03 19:04:07 +01:00
parent 88a66b1d92
commit e9ea2c627a
5 changed files with 42 additions and 15 deletions

View File

@@ -4,8 +4,10 @@ from fastapi import status
from httpx import AsyncClient
from inline_snapshot import snapshot
from polyfactory.factories.pydantic_factory import ModelFactory
from sqlalchemy.ext.asyncio import AsyncSession
from app.schemas.stuff import StuffSchema
from app.models import Stuff
pytestmark = pytest.mark.anyio
@@ -14,7 +16,7 @@ class StuffFactory(ModelFactory[StuffSchema]):
__model__ = StuffSchema
async def test_add_stuff(client: AsyncClient):
async def test_add_stuff(client: AsyncClient, db_session: AsyncSession):
stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json")
response = await client.post("/stuff", json=stuff)
assert response.status_code == status.HTTP_201_CREATED
@@ -32,22 +34,27 @@ async def test_add_stuff(client: AsyncClient):
)
async def test_get_stuff(client: AsyncClient):
async def test_get_stuff(client: AsyncClient, db_session: AsyncSession):
response = await client.get("/stuff/nonexistent")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json() == snapshot(
{"no_response": "The requested resource was not found"}
)
stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json")
await client.post("/stuff", json=stuff)
name = stuff["name"]
# await client.post("/stuff", json=stuff)
# name = stuff["name"]
stuff = Stuff(**stuff)
name = stuff.name
db_session.add(stuff)
await db_session.commit()
response = await client.get(f"/stuff/{name}")
assert response.status_code == status.HTTP_200_OK
assert response.json() == snapshot(
{
"id": IsUUID(4),
"name": stuff["name"],
"description": stuff["description"],
"name": stuff.name,
"description": stuff.description,
}
)

View File

@@ -2,11 +2,12 @@ from collections.abc import AsyncGenerator
from typing import Any
import pytest
from fastapi.exceptions import ResponseValidationError
from httpx import ASGITransport, AsyncClient
from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from app.database import engine, get_db, get_test_db, test_engine
from app.database import engine, get_db, test_engine, TestAsyncSessionFactory
from app.main import app
from app.models.base import Base
from app.redis import get_redis
@@ -43,7 +44,7 @@ def _create_db_schema(conn) -> None:
pass
@pytest.fixture(scope="session")
@pytest.fixture(scope="session", autouse=True)
async def start_db():
# The `engine` is configured for the default 'postgres' database.
# We connect to it and create the test database.
@@ -63,16 +64,37 @@ async def start_db():
await test_engine.dispose()
@pytest.fixture(scope="session")
async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
@pytest.fixture()
async def db_session():
connection = await test_engine.connect()
transaction = await connection.begin()
session = TestAsyncSessionFactory(bind=connection)
try:
yield session
finally:
# Rollback the overall transaction, restoring the state before the test ran.
await session.close()
if transaction.is_active:
await transaction.rollback()
await connection.close()
@pytest.fixture(scope="function")
async def client(db_session) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
transport = ASGITransport(
app=app,
)
async def override_get_db():
yield db_session
await db_session.commit()
async with AsyncClient(
base_url="http://testserver/v1",
headers={"Content-Type": "application/json"},
transport=transport,
) as test_client:
app.dependency_overrides[get_db] = get_test_db
app.dependency_overrides[get_db] = override_get_db
app.redis = await get_redis()
yield test_client