mirror of
https://github.com/grillazz/fastapi-sqlalchemy-asyncpg.git
synced 2026-01-17 11:40:39 +03:00
refactor: update test fixtures and remove unused environment variables
This commit is contained in:
1
.env
1
.env
@@ -7,7 +7,6 @@ POSTGRES_PORT=5432
|
||||
POSTGRES_DB=devdb
|
||||
POSTGRES_USER=devdb
|
||||
POSTGRES_TEST_DB=testdb
|
||||
POSTGRES_TEST_USER=testdb
|
||||
POSTGRES_PASSWORD=secret
|
||||
|
||||
# Redis
|
||||
|
||||
@@ -33,7 +33,6 @@ class Settings(BaseSettings):
|
||||
POSTGRES_PASSWORD: str
|
||||
POSTGRES_HOST: str
|
||||
POSTGRES_DB: str
|
||||
POSTGRES_TEST_USER: str
|
||||
POSTGRES_TEST_DB: str
|
||||
|
||||
@computed_field
|
||||
|
||||
@@ -18,7 +18,7 @@ engine = create_async_engine(
|
||||
test_engine = create_async_engine(
|
||||
global_settings.test_asyncpg_url.unicode_string(),
|
||||
future=True,
|
||||
echo=True,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# expire_on_commit=False will prevent attributes from being expired
|
||||
|
||||
@@ -4,8 +4,10 @@ from fastapi import status
|
||||
from httpx import AsyncClient
|
||||
from inline_snapshot import snapshot
|
||||
from polyfactory.factories.pydantic_factory import ModelFactory
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.schemas.stuff import StuffSchema
|
||||
from app.models import Stuff
|
||||
|
||||
pytestmark = pytest.mark.anyio
|
||||
|
||||
@@ -14,7 +16,7 @@ class StuffFactory(ModelFactory[StuffSchema]):
|
||||
__model__ = StuffSchema
|
||||
|
||||
|
||||
async def test_add_stuff(client: AsyncClient):
|
||||
async def test_add_stuff(client: AsyncClient, db_session: AsyncSession):
|
||||
stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json")
|
||||
response = await client.post("/stuff", json=stuff)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
@@ -32,22 +34,27 @@ async def test_add_stuff(client: AsyncClient):
|
||||
)
|
||||
|
||||
|
||||
async def test_get_stuff(client: AsyncClient):
|
||||
async def test_get_stuff(client: AsyncClient, db_session: AsyncSession):
|
||||
response = await client.get("/stuff/nonexistent")
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert response.json() == snapshot(
|
||||
{"no_response": "The requested resource was not found"}
|
||||
)
|
||||
stuff = StuffFactory.build(factory_use_constructors=True).model_dump(mode="json")
|
||||
await client.post("/stuff", json=stuff)
|
||||
name = stuff["name"]
|
||||
# await client.post("/stuff", json=stuff)
|
||||
# name = stuff["name"]
|
||||
stuff = Stuff(**stuff)
|
||||
name = stuff.name
|
||||
db_session.add(stuff)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.get(f"/stuff/{name}")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"id": IsUUID(4),
|
||||
"name": stuff["name"],
|
||||
"description": stuff["description"],
|
||||
"name": stuff.name,
|
||||
"description": stuff.description,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -2,11 +2,12 @@ from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi.exceptions import ResponseValidationError
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||
|
||||
from app.database import engine, get_db, get_test_db, test_engine
|
||||
from app.database import engine, get_db, test_engine, TestAsyncSessionFactory
|
||||
from app.main import app
|
||||
from app.models.base import Base
|
||||
from app.redis import get_redis
|
||||
@@ -43,7 +44,7 @@ def _create_db_schema(conn) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def start_db():
|
||||
# The `engine` is configured for the default 'postgres' database.
|
||||
# We connect to it and create the test database.
|
||||
@@ -63,16 +64,37 @@ async def start_db():
|
||||
await test_engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
|
||||
@pytest.fixture()
|
||||
async def db_session():
|
||||
connection = await test_engine.connect()
|
||||
transaction = await connection.begin()
|
||||
session = TestAsyncSessionFactory(bind=connection)
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Rollback the overall transaction, restoring the state before the test ran.
|
||||
await session.close()
|
||||
if transaction.is_active:
|
||||
await transaction.rollback()
|
||||
await connection.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def client(db_session) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
|
||||
transport = ASGITransport(
|
||||
app=app,
|
||||
)
|
||||
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
await db_session.commit()
|
||||
|
||||
async with AsyncClient(
|
||||
base_url="http://testserver/v1",
|
||||
headers={"Content-Type": "application/json"},
|
||||
transport=transport,
|
||||
) as test_client:
|
||||
app.dependency_overrides[get_db] = get_test_db
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.redis = await get_redis()
|
||||
yield test_client
|
||||
|
||||
Reference in New Issue
Block a user