From dbe64008f8ddb756845f3b1fb7d6e88c1c2ce4a2 Mon Sep 17 00:00:00 2001 From: Jakub Miazek Date: Sat, 12 Nov 2022 18:41:39 +0100 Subject: [PATCH] format --- app/api/shakespeare.py | 6 ++++-- app/config.py | 2 +- app/database.py | 29 ++++++++++++++++++++++++++--- app/main.py | 2 +- app/models/__init__.py | 4 ++-- app/models/base.py | 8 ++++---- app/models/nonsense.py | 4 +--- app/models/stuff.py | 4 +--- app/schemas/shakespeare.py | 2 +- app/utils.py | 2 +- 10 files changed, 42 insertions(+), 21 deletions(-) diff --git a/app/api/shakespeare.py b/app/api/shakespeare.py index 95528a6..ff814c7 100644 --- a/app/api/shakespeare.py +++ b/app/api/shakespeare.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db @@ -7,7 +7,9 @@ from app.models.shakespeare import Paragraph router = APIRouter(prefix="/v1/shakespeare") -@router.get("/",) +@router.get( + "/", +) async def find_paragraph( character: str, db_session: AsyncSession = Depends(get_db), diff --git a/app/config.py b/app/config.py index c299105..df210f5 100644 --- a/app/config.py +++ b/app/config.py @@ -38,7 +38,7 @@ class Settings(BaseSettings): jwt_access_toke_expire_minutes: int = os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 1) -@lru_cache() +@lru_cache def get_settings(): logger.info("Loading config settings from the environment...") return Settings() diff --git a/app/database.py b/app/database.py index 33bff22..609a41a 100644 --- a/app/database.py +++ b/app/database.py @@ -1,10 +1,15 @@ -from typing import AsyncGenerator +from collections.abc import AsyncGenerator +from http.client import HTTPException from fastapi.encoders import jsonable_encoder +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from app import config +from app.utils import get_logger + +logger = get_logger(__name__) global_settings = config.get_settings() url = global_settings.asyncpg_url @@ -18,10 +23,28 @@ engine = create_async_engine( # expire_on_commit=False will prevent attributes from being expired # after commit. -async_session_factory = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) +AsyncSessionFactory = sessionmaker(engine, autoflush=False, expire_on_commit=False, class_=AsyncSession) # Dependency async def get_db() -> AsyncGenerator: - async with async_session_factory() as session: + async with AsyncSessionFactory() as session: + logger.debug(f"ASYNC Pool: {engine.pool.status()}") yield session + + +async def get_async_db() -> AsyncGenerator: + try: + session: AsyncSession = AsyncSessionFactory() + logger.debug(f"ASYNC Pool: {engine.pool.status()}") + yield session + except SQLAlchemyError as sql_ex: + await session.rollback() + raise sql_ex + except HTTPException as http_ex: + await session.rollback() + raise http_ex + else: + await session.commit() + finally: + await session.close() diff --git a/app/main.py b/app/main.py index e40a153..0a5aa26 100644 --- a/app/main.py +++ b/app/main.py @@ -1,8 +1,8 @@ from fastapi import FastAPI from app.api.nonsense import router as nonsense_router -from app.api.stuff import router as stuff_router from app.api.shakespeare import router as shakespeare_router +from app.api.stuff import router as stuff_router from app.utils import get_logger logger = get_logger(__name__) diff --git a/app/models/__init__.py b/app/models/__init__.py index d61f6ad..5315b1d 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,4 +1,4 @@ # for Alembic and unit tests -from app.models.stuff import * # noqa from app.models.nonsense import * # noqa -from app.models.shakespeare import * # noqa \ No newline at end of file +from app.models.shakespeare import * # noqa +from app.models.stuff import * # noqa diff --git a/app/models/base.py b/app/models/base.py index 547f841..d7a665d 100644 --- a/app/models/base.py +++ b/app/models/base.py @@ -13,8 +13,8 @@ class BaseReadOnly: # Generate __tablename__ automatically @declared_attr - def __tablename__(cls) -> str: - return cls.__name__.lower() + def __tablename__(self) -> str: + return self.__name__.lower() @as_declarative() @@ -24,8 +24,8 @@ class Base: # Generate __tablename__ automatically @declared_attr - def __tablename__(cls) -> str: - return cls.__name__.lower() + def __tablename__(self) -> str: + return self.__name__.lower() async def save(self, db_session: AsyncSession): """ diff --git a/app/models/nonsense.py b/app/models/nonsense.py index 9a71a3b..5f4477d 100644 --- a/app/models/nonsense.py +++ b/app/models/nonsense.py @@ -10,9 +10,7 @@ from app.models.base import Base class Nonsense(Base): __tablename__ = "nonsense" - __table_args__ = ( - {"schema": "happy_hog"}, - ) + __table_args__ = ({"schema": "happy_hog"},) id = Column(UUID(as_uuid=True), unique=True, default=uuid.uuid4, autoincrement=True) name = Column(String, nullable=False, primary_key=True, unique=True) description = Column(String, nullable=False) diff --git a/app/models/stuff.py b/app/models/stuff.py index fc7944a..37dffad 100644 --- a/app/models/stuff.py +++ b/app/models/stuff.py @@ -10,9 +10,7 @@ from app.models.base import Base class Stuff(Base): __tablename__ = "stuff" - __table_args__ = ( - {"schema": "happy_hog"}, - ) + __table_args__ = ({"schema": "happy_hog"},) id = Column(UUID(as_uuid=True), unique=True, default=uuid.uuid4, autoincrement=True) name = Column(String, nullable=False, primary_key=True, unique=True) description = Column(String, nullable=False) diff --git a/app/schemas/shakespeare.py b/app/schemas/shakespeare.py index ecb1df0..97e7ec1 100644 --- a/app/schemas/shakespeare.py +++ b/app/schemas/shakespeare.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any from pydantic import BaseModel diff --git a/app/utils.py b/app/utils.py index d2f9f52..6c1ba75 100644 --- a/app/utils.py +++ b/app/utils.py @@ -7,7 +7,7 @@ from rich.logging import RichHandler console = Console(color_system="256", width=200, style="blue") -@lru_cache() +@lru_cache def get_logger(module_name): logger = logging.getLogger(module_name) handler = RichHandler(rich_tracebacks=True, console=console, tracebacks_show_locals=True)