diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index b5584b5..9000ef3 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -37,28 +37,17 @@ async def test_add_user(client: AsyncClient): # TODO: parametrize test with diff urls including 404 and 401 async def test_get_token(client: AsyncClient): - # First, create the user required for this test - user_payload = { - "email": "joe@grillazz.com", - "first_name": "Joe", - "last_name": "Garcia", - "password": "s1lly", - } - create_user_response = await client.post("/user/", json=user_payload) - assert create_user_response.status_code == status.HTTP_201_CREATED - - # Now, request the token for the newly created user - token_payload = {"email": "joe@grillazz.com", "password": "s1lly"} + payload = {"email": "joe@grillazz.com", "password": "s1lly"} response = await client.post( "/user/token", - data=token_payload, + data=payload, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) assert response.status_code == status.HTTP_201_CREATED claimset = jwt.decode( response.json()["access_token"], options={"verify_signature": False} ) - assert claimset["email"] == token_payload["email"] + assert claimset["email"] == payload["email"] assert claimset["expiry"] == IsPositiveFloat() assert claimset["platform"] == "python-httpx/0.28.1" diff --git a/tests/conftest.py b/tests/conftest.py index 0e17675..8d5b600 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,13 @@ from collections.abc import AsyncGenerator +from types import SimpleNamespace from typing import Any import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import text from sqlalchemy.exc import ProgrammingError -from sqlalchemy.ext.asyncio import AsyncSession -from app.database import engine, get_db, test_engine +from app.database import engine, test_engine, get_test_db, get_db from app.main import app from app.models.base import Base from app.redis import get_redis @@ -22,7 +22,6 @@ from app.redis import get_redis def anyio_backend(request): return request.param - def _create_db(conn) -> None: """Create the test database if it doesn't exist.""" try: @@ -65,43 +64,16 @@ async def start_db(): await test_engine.dispose() -@pytest.fixture(scope="function") -async def db_session() -> AsyncGenerator[AsyncSession, Any]: - """ - Provide a transactional database session for each test function. - Rolls back changes after the test. - """ - connection = await test_engine.connect() - transaction = await connection.begin() - session = AsyncSession(bind=connection) - - yield session - - await session.close() - await transaction.rollback() - await connection.close() - - -@pytest.fixture(scope="function") -async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, Any]: - """ - Provide a test client for making API requests. - Uses the function-scoped db_session for test isolation. - """ - - def get_test_db_override(): - yield db_session - - app.dependency_overrides[get_db] = get_test_db_override - app.redis = await get_redis() - - transport = ASGITransport(app=app) +@pytest.fixture(scope="session") +async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 + transport = ASGITransport( + app=app, + ) async with AsyncClient( base_url="http://testserver/v1", headers={"Content-Type": "application/json"}, transport=transport, ) as test_client: + app.dependency_overrides[get_db] = get_test_db + app.redis = await get_redis() yield test_client - - # Clean up dependency overrides - del app.dependency_overrides[get_db]