Files
fastapi-sqlalchemy-asyncpg/tests/conftest.py

101 lines
3.3 KiB
Python

from collections.abc import AsyncGenerator
from typing import Any
import pytest
from fastapi.exceptions import ResponseValidationError
from httpx import ASGITransport, AsyncClient
from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from app.database import engine, get_db, test_engine, TestAsyncSessionFactory
from app.main import app
from app.models.base import Base
from app.redis import get_redis
@pytest.fixture(
scope="session",
params=[
pytest.param(("asyncio", {"use_uvloop": True}), id="asyncio+uvloop"),
],
)
def anyio_backend(request):
return request.param
def _create_db(conn) -> None:
"""Create the test database if it doesn't exist."""
try:
conn.execute(text("CREATE DATABASE testdb"))
except ProgrammingError:
# This might be raised by databases that don't support `IF NOT EXISTS`
# and the schema already exists. You can choose to ignore it.
pass
def _create_db_schema(conn) -> None:
"""Create a database schema if it doesn't exist."""
try:
"""Create a database schema if it doesn't exist."""
conn.execute(text("CREATE SCHEMA IF NOT EXISTS happy_hog"))
conn.execute(text("CREATE SCHEMA IF NOT EXISTS shakespeare"))
except ProgrammingError:
# This might be raised by databases that don't support `IF NOT EXISTS`
# and the schema already exists. You can choose to ignore it.
pass
@pytest.fixture(scope="session", autouse=True)
async def start_db():
# The `engine` is configured for the default 'postgres' database.
# We connect to it and create the test database.
# A transaction block is not used, as CREATE DATABASE cannot run inside it.
async with engine.connect() as conn:
await conn.execute(text("COMMIT")) # Ensure we're not in a transaction
await conn.run_sync(_create_db)
# Now, connect to the newly created `testdb` with `test_engine`
async with test_engine.begin() as conn:
await conn.run_sync(_create_db_schema)
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
# for AsyncEngine created in function scope, close and
# clean-up pooled connections
await engine.dispose()
await test_engine.dispose()
@pytest.fixture()
async def db_session():
connection = await test_engine.connect()
transaction = await connection.begin()
session = TestAsyncSessionFactory(bind=connection)
try:
yield session
finally:
# Rollback the overall transaction, restoring the state before the test ran.
await session.close()
if transaction.is_active:
await transaction.rollback()
await connection.close()
@pytest.fixture(scope="function")
async def client(db_session) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
transport = ASGITransport(
app=app,
)
async def override_get_db():
yield db_session
await db_session.commit()
async with AsyncClient(
base_url="http://testserver/v1",
headers={"Content-Type": "application/json"},
transport=transport,
) as test_client:
app.dependency_overrides[get_db] = override_get_db
app.redis = await get_redis()
yield test_client