diff --git a/.env b/.env index 13f0864..3dbfa84 100644 --- a/.env +++ b/.env @@ -6,6 +6,8 @@ POSTGRES_HOST=postgres POSTGRES_PORT=5432 POSTGRES_DB=devdb POSTGRES_USER=devdb +POSTGRES_TEST_DB=testdb +POSTGRES_TEST_USER=testdb POSTGRES_PASSWORD=secret # Redis diff --git a/Makefile b/Makefile index 892f499..c67520d 100644 --- a/Makefile +++ b/Makefile @@ -45,7 +45,7 @@ docker-create-db-migration: ## Create a new alembic database migration. Example # ==================================================================================== .PHONY: docker-test docker-test: ## Run project tests - docker compose -f compose.yml -f test-compose.yml run --rm api1 pytest tests --durations=0 -vv + docker compose -f compose.yml run --rm api1 pytest tests --durations=0 -vv .PHONY: docker-test-snapshot docker-test-snapshot: ## Run project tests and update snapshots diff --git a/app/config.py b/app/config.py index 260ac36..81722b7 100644 --- a/app/config.py +++ b/app/config.py @@ -33,6 +33,8 @@ class Settings(BaseSettings): POSTGRES_PASSWORD: str POSTGRES_HOST: str POSTGRES_DB: str + POSTGRES_TEST_USER: str + POSTGRES_TEST_DB: str @computed_field @property @@ -80,6 +82,17 @@ class Settings(BaseSettings): path=self.POSTGRES_DB, ) + @computed_field + @property + def test_asyncpg_url(self) -> PostgresDsn: + return MultiHostUrl.build( + scheme="postgresql+asyncpg", + username=self.POSTGRES_USER, + password=self.POSTGRES_PASSWORD, + host=self.POSTGRES_HOST, + path=self.POSTGRES_TEST_DB, + ) + @computed_field @property def postgres_url(self) -> PostgresDsn: diff --git a/app/database.py b/app/database.py index 4917d9b..c893572 100644 --- a/app/database.py +++ b/app/database.py @@ -15,6 +15,12 @@ engine = create_async_engine( echo=True, ) +test_engine = create_async_engine( + global_settings.test_asyncpg_url.unicode_string(), + future=True, + echo=True, +) + # expire_on_commit=False will prevent attributes from being expired # after commit. AsyncSessionFactory = async_sessionmaker( @@ -23,6 +29,12 @@ AsyncSessionFactory = async_sessionmaker( expire_on_commit=False, ) +TestAsyncSessionFactory = async_sessionmaker( + test_engine, + autoflush=False, + expire_on_commit=False, +) + # Dependency async def get_db() -> AsyncGenerator: @@ -38,3 +50,18 @@ async def get_db() -> AsyncGenerator: if not isinstance(ex, ResponseValidationError): await logger.aerror(f"Database-related error: {repr(ex)}") raise # Re-raise to be handled by appropriate handlers + + +async def get_test_db() -> AsyncGenerator: + async with TestAsyncSessionFactory() as session: + try: + yield session + await session.commit() + except SQLAlchemyError: + # Re-raise SQLAlchemy errors to be handled by the global handler + raise + except Exception as ex: + # Only log actual database-related issues, not response validation + if not isinstance(ex, ResponseValidationError): + await logger.aerror(f"Database-related error: {repr(ex)}") + raise # Re-raise to be handled by appropriate handlers \ No newline at end of file diff --git a/db/Dockerfile b/db/Dockerfile index 0da0c0a..08170b4 100644 --- a/db/Dockerfile +++ b/db/Dockerfile @@ -1,9 +1,6 @@ # pull official base image FROM postgres:17.6-alpine -# run create.sql on init -ADD create.sql /docker-entrypoint-initdb.d - WORKDIR /home/gx/code COPY shakespeare_chapter.sql /home/gx/code/shakespeare_chapter.sql diff --git a/pyproject.toml b/pyproject.toml index 06bfee5..95dc7d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,8 @@ dev-dependencies = [ "ipython>=9.5.0", "sqlacodegen<=3.1.1", "tryceratops>=2.4.1", - "locust>=2.40.5" - + "locust>=2.40.5", + "sqlalchemy-utils>=0.41.1" ] diff --git a/test-compose.yml b/test-compose.yml deleted file mode 100644 index 7a3c6ba..0000000 --- a/test-compose.yml +++ /dev/null @@ -1,9 +0,0 @@ -services: - api1: - environment: - - POSTGRES_DB=testdb - - postgres: - environment: - - POSTGRES_USER=${POSTGRES_USER} - - POSTGRES_DB=testdb \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index cee66a1..3f232c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,13 @@ from collections.abc import AsyncGenerator +from types import SimpleNamespace from typing import Any import pytest from httpx import ASGITransport, AsyncClient +from sqlalchemy import text +from sqlalchemy.exc import ProgrammingError -from app.database import engine +from app.database import engine, test_engine, get_test_db, get_db from app.main import app from app.models.base import Base from app.redis import get_redis @@ -19,15 +22,46 @@ from app.redis import get_redis def anyio_backend(request): return request.param +def _create_db(conn) -> None: + """Create a database schema if it doesn't exist.""" + try: + conn.execute(text("CREATE DATABASE testdb")) + except ProgrammingError: + # This might be raised by databases that don't support `IF NOT EXISTS` + # and the schema already exists. You can choose to ignore it. + pass + + +def _create_db_schema(conn) -> None: + """Create a database schema if it doesn't exist.""" + try: + conn.execute(text("CREATE SCHEMA happy_hog")) + conn.execute(text("CREATE SCHEMA shakespeare")) + except ProgrammingError: + # This might be raised by databases that don't support `IF NOT EXISTS` + # and the schema already exists. You can choose to ignore it. + pass + @pytest.fixture(scope="session") async def start_db(): - async with engine.begin() as conn: + # The `engine` is configured for the default 'postgres' database. + # We connect to it and create the test database. + # A transaction block is not used, as CREATE DATABASE cannot run inside it. + async with engine.connect() as conn: + await conn.execute(text("COMMIT")) # Ensure we're not in a transaction + await conn.run_sync(_create_db) + + # Now, connect to the newly created `testdb` with `test_engine` + async with test_engine.begin() as conn: + await conn.execute(text("COMMIT")) # Ensure we're not in a transaction + await conn.run_sync(_create_db_schema) await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) # for AsyncEngine created in function scope, close and # clean-up pooled connections await engine.dispose() + await test_engine.dispose() @pytest.fixture(scope="session") @@ -40,5 +74,6 @@ async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 headers={"Content-Type": "application/json"}, transport=transport, ) as test_client: + app.dependency_overrides[get_db] = get_test_db app.redis = await get_redis() yield test_client diff --git a/uv.lock b/uv.lock index f1c47e6..41a04ef 100644 --- a/uv.lock +++ b/uv.lock @@ -500,6 +500,7 @@ dev = [ { name = "pyupgrade" }, { name = "ruff" }, { name = "sqlacodegen" }, + { name = "sqlalchemy-utils" }, { name = "tryceratops" }, ] @@ -540,6 +541,7 @@ dev = [ { name = "pyupgrade", specifier = ">=3.20.0" }, { name = "ruff", specifier = ">=0.13.1" }, { name = "sqlacodegen", specifier = "<=3.1.1" }, + { name = "sqlalchemy-utils", specifier = ">=0.41.1" }, { name = "tryceratops", specifier = ">=2.4.1" }, ] @@ -1660,6 +1662,18 @@ asyncio = [ { name = "greenlet" }, ] +[[package]] +name = "sqlalchemy-utils" +version = "0.42.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sqlalchemy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0f/7d/eb9565b6a49426552a5bf5c57e7c239c506dc0e4e5315aec6d1e8241dc7c/sqlalchemy_utils-0.42.1.tar.gz", hash = "sha256:881f9cd9e5044dc8f827bccb0425ce2e55490ce44fc0bb848c55cc8ee44cc02e", size = 130789, upload-time = "2025-12-13T03:14:13.591Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/25/7400c18c3ee97914cc99c90007795c00a4ec5b60c853b49db7ba24d11179/sqlalchemy_utils-0.42.1-py3-none-any.whl", hash = "sha256:243cfe1b3a1dae3c74118ae633f1d1e0ed8c787387bc33e556e37c990594ac80", size = 91761, upload-time = "2025-12-13T03:14:15.014Z" }, +] + [[package]] name = "stack-data" version = "0.6.3"