mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2025-12-16 21:20:39 +03:00
add user messages count action (#76)
* remove fastapi users dependency * add user service to chatbot service * add user save on bot info command * add user model to admin * fix tests
This commit is contained in:
7
bot_microservice/core/auth/dto.py
Normal file
7
bot_microservice/core/auth/dto.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserIsBannedDTO:
|
||||
is_banned: bool = False
|
||||
ban_reason: str | None = None
|
||||
@@ -1,25 +1,76 @@
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTable
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTable
|
||||
from sqlalchemy import INTEGER, VARCHAR, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from infra.database.base import Base
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTable[Mapped[int]], Base):
|
||||
class User(Base):
|
||||
__tablename__ = "users" # type: ignore[assignment]
|
||||
|
||||
id: Mapped[int] = mapped_column(INTEGER, primary_key=True)
|
||||
email: Mapped[str] = mapped_column(VARCHAR(length=320), unique=True, nullable=True) # type: ignore[assignment]
|
||||
email: Mapped[str] = mapped_column(VARCHAR(length=255), unique=True, nullable=True)
|
||||
username: Mapped[str] = mapped_column(VARCHAR(length=32), unique=True, index=True, nullable=False)
|
||||
first_name: Mapped[str | None] = mapped_column(VARCHAR(length=32), nullable=True)
|
||||
last_name: Mapped[str | None] = mapped_column(VARCHAR(length=32), nullable=True)
|
||||
ban_reason: Mapped[str | None] = mapped_column(String(length=1024), nullable=True)
|
||||
hashed_password: Mapped[str] = mapped_column(String(length=1024), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
|
||||
)
|
||||
|
||||
user_question_count: Mapped["UserQuestionCount"] = relationship(
|
||||
"UserQuestionCount",
|
||||
primaryjoin="UserQuestionCount.user_id == User.id",
|
||||
backref="user",
|
||||
lazy="selectin",
|
||||
uselist=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def question_count(self) -> int:
|
||||
if self.user_question_count:
|
||||
return self.user_question_count.question_count
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
id: int,
|
||||
email: str | None = None,
|
||||
username: str | None = None,
|
||||
first_name: str | None = None,
|
||||
last_name: str | None = None,
|
||||
ban_reason: str | None = None,
|
||||
hashed_password: str | None = None,
|
||||
is_active: bool = True,
|
||||
is_superuser: bool = False,
|
||||
) -> "User":
|
||||
username = username or str(id)
|
||||
return User( # type: ignore[call-arg]
|
||||
id=id,
|
||||
email=email,
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
ban_reason=ban_reason,
|
||||
hashed_password=hashed_password,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
)
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTable[Mapped[int]], Base):
|
||||
class AccessToken(Base):
|
||||
__tablename__ = "access_token" # type: ignore[assignment]
|
||||
|
||||
@declared_attr
|
||||
def user_id(cls) -> Mapped[int]:
|
||||
return mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
|
||||
user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
|
||||
token: Mapped[str] = mapped_column(String(length=42), primary_key=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
|
||||
)
|
||||
|
||||
|
||||
class UserQuestionCount(Base):
|
||||
|
||||
78
bot_microservice/core/auth/repository.py
Normal file
78
bot_microservice/core/auth/repository.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
from sqlalchemy.orm import load_only
|
||||
|
||||
from core.auth.dto import UserIsBannedDTO
|
||||
from core.auth.models.users import User, UserQuestionCount
|
||||
from infra.database.db_adapter import Database
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserRepository:
|
||||
db: Database
|
||||
|
||||
async def create_user(
|
||||
self,
|
||||
id: int,
|
||||
email: str | None,
|
||||
username: str | None,
|
||||
first_name: str | None,
|
||||
last_name: str | None,
|
||||
ban_reason: str | None,
|
||||
hashed_password: str | None,
|
||||
is_active: bool,
|
||||
is_superuser: bool,
|
||||
) -> User:
|
||||
user = User.build(
|
||||
id=id,
|
||||
email=email,
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
ban_reason=ban_reason,
|
||||
hashed_password=hashed_password,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
)
|
||||
|
||||
async with self.db.session() as session:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
async def get_user_by_id(self, user_id: int) -> User | None:
|
||||
query = select(User).filter_by(id=user_id)
|
||||
|
||||
async with self.db.session() as session:
|
||||
result = await session.execute(query)
|
||||
return result.scalar()
|
||||
|
||||
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
|
||||
query = select(User).options(load_only(User.is_active, User.ban_reason)).filter_by(id=user_id)
|
||||
|
||||
async with self.db.session() as session:
|
||||
result = await session.execute(query)
|
||||
if user := result.scalar():
|
||||
return UserIsBannedDTO(is_banned=not bool(user.is_active), ban_reason=user.ban_reason)
|
||||
return UserIsBannedDTO()
|
||||
|
||||
async def update_user_message_count(self, user_id: int) -> None:
|
||||
query = (
|
||||
insert(UserQuestionCount)
|
||||
.values({UserQuestionCount.user_id: user_id, UserQuestionCount.question_count: 1})
|
||||
.on_conflict_do_update(
|
||||
index_elements=[UserQuestionCount.user_id],
|
||||
set_={
|
||||
UserQuestionCount.get_real_column_name(
|
||||
UserQuestionCount.question_count.key
|
||||
): UserQuestionCount.question_count
|
||||
+ 1
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async with self.db.session() as session:
|
||||
await session.execute(query)
|
||||
56
bot_microservice/core/auth/services.py
Normal file
56
bot_microservice/core/auth/services.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.auth.dto import UserIsBannedDTO
|
||||
from core.auth.models.users import User
|
||||
from core.auth.repository import UserRepository
|
||||
from core.auth.utils import create_password_hash
|
||||
from infra.database.db_adapter import Database
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserService:
|
||||
repository: UserRepository
|
||||
|
||||
@classmethod
|
||||
def build(cls) -> "UserService":
|
||||
db = Database(settings=settings)
|
||||
repository = UserRepository(db=db)
|
||||
return UserService(repository=repository)
|
||||
|
||||
async def get_user_by_id(self, user_id: int) -> User | None:
|
||||
return await self.repository.get_user_by_id(user_id)
|
||||
|
||||
async def get_or_create_user_by_id(
|
||||
self,
|
||||
user_id: int,
|
||||
hashed_password: str | None = None,
|
||||
email: str | None = None,
|
||||
username: str | None = None,
|
||||
first_name: str | None = None,
|
||||
last_name: str | None = None,
|
||||
ban_reason: str | None = None,
|
||||
is_active: bool = True,
|
||||
is_superuser: bool = False,
|
||||
) -> User:
|
||||
hashed_password = hashed_password or create_password_hash(uuid.uuid4().hex)
|
||||
if not (user := await self.repository.get_user_by_id(user_id=user_id)):
|
||||
user = await self.repository.create_user(
|
||||
id=user_id,
|
||||
email=email,
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
ban_reason=ban_reason,
|
||||
hashed_password=hashed_password,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
)
|
||||
return user
|
||||
|
||||
async def update_user_message_count(self, user_id: int) -> None:
|
||||
await self.repository.update_user_message_count(user_id)
|
||||
|
||||
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
|
||||
return await self.repository.check_user_is_banned(user_id)
|
||||
9
bot_microservice/core/auth/utils.py
Normal file
9
bot_microservice/core/auth/utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import hashlib
|
||||
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
def create_password_hash(password: str) -> str:
|
||||
if not settings.SALT:
|
||||
return password
|
||||
return hashlib.sha256((password + settings.SALT.get_secret_value()).encode()).hexdigest()
|
||||
@@ -92,6 +92,11 @@ async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
if not update.effective_user:
|
||||
logger.error('no effective user', update=update, context=context)
|
||||
await update.message.reply_text("Бот не смог определить пользователя. :(\nОб ошибке уже сообщено.")
|
||||
return
|
||||
|
||||
await update.message.reply_text(
|
||||
f"Ответ в среднем занимает 10-15 секунд.\n"
|
||||
f"- Список команд: /{BotCommands.help}\n"
|
||||
@@ -100,8 +105,16 @@ async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No
|
||||
|
||||
chatgpt_service = ChatGptService.build()
|
||||
logger.warning("question asked", user=update.message.from_user, question=update.message.text)
|
||||
answer = await chatgpt_service.request_to_chatgpt(question=update.message.text)
|
||||
await update.message.reply_text(answer)
|
||||
answer, user = await asyncio.gather(
|
||||
chatgpt_service.request_to_chatgpt(question=update.message.text),
|
||||
chatgpt_service.get_or_create_bot_user(
|
||||
user_id=update.effective_user.id,
|
||||
username=update.effective_user.username,
|
||||
first_name=update.effective_user.first_name,
|
||||
last_name=update.effective_user.last_name,
|
||||
),
|
||||
)
|
||||
await asyncio.gather(update.message.reply_text(answer), chatgpt_service.update_bot_user_message_count(user.id))
|
||||
|
||||
|
||||
async def voice_recognize(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
|
||||
@@ -15,6 +15,9 @@ from speech_recognition import (
|
||||
)
|
||||
|
||||
from constants import AUDIO_SEGMENT_DURATION
|
||||
from core.auth.models.users import User
|
||||
from core.auth.repository import UserRepository
|
||||
from core.auth.services import UserService
|
||||
from core.bot.models.chatgpt import ChatGptModels
|
||||
from core.bot.repository import ChatGPTRepository
|
||||
from infra.database.db_adapter import Database
|
||||
@@ -89,6 +92,15 @@ class SpeechToTextService:
|
||||
@dataclass
|
||||
class ChatGptService:
|
||||
repository: ChatGPTRepository
|
||||
user_service: UserService
|
||||
|
||||
@classmethod
|
||||
def build(cls) -> "ChatGptService":
|
||||
db = Database(settings=settings)
|
||||
repository = ChatGPTRepository(settings=settings, db=db)
|
||||
user_repository = UserRepository(db=db)
|
||||
user_service = UserService(repository=user_repository)
|
||||
return ChatGptService(repository=repository, user_service=user_service)
|
||||
|
||||
async def get_chatgpt_models(self) -> Sequence[ChatGptModels]:
|
||||
return await self.repository.get_chatgpt_models()
|
||||
@@ -117,8 +129,27 @@ class ChatGptService:
|
||||
async def delete_chatgpt_model(self, model_id: int) -> None:
|
||||
return await self.repository.delete_chatgpt_model(model_id=model_id)
|
||||
|
||||
@classmethod
|
||||
def build(cls) -> "ChatGptService":
|
||||
db = Database(settings=settings)
|
||||
repository = ChatGPTRepository(settings=settings, db=db)
|
||||
return ChatGptService(repository=repository)
|
||||
async def get_or_create_bot_user(
|
||||
self,
|
||||
user_id: int,
|
||||
email: str | None = None,
|
||||
username: str | None = None,
|
||||
first_name: str | None = None,
|
||||
last_name: str | None = None,
|
||||
ban_reason: str | None = None,
|
||||
is_active: bool = True,
|
||||
is_superuser: bool = False,
|
||||
) -> User:
|
||||
return await self.user_service.get_or_create_user_by_id(
|
||||
user_id=user_id,
|
||||
email=email,
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
ban_reason=ban_reason,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
)
|
||||
|
||||
async def update_bot_user_message_count(self, user_id: int) -> None:
|
||||
await self.user_service.update_user_message_count(user_id)
|
||||
|
||||
Reference in New Issue
Block a user