refactor: update test fixtures and remove unused environment variables

This commit is contained in:
grillazz
2026-01-03 19:04:07 +01:00
parent 88a66b1d92
commit e9ea2c627a
5 changed files with 42 additions and 15 deletions

View File

@@ -2,11 +2,12 @@ from collections.abc import AsyncGenerator
from typing import Any
import pytest
from fastapi.exceptions import ResponseValidationError
from httpx import ASGITransport, AsyncClient
from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from app.database import engine, get_db, get_test_db, test_engine
from app.database import engine, get_db, test_engine, TestAsyncSessionFactory
from app.main import app
from app.models.base import Base
from app.redis import get_redis
@@ -43,7 +44,7 @@ def _create_db_schema(conn) -> None:
pass
@pytest.fixture(scope="session")
@pytest.fixture(scope="session", autouse=True)
async def start_db():
# The `engine` is configured for the default 'postgres' database.
# We connect to it and create the test database.
@@ -63,16 +64,37 @@ async def start_db():
await test_engine.dispose()
@pytest.fixture(scope="session")
async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
@pytest.fixture()
async def db_session():
connection = await test_engine.connect()
transaction = await connection.begin()
session = TestAsyncSessionFactory(bind=connection)
try:
yield session
finally:
# Rollback the overall transaction, restoring the state before the test ran.
await session.close()
if transaction.is_active:
await transaction.rollback()
await connection.close()
@pytest.fixture(scope="function")
async def client(db_session) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
transport = ASGITransport(
app=app,
)
async def override_get_db():
yield db_session
await db_session.commit()
async with AsyncClient(
base_url="http://testserver/v1",
headers={"Content-Type": "application/json"},
transport=transport,
) as test_client:
app.dependency_overrides[get_db] = get_test_db
app.dependency_overrides[get_db] = override_get_db
app.redis = await get_redis()
yield test_client