From 514eea72312f293c30d41d83a8a1a1e7b04d5a13 Mon Sep 17 00:00:00 2001 From: Jakub Miazek Date: Sat, 11 May 2024 16:03:33 +0200 Subject: [PATCH 1/3] - implement postgres connection pool - adapt current sqlalchemy based methods to produce raw sql --- app/api/stuff.py | 25 ++++++++++++++++++++++--- app/config.py | 24 ++++++++++++++++++++++++ app/main.py | 13 +++++++++++++ app/models/stuff.py | 19 ++++--------------- tests/api/test_stuff.py | 1 - 5 files changed, 63 insertions(+), 19 deletions(-) diff --git a/app/api/stuff.py b/app/api/stuff.py index 8c21cb1..2b9b010 100644 --- a/app/api/stuff.py +++ b/app/api/stuff.py @@ -1,4 +1,5 @@ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status, Request +from fastapi.exceptions import ResponseValidationError from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession @@ -21,7 +22,6 @@ async def create_multi_stuff( db_session.add_all(stuff_instances) await db_session.commit() except SQLAlchemyError as ex: - # logger.exception(ex) raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=repr(ex) ) from ex @@ -43,10 +43,29 @@ async def create_stuff( @router.get("/{name}", response_model=StuffResponse) async def find_stuff( + request: Request, name: str, + pool: bool = False, db_session: AsyncSession = Depends(get_db), ): - return await Stuff.find(db_session, name) + try: + if not pool: + result = await Stuff.find(db_session, name) + else: + # execute the compiled SQL statement + stmt = await Stuff.find(db_session, name, compile_sql=True) + result = await request.app.postgres_pool.fetchrow(str(stmt)) + result = dict(result) + except SQLAlchemyError as ex: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=repr(ex) + ) from ex + if not result: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Stuff with name {name} not found.", + ) + return result @router.delete("/{name}") diff --git a/app/config.py b/app/config.py index bf40388..9ad9a56 100644 --- a/app/config.py +++ b/app/config.py @@ -70,5 +70,29 @@ class Settings(BaseSettings): path=self.POSTGRES_DB, ) + @computed_field + @property + def postgres_url(self) -> PostgresDsn: + """ + This is a computed field that generates a PostgresDsn URL + + The URL is built using the MultiHostUrl.build method, which takes the following parameters: + - scheme: The scheme of the URL. In this case, it is "postgres". + - username: The username for the Postgres database, retrieved from the POSTGRES_USER environment variable. + - password: The password for the Postgres database, retrieved from the POSTGRES_PASSWORD environment variable. + - host: The host of the Postgres database, retrieved from the POSTGRES_HOST environment variable. + - path: The path of the Postgres database, retrieved from the POSTGRES_DB environment variable. + + Returns: + PostgresDsn: The constructed PostgresDsn URL. + """ + return MultiHostUrl.build( + scheme="postgres", + username=self.POSTGRES_USER, + password=self.POSTGRES_PASSWORD, + host=self.POSTGRES_HOST, + path=self.POSTGRES_DB, + ) + settings = Settings() diff --git a/app/main.py b/app/main.py index b26ffe4..4908f5d 100644 --- a/app/main.py +++ b/app/main.py @@ -1,3 +1,4 @@ +import asyncpg from contextlib import asynccontextmanager from fastapi import FastAPI, Depends @@ -7,6 +8,7 @@ from fastapi_cache.backends.redis import RedisBackend from app.api.nonsense import router as nonsense_router from app.api.shakespeare import router as shakespeare_router from app.api.stuff import router as stuff_router +from app.config import settings as global_settings from app.utils.logging import AppLogger from app.api.user import router as user_router from app.api.health import router as health_router @@ -21,15 +23,26 @@ async def lifespan(_app: FastAPI): # Load the redis connection _app.redis = await get_redis() + _postgres_dsn = global_settings.postgres_url.unicode_string() + try: # Initialize the cache with the redis connection redis_cache = await get_cache() FastAPICache.init(RedisBackend(redis_cache), prefix="fastapi-cache") logger.info(FastAPICache.get_cache_status_header()) + # Initialize the postgres connection pool + _app.postgres_pool = await asyncpg.create_pool( + dsn=_postgres_dsn, + min_size=5, + max_size=20, + ) + logger.info(f"Postgres pool created: {_app.postgres_pool.get_idle_size()=}") yield finally: # close redis connection and release the resources await _app.redis.close() + # close postgres connection pool and release the resources + await _app.postgres_pool.close() app = FastAPI(title="Stuff And Nonsense API", version="0.6", lifespan=lifespan) diff --git a/app/models/stuff.py b/app/models/stuff.py index 5717f92..45a0851 100644 --- a/app/models/stuff.py +++ b/app/models/stuff.py @@ -24,23 +24,12 @@ class Stuff(Base): ) @classmethod - async def find(cls, db_session: AsyncSession, name: str): - """ - - :param db_session: - :param name: - :return: - """ + async def find(cls, db_session: AsyncSession, name: str, compile_sql: bool = False): stmt = select(cls).options(joinedload(cls.nonsense)).where(cls.name == name) + if compile_sql: + return stmt.compile(compile_kwargs={"literal_binds": True}) result = await db_session.execute(stmt) - instance = result.scalars().first() - if instance is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={"Not found": f"There is no record for name: {name}"}, - ) - else: - return instance + return result.scalars().first() class StuffFullOfNonsense(Base): diff --git a/tests/api/test_stuff.py b/tests/api/test_stuff.py index 21da64c..2e66ab8 100644 --- a/tests/api/test_stuff.py +++ b/tests/api/test_stuff.py @@ -1,4 +1,3 @@ - import pytest from fastapi import status from httpx import AsyncClient From e003454647e72d99c5bcab936cf191a174594843 Mon Sep 17 00:00:00 2001 From: Jakub Miazek Date: Sat, 11 May 2024 16:09:10 +0200 Subject: [PATCH 2/3] - split endpoint to find stuff for pool and traditional --- app/api/stuff.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/app/api/stuff.py b/app/api/stuff.py index 2b9b010..5cab5c3 100644 --- a/app/api/stuff.py +++ b/app/api/stuff.py @@ -42,20 +42,43 @@ async def create_stuff( @router.get("/{name}", response_model=StuffResponse) -async def find_stuff( +async def find_stuff(name: str, db_session: AsyncSession = Depends(get_db)): + result = await Stuff.find(db_session, name) + if not result: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Stuff with name {name} not found.") + return result + + +@router.get("/pool/{name}", response_model=StuffResponse) +async def find_stuff_pool( request: Request, name: str, - pool: bool = False, db_session: AsyncSession = Depends(get_db), ): + """ + Asynchronous function to find a specific 'Stuff' object in the database using a connection pool. + + This function compiles an SQL statement to find a 'Stuff' object by its name, executes the statement + using a connection from the application's connection pool, and returns the result as a dictionary. + If the 'Stuff' object is not found, it raises an HTTPException with a 404 status code. + If an SQLAlchemyError occurs during the execution of the SQL statement, it raises an HTTPException + with a 422 status code. + + Args: + request (Request): The incoming request. Used to access the application's connection pool. + name (str): The name of the 'Stuff' object to find. + db_session (AsyncSession): The database session. Used to compile the SQL statement. + + Returns: + dict: The found 'Stuff' object as a dictionary. + + Raises: + HTTPException: If the 'Stuff' object is not found or an SQLAlchemyError occurs. + """ try: - if not pool: - result = await Stuff.find(db_session, name) - else: - # execute the compiled SQL statement - stmt = await Stuff.find(db_session, name, compile_sql=True) - result = await request.app.postgres_pool.fetchrow(str(stmt)) - result = dict(result) + stmt = await Stuff.find(db_session, name, compile_sql=True) + result = await request.app.postgres_pool.fetchrow(str(stmt)) + result = dict(result) except SQLAlchemyError as ex: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=repr(ex) @@ -68,6 +91,7 @@ async def find_stuff( return result + @router.delete("/{name}") async def delete_stuff(name: str, db_session: AsyncSession = Depends(get_db)): stuff = await Stuff.find(db_session, name) From c8efc0596c5a4309c7c4358258ec0f6c94edcf9c Mon Sep 17 00:00:00 2001 From: Jakub Miazek Date: Sat, 11 May 2024 16:29:00 +0200 Subject: [PATCH 3/3] - add decorator to compile sql for pool or return scalar from statement --- app/api/stuff.py | 6 ++++-- app/models/stuff.py | 22 +++++++++++++++++----- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/app/api/stuff.py b/app/api/stuff.py index 5cab5c3..bfb6c8c 100644 --- a/app/api/stuff.py +++ b/app/api/stuff.py @@ -45,7 +45,10 @@ async def create_stuff( async def find_stuff(name: str, db_session: AsyncSession = Depends(get_db)): result = await Stuff.find(db_session, name) if not result: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Stuff with name {name} not found.") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Stuff with name {name} not found.", + ) return result @@ -91,7 +94,6 @@ async def find_stuff_pool( return result - @router.delete("/{name}") async def delete_stuff(name: str, db_session: AsyncSession = Depends(get_db)): stuff = await Stuff.find(db_session, name) diff --git a/app/models/stuff.py b/app/models/stuff.py index 45a0851..fdd2dc7 100644 --- a/app/models/stuff.py +++ b/app/models/stuff.py @@ -9,6 +9,20 @@ from sqlalchemy.orm import mapped_column, Mapped, relationship, joinedload from app.models.base import Base from app.models.nonsense import Nonsense +from functools import wraps + + +def compile_sql_or_scalar(func): + @wraps(func) + async def wrapper(cls, db_session, name, compile_sql=False, *args, **kwargs): + stmt = await func(cls, db_session, name, *args, **kwargs) + if compile_sql: + return stmt.compile(compile_kwargs={"literal_binds": True}) + result = await db_session.execute(stmt) + return result.scalars().first() + + return wrapper + class Stuff(Base): __tablename__ = "stuff" @@ -24,12 +38,10 @@ class Stuff(Base): ) @classmethod - async def find(cls, db_session: AsyncSession, name: str, compile_sql: bool = False): + @compile_sql_or_scalar + async def find(cls, db_session: AsyncSession, name: str, compile_sql=False): stmt = select(cls).options(joinedload(cls.nonsense)).where(cls.name == name) - if compile_sql: - return stmt.compile(compile_kwargs={"literal_binds": True}) - result = await db_session.execute(stmt) - return result.scalars().first() + return stmt class StuffFullOfNonsense(Base):