refactor: update test fixtures for improved database session management and isolation

This commit is contained in:
grillazz
2025-12-28 18:41:22 +01:00
parent 1fe0faa5c9
commit 1188734a0f

View File

@@ -6,6 +6,7 @@ import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError from sqlalchemy.exc import ProgrammingError
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import engine, test_engine, get_test_db, get_db from app.database import engine, test_engine, get_test_db, get_db
from app.main import app from app.main import app
@@ -64,16 +65,43 @@ async def start_db():
await test_engine.dispose() await test_engine.dispose()
@pytest.fixture(scope="session") @pytest.fixture(scope="function")
async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 async def db_session(start_db) -> AsyncGenerator[AsyncSession, Any]:
transport = ASGITransport( """
app=app, Provide a transactional database session for each test function.
) Rolls back changes after the test.
"""
connection = await test_engine.connect()
transaction = await connection.begin()
session = AsyncSession(bind=connection)
yield session
await session.close()
await transaction.rollback()
await connection.close()
@pytest.fixture(scope="function")
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, Any]:
"""
Provide a test client for making API requests.
Uses the function-scoped db_session for test isolation.
"""
def get_test_db_override():
yield db_session
app.dependency_overrides[get_db] = get_test_db_override
app.redis = await get_redis()
transport = ASGITransport(app=app)
async with AsyncClient( async with AsyncClient(
base_url="http://testserver/v1", base_url="http://testserver/v1",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
transport=transport, transport=transport,
) as test_client: ) as test_client:
app.dependency_overrides[get_db] = get_test_db
app.redis = await get_redis()
yield test_client yield test_client
# Clean up dependency overrides
del app.dependency_overrides[get_db]