add auth context (#62)

* add user table and superuser creation

* add gpt-4-stream-aivvm provider

* rename user migration to auth migration
This commit is contained in:
Dmitry Afanasyev
2023-11-28 23:06:26 +03:00
committed by GitHub
parent c80b001740
commit 2359481fb7
17 changed files with 668 additions and 280 deletions

View File

@@ -1,7 +1,10 @@
from fastapi import Depends
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.requests import Request
from telegram import Update
from core.auth.models.users import User
from core.bot.app import BotApplication, BotQueue
from core.bot.repository import ChatGPTRepository
from core.bot.services import ChatGptService
@@ -21,6 +24,10 @@ def get_bot_queue(request: Request) -> BotQueue:
return request.app.state.queue
def get_db_session(request: Request) -> AsyncSession:
return request.app.state.db_session_factory()
async def get_update_from_request(request: Request, bot_app: BotApplication = Depends(get_bot_app)) -> Update | None:
data = await request.json()
return Update.de_json(data, bot_app.bot)
@@ -44,3 +51,9 @@ def get_chatgpt_service(
chatgpt_repository: ChatGPTRepository = Depends(get_chatgpt_repository),
) -> ChatGptService:
return ChatGptService(repository=chatgpt_repository)
async def get_user_db( # type: ignore[misc]
session: AsyncSession = Depends(get_db_session),
) -> SQLAlchemyUserDatabase: # type: ignore[type-arg]
yield SQLAlchemyUserDatabase(session, User)

View File

@@ -56,6 +56,7 @@ class ChatGptModelsEnum(StrEnum):
gpt_3_5_turbo_ChatgptNext = "gpt-3.5-turbo-ChatgptNext"
gpt_3_5_turbo_stream_gptTalkRu = "gpt-3.5-turbo--stream-gptTalkRu"
Llama_2_70b_chat_hf_stream_DeepInfra = "Llama-2-70b-chat-hf-stream-DeepInfra"
gpt_4_stream_aivvm = "gpt-4-stream-aivvm"
llama2 = "llama2"
gpt_3_5_turbo_stream_Berlin = "gpt-3.5-turbo-stream-Berlin"
gpt_4_ChatGpt4Online = "gpt-4-ChatGpt4Online"

View File

View File

@@ -0,0 +1,22 @@
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTable
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTable
from sqlalchemy import INTEGER, VARCHAR, ForeignKey
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
from infra.database.base import Base
class User(SQLAlchemyBaseUserTable[Mapped[int]], Base):
__tablename__ = "users" # type: ignore[assignment]
id: Mapped[int] = mapped_column(INTEGER, primary_key=True)
email: Mapped[str] = mapped_column(VARCHAR(length=320), unique=True, nullable=True) # type: ignore[assignment]
username: Mapped[str] = mapped_column(VARCHAR(length=32), unique=True, index=True, nullable=False)
class AccessToken(SQLAlchemyBaseAccessTokenTable[Mapped[int]], Base):
__tablename__ = "access_token" # type: ignore[assignment]
@declared_attr
def user_id(cls) -> Mapped[int]:
return mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)

View File

@@ -1,18 +1,11 @@
from asyncio import current_task
from typing import Awaitable, Callable
from fastapi import FastAPI
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_scoped_session,
async_sessionmaker,
create_async_engine,
)
from settings.config import AppSettings
from infra.database.db_adapter import Database
def startup(app: FastAPI, settings: AppSettings) -> Callable[[], Awaitable[None]]:
def startup(app: FastAPI, database: Database) -> Callable[[], Awaitable[None]]:
"""
Actions to run on application startup.
@@ -20,13 +13,13 @@ def startup(app: FastAPI, settings: AppSettings) -> Callable[[], Awaitable[None]
such as db_engine.
:param app: the fastAPI application.
:param settings: app settings
:param database: app settings
:return: function that actually performs actions.
"""
async def _startup() -> None:
_setup_db(app, settings)
_setup_db(app, database)
return _startup
@@ -46,7 +39,7 @@ def shutdown(app: FastAPI) -> Callable[[], Awaitable[None]]:
return _shutdown
def _setup_db(app: FastAPI, settings: AppSettings) -> None:
def _setup_db(app: FastAPI, database: Database) -> None:
"""
Create connection to the database.
@@ -56,18 +49,5 @@ def _setup_db(app: FastAPI, settings: AppSettings) -> None:
:param app: fastAPI application.
"""
engine = create_async_engine(
str(settings.async_db_url),
echo=settings.DB_ECHO,
execution_options={"isolation_level": "AUTOCOMMIT"},
)
session_factory = async_scoped_session(
async_sessionmaker(
engine,
expire_on_commit=False,
class_=AsyncSession,
),
scopefunc=current_task,
)
app.state.db_engine = engine
app.state.db_session_factory = session_factory
app.state.db_engine = database.async_engine
app.state.db_session_factory = database._async_session_factory

View File

@@ -108,7 +108,7 @@ class Database:
def load_all_models() -> None:
"""Load all models from this folder."""
package_dir = Path(__file__).resolve().parent.parent
package_dir = Path(__file__).resolve().parent.parent.parent
package_dir = package_dir.joinpath("core")
modules = pkgutil.walk_packages(path=[str(package_dir)], prefix="core.")
models_packages = [module for module in modules if module.ispkg and "models" in module.name]

View File

@@ -1,20 +1,33 @@
from typing import AsyncGenerator
from contextlib import asynccontextmanager, contextmanager
from typing import AsyncGenerator, Generator
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.requests import Request
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Session, sessionmaker
from settings.config import settings
async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession, None]:
"""
Create and get database session.
:param request: current request.
:yield: database session.
"""
session: AsyncSession = request.app.state.db_session_factory()
@contextmanager
def get_sync_session() -> Generator[Session, None, None]:
engine = create_engine(str(settings.sync_db_url), echo=settings.DB_ECHO)
session_factory = sessionmaker(engine)
try:
yield session
yield session_factory()
finally:
await session.commit()
await session.close()
engine.dispose()
@asynccontextmanager
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async_engine = create_async_engine(
str(settings.async_db_url),
echo=settings.DB_ECHO,
execution_options={"isolation_level": "AUTOCOMMIT"},
)
async_session_maker = async_sessionmaker(bind=async_engine, expire_on_commit=False)
try:
async with async_session_maker() as session:
yield session
finally:
await async_engine.dispose()

View File

@@ -6,12 +6,11 @@ Create Date: 2025-10-05 20:44:05.414977
"""
from loguru import logger
from sqlalchemy import create_engine, select, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy import select, text
from constants import ChatGptModelsEnum
from core.bot.models.chat_gpt import ChatGpt
from settings.config import settings
from infra.database.deps import get_sync_session
# revision identifiers, used by Alembic.
revision = "0002_create_chatgpt_models"
@@ -19,12 +18,9 @@ down_revision = "0001_create_chatgpt_table"
branch_labels: str | None = None
depends_on: str | None = None
engine = create_engine(str(settings.async_db_url), echo=settings.DB_ECHO)
session_factory = sessionmaker(engine)
def upgrade() -> None:
with session_factory() as session:
with get_sync_session() as session:
query = select(ChatGpt)
results = session.execute(query)
models = results.scalars().all()
@@ -40,11 +36,8 @@ def upgrade() -> None:
def downgrade() -> None:
chatgpt_table_name = ChatGpt.__tablename__
with session_factory() as session:
with get_sync_session() as session:
# Truncate doesn't exists for SQLite
session.execute(text(f"""DELETE FROM {chatgpt_table_name}""")) # noqa: S608
session.commit()
logger.info("chatgpt models table has been truncated", table=chatgpt_table_name)
engine.dispose()

View File

@@ -0,0 +1,68 @@
"""create_auth_tables
Revision ID: 0003_create_users_table
Revises: 0002_create_chatgpt_models
Create Date: 2023-11-28 00:58:01.984654
"""
import hashlib
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.sqlite import insert
from core.auth.models.users import User
from infra.database.deps import get_sync_session
from settings.config import settings
# revision identifiers, used by Alembic.
revision = "0003_create_auth_tables"
down_revision = "0002_create_chatgpt_models"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"users",
sa.Column("id", sa.INTEGER(), nullable=False),
sa.Column("email", sa.VARCHAR(length=320), nullable=True),
sa.Column("username", sa.VARCHAR(length=32), nullable=False),
sa.Column("hashed_password", sa.String(length=1024), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.Column("is_verified", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("email"),
)
op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
op.create_table(
"access_token",
sa.Column("user_id", sa.INTEGER(), nullable=False),
sa.Column("token", sa.String(length=43), nullable=False),
sa.Column("created_at", fastapi_users_db_sqlalchemy.generics.TIMESTAMPAware(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="cascade"),
sa.PrimaryKeyConstraint("token"),
)
op.create_index(op.f("ix_access_token_created_at"), "access_token", ["created_at"], unique=False)
# ### end Alembic commands ###
username, password, salt = settings.SUPERUSER, settings.SUPERUSER_PASSWORD, settings.SALT
if not all([username, password, salt]):
return
with get_sync_session() as session:
hashed_password = hashlib.sha256((password.get_secret_value() + salt.get_secret_value()).encode()).hexdigest()
query = insert(User).values({"username": username, "hashed_password": hashed_password})
session.execute(query)
session.commit()
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_access_token_created_at"), table_name="access_token")
op.drop_table("access_token")
op.drop_index(op.f("ix_users_username"), table_name="users")
op.drop_table("users")
# ### end Alembic commands ###

View File

@@ -35,7 +35,7 @@ class Application:
self.app.state.queue = self._bot_queue
self.app.state.bot_app = self.bot_app
self.app.on_event("startup")(startup(self.app, settings))
self.app.on_event("startup")(startup(self.app, self.db))
self.app.on_event("shutdown")(shutdown(self.app))
self.app.include_router(api_router)

View File

@@ -9,6 +9,10 @@ WORKERS_COUNT=1
RELOAD="true"
DEBUG="true"
SUPERUSER="user"
SUPERUSER_PASSWORD="hackme"
SALT="change me"
# ==== sentry ====
ENABLE_SENTRY="false"
SENTRY_DSN=

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Any
from dotenv import load_dotenv
from pydantic import model_validator
from pydantic import SecretStr, model_validator
from pydantic_settings import BaseSettings
from yarl import URL
@@ -71,6 +71,10 @@ class AppSettings(SentrySettings, LoggingSettings, BaseSettings):
# Enable uvicorn reloading
RELOAD: bool = False
SUPERUSER: str | None = None
SUPERUSER_PASSWORD: SecretStr | None = None
SALT: SecretStr | None = None
# telegram settings
TELEGRAM_API_TOKEN: str = "123456789:AABBCCDDEEFFaabbccddeeff-1234567890"
START_WITH_WEBHOOK: bool = False