This commit is contained in:
grillazz
2025-12-28 19:05:43 +01:00
parent ef6f9bc44b
commit adecd854f3
2 changed files with 12 additions and 51 deletions

View File

@@ -37,28 +37,17 @@ async def test_add_user(client: AsyncClient):
# TODO: parametrize test with diff urls including 404 and 401 # TODO: parametrize test with diff urls including 404 and 401
async def test_get_token(client: AsyncClient): async def test_get_token(client: AsyncClient):
# First, create the user required for this test payload = {"email": "joe@grillazz.com", "password": "s1lly"}
user_payload = {
"email": "joe@grillazz.com",
"first_name": "Joe",
"last_name": "Garcia",
"password": "s1lly",
}
create_user_response = await client.post("/user/", json=user_payload)
assert create_user_response.status_code == status.HTTP_201_CREATED
# Now, request the token for the newly created user
token_payload = {"email": "joe@grillazz.com", "password": "s1lly"}
response = await client.post( response = await client.post(
"/user/token", "/user/token",
data=token_payload, data=payload,
headers={"Content-Type": "application/x-www-form-urlencoded"}, headers={"Content-Type": "application/x-www-form-urlencoded"},
) )
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
claimset = jwt.decode( claimset = jwt.decode(
response.json()["access_token"], options={"verify_signature": False} response.json()["access_token"], options={"verify_signature": False}
) )
assert claimset["email"] == token_payload["email"] assert claimset["email"] == payload["email"]
assert claimset["expiry"] == IsPositiveFloat() assert claimset["expiry"] == IsPositiveFloat()
assert claimset["platform"] == "python-httpx/0.28.1" assert claimset["platform"] == "python-httpx/0.28.1"

View File

@@ -1,13 +1,13 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from types import SimpleNamespace
from typing import Any from typing import Any
import pytest import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError from sqlalchemy.exc import ProgrammingError
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import engine, get_db, test_engine from app.database import engine, test_engine, get_test_db, get_db
from app.main import app from app.main import app
from app.models.base import Base from app.models.base import Base
from app.redis import get_redis from app.redis import get_redis
@@ -22,7 +22,6 @@ from app.redis import get_redis
def anyio_backend(request): def anyio_backend(request):
return request.param return request.param
def _create_db(conn) -> None: def _create_db(conn) -> None:
"""Create the test database if it doesn't exist.""" """Create the test database if it doesn't exist."""
try: try:
@@ -65,43 +64,16 @@ async def start_db():
await test_engine.dispose() await test_engine.dispose()
@pytest.fixture(scope="function") @pytest.fixture(scope="session")
async def db_session() -> AsyncGenerator[AsyncSession, Any]: async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
""" transport = ASGITransport(
Provide a transactional database session for each test function. app=app,
Rolls back changes after the test. )
"""
connection = await test_engine.connect()
transaction = await connection.begin()
session = AsyncSession(bind=connection)
yield session
await session.close()
await transaction.rollback()
await connection.close()
@pytest.fixture(scope="function")
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, Any]:
"""
Provide a test client for making API requests.
Uses the function-scoped db_session for test isolation.
"""
def get_test_db_override():
yield db_session
app.dependency_overrides[get_db] = get_test_db_override
app.redis = await get_redis()
transport = ASGITransport(app=app)
async with AsyncClient( async with AsyncClient(
base_url="http://testserver/v1", base_url="http://testserver/v1",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
transport=transport, transport=transport,
) as test_client: ) as test_client:
app.dependency_overrides[get_db] = get_test_db
app.redis = await get_redis()
yield test_client yield test_client
# Clean up dependency overrides
del app.dependency_overrides[get_db]