From 1188734a0f1fc0eda59b94b510ede88ed8e9aca9 Mon Sep 17 00:00:00 2001 From: grillazz Date: Sun, 28 Dec 2025 18:41:22 +0100 Subject: [PATCH] refactor: update test fixtures for improved database session management and isolation --- tests/conftest.py | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 8d5b600..30765e8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import text from sqlalchemy.exc import ProgrammingError +from sqlalchemy.ext.asyncio import AsyncSession from app.database import engine, test_engine, get_test_db, get_db from app.main import app @@ -64,16 +65,43 @@ async def start_db(): await test_engine.dispose() -@pytest.fixture(scope="session") -async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 - transport = ASGITransport( - app=app, - ) +@pytest.fixture(scope="function") +async def db_session(start_db) -> AsyncGenerator[AsyncSession, Any]: + """ + Provide a transactional database session for each test function. + Rolls back changes after the test. + """ + connection = await test_engine.connect() + transaction = await connection.begin() + session = AsyncSession(bind=connection) + + yield session + + await session.close() + await transaction.rollback() + await connection.close() + + +@pytest.fixture(scope="function") +async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, Any]: + """ + Provide a test client for making API requests. + Uses the function-scoped db_session for test isolation. + """ + + def get_test_db_override(): + yield db_session + + app.dependency_overrides[get_db] = get_test_db_override + app.redis = await get_redis() + + transport = ASGITransport(app=app) async with AsyncClient( base_url="http://testserver/v1", headers={"Content-Type": "application/json"}, transport=transport, ) as test_client: - app.dependency_overrides[get_db] = get_test_db - app.redis = await get_redis() yield test_client + + # Clean up dependency overrides + del app.dependency_overrides[get_db]