mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2025-12-16 21:20:39 +03:00
add auth context (#62)
* add user table and superuser creation * add gpt-4-stream-aivvm provider * rename user migration to auth migration
This commit is contained in:
@@ -108,7 +108,7 @@ class Database:
|
||||
|
||||
def load_all_models() -> None:
|
||||
"""Load all models from this folder."""
|
||||
package_dir = Path(__file__).resolve().parent.parent
|
||||
package_dir = Path(__file__).resolve().parent.parent.parent
|
||||
package_dir = package_dir.joinpath("core")
|
||||
modules = pkgutil.walk_packages(path=[str(package_dir)], prefix="core.")
|
||||
models_packages = [module for module in modules if module.ispkg and "models" in module.name]
|
||||
|
||||
@@ -1,20 +1,33 @@
|
||||
from typing import AsyncGenerator
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import AsyncGenerator, Generator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.requests import Request
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Create and get database session.
|
||||
|
||||
:param request: current request.
|
||||
:yield: database session.
|
||||
"""
|
||||
session: AsyncSession = request.app.state.db_session_factory()
|
||||
|
||||
@contextmanager
|
||||
def get_sync_session() -> Generator[Session, None, None]:
|
||||
engine = create_engine(str(settings.sync_db_url), echo=settings.DB_ECHO)
|
||||
session_factory = sessionmaker(engine)
|
||||
try:
|
||||
yield session
|
||||
yield session_factory()
|
||||
finally:
|
||||
await session.commit()
|
||||
await session.close()
|
||||
engine.dispose()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async_engine = create_async_engine(
|
||||
str(settings.async_db_url),
|
||||
echo=settings.DB_ECHO,
|
||||
execution_options={"isolation_level": "AUTOCOMMIT"},
|
||||
)
|
||||
async_session_maker = async_sessionmaker(bind=async_engine, expire_on_commit=False)
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
finally:
|
||||
await async_engine.dispose()
|
||||
|
||||
@@ -6,12 +6,11 @@ Create Date: 2025-10-05 20:44:05.414977
|
||||
|
||||
"""
|
||||
from loguru import logger
|
||||
from sqlalchemy import create_engine, select, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from constants import ChatGptModelsEnum
|
||||
from core.bot.models.chat_gpt import ChatGpt
|
||||
from settings.config import settings
|
||||
from infra.database.deps import get_sync_session
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0002_create_chatgpt_models"
|
||||
@@ -19,12 +18,9 @@ down_revision = "0001_create_chatgpt_table"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | None = None
|
||||
|
||||
engine = create_engine(str(settings.async_db_url), echo=settings.DB_ECHO)
|
||||
session_factory = sessionmaker(engine)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with session_factory() as session:
|
||||
with get_sync_session() as session:
|
||||
query = select(ChatGpt)
|
||||
results = session.execute(query)
|
||||
models = results.scalars().all()
|
||||
@@ -40,11 +36,8 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
chatgpt_table_name = ChatGpt.__tablename__
|
||||
with session_factory() as session:
|
||||
with get_sync_session() as session:
|
||||
# Truncate doesn't exists for SQLite
|
||||
session.execute(text(f"""DELETE FROM {chatgpt_table_name}""")) # noqa: S608
|
||||
session.commit()
|
||||
logger.info("chatgpt models table has been truncated", table=chatgpt_table_name)
|
||||
|
||||
|
||||
engine.dispose()
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""create_auth_tables
|
||||
|
||||
Revision ID: 0003_create_users_table
|
||||
Revises: 0002_create_chatgpt_models
|
||||
Create Date: 2023-11-28 00:58:01.984654
|
||||
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
|
||||
from core.auth.models.users import User
|
||||
from infra.database.deps import get_sync_session
|
||||
from settings.config import settings
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0003_create_auth_tables"
|
||||
down_revision = "0002_create_chatgpt_models"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("id", sa.INTEGER(), nullable=False),
|
||||
sa.Column("email", sa.VARCHAR(length=320), nullable=True),
|
||||
sa.Column("username", sa.VARCHAR(length=32), nullable=False),
|
||||
sa.Column("hashed_password", sa.String(length=1024), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_verified", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("email"),
|
||||
)
|
||||
op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
|
||||
op.create_table(
|
||||
"access_token",
|
||||
sa.Column("user_id", sa.INTEGER(), nullable=False),
|
||||
sa.Column("token", sa.String(length=43), nullable=False),
|
||||
sa.Column("created_at", fastapi_users_db_sqlalchemy.generics.TIMESTAMPAware(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="cascade"),
|
||||
sa.PrimaryKeyConstraint("token"),
|
||||
)
|
||||
op.create_index(op.f("ix_access_token_created_at"), "access_token", ["created_at"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
username, password, salt = settings.SUPERUSER, settings.SUPERUSER_PASSWORD, settings.SALT
|
||||
if not all([username, password, salt]):
|
||||
return
|
||||
with get_sync_session() as session:
|
||||
hashed_password = hashlib.sha256((password.get_secret_value() + salt.get_secret_value()).encode()).hexdigest()
|
||||
query = insert(User).values({"username": username, "hashed_password": hashed_password})
|
||||
session.execute(query)
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_access_token_created_at"), table_name="access_token")
|
||||
op.drop_table("access_token")
|
||||
op.drop_index(op.f("ix_users_username"), table_name="users")
|
||||
op.drop_table("users")
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user