mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2025-12-15 16:10:39 +03:00
add auth context (#62)
* add user table and superuser creation * add gpt-4-stream-aivvm provider * rename user migration to auth migration
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
from fastapi import Depends
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.requests import Request
|
||||
from telegram import Update
|
||||
|
||||
from core.auth.models.users import User
|
||||
from core.bot.app import BotApplication, BotQueue
|
||||
from core.bot.repository import ChatGPTRepository
|
||||
from core.bot.services import ChatGptService
|
||||
@@ -21,6 +24,10 @@ def get_bot_queue(request: Request) -> BotQueue:
|
||||
return request.app.state.queue
|
||||
|
||||
|
||||
def get_db_session(request: Request) -> AsyncSession:
|
||||
return request.app.state.db_session_factory()
|
||||
|
||||
|
||||
async def get_update_from_request(request: Request, bot_app: BotApplication = Depends(get_bot_app)) -> Update | None:
|
||||
data = await request.json()
|
||||
return Update.de_json(data, bot_app.bot)
|
||||
@@ -44,3 +51,9 @@ def get_chatgpt_service(
|
||||
chatgpt_repository: ChatGPTRepository = Depends(get_chatgpt_repository),
|
||||
) -> ChatGptService:
|
||||
return ChatGptService(repository=chatgpt_repository)
|
||||
|
||||
|
||||
async def get_user_db( # type: ignore[misc]
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
) -> SQLAlchemyUserDatabase: # type: ignore[type-arg]
|
||||
yield SQLAlchemyUserDatabase(session, User)
|
||||
|
||||
@@ -56,6 +56,7 @@ class ChatGptModelsEnum(StrEnum):
|
||||
gpt_3_5_turbo_ChatgptNext = "gpt-3.5-turbo-ChatgptNext"
|
||||
gpt_3_5_turbo_stream_gptTalkRu = "gpt-3.5-turbo--stream-gptTalkRu"
|
||||
Llama_2_70b_chat_hf_stream_DeepInfra = "Llama-2-70b-chat-hf-stream-DeepInfra"
|
||||
gpt_4_stream_aivvm = "gpt-4-stream-aivvm"
|
||||
llama2 = "llama2"
|
||||
gpt_3_5_turbo_stream_Berlin = "gpt-3.5-turbo-stream-Berlin"
|
||||
gpt_4_ChatGpt4Online = "gpt-4-ChatGpt4Online"
|
||||
|
||||
0
bot_microservice/core/auth/__init__.py
Normal file
0
bot_microservice/core/auth/__init__.py
Normal file
0
bot_microservice/core/auth/models/__init__.py
Normal file
0
bot_microservice/core/auth/models/__init__.py
Normal file
22
bot_microservice/core/auth/models/users.py
Normal file
22
bot_microservice/core/auth/models/users.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTable
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTable
|
||||
from sqlalchemy import INTEGER, VARCHAR, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
|
||||
from infra.database.base import Base
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTable[Mapped[int]], Base):
|
||||
__tablename__ = "users" # type: ignore[assignment]
|
||||
|
||||
id: Mapped[int] = mapped_column(INTEGER, primary_key=True)
|
||||
email: Mapped[str] = mapped_column(VARCHAR(length=320), unique=True, nullable=True) # type: ignore[assignment]
|
||||
username: Mapped[str] = mapped_column(VARCHAR(length=32), unique=True, index=True, nullable=False)
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTable[Mapped[int]], Base):
|
||||
__tablename__ = "access_token" # type: ignore[assignment]
|
||||
|
||||
@declared_attr
|
||||
def user_id(cls) -> Mapped[int]:
|
||||
return mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
|
||||
@@ -1,18 +1,11 @@
|
||||
from asyncio import current_task
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_scoped_session,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from settings.config import AppSettings
|
||||
from infra.database.db_adapter import Database
|
||||
|
||||
|
||||
def startup(app: FastAPI, settings: AppSettings) -> Callable[[], Awaitable[None]]:
|
||||
def startup(app: FastAPI, database: Database) -> Callable[[], Awaitable[None]]:
|
||||
"""
|
||||
Actions to run on application startup.
|
||||
|
||||
@@ -20,13 +13,13 @@ def startup(app: FastAPI, settings: AppSettings) -> Callable[[], Awaitable[None]
|
||||
such as db_engine.
|
||||
|
||||
:param app: the fastAPI application.
|
||||
:param settings: app settings
|
||||
:param database: app settings
|
||||
:return: function that actually performs actions.
|
||||
|
||||
"""
|
||||
|
||||
async def _startup() -> None:
|
||||
_setup_db(app, settings)
|
||||
_setup_db(app, database)
|
||||
|
||||
return _startup
|
||||
|
||||
@@ -46,7 +39,7 @@ def shutdown(app: FastAPI) -> Callable[[], Awaitable[None]]:
|
||||
return _shutdown
|
||||
|
||||
|
||||
def _setup_db(app: FastAPI, settings: AppSettings) -> None:
|
||||
def _setup_db(app: FastAPI, database: Database) -> None:
|
||||
"""
|
||||
Create connection to the database.
|
||||
|
||||
@@ -56,18 +49,5 @@ def _setup_db(app: FastAPI, settings: AppSettings) -> None:
|
||||
|
||||
:param app: fastAPI application.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
str(settings.async_db_url),
|
||||
echo=settings.DB_ECHO,
|
||||
execution_options={"isolation_level": "AUTOCOMMIT"},
|
||||
)
|
||||
session_factory = async_scoped_session(
|
||||
async_sessionmaker(
|
||||
engine,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession,
|
||||
),
|
||||
scopefunc=current_task,
|
||||
)
|
||||
app.state.db_engine = engine
|
||||
app.state.db_session_factory = session_factory
|
||||
app.state.db_engine = database.async_engine
|
||||
app.state.db_session_factory = database._async_session_factory
|
||||
|
||||
@@ -108,7 +108,7 @@ class Database:
|
||||
|
||||
def load_all_models() -> None:
|
||||
"""Load all models from this folder."""
|
||||
package_dir = Path(__file__).resolve().parent.parent
|
||||
package_dir = Path(__file__).resolve().parent.parent.parent
|
||||
package_dir = package_dir.joinpath("core")
|
||||
modules = pkgutil.walk_packages(path=[str(package_dir)], prefix="core.")
|
||||
models_packages = [module for module in modules if module.ispkg and "models" in module.name]
|
||||
|
||||
@@ -1,20 +1,33 @@
|
||||
from typing import AsyncGenerator
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import AsyncGenerator, Generator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.requests import Request
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Create and get database session.
|
||||
|
||||
:param request: current request.
|
||||
:yield: database session.
|
||||
"""
|
||||
session: AsyncSession = request.app.state.db_session_factory()
|
||||
|
||||
@contextmanager
|
||||
def get_sync_session() -> Generator[Session, None, None]:
|
||||
engine = create_engine(str(settings.sync_db_url), echo=settings.DB_ECHO)
|
||||
session_factory = sessionmaker(engine)
|
||||
try:
|
||||
yield session
|
||||
yield session_factory()
|
||||
finally:
|
||||
await session.commit()
|
||||
await session.close()
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async_engine = create_async_engine(
|
||||
str(settings.async_db_url),
|
||||
echo=settings.DB_ECHO,
|
||||
execution_options={"isolation_level": "AUTOCOMMIT"},
|
||||
)
|
||||
async_session_maker = async_sessionmaker(bind=async_engine, expire_on_commit=False)
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
finally:
|
||||
await async_engine.dispose()
|
||||
|
||||
@@ -6,12 +6,11 @@ Create Date: 2025-10-05 20:44:05.414977
|
||||
|
||||
"""
|
||||
from loguru import logger
|
||||
from sqlalchemy import create_engine, select, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from constants import ChatGptModelsEnum
|
||||
from core.bot.models.chat_gpt import ChatGpt
|
||||
from settings.config import settings
|
||||
from infra.database.deps import get_sync_session
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0002_create_chatgpt_models"
|
||||
@@ -19,12 +18,9 @@ down_revision = "0001_create_chatgpt_table"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | None = None
|
||||
|
||||
engine = create_engine(str(settings.async_db_url), echo=settings.DB_ECHO)
|
||||
session_factory = sessionmaker(engine)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with session_factory() as session:
|
||||
with get_sync_session() as session:
|
||||
query = select(ChatGpt)
|
||||
results = session.execute(query)
|
||||
models = results.scalars().all()
|
||||
@@ -40,11 +36,8 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
chatgpt_table_name = ChatGpt.__tablename__
|
||||
with session_factory() as session:
|
||||
with get_sync_session() as session:
|
||||
# Truncate doesn't exists for SQLite
|
||||
session.execute(text(f"""DELETE FROM {chatgpt_table_name}""")) # noqa: S608
|
||||
session.commit()
|
||||
logger.info("chatgpt models table has been truncated", table=chatgpt_table_name)
|
||||
|
||||
|
||||
engine.dispose()
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""create_auth_tables
|
||||
|
||||
Revision ID: 0003_create_users_table
|
||||
Revises: 0002_create_chatgpt_models
|
||||
Create Date: 2023-11-28 00:58:01.984654
|
||||
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
|
||||
from core.auth.models.users import User
|
||||
from infra.database.deps import get_sync_session
|
||||
from settings.config import settings
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0003_create_auth_tables"
|
||||
down_revision = "0002_create_chatgpt_models"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("id", sa.INTEGER(), nullable=False),
|
||||
sa.Column("email", sa.VARCHAR(length=320), nullable=True),
|
||||
sa.Column("username", sa.VARCHAR(length=32), nullable=False),
|
||||
sa.Column("hashed_password", sa.String(length=1024), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_verified", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("email"),
|
||||
)
|
||||
op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
|
||||
op.create_table(
|
||||
"access_token",
|
||||
sa.Column("user_id", sa.INTEGER(), nullable=False),
|
||||
sa.Column("token", sa.String(length=43), nullable=False),
|
||||
sa.Column("created_at", fastapi_users_db_sqlalchemy.generics.TIMESTAMPAware(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="cascade"),
|
||||
sa.PrimaryKeyConstraint("token"),
|
||||
)
|
||||
op.create_index(op.f("ix_access_token_created_at"), "access_token", ["created_at"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
username, password, salt = settings.SUPERUSER, settings.SUPERUSER_PASSWORD, settings.SALT
|
||||
if not all([username, password, salt]):
|
||||
return
|
||||
with get_sync_session() as session:
|
||||
hashed_password = hashlib.sha256((password.get_secret_value() + salt.get_secret_value()).encode()).hexdigest()
|
||||
query = insert(User).values({"username": username, "hashed_password": hashed_password})
|
||||
session.execute(query)
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_access_token_created_at"), table_name="access_token")
|
||||
op.drop_table("access_token")
|
||||
op.drop_index(op.f("ix_users_username"), table_name="users")
|
||||
op.drop_table("users")
|
||||
# ### end Alembic commands ###
|
||||
@@ -35,7 +35,7 @@ class Application:
|
||||
self.app.state.queue = self._bot_queue
|
||||
self.app.state.bot_app = self.bot_app
|
||||
|
||||
self.app.on_event("startup")(startup(self.app, settings))
|
||||
self.app.on_event("startup")(startup(self.app, self.db))
|
||||
self.app.on_event("shutdown")(shutdown(self.app))
|
||||
|
||||
self.app.include_router(api_router)
|
||||
|
||||
@@ -9,6 +9,10 @@ WORKERS_COUNT=1
|
||||
RELOAD="true"
|
||||
DEBUG="true"
|
||||
|
||||
SUPERUSER="user"
|
||||
SUPERUSER_PASSWORD="hackme"
|
||||
SALT="change me"
|
||||
|
||||
# ==== sentry ====
|
||||
ENABLE_SENTRY="false"
|
||||
SENTRY_DSN=
|
||||
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import model_validator
|
||||
from pydantic import SecretStr, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
from yarl import URL
|
||||
|
||||
@@ -71,6 +71,10 @@ class AppSettings(SentrySettings, LoggingSettings, BaseSettings):
|
||||
# Enable uvicorn reloading
|
||||
RELOAD: bool = False
|
||||
|
||||
SUPERUSER: str | None = None
|
||||
SUPERUSER_PASSWORD: SecretStr | None = None
|
||||
SALT: SecretStr | None = None
|
||||
|
||||
# telegram settings
|
||||
TELEGRAM_API_TOKEN: str = "123456789:AABBCCDDEEFFaabbccddeeff-1234567890"
|
||||
START_WITH_WEBHOOK: bool = False
|
||||
|
||||
Reference in New Issue
Block a user