mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2025-09-10 17:20:41 +03:00
Compare commits
4 Commits
8266342214
...
28d01f551b
Author | SHA1 | Date | |
---|---|---|---|
|
28d01f551b | ||
|
28895f3510 | ||
|
7cbe7b7c50 | ||
|
de55d873f9 |
@ -110,6 +110,7 @@ Init alembic
|
||||
cd bot_microservice
|
||||
alembic revision --autogenerate -m 'create_chatgpt_table'
|
||||
alembic upgrade head
|
||||
python3 core/bot/managment/update_gpt_models.py
|
||||
```
|
||||
|
||||
|
||||
@ -130,6 +131,7 @@ Run migrations
|
||||
```bash
|
||||
cd ./bot_microservice # alembic root
|
||||
alembic --config ./alembic.ini upgrade head
|
||||
python3 core/bot/managment/update_gpt_models.py
|
||||
```
|
||||
|
||||
**downgrade to 0001_create_chatgpt_table revision:**
|
||||
@ -166,4 +168,4 @@ alembic downgrade base
|
||||
- [x] add sentry
|
||||
- [x] add graylog integration and availability to log to file
|
||||
- [x] add user model
|
||||
- [ ] add messages statistic
|
||||
- [x] add messages statistic
|
||||
|
1
bot_microservice/api/bot/constants.py
Normal file
1
bot_microservice/api/bot/constants.py
Normal file
@ -0,0 +1 @@
|
||||
BOT_ACCESS_API_HEADER = "BOT-API-KEY"
|
@ -3,13 +3,19 @@ from starlette import status
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from telegram import Update
|
||||
|
||||
from api.bot.deps import get_bot_queue, get_chatgpt_service, get_update_from_request
|
||||
from api.bot.deps import (
|
||||
get_access_to_bot_api_or_403,
|
||||
get_bot_queue,
|
||||
get_chatgpt_service,
|
||||
get_update_from_request,
|
||||
)
|
||||
from api.bot.serializers import (
|
||||
ChatGptModelSerializer,
|
||||
ChatGptModelsPrioritySerializer,
|
||||
GETChatGptModelsSerializer,
|
||||
LightChatGptModel,
|
||||
)
|
||||
from api.exceptions import PermissionMissingResponse
|
||||
from core.bot.app import BotQueue
|
||||
from core.bot.services import ChatGptService
|
||||
from settings.config import settings
|
||||
@ -53,6 +59,10 @@ async def models_list(
|
||||
@router.put(
|
||||
"/chatgpt/models/{model_id}/priority",
|
||||
name="bot:change_model_priority",
|
||||
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||
responses={
|
||||
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||
},
|
||||
response_class=Response,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="change gpt model priority",
|
||||
@ -69,6 +79,10 @@ async def change_model_priority(
|
||||
@router.put(
|
||||
"/chatgpt/models/priority/reset",
|
||||
name="bot:reset_models_priority",
|
||||
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||
responses={
|
||||
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||
},
|
||||
response_class=Response,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="reset all model priority to default",
|
||||
@ -83,6 +97,11 @@ async def reset_models_priority(
|
||||
@router.post(
|
||||
"/chatgpt/models",
|
||||
name="bot:add_new_model",
|
||||
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||
responses={
|
||||
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||
status.HTTP_201_CREATED: {"model": ChatGptModelSerializer},
|
||||
},
|
||||
response_model=ChatGptModelSerializer,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="add new model",
|
||||
@ -100,6 +119,10 @@ async def add_new_model(
|
||||
@router.delete(
|
||||
"/chatgpt/models/{model_id}",
|
||||
name="bot:delete_gpt_model",
|
||||
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||
responses={
|
||||
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||
},
|
||||
response_class=Response,
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="delete gpt model",
|
||||
|
@ -1,15 +1,17 @@
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
from telegram import Update
|
||||
|
||||
from api.auth.deps import get_user_service
|
||||
from api.bot.constants import BOT_ACCESS_API_HEADER
|
||||
from api.deps import get_database
|
||||
from core.auth.services import UserService
|
||||
from core.bot.app import BotApplication, BotQueue
|
||||
from core.bot.repository import ChatGPTRepository
|
||||
from core.bot.services import ChatGptService
|
||||
from infra.database.db_adapter import Database
|
||||
from settings.config import AppSettings, get_settings
|
||||
from settings.config import AppSettings, get_settings, settings
|
||||
|
||||
|
||||
def get_bot_app(request: Request) -> BotApplication:
|
||||
@ -40,3 +42,13 @@ def get_chatgpt_service(
|
||||
user_service: UserService = Depends(get_user_service),
|
||||
) -> ChatGptService:
|
||||
return ChatGptService(repository=chatgpt_repository, user_service=user_service)
|
||||
|
||||
|
||||
async def get_access_to_bot_api_or_403(
|
||||
bot_api_key: str | None = Header(None, alias=BOT_ACCESS_API_HEADER, description="Ключ доступа до API бота"),
|
||||
user_service: UserService = Depends(get_user_service),
|
||||
) -> None:
|
||||
access_token = await user_service.get_user_access_token_by_username(settings.SUPERUSER)
|
||||
|
||||
if not access_token or access_token != bot_api_key:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate api header")
|
||||
|
@ -1,11 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
|
||||
from api.base_schemas import BaseError, BaseResponse
|
||||
|
||||
|
||||
class BaseAPIException(Exception):
|
||||
pass
|
||||
_content_type: str = "application/json"
|
||||
model: type[BaseResponse] = BaseResponse
|
||||
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
title: str | None = None
|
||||
type: str | None = None
|
||||
detail: str | None = None
|
||||
instance: str | None = None
|
||||
headers: dict[str, str] | None = None
|
||||
|
||||
def __init__(self, **ctx: Any) -> None:
|
||||
self.__dict__ = ctx
|
||||
|
||||
@classmethod
|
||||
def example(cls) -> dict[str, Any] | None:
|
||||
if isinstance(cls.model.Config.schema_extra, dict): # type: ignore[attr-defined]
|
||||
return cls.model.Config.schema_extra.get("example") # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def response(cls) -> dict[str, Any]:
|
||||
return {
|
||||
"model": cls.model,
|
||||
"content": {
|
||||
cls._content_type: cls.model.Config.schema_extra, # type: ignore[attr-defined]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class InternalServerError(BaseError):
|
||||
@ -28,6 +56,31 @@ class InternalServerErrorResponse(BaseResponse):
|
||||
}
|
||||
|
||||
|
||||
class PermissionMissing(BaseError):
|
||||
pass
|
||||
|
||||
|
||||
class PermissionMissingResponse(BaseResponse):
|
||||
error: PermissionMissing
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"status": 403,
|
||||
"error": {
|
||||
"type": "PermissionMissing",
|
||||
"title": "Permission required for this endpoint is missing",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PermissionMissingError(BaseAPIException):
|
||||
model = PermissionMissingResponse
|
||||
status_code = status.HTTP_403_FORBIDDEN
|
||||
title: str = "Permission required for this endpoint is missing"
|
||||
|
||||
|
||||
async def internal_server_error_handler(_request: Request, _exception: Exception) -> ORJSONResponse:
|
||||
error = InternalServerError(title="Something went wrong!", type="InternalServerError")
|
||||
response = InternalServerErrorResponse(status=500, error=error).model_dump(exclude_unset=True)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String
|
||||
@ -18,15 +19,13 @@ class User(Base):
|
||||
hashed_password: Mapped[str] = mapped_column(String(length=1024), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, default=datetime.now)
|
||||
|
||||
user_question_count: Mapped["UserQuestionCount"] = relationship(
|
||||
"UserQuestionCount",
|
||||
primaryjoin="UserQuestionCount.user_id == User.id",
|
||||
backref="user",
|
||||
lazy="selectin",
|
||||
lazy="noload",
|
||||
uselist=False,
|
||||
cascade="delete",
|
||||
)
|
||||
@ -37,6 +36,12 @@ class User(Base):
|
||||
return self.user_question_count.question_count
|
||||
return 0
|
||||
|
||||
@property
|
||||
def last_question_at(self) -> str | None:
|
||||
if self.user_question_count:
|
||||
return self.user_question_count.last_question_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
@ -68,14 +73,27 @@ class AccessToken(Base):
|
||||
__tablename__ = "access_token" # type: ignore[assignment]
|
||||
|
||||
user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
|
||||
token: Mapped[str] = mapped_column(String(length=42), primary_key=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
|
||||
token: Mapped[str] = mapped_column(String(length=42), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, default=datetime.now)
|
||||
|
||||
user: Mapped["User"] = relationship(
|
||||
"User",
|
||||
backref="access_token",
|
||||
lazy="selectin",
|
||||
uselist=False,
|
||||
cascade="expunge",
|
||||
)
|
||||
|
||||
@property
|
||||
def username(self) -> str:
|
||||
if self.user:
|
||||
return self.user.username
|
||||
return ""
|
||||
|
||||
|
||||
class UserQuestionCount(Base):
|
||||
__tablename__ = "user_question_count" # type: ignore[assignment]
|
||||
|
||||
user_id: Mapped[int] = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), primary_key=True)
|
||||
question_count: Mapped[int] = mapped_column(INTEGER, default=0, nullable=False)
|
||||
last_question_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, default=datetime.now)
|
||||
|
@ -1,11 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
from sqlalchemy.orm import load_only
|
||||
|
||||
from core.auth.dto import UserIsBannedDTO
|
||||
from core.auth.models.users import User, UserQuestionCount
|
||||
from core.auth.models.users import AccessToken, User, UserQuestionCount
|
||||
from infra.database.db_adapter import Database
|
||||
|
||||
|
||||
@ -69,10 +69,18 @@ class UserRepository:
|
||||
UserQuestionCount.get_real_column_name(
|
||||
UserQuestionCount.question_count.key
|
||||
): UserQuestionCount.question_count
|
||||
+ 1
|
||||
+ 1,
|
||||
UserQuestionCount.get_real_column_name(UserQuestionCount.last_question_at.key): func.now(),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async with self.db.session() as session:
|
||||
await session.execute(query)
|
||||
|
||||
async def get_user_access_token(self, username: str | None) -> str | None:
|
||||
query = select(AccessToken.token).join(AccessToken.user).where(User.username == username)
|
||||
|
||||
async with self.db.session() as session:
|
||||
result = await session.execute(query)
|
||||
return result.scalar()
|
||||
|
@ -62,6 +62,9 @@ class UserService:
|
||||
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
|
||||
return await self.repository.check_user_is_banned(user_id)
|
||||
|
||||
async def get_user_access_token_by_username(self, username: str | None) -> str | None:
|
||||
return await self.repository.get_user_access_token(username)
|
||||
|
||||
|
||||
def check_user_is_banned(func: Any) -> Any:
|
||||
@wraps(func)
|
||||
|
0
bot_microservice/core/bot/managment/__init__.py
Normal file
0
bot_microservice/core/bot/managment/__init__.py
Normal file
13
bot_microservice/core/bot/managment/update_gpt_models.py
Normal file
13
bot_microservice/core/bot/managment/update_gpt_models.py
Normal file
@ -0,0 +1,13 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
|
||||
from core.bot.services import ChatGptService
|
||||
|
||||
chatgpt_service = ChatGptService.build()
|
||||
asyncio.run(chatgpt_service.update_chatgpt_models())
|
||||
logger.info("chatgpt models has been updated")
|
@ -38,6 +38,18 @@ class ChatGPTRepository:
|
||||
async with self.db.session() as session:
|
||||
await session.execute(query)
|
||||
|
||||
async def delete_all_chatgpt_models(self) -> None:
|
||||
query = delete(ChatGptModels)
|
||||
async with self.db.session() as session:
|
||||
await session.execute(query)
|
||||
|
||||
async def bulk_insert_chatgpt_models(self, models_priority: list[dict[str, Any]]) -> None:
|
||||
models = [ChatGptModels(**model_priority) for model_priority in models_priority]
|
||||
|
||||
async with self.db.session() as session:
|
||||
session.add_all(models)
|
||||
await session.commit()
|
||||
|
||||
async def add_chatgpt_model(self, model: str, priority: int) -> dict[str, str | int]:
|
||||
query = (
|
||||
insert(ChatGptModels)
|
||||
|
@ -14,7 +14,7 @@ from speech_recognition import (
|
||||
UnknownValueError as SpeechRecognizerError,
|
||||
)
|
||||
|
||||
from constants import AUDIO_SEGMENT_DURATION
|
||||
from constants import AUDIO_SEGMENT_DURATION, ChatGptModelsEnum
|
||||
from core.auth.models.users import User
|
||||
from core.auth.repository import UserRepository
|
||||
from core.auth.services import UserService
|
||||
@ -24,6 +24,77 @@ from infra.database.db_adapter import Database
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatGptService:
|
||||
repository: ChatGPTRepository
|
||||
user_service: UserService
|
||||
|
||||
@classmethod
|
||||
def build(cls) -> "ChatGptService":
|
||||
db = Database(settings=settings)
|
||||
repository = ChatGPTRepository(settings=settings, db=db)
|
||||
user_repository = UserRepository(db=db)
|
||||
user_service = UserService(repository=user_repository)
|
||||
return ChatGptService(repository=repository, user_service=user_service)
|
||||
|
||||
async def get_chatgpt_models(self) -> Sequence[ChatGptModels]:
|
||||
return await self.repository.get_chatgpt_models()
|
||||
|
||||
async def request_to_chatgpt(self, question: str | None) -> str:
|
||||
question = question or "Привет!"
|
||||
chatgpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.ask_question(question=question, chatgpt_model=chatgpt_model)
|
||||
|
||||
async def request_to_chatgpt_microservice(self, question: str) -> Response:
|
||||
chatgpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.request_to_chatgpt_microservice(question=question, chatgpt_model=chatgpt_model)
|
||||
|
||||
async def get_current_chatgpt_model(self) -> str:
|
||||
return await self.repository.get_current_chatgpt_model()
|
||||
|
||||
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
|
||||
return await self.repository.change_chatgpt_model_priority(model_id=model_id, priority=priority)
|
||||
|
||||
async def update_chatgpt_models(self) -> None:
|
||||
await self.repository.delete_all_chatgpt_models()
|
||||
models = ChatGptModelsEnum.base_models_priority()
|
||||
await self.repository.bulk_insert_chatgpt_models(models)
|
||||
|
||||
async def reset_all_chatgpt_models_priority(self) -> None:
|
||||
return await self.repository.reset_all_chatgpt_models_priority()
|
||||
|
||||
async def add_chatgpt_model(self, gpt_model: str, priority: int) -> dict[str, str | int]:
|
||||
return await self.repository.add_chatgpt_model(model=gpt_model, priority=priority)
|
||||
|
||||
async def delete_chatgpt_model(self, model_id: int) -> None:
|
||||
return await self.repository.delete_chatgpt_model(model_id=model_id)
|
||||
|
||||
async def get_or_create_bot_user(
|
||||
self,
|
||||
user_id: int,
|
||||
email: str | None = None,
|
||||
username: str | None = None,
|
||||
first_name: str | None = None,
|
||||
last_name: str | None = None,
|
||||
ban_reason: str | None = None,
|
||||
is_active: bool = True,
|
||||
is_superuser: bool = False,
|
||||
) -> User:
|
||||
return await self.user_service.get_or_create_user_by_id(
|
||||
user_id=user_id,
|
||||
email=email,
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
ban_reason=ban_reason,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
)
|
||||
|
||||
async def update_bot_user_message_count(self, user_id: int) -> None:
|
||||
await self.user_service.update_user_message_count(user_id)
|
||||
|
||||
|
||||
class SpeechToTextService:
|
||||
def __init__(self, filename: str) -> None:
|
||||
self.filename = filename
|
||||
@ -87,69 +158,3 @@ class SpeechToTextService:
|
||||
except SpeechRecognizerError as error:
|
||||
logger.error("error recognizing text with google", error=error)
|
||||
raise
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatGptService:
|
||||
repository: ChatGPTRepository
|
||||
user_service: UserService
|
||||
|
||||
@classmethod
|
||||
def build(cls) -> "ChatGptService":
|
||||
db = Database(settings=settings)
|
||||
repository = ChatGPTRepository(settings=settings, db=db)
|
||||
user_repository = UserRepository(db=db)
|
||||
user_service = UserService(repository=user_repository)
|
||||
return ChatGptService(repository=repository, user_service=user_service)
|
||||
|
||||
async def get_chatgpt_models(self) -> Sequence[ChatGptModels]:
|
||||
return await self.repository.get_chatgpt_models()
|
||||
|
||||
async def request_to_chatgpt(self, question: str | None) -> str:
|
||||
question = question or "Привет!"
|
||||
chatgpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.ask_question(question=question, chatgpt_model=chatgpt_model)
|
||||
|
||||
async def request_to_chatgpt_microservice(self, question: str) -> Response:
|
||||
chatgpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.request_to_chatgpt_microservice(question=question, chatgpt_model=chatgpt_model)
|
||||
|
||||
async def get_current_chatgpt_model(self) -> str:
|
||||
return await self.repository.get_current_chatgpt_model()
|
||||
|
||||
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
|
||||
return await self.repository.change_chatgpt_model_priority(model_id=model_id, priority=priority)
|
||||
|
||||
async def reset_all_chatgpt_models_priority(self) -> None:
|
||||
return await self.repository.reset_all_chatgpt_models_priority()
|
||||
|
||||
async def add_chatgpt_model(self, gpt_model: str, priority: int) -> dict[str, str | int]:
|
||||
return await self.repository.add_chatgpt_model(model=gpt_model, priority=priority)
|
||||
|
||||
async def delete_chatgpt_model(self, model_id: int) -> None:
|
||||
return await self.repository.delete_chatgpt_model(model_id=model_id)
|
||||
|
||||
async def get_or_create_bot_user(
|
||||
self,
|
||||
user_id: int,
|
||||
email: str | None = None,
|
||||
username: str | None = None,
|
||||
first_name: str | None = None,
|
||||
last_name: str | None = None,
|
||||
ban_reason: str | None = None,
|
||||
is_active: bool = True,
|
||||
is_superuser: bool = False,
|
||||
) -> User:
|
||||
return await self.user_service.get_or_create_user_by_id(
|
||||
user_id=user_id,
|
||||
email=email,
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
ban_reason=ban_reason,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
)
|
||||
|
||||
async def update_bot_user_message_count(self, user_id: int) -> None:
|
||||
await self.user_service.update_user_message_count(user_id)
|
||||
|
@ -1,8 +1,11 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqladmin import Admin, ModelView
|
||||
from sqlalchemy import Select, desc, select
|
||||
from sqlalchemy.orm import contains_eager, load_only
|
||||
from starlette.requests import Request
|
||||
|
||||
from core.auth.models.users import User
|
||||
from core.auth.models.users import AccessToken, User, UserQuestionCount
|
||||
from core.bot.models.chatgpt import ChatGptModels
|
||||
from core.utils import build_uri
|
||||
from settings.config import settings
|
||||
@ -34,12 +37,44 @@ class UserAdmin(ModelView, model=User):
|
||||
User.is_active,
|
||||
User.ban_reason,
|
||||
"question_count",
|
||||
"last_question_at",
|
||||
User.created_at,
|
||||
]
|
||||
column_sortable_list = [User.created_at]
|
||||
|
||||
column_default_sort = ("created_at", True)
|
||||
form_widget_args = {"created_at": {"readonly": True}}
|
||||
|
||||
def list_query(self, request: Request) -> Select[tuple[User]]:
|
||||
return (
|
||||
select(User)
|
||||
.options(
|
||||
load_only(
|
||||
User.id,
|
||||
User.username,
|
||||
User.first_name,
|
||||
User.last_name,
|
||||
User.is_active,
|
||||
User.created_at,
|
||||
)
|
||||
)
|
||||
.outerjoin(User.user_question_count)
|
||||
.options(
|
||||
contains_eager(User.user_question_count).options(
|
||||
load_only(
|
||||
UserQuestionCount.question_count,
|
||||
UserQuestionCount.last_question_at,
|
||||
)
|
||||
)
|
||||
)
|
||||
).order_by(desc(UserQuestionCount.question_count))
|
||||
|
||||
|
||||
class AccessTokenAdmin(ModelView, model=AccessToken):
|
||||
name = "API access token"
|
||||
name_plural = "API access tokens"
|
||||
column_list = [AccessToken.user_id, "username", AccessToken.token, AccessToken.created_at]
|
||||
form_widget_args = {"created_at": {"readonly": True}}
|
||||
|
||||
|
||||
def create_admin(application: "Application") -> Admin:
|
||||
admin = Admin(
|
||||
@ -51,4 +86,5 @@ def create_admin(application: "Application") -> Admin:
|
||||
)
|
||||
admin.add_view(ChatGptAdmin)
|
||||
admin.add_view(UserAdmin)
|
||||
admin.add_view(AccessTokenAdmin)
|
||||
return admin
|
||||
|
@ -10,9 +10,8 @@ from datetime import datetime
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy import TIMESTAMP
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
|
||||
from core.auth.models.users import User
|
||||
from core.auth.models.users import AccessToken, User
|
||||
from core.auth.utils import create_password_hash
|
||||
from infra.database.deps import get_sync_session
|
||||
from settings.config import settings
|
||||
@ -58,8 +57,14 @@ def upgrade() -> None:
|
||||
return
|
||||
with get_sync_session() as session:
|
||||
hashed_password = create_password_hash(password.get_secret_value())
|
||||
query = insert(User).values({"username": username, "hashed_password": hashed_password})
|
||||
session.execute(query)
|
||||
user = User(username=username, hashed_password=hashed_password)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
session.refresh(user)
|
||||
|
||||
access_token = AccessToken(user_id=user.id)
|
||||
session.add(access_token)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
|
@ -1,43 +0,0 @@
|
||||
"""create chatgpt models
|
||||
|
||||
Revision ID: 0004_add_chatgpt_models
|
||||
Revises: 0003_create_user_question_count_table
|
||||
Create Date: 2025-10-05 20:44:05.414977
|
||||
|
||||
"""
|
||||
from loguru import logger
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from constants import ChatGptModelsEnum
|
||||
from core.bot.models.chatgpt import ChatGptModels
|
||||
from infra.database.deps import get_sync_session
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0004_add_chatgpt_models"
|
||||
down_revision = "0003_create_user_question_count_table"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with get_sync_session() as session:
|
||||
query = select(ChatGptModels)
|
||||
results = session.execute(query)
|
||||
models = results.scalars().all()
|
||||
|
||||
if models:
|
||||
return
|
||||
models = []
|
||||
for data in ChatGptModelsEnum.base_models_priority():
|
||||
models.append(ChatGptModels(**data))
|
||||
session.add_all(models)
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
chatgpt_table_name = ChatGptModels.__tablename__
|
||||
with get_sync_session() as session:
|
||||
# Truncate doesn't exists for SQLite
|
||||
session.execute(text(f"""DELETE FROM {chatgpt_table_name}""")) # noqa: S608
|
||||
session.commit()
|
||||
logger.info("chatgpt models table has been truncated", table=chatgpt_table_name)
|
@ -0,0 +1,29 @@
|
||||
"""add_last_question_at
|
||||
|
||||
Revision ID: 0004_add_last_question_at
|
||||
Revises: 0003_create_user_question_count_table
|
||||
Create Date: 2024-01-08 20:56:34.815976
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0004_add_last_question_at"
|
||||
down_revision = "0003_create_user_question_count_table"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_access_token_created_at", table_name="access_token")
|
||||
op.add_column("user_question_count", sa.Column("last_question_at", sa.TIMESTAMP(timezone=True), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("user_question_count", "last_question_at")
|
||||
op.create_index("ix_access_token_created_at", "access_token", ["created_at"], unique=False)
|
||||
# ### end Alembic commands ###
|
@ -22,7 +22,7 @@ class Application:
|
||||
self.app = FastAPI(
|
||||
title="Chat gpt bot",
|
||||
description="Bot for proxy to chat gpt in telegram",
|
||||
version="0.0.3",
|
||||
version="1.1.12",
|
||||
docs_url=build_uri([settings.api_prefix, "docs"]),
|
||||
redoc_url=build_uri([settings.api_prefix, "redocs"]),
|
||||
openapi_url=build_uri([settings.api_prefix, "openapi.json"]),
|
||||
|
@ -11,6 +11,8 @@ TELEGRAM_API_TOKEN="123456789:AABBCCDDEEFFaabbccddeeff-1234567890"
|
||||
# set to true to start with webhook. Else bot will start on polling method
|
||||
START_WITH_WEBHOOK="false"
|
||||
|
||||
SUPERUSER="Superuser"
|
||||
|
||||
# ==== domain settings ====
|
||||
DOMAIN="http://localhost"
|
||||
URL_PREFIX=
|
||||
|
@ -6,7 +6,9 @@ from sqlalchemy import desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.bot.models.chatgpt import ChatGptModels
|
||||
from settings.config import AppSettings
|
||||
from tests.integration.factories.bot import ChatGptModelFactory
|
||||
from tests.integration.factories.user import AccessTokenFactory, UserFactory
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.asyncio,
|
||||
@ -51,11 +53,18 @@ async def test_change_chatgpt_model_priority(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
faker: Faker,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
model1 = ChatGptModelFactory(priority=0)
|
||||
model2 = ChatGptModelFactory(priority=1)
|
||||
priority = faker.random_int(min=2, max=7)
|
||||
response = await rest_client.put(url=f"/api/chatgpt/models/{model2.id}/priority", json={"priority": priority})
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
access_token = AccessTokenFactory(user_id=user.id)
|
||||
response = await rest_client.put(
|
||||
url=f"/api/chatgpt/models/{model2.id}/priority",
|
||||
json={"priority": priority},
|
||||
headers={"BOT-API-KEY": access_token.token},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
upd_model1, upd_model2 = dbsession.query(ChatGptModels).order_by(ChatGptModels.priority).all()
|
||||
@ -66,14 +75,53 @@ async def test_change_chatgpt_model_priority(
|
||||
assert upd_model2.priority == priority
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bot_api_key_header",
|
||||
[
|
||||
pytest.param({"BOT-API-KEY": ""}, id="empty key in header"),
|
||||
pytest.param({"BOT-API-KEY": "Hello World"}, id="incorrect token"),
|
||||
pytest.param({"Hello-World": "lasdkj3-wer-weqwe_34"}, id="correct token but wrong api key"),
|
||||
pytest.param({}, id="no api key header"),
|
||||
],
|
||||
)
|
||||
async def test_cant_change_chatgpt_model_priority_with_wrong_api_key(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
faker: Faker,
|
||||
test_settings: AppSettings,
|
||||
bot_api_key_header: dict[str, str],
|
||||
) -> None:
|
||||
ChatGptModelFactory(priority=0)
|
||||
model2 = ChatGptModelFactory(priority=1)
|
||||
priority = faker.random_int(min=2, max=7)
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
AccessTokenFactory(user_id=user.id, token="lasdkj3-wer-weqwe_34") # noqa: S106
|
||||
response = await rest_client.put(
|
||||
url=f"/api/chatgpt/models/{model2.id}/priority",
|
||||
json={"priority": priority},
|
||||
headers=bot_api_key_header,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
changed_model = dbsession.query(ChatGptModels).filter_by(id=model2.id).one()
|
||||
assert changed_model.priority == model2.priority
|
||||
|
||||
|
||||
async def test_reset_chatgpt_models_priority(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=4)
|
||||
ChatGptModelFactory(priority=42)
|
||||
|
||||
response = await rest_client.put(url="/api/chatgpt/models/priority/reset")
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
access_token = AccessTokenFactory(user_id=user.id)
|
||||
|
||||
response = await rest_client.put(
|
||||
url="/api/chatgpt/models/priority/reset",
|
||||
headers={"BOT-API-KEY": access_token.token},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
@ -85,14 +133,49 @@ async def test_reset_chatgpt_models_priority(
|
||||
assert model.priority == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bot_api_key_header",
|
||||
[
|
||||
pytest.param({"BOT-API-KEY": ""}, id="empty key in header"),
|
||||
pytest.param({"BOT-API-KEY": "Hello World"}, id="incorrect token"),
|
||||
],
|
||||
)
|
||||
async def test_cant_reset_chatgpt_models_priority_with_wrong_api_key(
|
||||
dbsession: Session, rest_client: AsyncClient, test_settings: AppSettings, bot_api_key_header: dict[str, str]
|
||||
) -> None:
|
||||
chat_gpt_models = ChatGptModelFactory.create_batch(size=2)
|
||||
model_with_highest_priority = ChatGptModelFactory(priority=42)
|
||||
|
||||
priorities = sorted([model.priority for model in chat_gpt_models] + [model_with_highest_priority.priority])
|
||||
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
AccessTokenFactory(user_id=user.id)
|
||||
|
||||
response = await rest_client.put(
|
||||
url="/api/chatgpt/models/priority/reset",
|
||||
headers=bot_api_key_header,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
|
||||
changed_priorities = sorted([model.priority for model in models])
|
||||
|
||||
assert changed_priorities == priorities
|
||||
|
||||
|
||||
async def test_create_new_chatgpt_model(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
faker: Faker,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=2)
|
||||
ChatGptModelFactory(priority=42)
|
||||
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
access_token = AccessTokenFactory(user_id=user.id)
|
||||
|
||||
model_name = "new-gpt-model"
|
||||
model_priority = faker.random_int(min=1, max=5)
|
||||
|
||||
@ -105,6 +188,7 @@ async def test_create_new_chatgpt_model(
|
||||
"model": model_name,
|
||||
"priority": model_priority,
|
||||
},
|
||||
headers={"BOT-API-KEY": access_token.token},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
@ -121,13 +205,47 @@ async def test_create_new_chatgpt_model(
|
||||
}
|
||||
|
||||
|
||||
async def test_cant_create_new_chatgpt_model_with_wrong_api_key(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
faker: Faker,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=2)
|
||||
ChatGptModelFactory(priority=42)
|
||||
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
AccessTokenFactory(user_id=user.id)
|
||||
|
||||
model_name = "new-gpt-model"
|
||||
model_priority = faker.random_int(min=1, max=5)
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
assert len(models) == 3
|
||||
|
||||
response = await rest_client.post(
|
||||
url="/api/chatgpt/models",
|
||||
json={
|
||||
"model": model_name,
|
||||
"priority": model_priority,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
assert len(models) == 3
|
||||
|
||||
|
||||
async def test_add_existing_chatgpt_model(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
faker: Faker,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=2)
|
||||
model = ChatGptModelFactory(priority=42)
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
access_token = AccessTokenFactory(user_id=user.id)
|
||||
|
||||
model_name = model.model
|
||||
model_priority = faker.random_int(min=1, max=5)
|
||||
@ -141,6 +259,7 @@ async def test_add_existing_chatgpt_model(
|
||||
"model": model_name,
|
||||
"priority": model_priority,
|
||||
},
|
||||
headers={"BOT-API-KEY": access_token.token},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
@ -151,17 +270,48 @@ async def test_add_existing_chatgpt_model(
|
||||
async def test_delete_chatgpt_model(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=2)
|
||||
model = ChatGptModelFactory(priority=42)
|
||||
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
access_token = AccessTokenFactory(user_id=user.id)
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
assert len(models) == 3
|
||||
|
||||
response = await rest_client.delete(url=f"/api/chatgpt/models/{model.id}")
|
||||
response = await rest_client.delete(
|
||||
url=f"/api/chatgpt/models/{model.id}",
|
||||
headers={"BOT-API-KEY": access_token.token},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
assert len(models) == 2
|
||||
|
||||
assert model not in models
|
||||
|
||||
|
||||
async def test_cant_delete_chatgpt_model_with_wrong_api_key(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=2)
|
||||
model = ChatGptModelFactory(priority=42)
|
||||
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
access_token = AccessTokenFactory(user_id=user.id)
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
assert len(models) == 3
|
||||
|
||||
response = await rest_client.delete(
|
||||
url=f"/api/chatgpt/models/{model.id}",
|
||||
headers={"ROOT-ACCESS": access_token.token},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
assert len(models) == 3
|
||||
|
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
from unittest import mock
|
||||
|
||||
import httpx
|
||||
@ -11,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update
|
||||
|
||||
from constants import BotStagesEnum
|
||||
from core.auth.models.users import User, UserQuestionCount
|
||||
from core.bot.app import BotApplication, BotQueue
|
||||
from main import Application
|
||||
from settings.config import AppSettings
|
||||
@ -22,6 +24,7 @@ from tests.integration.factories.bot import (
|
||||
CallBackFactory,
|
||||
ChatGptModelFactory,
|
||||
)
|
||||
from tests.integration.factories.user import UserFactory, UserQuestionCountFactory
|
||||
from tests.integration.utils import mocked_ask_question_api
|
||||
|
||||
pytestmark = [
|
||||
@ -113,6 +116,33 @@ async def test_help_command(
|
||||
)
|
||||
|
||||
|
||||
async def test_help_command_user_is_banned(
|
||||
main_application: Application,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
message = BotMessageFactory.create_instance(text="/help")
|
||||
user = message["from"]
|
||||
UserFactory(
|
||||
id=user["id"],
|
||||
first_name=user["first_name"],
|
||||
last_name=user["last_name"],
|
||||
username=user["username"],
|
||||
is_active=False,
|
||||
ban_reason="test reason",
|
||||
)
|
||||
with mock.patch.object(
|
||||
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
|
||||
) as mocked_send_message:
|
||||
bot_update = BotUpdateFactory(message=message)
|
||||
|
||||
await main_application.bot_app.application.process_update(
|
||||
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
|
||||
)
|
||||
assert mocked_send_message.call_args.kwargs["text"] == (
|
||||
"You have banned for reason: *test reason*.\nPlease contact the /developer"
|
||||
)
|
||||
|
||||
|
||||
async def test_start_entry(
|
||||
main_application: Application,
|
||||
test_settings: AppSettings,
|
||||
@ -232,6 +262,35 @@ async def test_website_callback_action(
|
||||
assert mocked_reply_text.call_args.args == ("Веб версия: http://localhost/chat/",)
|
||||
|
||||
|
||||
async def test_website_callback_action_user_is_banned(
|
||||
main_application: Application,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
message = BotMessageFactory.create_instance(text="Список основных команд:")
|
||||
user = message["from"]
|
||||
UserFactory(
|
||||
id=user["id"],
|
||||
first_name=user["first_name"],
|
||||
last_name=user["last_name"],
|
||||
username=user["username"],
|
||||
is_active=False,
|
||||
ban_reason="test reason",
|
||||
)
|
||||
with mock.patch.object(telegram._message.Message, "reply_text") as mocked_reply_text:
|
||||
bot_update = BotCallBackQueryFactory(
|
||||
message=message,
|
||||
callback_query=CallBackFactory(data=BotStagesEnum.website),
|
||||
)
|
||||
|
||||
await main_application.bot_app.application.process_update(
|
||||
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
|
||||
)
|
||||
|
||||
assert mocked_reply_text.call_args.kwargs["text"] == (
|
||||
"You have banned for reason: *test reason*.\nPlease contact the /developer"
|
||||
)
|
||||
|
||||
|
||||
async def test_bug_report_action(
|
||||
main_application: Application,
|
||||
test_settings: AppSettings,
|
||||
@ -261,6 +320,35 @@ async def test_bug_report_action(
|
||||
)
|
||||
|
||||
|
||||
async def test_bug_report_action_user_is_banned(
|
||||
main_application: Application,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
message = BotMessageFactory.create_instance(text="/bug_report")
|
||||
user = message["from"]
|
||||
UserFactory(
|
||||
id=user["id"],
|
||||
first_name=user["first_name"],
|
||||
last_name=user["last_name"],
|
||||
username=user["username"],
|
||||
is_active=False,
|
||||
ban_reason="test reason",
|
||||
)
|
||||
with (
|
||||
mock.patch.object(telegram._message.Message, "reply_text") as mocked_reply_text,
|
||||
mock.patch.object(telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)),
|
||||
):
|
||||
bot_update = BotUpdateFactory(message=message)
|
||||
|
||||
await main_application.bot_app.application.process_update(
|
||||
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
|
||||
)
|
||||
|
||||
assert mocked_reply_text.call_args.kwargs["text"] == (
|
||||
"You have banned for reason: *test reason*.\nPlease contact the /developer"
|
||||
)
|
||||
|
||||
|
||||
async def test_get_developer_action(
|
||||
main_application: Application,
|
||||
test_settings: AppSettings,
|
||||
@ -278,19 +366,28 @@ async def test_get_developer_action(
|
||||
assert mocked_reply_text.call_args.args == ("Автор бота: *Дмитрий Афанасьев*\n\nTg nickname: *Balshtg*",)
|
||||
|
||||
|
||||
async def test_ask_question_action(
|
||||
async def test_ask_question_action_bot_user_not_exists(
|
||||
dbsession: Session,
|
||||
main_application: Application,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=3)
|
||||
users = dbsession.query(User).all()
|
||||
users_question_count = dbsession.query(UserQuestionCount).all()
|
||||
|
||||
assert len(users) == 0
|
||||
assert len(users_question_count) == 0
|
||||
|
||||
message = BotMessageFactory.create_instance(text="Привет!")
|
||||
user = message["from"]
|
||||
|
||||
with mock.patch.object(
|
||||
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
|
||||
) as mocked_send_message, mocked_ask_question_api(
|
||||
host=test_settings.GPT_BASE_HOST,
|
||||
return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"),
|
||||
):
|
||||
bot_update = BotUpdateFactory(message=BotMessageFactory.create_instance(text="Привет!"))
|
||||
bot_update = BotUpdateFactory(message=message)
|
||||
bot_update["message"].pop("entities")
|
||||
|
||||
await main_application.bot_app.application.process_update(
|
||||
@ -313,6 +410,120 @@ async def test_ask_question_action(
|
||||
include=["text", "chat_id"],
|
||||
)
|
||||
|
||||
created_user = dbsession.query(User).filter_by(id=user["id"]).one()
|
||||
assert created_user.username == user["username"]
|
||||
|
||||
created_user_question_count = dbsession.query(UserQuestionCount).filter_by(user_id=user["id"]).one()
|
||||
assert created_user_question_count.question_count == 1
|
||||
assert created_user_question_count.last_question_at - datetime.datetime.now() < datetime.timedelta( # noqa: DTZ005
|
||||
seconds=2
|
||||
)
|
||||
|
||||
|
||||
async def test_ask_question_action_bot_user_already_exists(
|
||||
dbsession: Session,
|
||||
main_application: Application,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=3)
|
||||
message = BotMessageFactory.create_instance(text="Привет!")
|
||||
user = message["from"]
|
||||
existing_user = UserFactory(
|
||||
id=user["id"], first_name=user["first_name"], last_name=user["last_name"], username=user["username"]
|
||||
)
|
||||
existing_user_question_count = UserQuestionCountFactory(user_id=existing_user.id).question_count
|
||||
|
||||
users = dbsession.query(User).all()
|
||||
assert len(users) == 1
|
||||
|
||||
with mock.patch.object(
|
||||
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
|
||||
) as mocked_send_message, mocked_ask_question_api(
|
||||
host=test_settings.GPT_BASE_HOST,
|
||||
return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"),
|
||||
):
|
||||
bot_update = BotUpdateFactory(message=message)
|
||||
bot_update["message"].pop("entities")
|
||||
|
||||
await main_application.bot_app.application.process_update(
|
||||
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
|
||||
)
|
||||
assert_that(mocked_send_message.call_args_list[0].kwargs).is_equal_to(
|
||||
{
|
||||
"text": (
|
||||
"Ответ в среднем занимает 10-15 секунд.\n- Список команд: /help\n- Сообщить об ошибке: /bug_report"
|
||||
),
|
||||
"chat_id": bot_update["message"]["chat"]["id"],
|
||||
},
|
||||
include=["text", "chat_id"],
|
||||
)
|
||||
assert_that(mocked_send_message.call_args_list[1].kwargs).is_equal_to(
|
||||
{
|
||||
"text": "Привет! Как я могу помочь вам сегодня?",
|
||||
"chat_id": bot_update["message"]["chat"]["id"],
|
||||
},
|
||||
include=["text", "chat_id"],
|
||||
)
|
||||
|
||||
users = dbsession.query(User).all()
|
||||
assert len(users) == 1
|
||||
|
||||
updated_user_question_count = dbsession.query(UserQuestionCount).filter_by(user_id=user["id"]).one()
|
||||
assert updated_user_question_count.question_count == existing_user_question_count + 1
|
||||
assert updated_user_question_count.last_question_at - datetime.datetime.now() < datetime.timedelta( # noqa: DTZ005
|
||||
seconds=2
|
||||
)
|
||||
|
||||
|
||||
async def test_ask_question_action_user_is_banned(
|
||||
dbsession: Session,
|
||||
main_application: Application,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=3)
|
||||
|
||||
users_question_count = dbsession.query(UserQuestionCount).all()
|
||||
assert len(users_question_count) == 0
|
||||
|
||||
message = BotMessageFactory.create_instance(text="Привет!")
|
||||
user = message["from"]
|
||||
UserFactory(
|
||||
id=user["id"],
|
||||
first_name=user["first_name"],
|
||||
last_name=user["last_name"],
|
||||
username=user["username"],
|
||||
is_active=False,
|
||||
ban_reason="test reason",
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
|
||||
) as mocked_send_message, mocked_ask_question_api(
|
||||
host=test_settings.GPT_BASE_HOST,
|
||||
return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"),
|
||||
assert_all_called=False,
|
||||
):
|
||||
bot_update = BotUpdateFactory(message=message)
|
||||
bot_update["message"].pop("entities")
|
||||
|
||||
await main_application.bot_app.application.process_update(
|
||||
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
|
||||
)
|
||||
assert_that(mocked_send_message.call_args_list[0].kwargs).is_equal_to(
|
||||
{
|
||||
"text": ("You have banned for reason: *test reason*.\nPlease contact the /developer"),
|
||||
"chat_id": bot_update["message"]["chat"]["id"],
|
||||
},
|
||||
include=["text", "chat_id"],
|
||||
)
|
||||
|
||||
created_user = dbsession.query(User).filter_by(id=user["id"]).one()
|
||||
assert created_user.username == user["username"]
|
||||
assert created_user.is_active is False
|
||||
|
||||
created_user_question_count = dbsession.query(UserQuestionCount).filter_by(user_id=user["id"]).scalar()
|
||||
assert created_user_question_count is None
|
||||
|
||||
|
||||
async def test_ask_question_action_not_success(
|
||||
dbsession: Session,
|
||||
|
@ -0,0 +1,30 @@
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import ChatGptModelsEnum
|
||||
from core.bot.models.chatgpt import ChatGptModels
|
||||
from core.bot.services import ChatGptService
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.asyncio,
|
||||
pytest.mark.enable_socket,
|
||||
]
|
||||
|
||||
|
||||
async def test_models_update(dbsession: Session) -> None:
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
|
||||
assert len(models) == 0
|
||||
|
||||
chatgpt_service = ChatGptService.build()
|
||||
|
||||
await chatgpt_service.update_chatgpt_models()
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
|
||||
model_priorities = {model.model: model.priority for model in models}
|
||||
|
||||
assert len(models) == len(ChatGptModelsEnum.base_models_priority())
|
||||
|
||||
for model_priority in ChatGptModelsEnum.base_models_priority():
|
||||
assert model_priorities[model_priority["model"]] == model_priority["priority"]
|
@ -1,13 +1,15 @@
|
||||
import uuid
|
||||
|
||||
import factory
|
||||
|
||||
from core.auth.models.users import User
|
||||
from core.auth.models.users import AccessToken, User, UserQuestionCount
|
||||
from tests.integration.factories.utils import BaseModelFactory
|
||||
|
||||
|
||||
class UserFactory(BaseModelFactory):
|
||||
id = factory.Sequence(lambda n: n + 1)
|
||||
email = factory.Faker("email")
|
||||
username = factory.Faker("user_name", locale="en_EN")
|
||||
username = factory.Faker("user_name", locale="en")
|
||||
first_name = factory.Faker("word")
|
||||
last_name = factory.Faker("word")
|
||||
ban_reason = factory.Faker("text", max_nb_chars=100)
|
||||
@ -18,3 +20,21 @@ class UserFactory(BaseModelFactory):
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
|
||||
|
||||
class AccessTokenFactory(BaseModelFactory):
|
||||
user_id = factory.Sequence(lambda n: n + 1)
|
||||
token = factory.LazyAttribute(lambda o: str(uuid.uuid4()))
|
||||
created_at = factory.Faker("past_datetime")
|
||||
|
||||
class Meta:
|
||||
model = AccessToken
|
||||
|
||||
|
||||
class UserQuestionCountFactory(BaseModelFactory):
|
||||
user_id = factory.Sequence(lambda n: n + 1)
|
||||
question_count = factory.Faker("random_int")
|
||||
last_question_at = factory.Faker("past_datetime")
|
||||
|
||||
class Meta:
|
||||
model = UserQuestionCount
|
||||
|
@ -9,11 +9,11 @@ from settings.config import settings
|
||||
|
||||
@contextmanager
|
||||
def mocked_ask_question_api(
|
||||
host: str, return_value: Response | None = None, side_effect: Any | None = None
|
||||
host: str, return_value: Response | None = None, side_effect: Any | None = None, assert_all_called: bool = True
|
||||
) -> Iterator[respx.MockRouter]:
|
||||
with respx.mock(
|
||||
assert_all_mocked=True,
|
||||
assert_all_called=True,
|
||||
assert_all_called=assert_all_called,
|
||||
base_url=host,
|
||||
) as respx_mock:
|
||||
ask_question_route = respx_mock.post(url=settings.chatgpt_backend_url, name="ask_question")
|
||||
|
69
poetry.lock
generated
69
poetry.lock
generated
@ -607,25 +607,25 @@ toml = ["tomli"]
|
||||
|
||||
[[package]]
|
||||
name = "cyclonedx-python-lib"
|
||||
version = "5.2.0"
|
||||
version = "6.3.0"
|
||||
description = "Python library for CycloneDX"
|
||||
optional = false
|
||||
python-versions = ">=3.8,<4.0"
|
||||
files = [
|
||||
{file = "cyclonedx_python_lib-5.2.0-py3-none-any.whl", hash = "sha256:1b43065205cdc53490c825fcfbda73142b758aa40ca169c968e342e88d06734f"},
|
||||
{file = "cyclonedx_python_lib-5.2.0.tar.gz", hash = "sha256:b9ebf2c0520721d2f8ee16aadc2bbb9d4e015862c84ab1691a49b177f3014d99"},
|
||||
{file = "cyclonedx_python_lib-6.3.0-py3-none-any.whl", hash = "sha256:0e73c1036c2f7fc67adc28aef807e6b44340ea70202aab197fb06b20ea165de8"},
|
||||
{file = "cyclonedx_python_lib-6.3.0.tar.gz", hash = "sha256:82f2489de3c0cadad5af1ad7fa6b6a185f985746370245d38769699c734533c6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
license-expression = ">=30,<31"
|
||||
packageurl-python = ">=0.11"
|
||||
py-serializable = ">=0.15,<0.16"
|
||||
py-serializable = ">=0.16,<0.18"
|
||||
sortedcontainers = ">=2.4.0,<3.0.0"
|
||||
|
||||
[package.extras]
|
||||
json-validation = ["jsonschema[format] (>=4.18,<5.0)"]
|
||||
validation = ["jsonschema[format] (>=4.18,<5.0)", "lxml (>=4,<5)"]
|
||||
xml-validation = ["lxml (>=4,<5)"]
|
||||
validation = ["jsonschema[format] (>=4.18,<5.0)", "lxml (>=4,<6)"]
|
||||
xml-validation = ["lxml (>=4,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "decorator"
|
||||
@ -761,19 +761,19 @@ typing = ["typing-extensions (>=4.8)"]
|
||||
|
||||
[[package]]
|
||||
name = "flake8"
|
||||
version = "6.1.0"
|
||||
version = "7.0.0"
|
||||
description = "the modular source code checker: pep8 pyflakes and co"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1"
|
||||
files = [
|
||||
{file = "flake8-6.1.0-py2.py3-none-any.whl", hash = "sha256:ffdfce58ea94c6580c77888a86506937f9a1a227dfcd15f245d694ae20a6b6e5"},
|
||||
{file = "flake8-6.1.0.tar.gz", hash = "sha256:d5b3857f07c030bdb5bf41c7f53799571d75c4491748a3adcd47de929e34cd23"},
|
||||
{file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"},
|
||||
{file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mccabe = ">=0.7.0,<0.8.0"
|
||||
pycodestyle = ">=2.11.0,<2.12.0"
|
||||
pyflakes = ">=3.1.0,<3.2.0"
|
||||
pyflakes = ">=3.2.0,<3.3.0"
|
||||
|
||||
[[package]]
|
||||
name = "flake8-aaa"
|
||||
@ -954,25 +954,6 @@ files = [
|
||||
[package.dependencies]
|
||||
flake8 = ">3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "flake8-noqa"
|
||||
version = "1.3.2"
|
||||
description = "Flake8 noqa comment validation"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "flake8-noqa-1.3.2.tar.gz", hash = "sha256:b12ddf7b02dedabaca0f807cb436ea7992cc0106cb6fa41e997ad45a0a3bf754"},
|
||||
{file = "flake8_noqa-1.3.2-py3-none-any.whl", hash = "sha256:a2c139c4cc223f268fb262cd32a46fa72f509225d038058baa87c0ff8ac4d348"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
flake8 = ">=3.8.0,<7.0"
|
||||
typing-extensions = ">=3.7.4.2"
|
||||
|
||||
[package.extras]
|
||||
dev = ["flake8 (>=3.8.0,<6.0.0)", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-polyfill", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming"]
|
||||
test = ["flake8-docstrings"]
|
||||
|
||||
[[package]]
|
||||
name = "flake8-plugin-utils"
|
||||
version = "1.3.3"
|
||||
@ -1410,13 +1391,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "ipython"
|
||||
version = "8.19.0"
|
||||
version = "8.20.0"
|
||||
description = "IPython: Productive Interactive Computing"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
files = [
|
||||
{file = "ipython-8.19.0-py3-none-any.whl", hash = "sha256:2f55d59370f59d0d2b2212109fe0e6035cfea436b1c0e6150ad2244746272ec5"},
|
||||
{file = "ipython-8.19.0.tar.gz", hash = "sha256:ac4da4ecf0042fb4e0ce57c60430c2db3c719fa8bdf92f8631d6bd8a5785d1f0"},
|
||||
{file = "ipython-8.20.0-py3-none-any.whl", hash = "sha256:bc9716aad6f29f36c449e30821c9dd0c1c1a7b59ddcc26931685b87b4c569619"},
|
||||
{file = "ipython-8.20.0.tar.gz", hash = "sha256:2f21bd3fc1d51550c89ee3944ae04bbc7bc79e129ea0937da6e6c68bfdbf117a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -2089,18 +2070,18 @@ pip = "*"
|
||||
|
||||
[[package]]
|
||||
name = "pip-audit"
|
||||
version = "2.6.2"
|
||||
version = "2.6.3"
|
||||
description = "A tool for scanning Python environments for known vulnerabilities"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pip_audit-2.6.2-py3-none-any.whl", hash = "sha256:ac3a4b6e977ef2c574aa8d19a5d71d12201bdb65bba2d67d9df49f53f0be5e7d"},
|
||||
{file = "pip_audit-2.6.2.tar.gz", hash = "sha256:0bbd023a199a104b29f949f063a872d41113b5a9048285666820fa35a76a7794"},
|
||||
{file = "pip_audit-2.6.3-py3-none-any.whl", hash = "sha256:216983210db4a15393f9e80e4d24a805f5767e4c8e0c31fc70c336acc629613b"},
|
||||
{file = "pip_audit-2.6.3.tar.gz", hash = "sha256:bd796066f69684b2f4fc2c2b6d222589e23190db0bbde069cea5c2b0be2cc57d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
CacheControl = {version = ">=0.13.0", extras = ["filecache"]}
|
||||
cyclonedx-python-lib = ">=4,<6"
|
||||
cyclonedx-python-lib = ">=5,<7"
|
||||
html5lib = ">=1.1"
|
||||
packaging = ">=23.0.0"
|
||||
pip-api = ">=0.0.28"
|
||||
@ -2112,7 +2093,7 @@ toml = ">=0.10"
|
||||
[package.extras]
|
||||
dev = ["build", "bump (>=1.3.2)", "pip-audit[doc,lint,test]"]
|
||||
doc = ["pdoc"]
|
||||
lint = ["interrogate", "mypy", "ruff (<0.1.9)", "types-html5lib", "types-requests", "types-toml"]
|
||||
lint = ["interrogate", "mypy", "ruff (<0.1.12)", "types-html5lib", "types-requests", "types-toml"]
|
||||
test = ["coverage[toml] (>=7.0,!=7.3.3,<8.0)", "pretend", "pytest", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
@ -2216,13 +2197,13 @@ tests = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "py-serializable"
|
||||
version = "0.15.0"
|
||||
version = "0.17.1"
|
||||
description = "Library for serializing and deserializing Python Objects to and from JSON and XML."
|
||||
optional = false
|
||||
python-versions = ">=3.7,<4.0"
|
||||
files = [
|
||||
{file = "py-serializable-0.15.0.tar.gz", hash = "sha256:8fc41457d8ee5f5c5a12f41fd87bf1a4f2ecf9da39fee92059b728e78f320771"},
|
||||
{file = "py_serializable-0.15.0-py3-none-any.whl", hash = "sha256:d3f1201b33420c481aa83f7860c7bf2c2f036ba3ea82b6e15a96696457c36cd2"},
|
||||
{file = "py-serializable-0.17.1.tar.gz", hash = "sha256:875bb9c01df77f563dfcd1e75bb4244b5596083d3aad4ccd3fb63e1f5a9d3e5f"},
|
||||
{file = "py_serializable-0.17.1-py3-none-any.whl", hash = "sha256:389c2254d912bec3a44acdac667c947d73c59325050d5ae66386e1ed7108a45a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -2407,13 +2388,13 @@ resolved_reference = "996cec42e9621701edb83354232b2c0ca0121560"
|
||||
|
||||
[[package]]
|
||||
name = "pyflakes"
|
||||
version = "3.1.0"
|
||||
version = "3.2.0"
|
||||
description = "passive checker of Python programs"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyflakes-3.1.0-py2.py3-none-any.whl", hash = "sha256:4132f6d49cb4dae6819e5379898f2b8cce3c5f23994194c24b77d5da2e36f774"},
|
||||
{file = "pyflakes-3.1.0.tar.gz", hash = "sha256:a0aae034c444db0071aa077972ba4768d40c830d9539fd45bf4cd3f8f6992efc"},
|
||||
{file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"},
|
||||
{file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3732,4 +3713,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "3a78cb3473202950a11ebbe28e5662f6a1a57508a0e2703761c3c448744b0bed"
|
||||
content-hash = "30a4b709eed24d6b55f78015994cdf1901d0d7279fefa86b37e9490e189a2c52"
|
||||
|
@ -1,11 +1,11 @@
|
||||
[tool.poetry]
|
||||
name = "chat_gpt_bot"
|
||||
version = "1.4.1"
|
||||
version = "1.5.1"
|
||||
description = "Bot to integrated with Chat gpt"
|
||||
authors = ["Dmitry Afanasyev <Balshbox@gmail.com>"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.7.0"]
|
||||
requires = ["poetry-core>=1.7.1"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
@ -84,7 +84,6 @@ autoflake = "^2.2"
|
||||
flake8-aaa = "^0.17.0"
|
||||
flake8-variables-names = "^0.0.6"
|
||||
flake8-deprecated = "^2.2.1"
|
||||
flake8-noqa = "^1.3.2"
|
||||
flake8-annotations-complexity = "^0.0.8"
|
||||
flake8-useless-assert = "^0.4.4"
|
||||
flake8-newspaper-style = "^0.4.3"
|
||||
|
@ -2,12 +2,8 @@
|
||||
|
||||
set -e
|
||||
|
||||
if [ -f shared/${DB_NAME:-chatgpt.db} ]
|
||||
then
|
||||
alembic downgrade -1 && alembic upgrade "head"
|
||||
else
|
||||
alembic upgrade "head"
|
||||
fi
|
||||
alembic upgrade "head" && \
|
||||
python3 core/bot/managment/update_gpt_models.py
|
||||
|
||||
echo "starting the bot"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user