mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2025-09-10 17:20:41 +03:00
Compare commits
4 Commits
8266342214
...
28d01f551b
Author | SHA1 | Date | |
---|---|---|---|
|
28d01f551b | ||
|
28895f3510 | ||
|
7cbe7b7c50 | ||
|
de55d873f9 |
@ -110,6 +110,7 @@ Init alembic
|
|||||||
cd bot_microservice
|
cd bot_microservice
|
||||||
alembic revision --autogenerate -m 'create_chatgpt_table'
|
alembic revision --autogenerate -m 'create_chatgpt_table'
|
||||||
alembic upgrade head
|
alembic upgrade head
|
||||||
|
python3 core/bot/managment/update_gpt_models.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@ -130,6 +131,7 @@ Run migrations
|
|||||||
```bash
|
```bash
|
||||||
cd ./bot_microservice # alembic root
|
cd ./bot_microservice # alembic root
|
||||||
alembic --config ./alembic.ini upgrade head
|
alembic --config ./alembic.ini upgrade head
|
||||||
|
python3 core/bot/managment/update_gpt_models.py
|
||||||
```
|
```
|
||||||
|
|
||||||
**downgrade to 0001_create_chatgpt_table revision:**
|
**downgrade to 0001_create_chatgpt_table revision:**
|
||||||
@ -166,4 +168,4 @@ alembic downgrade base
|
|||||||
- [x] add sentry
|
- [x] add sentry
|
||||||
- [x] add graylog integration and availability to log to file
|
- [x] add graylog integration and availability to log to file
|
||||||
- [x] add user model
|
- [x] add user model
|
||||||
- [ ] add messages statistic
|
- [x] add messages statistic
|
||||||
|
1
bot_microservice/api/bot/constants.py
Normal file
1
bot_microservice/api/bot/constants.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
BOT_ACCESS_API_HEADER = "BOT-API-KEY"
|
@ -3,13 +3,19 @@ from starlette import status
|
|||||||
from starlette.responses import JSONResponse, Response
|
from starlette.responses import JSONResponse, Response
|
||||||
from telegram import Update
|
from telegram import Update
|
||||||
|
|
||||||
from api.bot.deps import get_bot_queue, get_chatgpt_service, get_update_from_request
|
from api.bot.deps import (
|
||||||
|
get_access_to_bot_api_or_403,
|
||||||
|
get_bot_queue,
|
||||||
|
get_chatgpt_service,
|
||||||
|
get_update_from_request,
|
||||||
|
)
|
||||||
from api.bot.serializers import (
|
from api.bot.serializers import (
|
||||||
ChatGptModelSerializer,
|
ChatGptModelSerializer,
|
||||||
ChatGptModelsPrioritySerializer,
|
ChatGptModelsPrioritySerializer,
|
||||||
GETChatGptModelsSerializer,
|
GETChatGptModelsSerializer,
|
||||||
LightChatGptModel,
|
LightChatGptModel,
|
||||||
)
|
)
|
||||||
|
from api.exceptions import PermissionMissingResponse
|
||||||
from core.bot.app import BotQueue
|
from core.bot.app import BotQueue
|
||||||
from core.bot.services import ChatGptService
|
from core.bot.services import ChatGptService
|
||||||
from settings.config import settings
|
from settings.config import settings
|
||||||
@ -53,6 +59,10 @@ async def models_list(
|
|||||||
@router.put(
|
@router.put(
|
||||||
"/chatgpt/models/{model_id}/priority",
|
"/chatgpt/models/{model_id}/priority",
|
||||||
name="bot:change_model_priority",
|
name="bot:change_model_priority",
|
||||||
|
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||||
|
responses={
|
||||||
|
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||||
|
},
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
status_code=status.HTTP_202_ACCEPTED,
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
summary="change gpt model priority",
|
summary="change gpt model priority",
|
||||||
@ -69,6 +79,10 @@ async def change_model_priority(
|
|||||||
@router.put(
|
@router.put(
|
||||||
"/chatgpt/models/priority/reset",
|
"/chatgpt/models/priority/reset",
|
||||||
name="bot:reset_models_priority",
|
name="bot:reset_models_priority",
|
||||||
|
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||||
|
responses={
|
||||||
|
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||||
|
},
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
status_code=status.HTTP_202_ACCEPTED,
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
summary="reset all model priority to default",
|
summary="reset all model priority to default",
|
||||||
@ -83,6 +97,11 @@ async def reset_models_priority(
|
|||||||
@router.post(
|
@router.post(
|
||||||
"/chatgpt/models",
|
"/chatgpt/models",
|
||||||
name="bot:add_new_model",
|
name="bot:add_new_model",
|
||||||
|
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||||
|
responses={
|
||||||
|
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||||
|
status.HTTP_201_CREATED: {"model": ChatGptModelSerializer},
|
||||||
|
},
|
||||||
response_model=ChatGptModelSerializer,
|
response_model=ChatGptModelSerializer,
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
summary="add new model",
|
summary="add new model",
|
||||||
@ -100,6 +119,10 @@ async def add_new_model(
|
|||||||
@router.delete(
|
@router.delete(
|
||||||
"/chatgpt/models/{model_id}",
|
"/chatgpt/models/{model_id}",
|
||||||
name="bot:delete_gpt_model",
|
name="bot:delete_gpt_model",
|
||||||
|
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||||
|
responses={
|
||||||
|
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||||
|
},
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
status_code=status.HTTP_204_NO_CONTENT,
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
summary="delete gpt model",
|
summary="delete gpt model",
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
from fastapi import Depends
|
from fastapi import Depends, Header, HTTPException
|
||||||
|
from starlette import status
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from telegram import Update
|
from telegram import Update
|
||||||
|
|
||||||
from api.auth.deps import get_user_service
|
from api.auth.deps import get_user_service
|
||||||
|
from api.bot.constants import BOT_ACCESS_API_HEADER
|
||||||
from api.deps import get_database
|
from api.deps import get_database
|
||||||
from core.auth.services import UserService
|
from core.auth.services import UserService
|
||||||
from core.bot.app import BotApplication, BotQueue
|
from core.bot.app import BotApplication, BotQueue
|
||||||
from core.bot.repository import ChatGPTRepository
|
from core.bot.repository import ChatGPTRepository
|
||||||
from core.bot.services import ChatGptService
|
from core.bot.services import ChatGptService
|
||||||
from infra.database.db_adapter import Database
|
from infra.database.db_adapter import Database
|
||||||
from settings.config import AppSettings, get_settings
|
from settings.config import AppSettings, get_settings, settings
|
||||||
|
|
||||||
|
|
||||||
def get_bot_app(request: Request) -> BotApplication:
|
def get_bot_app(request: Request) -> BotApplication:
|
||||||
@ -40,3 +42,13 @@ def get_chatgpt_service(
|
|||||||
user_service: UserService = Depends(get_user_service),
|
user_service: UserService = Depends(get_user_service),
|
||||||
) -> ChatGptService:
|
) -> ChatGptService:
|
||||||
return ChatGptService(repository=chatgpt_repository, user_service=user_service)
|
return ChatGptService(repository=chatgpt_repository, user_service=user_service)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_access_to_bot_api_or_403(
|
||||||
|
bot_api_key: str | None = Header(None, alias=BOT_ACCESS_API_HEADER, description="Ключ доступа до API бота"),
|
||||||
|
user_service: UserService = Depends(get_user_service),
|
||||||
|
) -> None:
|
||||||
|
access_token = await user_service.get_user_access_token_by_username(settings.SUPERUSER)
|
||||||
|
|
||||||
|
if not access_token or access_token != bot_api_key:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate api header")
|
||||||
|
@ -1,11 +1,39 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
|
from starlette import status
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from api.base_schemas import BaseError, BaseResponse
|
from api.base_schemas import BaseError, BaseResponse
|
||||||
|
|
||||||
|
|
||||||
class BaseAPIException(Exception):
|
class BaseAPIException(Exception):
|
||||||
pass
|
_content_type: str = "application/json"
|
||||||
|
model: type[BaseResponse] = BaseResponse
|
||||||
|
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
title: str | None = None
|
||||||
|
type: str | None = None
|
||||||
|
detail: str | None = None
|
||||||
|
instance: str | None = None
|
||||||
|
headers: dict[str, str] | None = None
|
||||||
|
|
||||||
|
def __init__(self, **ctx: Any) -> None:
|
||||||
|
self.__dict__ = ctx
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def example(cls) -> dict[str, Any] | None:
|
||||||
|
if isinstance(cls.model.Config.schema_extra, dict): # type: ignore[attr-defined]
|
||||||
|
return cls.model.Config.schema_extra.get("example") # type: ignore[attr-defined]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response(cls) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"model": cls.model,
|
||||||
|
"content": {
|
||||||
|
cls._content_type: cls.model.Config.schema_extra, # type: ignore[attr-defined]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class InternalServerError(BaseError):
|
class InternalServerError(BaseError):
|
||||||
@ -28,6 +56,31 @@ class InternalServerErrorResponse(BaseResponse):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionMissing(BaseError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionMissingResponse(BaseResponse):
|
||||||
|
error: PermissionMissing
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"status": 403,
|
||||||
|
"error": {
|
||||||
|
"type": "PermissionMissing",
|
||||||
|
"title": "Permission required for this endpoint is missing",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionMissingError(BaseAPIException):
|
||||||
|
model = PermissionMissingResponse
|
||||||
|
status_code = status.HTTP_403_FORBIDDEN
|
||||||
|
title: str = "Permission required for this endpoint is missing"
|
||||||
|
|
||||||
|
|
||||||
async def internal_server_error_handler(_request: Request, _exception: Exception) -> ORJSONResponse:
|
async def internal_server_error_handler(_request: Request, _exception: Exception) -> ORJSONResponse:
|
||||||
error = InternalServerError(title="Something went wrong!", type="InternalServerError")
|
error = InternalServerError(title="Something went wrong!", type="InternalServerError")
|
||||||
response = InternalServerErrorResponse(status=500, error=error).model_dump(exclude_unset=True)
|
response = InternalServerErrorResponse(status=500, error=error).model_dump(exclude_unset=True)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String
|
from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String
|
||||||
@ -18,15 +19,13 @@ class User(Base):
|
|||||||
hashed_password: Mapped[str] = mapped_column(String(length=1024), nullable=False)
|
hashed_password: Mapped[str] = mapped_column(String(length=1024), nullable=False)
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, default=datetime.now)
|
||||||
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
|
|
||||||
)
|
|
||||||
|
|
||||||
user_question_count: Mapped["UserQuestionCount"] = relationship(
|
user_question_count: Mapped["UserQuestionCount"] = relationship(
|
||||||
"UserQuestionCount",
|
"UserQuestionCount",
|
||||||
primaryjoin="UserQuestionCount.user_id == User.id",
|
primaryjoin="UserQuestionCount.user_id == User.id",
|
||||||
backref="user",
|
backref="user",
|
||||||
lazy="selectin",
|
lazy="noload",
|
||||||
uselist=False,
|
uselist=False,
|
||||||
cascade="delete",
|
cascade="delete",
|
||||||
)
|
)
|
||||||
@ -37,6 +36,12 @@ class User(Base):
|
|||||||
return self.user_question_count.question_count
|
return self.user_question_count.question_count
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_question_at(self) -> str | None:
|
||||||
|
if self.user_question_count:
|
||||||
|
return self.user_question_count.last_question_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(
|
def build(
|
||||||
cls,
|
cls,
|
||||||
@ -68,14 +73,27 @@ class AccessToken(Base):
|
|||||||
__tablename__ = "access_token" # type: ignore[assignment]
|
__tablename__ = "access_token" # type: ignore[assignment]
|
||||||
|
|
||||||
user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
|
user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
|
||||||
token: Mapped[str] = mapped_column(String(length=42), primary_key=True)
|
token: Mapped[str] = mapped_column(String(length=42), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, default=datetime.now)
|
||||||
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
|
|
||||||
|
user: Mapped["User"] = relationship(
|
||||||
|
"User",
|
||||||
|
backref="access_token",
|
||||||
|
lazy="selectin",
|
||||||
|
uselist=False,
|
||||||
|
cascade="expunge",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def username(self) -> str:
|
||||||
|
if self.user:
|
||||||
|
return self.user.username
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class UserQuestionCount(Base):
|
class UserQuestionCount(Base):
|
||||||
__tablename__ = "user_question_count" # type: ignore[assignment]
|
__tablename__ = "user_question_count" # type: ignore[assignment]
|
||||||
|
|
||||||
user_id: Mapped[int] = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), primary_key=True)
|
user_id: Mapped[int] = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), primary_key=True)
|
||||||
question_count: Mapped[int] = mapped_column(INTEGER, default=0, nullable=False)
|
question_count: Mapped[int] = mapped_column(INTEGER, default=0, nullable=False)
|
||||||
|
last_question_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, default=datetime.now)
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.dialects.sqlite import insert
|
from sqlalchemy.dialects.sqlite import insert
|
||||||
from sqlalchemy.orm import load_only
|
from sqlalchemy.orm import load_only
|
||||||
|
|
||||||
from core.auth.dto import UserIsBannedDTO
|
from core.auth.dto import UserIsBannedDTO
|
||||||
from core.auth.models.users import User, UserQuestionCount
|
from core.auth.models.users import AccessToken, User, UserQuestionCount
|
||||||
from infra.database.db_adapter import Database
|
from infra.database.db_adapter import Database
|
||||||
|
|
||||||
|
|
||||||
@ -69,10 +69,18 @@ class UserRepository:
|
|||||||
UserQuestionCount.get_real_column_name(
|
UserQuestionCount.get_real_column_name(
|
||||||
UserQuestionCount.question_count.key
|
UserQuestionCount.question_count.key
|
||||||
): UserQuestionCount.question_count
|
): UserQuestionCount.question_count
|
||||||
+ 1
|
+ 1,
|
||||||
|
UserQuestionCount.get_real_column_name(UserQuestionCount.last_question_at.key): func.now(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self.db.session() as session:
|
async with self.db.session() as session:
|
||||||
await session.execute(query)
|
await session.execute(query)
|
||||||
|
|
||||||
|
async def get_user_access_token(self, username: str | None) -> str | None:
|
||||||
|
query = select(AccessToken.token).join(AccessToken.user).where(User.username == username)
|
||||||
|
|
||||||
|
async with self.db.session() as session:
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.scalar()
|
||||||
|
@ -62,6 +62,9 @@ class UserService:
|
|||||||
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
|
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
|
||||||
return await self.repository.check_user_is_banned(user_id)
|
return await self.repository.check_user_is_banned(user_id)
|
||||||
|
|
||||||
|
async def get_user_access_token_by_username(self, username: str | None) -> str | None:
|
||||||
|
return await self.repository.get_user_access_token(username)
|
||||||
|
|
||||||
|
|
||||||
def check_user_is_banned(func: Any) -> Any:
|
def check_user_is_banned(func: Any) -> Any:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
|
0
bot_microservice/core/bot/managment/__init__.py
Normal file
0
bot_microservice/core/bot/managment/__init__.py
Normal file
13
bot_microservice/core/bot/managment/update_gpt_models.py
Normal file
13
bot_microservice/core/bot/managment/update_gpt_models.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
|
||||||
|
from core.bot.services import ChatGptService
|
||||||
|
|
||||||
|
chatgpt_service = ChatGptService.build()
|
||||||
|
asyncio.run(chatgpt_service.update_chatgpt_models())
|
||||||
|
logger.info("chatgpt models has been updated")
|
@ -38,6 +38,18 @@ class ChatGPTRepository:
|
|||||||
async with self.db.session() as session:
|
async with self.db.session() as session:
|
||||||
await session.execute(query)
|
await session.execute(query)
|
||||||
|
|
||||||
|
async def delete_all_chatgpt_models(self) -> None:
|
||||||
|
query = delete(ChatGptModels)
|
||||||
|
async with self.db.session() as session:
|
||||||
|
await session.execute(query)
|
||||||
|
|
||||||
|
async def bulk_insert_chatgpt_models(self, models_priority: list[dict[str, Any]]) -> None:
|
||||||
|
models = [ChatGptModels(**model_priority) for model_priority in models_priority]
|
||||||
|
|
||||||
|
async with self.db.session() as session:
|
||||||
|
session.add_all(models)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def add_chatgpt_model(self, model: str, priority: int) -> dict[str, str | int]:
|
async def add_chatgpt_model(self, model: str, priority: int) -> dict[str, str | int]:
|
||||||
query = (
|
query = (
|
||||||
insert(ChatGptModels)
|
insert(ChatGptModels)
|
||||||
|
@ -14,7 +14,7 @@ from speech_recognition import (
|
|||||||
UnknownValueError as SpeechRecognizerError,
|
UnknownValueError as SpeechRecognizerError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from constants import AUDIO_SEGMENT_DURATION
|
from constants import AUDIO_SEGMENT_DURATION, ChatGptModelsEnum
|
||||||
from core.auth.models.users import User
|
from core.auth.models.users import User
|
||||||
from core.auth.repository import UserRepository
|
from core.auth.repository import UserRepository
|
||||||
from core.auth.services import UserService
|
from core.auth.services import UserService
|
||||||
@ -24,6 +24,77 @@ from infra.database.db_adapter import Database
|
|||||||
from settings.config import settings
|
from settings.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatGptService:
|
||||||
|
repository: ChatGPTRepository
|
||||||
|
user_service: UserService
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls) -> "ChatGptService":
|
||||||
|
db = Database(settings=settings)
|
||||||
|
repository = ChatGPTRepository(settings=settings, db=db)
|
||||||
|
user_repository = UserRepository(db=db)
|
||||||
|
user_service = UserService(repository=user_repository)
|
||||||
|
return ChatGptService(repository=repository, user_service=user_service)
|
||||||
|
|
||||||
|
async def get_chatgpt_models(self) -> Sequence[ChatGptModels]:
|
||||||
|
return await self.repository.get_chatgpt_models()
|
||||||
|
|
||||||
|
async def request_to_chatgpt(self, question: str | None) -> str:
|
||||||
|
question = question or "Привет!"
|
||||||
|
chatgpt_model = await self.get_current_chatgpt_model()
|
||||||
|
return await self.repository.ask_question(question=question, chatgpt_model=chatgpt_model)
|
||||||
|
|
||||||
|
async def request_to_chatgpt_microservice(self, question: str) -> Response:
|
||||||
|
chatgpt_model = await self.get_current_chatgpt_model()
|
||||||
|
return await self.repository.request_to_chatgpt_microservice(question=question, chatgpt_model=chatgpt_model)
|
||||||
|
|
||||||
|
async def get_current_chatgpt_model(self) -> str:
|
||||||
|
return await self.repository.get_current_chatgpt_model()
|
||||||
|
|
||||||
|
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
|
||||||
|
return await self.repository.change_chatgpt_model_priority(model_id=model_id, priority=priority)
|
||||||
|
|
||||||
|
async def update_chatgpt_models(self) -> None:
|
||||||
|
await self.repository.delete_all_chatgpt_models()
|
||||||
|
models = ChatGptModelsEnum.base_models_priority()
|
||||||
|
await self.repository.bulk_insert_chatgpt_models(models)
|
||||||
|
|
||||||
|
async def reset_all_chatgpt_models_priority(self) -> None:
|
||||||
|
return await self.repository.reset_all_chatgpt_models_priority()
|
||||||
|
|
||||||
|
async def add_chatgpt_model(self, gpt_model: str, priority: int) -> dict[str, str | int]:
|
||||||
|
return await self.repository.add_chatgpt_model(model=gpt_model, priority=priority)
|
||||||
|
|
||||||
|
async def delete_chatgpt_model(self, model_id: int) -> None:
|
||||||
|
return await self.repository.delete_chatgpt_model(model_id=model_id)
|
||||||
|
|
||||||
|
async def get_or_create_bot_user(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
email: str | None = None,
|
||||||
|
username: str | None = None,
|
||||||
|
first_name: str | None = None,
|
||||||
|
last_name: str | None = None,
|
||||||
|
ban_reason: str | None = None,
|
||||||
|
is_active: bool = True,
|
||||||
|
is_superuser: bool = False,
|
||||||
|
) -> User:
|
||||||
|
return await self.user_service.get_or_create_user_by_id(
|
||||||
|
user_id=user_id,
|
||||||
|
email=email,
|
||||||
|
username=username,
|
||||||
|
first_name=first_name,
|
||||||
|
last_name=last_name,
|
||||||
|
ban_reason=ban_reason,
|
||||||
|
is_active=is_active,
|
||||||
|
is_superuser=is_superuser,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_bot_user_message_count(self, user_id: int) -> None:
|
||||||
|
await self.user_service.update_user_message_count(user_id)
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextService:
|
class SpeechToTextService:
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, filename: str) -> None:
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
@ -87,69 +158,3 @@ class SpeechToTextService:
|
|||||||
except SpeechRecognizerError as error:
|
except SpeechRecognizerError as error:
|
||||||
logger.error("error recognizing text with google", error=error)
|
logger.error("error recognizing text with google", error=error)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ChatGptService:
|
|
||||||
repository: ChatGPTRepository
|
|
||||||
user_service: UserService
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build(cls) -> "ChatGptService":
|
|
||||||
db = Database(settings=settings)
|
|
||||||
repository = ChatGPTRepository(settings=settings, db=db)
|
|
||||||
user_repository = UserRepository(db=db)
|
|
||||||
user_service = UserService(repository=user_repository)
|
|
||||||
return ChatGptService(repository=repository, user_service=user_service)
|
|
||||||
|
|
||||||
async def get_chatgpt_models(self) -> Sequence[ChatGptModels]:
|
|
||||||
return await self.repository.get_chatgpt_models()
|
|
||||||
|
|
||||||
async def request_to_chatgpt(self, question: str | None) -> str:
|
|
||||||
question = question or "Привет!"
|
|
||||||
chatgpt_model = await self.get_current_chatgpt_model()
|
|
||||||
return await self.repository.ask_question(question=question, chatgpt_model=chatgpt_model)
|
|
||||||
|
|
||||||
async def request_to_chatgpt_microservice(self, question: str) -> Response:
|
|
||||||
chatgpt_model = await self.get_current_chatgpt_model()
|
|
||||||
return await self.repository.request_to_chatgpt_microservice(question=question, chatgpt_model=chatgpt_model)
|
|
||||||
|
|
||||||
async def get_current_chatgpt_model(self) -> str:
|
|
||||||
return await self.repository.get_current_chatgpt_model()
|
|
||||||
|
|
||||||
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
|
|
||||||
return await self.repository.change_chatgpt_model_priority(model_id=model_id, priority=priority)
|
|
||||||
|
|
||||||
async def reset_all_chatgpt_models_priority(self) -> None:
|
|
||||||
return await self.repository.reset_all_chatgpt_models_priority()
|
|
||||||
|
|
||||||
async def add_chatgpt_model(self, gpt_model: str, priority: int) -> dict[str, str | int]:
|
|
||||||
return await self.repository.add_chatgpt_model(model=gpt_model, priority=priority)
|
|
||||||
|
|
||||||
async def delete_chatgpt_model(self, model_id: int) -> None:
|
|
||||||
return await self.repository.delete_chatgpt_model(model_id=model_id)
|
|
||||||
|
|
||||||
async def get_or_create_bot_user(
|
|
||||||
self,
|
|
||||||
user_id: int,
|
|
||||||
email: str | None = None,
|
|
||||||
username: str | None = None,
|
|
||||||
first_name: str | None = None,
|
|
||||||
last_name: str | None = None,
|
|
||||||
ban_reason: str | None = None,
|
|
||||||
is_active: bool = True,
|
|
||||||
is_superuser: bool = False,
|
|
||||||
) -> User:
|
|
||||||
return await self.user_service.get_or_create_user_by_id(
|
|
||||||
user_id=user_id,
|
|
||||||
email=email,
|
|
||||||
username=username,
|
|
||||||
first_name=first_name,
|
|
||||||
last_name=last_name,
|
|
||||||
ban_reason=ban_reason,
|
|
||||||
is_active=is_active,
|
|
||||||
is_superuser=is_superuser,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def update_bot_user_message_count(self, user_id: int) -> None:
|
|
||||||
await self.user_service.update_user_message_count(user_id)
|
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqladmin import Admin, ModelView
|
from sqladmin import Admin, ModelView
|
||||||
|
from sqlalchemy import Select, desc, select
|
||||||
|
from sqlalchemy.orm import contains_eager, load_only
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
from core.auth.models.users import User
|
from core.auth.models.users import AccessToken, User, UserQuestionCount
|
||||||
from core.bot.models.chatgpt import ChatGptModels
|
from core.bot.models.chatgpt import ChatGptModels
|
||||||
from core.utils import build_uri
|
from core.utils import build_uri
|
||||||
from settings.config import settings
|
from settings.config import settings
|
||||||
@ -34,12 +37,44 @@ class UserAdmin(ModelView, model=User):
|
|||||||
User.is_active,
|
User.is_active,
|
||||||
User.ban_reason,
|
User.ban_reason,
|
||||||
"question_count",
|
"question_count",
|
||||||
|
"last_question_at",
|
||||||
User.created_at,
|
User.created_at,
|
||||||
]
|
]
|
||||||
column_sortable_list = [User.created_at]
|
|
||||||
column_default_sort = ("created_at", True)
|
column_default_sort = ("created_at", True)
|
||||||
form_widget_args = {"created_at": {"readonly": True}}
|
form_widget_args = {"created_at": {"readonly": True}}
|
||||||
|
|
||||||
|
def list_query(self, request: Request) -> Select[tuple[User]]:
|
||||||
|
return (
|
||||||
|
select(User)
|
||||||
|
.options(
|
||||||
|
load_only(
|
||||||
|
User.id,
|
||||||
|
User.username,
|
||||||
|
User.first_name,
|
||||||
|
User.last_name,
|
||||||
|
User.is_active,
|
||||||
|
User.created_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.outerjoin(User.user_question_count)
|
||||||
|
.options(
|
||||||
|
contains_eager(User.user_question_count).options(
|
||||||
|
load_only(
|
||||||
|
UserQuestionCount.question_count,
|
||||||
|
UserQuestionCount.last_question_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).order_by(desc(UserQuestionCount.question_count))
|
||||||
|
|
||||||
|
|
||||||
|
class AccessTokenAdmin(ModelView, model=AccessToken):
|
||||||
|
name = "API access token"
|
||||||
|
name_plural = "API access tokens"
|
||||||
|
column_list = [AccessToken.user_id, "username", AccessToken.token, AccessToken.created_at]
|
||||||
|
form_widget_args = {"created_at": {"readonly": True}}
|
||||||
|
|
||||||
|
|
||||||
def create_admin(application: "Application") -> Admin:
|
def create_admin(application: "Application") -> Admin:
|
||||||
admin = Admin(
|
admin = Admin(
|
||||||
@ -51,4 +86,5 @@ def create_admin(application: "Application") -> Admin:
|
|||||||
)
|
)
|
||||||
admin.add_view(ChatGptAdmin)
|
admin.add_view(ChatGptAdmin)
|
||||||
admin.add_view(UserAdmin)
|
admin.add_view(UserAdmin)
|
||||||
|
admin.add_view(AccessTokenAdmin)
|
||||||
return admin
|
return admin
|
||||||
|
@ -10,9 +10,8 @@ from datetime import datetime
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from alembic import op
|
from alembic import op
|
||||||
from sqlalchemy import TIMESTAMP
|
from sqlalchemy import TIMESTAMP
|
||||||
from sqlalchemy.dialects.sqlite import insert
|
|
||||||
|
|
||||||
from core.auth.models.users import User
|
from core.auth.models.users import AccessToken, User
|
||||||
from core.auth.utils import create_password_hash
|
from core.auth.utils import create_password_hash
|
||||||
from infra.database.deps import get_sync_session
|
from infra.database.deps import get_sync_session
|
||||||
from settings.config import settings
|
from settings.config import settings
|
||||||
@ -58,8 +57,14 @@ def upgrade() -> None:
|
|||||||
return
|
return
|
||||||
with get_sync_session() as session:
|
with get_sync_session() as session:
|
||||||
hashed_password = create_password_hash(password.get_secret_value())
|
hashed_password = create_password_hash(password.get_secret_value())
|
||||||
query = insert(User).values({"username": username, "hashed_password": hashed_password})
|
user = User(username=username, hashed_password=hashed_password)
|
||||||
session.execute(query)
|
session.add(user)
|
||||||
|
session.flush()
|
||||||
|
session.refresh(user)
|
||||||
|
|
||||||
|
access_token = AccessToken(user_id=user.id)
|
||||||
|
session.add(access_token)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
"""create chatgpt models
|
|
||||||
|
|
||||||
Revision ID: 0004_add_chatgpt_models
|
|
||||||
Revises: 0003_create_user_question_count_table
|
|
||||||
Create Date: 2025-10-05 20:44:05.414977
|
|
||||||
|
|
||||||
"""
|
|
||||||
from loguru import logger
|
|
||||||
from sqlalchemy import select, text
|
|
||||||
|
|
||||||
from constants import ChatGptModelsEnum
|
|
||||||
from core.bot.models.chatgpt import ChatGptModels
|
|
||||||
from infra.database.deps import get_sync_session
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision = "0004_add_chatgpt_models"
|
|
||||||
down_revision = "0003_create_user_question_count_table"
|
|
||||||
branch_labels: str | None = None
|
|
||||||
depends_on: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
with get_sync_session() as session:
|
|
||||||
query = select(ChatGptModels)
|
|
||||||
results = session.execute(query)
|
|
||||||
models = results.scalars().all()
|
|
||||||
|
|
||||||
if models:
|
|
||||||
return
|
|
||||||
models = []
|
|
||||||
for data in ChatGptModelsEnum.base_models_priority():
|
|
||||||
models.append(ChatGptModels(**data))
|
|
||||||
session.add_all(models)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
chatgpt_table_name = ChatGptModels.__tablename__
|
|
||||||
with get_sync_session() as session:
|
|
||||||
# Truncate doesn't exists for SQLite
|
|
||||||
session.execute(text(f"""DELETE FROM {chatgpt_table_name}""")) # noqa: S608
|
|
||||||
session.commit()
|
|
||||||
logger.info("chatgpt models table has been truncated", table=chatgpt_table_name)
|
|
@ -0,0 +1,29 @@
|
|||||||
|
"""add_last_question_at
|
||||||
|
|
||||||
|
Revision ID: 0004_add_last_question_at
|
||||||
|
Revises: 0003_create_user_question_count_table
|
||||||
|
Create Date: 2024-01-08 20:56:34.815976
|
||||||
|
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "0004_add_last_question_at"
|
||||||
|
down_revision = "0003_create_user_question_count_table"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index("ix_access_token_created_at", table_name="access_token")
|
||||||
|
op.add_column("user_question_count", sa.Column("last_question_at", sa.TIMESTAMP(timezone=True), nullable=False))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column("user_question_count", "last_question_at")
|
||||||
|
op.create_index("ix_access_token_created_at", "access_token", ["created_at"], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
@ -22,7 +22,7 @@ class Application:
|
|||||||
self.app = FastAPI(
|
self.app = FastAPI(
|
||||||
title="Chat gpt bot",
|
title="Chat gpt bot",
|
||||||
description="Bot for proxy to chat gpt in telegram",
|
description="Bot for proxy to chat gpt in telegram",
|
||||||
version="0.0.3",
|
version="1.1.12",
|
||||||
docs_url=build_uri([settings.api_prefix, "docs"]),
|
docs_url=build_uri([settings.api_prefix, "docs"]),
|
||||||
redoc_url=build_uri([settings.api_prefix, "redocs"]),
|
redoc_url=build_uri([settings.api_prefix, "redocs"]),
|
||||||
openapi_url=build_uri([settings.api_prefix, "openapi.json"]),
|
openapi_url=build_uri([settings.api_prefix, "openapi.json"]),
|
||||||
|
@ -11,6 +11,8 @@ TELEGRAM_API_TOKEN="123456789:AABBCCDDEEFFaabbccddeeff-1234567890"
|
|||||||
# set to true to start with webhook. Else bot will start on polling method
|
# set to true to start with webhook. Else bot will start on polling method
|
||||||
START_WITH_WEBHOOK="false"
|
START_WITH_WEBHOOK="false"
|
||||||
|
|
||||||
|
SUPERUSER="Superuser"
|
||||||
|
|
||||||
# ==== domain settings ====
|
# ==== domain settings ====
|
||||||
DOMAIN="http://localhost"
|
DOMAIN="http://localhost"
|
||||||
URL_PREFIX=
|
URL_PREFIX=
|
||||||
|
@ -6,7 +6,9 @@ from sqlalchemy import desc
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.bot.models.chatgpt import ChatGptModels
|
from core.bot.models.chatgpt import ChatGptModels
|
||||||
|
from settings.config import AppSettings
|
||||||
from tests.integration.factories.bot import ChatGptModelFactory
|
from tests.integration.factories.bot import ChatGptModelFactory
|
||||||
|
from tests.integration.factories.user import AccessTokenFactory, UserFactory
|
||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.asyncio,
|
pytest.mark.asyncio,
|
||||||
@ -51,11 +53,18 @@ async def test_change_chatgpt_model_priority(
|
|||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
rest_client: AsyncClient,
|
rest_client: AsyncClient,
|
||||||
faker: Faker,
|
faker: Faker,
|
||||||
|
test_settings: AppSettings,
|
||||||
) -> None:
|
) -> None:
|
||||||
model1 = ChatGptModelFactory(priority=0)
|
model1 = ChatGptModelFactory(priority=0)
|
||||||
model2 = ChatGptModelFactory(priority=1)
|
model2 = ChatGptModelFactory(priority=1)
|
||||||
priority = faker.random_int(min=2, max=7)
|
priority = faker.random_int(min=2, max=7)
|
||||||
response = await rest_client.put(url=f"/api/chatgpt/models/{model2.id}/priority", json={"priority": priority})
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
access_token = AccessTokenFactory(user_id=user.id)
|
||||||
|
response = await rest_client.put(
|
||||||
|
url=f"/api/chatgpt/models/{model2.id}/priority",
|
||||||
|
json={"priority": priority},
|
||||||
|
headers={"BOT-API-KEY": access_token.token},
|
||||||
|
)
|
||||||
assert response.status_code == 202
|
assert response.status_code == 202
|
||||||
|
|
||||||
upd_model1, upd_model2 = dbsession.query(ChatGptModels).order_by(ChatGptModels.priority).all()
|
upd_model1, upd_model2 = dbsession.query(ChatGptModels).order_by(ChatGptModels.priority).all()
|
||||||
@ -66,14 +75,53 @@ async def test_change_chatgpt_model_priority(
|
|||||||
assert upd_model2.priority == priority
|
assert upd_model2.priority == priority
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"bot_api_key_header",
|
||||||
|
[
|
||||||
|
pytest.param({"BOT-API-KEY": ""}, id="empty key in header"),
|
||||||
|
pytest.param({"BOT-API-KEY": "Hello World"}, id="incorrect token"),
|
||||||
|
pytest.param({"Hello-World": "lasdkj3-wer-weqwe_34"}, id="correct token but wrong api key"),
|
||||||
|
pytest.param({}, id="no api key header"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_cant_change_chatgpt_model_priority_with_wrong_api_key(
|
||||||
|
dbsession: Session,
|
||||||
|
rest_client: AsyncClient,
|
||||||
|
faker: Faker,
|
||||||
|
test_settings: AppSettings,
|
||||||
|
bot_api_key_header: dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
ChatGptModelFactory(priority=0)
|
||||||
|
model2 = ChatGptModelFactory(priority=1)
|
||||||
|
priority = faker.random_int(min=2, max=7)
|
||||||
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
AccessTokenFactory(user_id=user.id, token="lasdkj3-wer-weqwe_34") # noqa: S106
|
||||||
|
response = await rest_client.put(
|
||||||
|
url=f"/api/chatgpt/models/{model2.id}/priority",
|
||||||
|
json={"priority": priority},
|
||||||
|
headers=bot_api_key_header,
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
changed_model = dbsession.query(ChatGptModels).filter_by(id=model2.id).one()
|
||||||
|
assert changed_model.priority == model2.priority
|
||||||
|
|
||||||
|
|
||||||
async def test_reset_chatgpt_models_priority(
|
async def test_reset_chatgpt_models_priority(
|
||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
rest_client: AsyncClient,
|
rest_client: AsyncClient,
|
||||||
|
test_settings: AppSettings,
|
||||||
) -> None:
|
) -> None:
|
||||||
ChatGptModelFactory.create_batch(size=4)
|
ChatGptModelFactory.create_batch(size=4)
|
||||||
ChatGptModelFactory(priority=42)
|
ChatGptModelFactory(priority=42)
|
||||||
|
|
||||||
response = await rest_client.put(url="/api/chatgpt/models/priority/reset")
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
access_token = AccessTokenFactory(user_id=user.id)
|
||||||
|
|
||||||
|
response = await rest_client.put(
|
||||||
|
url="/api/chatgpt/models/priority/reset",
|
||||||
|
headers={"BOT-API-KEY": access_token.token},
|
||||||
|
)
|
||||||
assert response.status_code == 202
|
assert response.status_code == 202
|
||||||
|
|
||||||
models = dbsession.query(ChatGptModels).all()
|
models = dbsession.query(ChatGptModels).all()
|
||||||
@ -85,14 +133,49 @@ async def test_reset_chatgpt_models_priority(
|
|||||||
assert model.priority == 0
|
assert model.priority == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"bot_api_key_header",
|
||||||
|
[
|
||||||
|
pytest.param({"BOT-API-KEY": ""}, id="empty key in header"),
|
||||||
|
pytest.param({"BOT-API-KEY": "Hello World"}, id="incorrect token"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_cant_reset_chatgpt_models_priority_with_wrong_api_key(
|
||||||
|
dbsession: Session, rest_client: AsyncClient, test_settings: AppSettings, bot_api_key_header: dict[str, str]
|
||||||
|
) -> None:
|
||||||
|
chat_gpt_models = ChatGptModelFactory.create_batch(size=2)
|
||||||
|
model_with_highest_priority = ChatGptModelFactory(priority=42)
|
||||||
|
|
||||||
|
priorities = sorted([model.priority for model in chat_gpt_models] + [model_with_highest_priority.priority])
|
||||||
|
|
||||||
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
AccessTokenFactory(user_id=user.id)
|
||||||
|
|
||||||
|
response = await rest_client.put(
|
||||||
|
url="/api/chatgpt/models/priority/reset",
|
||||||
|
headers=bot_api_key_header,
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
models = dbsession.query(ChatGptModels).all()
|
||||||
|
|
||||||
|
changed_priorities = sorted([model.priority for model in models])
|
||||||
|
|
||||||
|
assert changed_priorities == priorities
|
||||||
|
|
||||||
|
|
||||||
async def test_create_new_chatgpt_model(
|
async def test_create_new_chatgpt_model(
|
||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
rest_client: AsyncClient,
|
rest_client: AsyncClient,
|
||||||
faker: Faker,
|
faker: Faker,
|
||||||
|
test_settings: AppSettings,
|
||||||
) -> None:
|
) -> None:
|
||||||
ChatGptModelFactory.create_batch(size=2)
|
ChatGptModelFactory.create_batch(size=2)
|
||||||
ChatGptModelFactory(priority=42)
|
ChatGptModelFactory(priority=42)
|
||||||
|
|
||||||
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
access_token = AccessTokenFactory(user_id=user.id)
|
||||||
|
|
||||||
model_name = "new-gpt-model"
|
model_name = "new-gpt-model"
|
||||||
model_priority = faker.random_int(min=1, max=5)
|
model_priority = faker.random_int(min=1, max=5)
|
||||||
|
|
||||||
@ -105,6 +188,7 @@ async def test_create_new_chatgpt_model(
|
|||||||
"model": model_name,
|
"model": model_name,
|
||||||
"priority": model_priority,
|
"priority": model_priority,
|
||||||
},
|
},
|
||||||
|
headers={"BOT-API-KEY": access_token.token},
|
||||||
)
|
)
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
|
|
||||||
@ -121,13 +205,47 @@ async def test_create_new_chatgpt_model(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cant_create_new_chatgpt_model_with_wrong_api_key(
|
||||||
|
dbsession: Session,
|
||||||
|
rest_client: AsyncClient,
|
||||||
|
faker: Faker,
|
||||||
|
test_settings: AppSettings,
|
||||||
|
) -> None:
|
||||||
|
ChatGptModelFactory.create_batch(size=2)
|
||||||
|
ChatGptModelFactory(priority=42)
|
||||||
|
|
||||||
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
AccessTokenFactory(user_id=user.id)
|
||||||
|
|
||||||
|
model_name = "new-gpt-model"
|
||||||
|
model_priority = faker.random_int(min=1, max=5)
|
||||||
|
|
||||||
|
models = dbsession.query(ChatGptModels).all()
|
||||||
|
assert len(models) == 3
|
||||||
|
|
||||||
|
response = await rest_client.post(
|
||||||
|
url="/api/chatgpt/models",
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"priority": model_priority,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
models = dbsession.query(ChatGptModels).all()
|
||||||
|
assert len(models) == 3
|
||||||
|
|
||||||
|
|
||||||
async def test_add_existing_chatgpt_model(
|
async def test_add_existing_chatgpt_model(
|
||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
rest_client: AsyncClient,
|
rest_client: AsyncClient,
|
||||||
faker: Faker,
|
faker: Faker,
|
||||||
|
test_settings: AppSettings,
|
||||||
) -> None:
|
) -> None:
|
||||||
ChatGptModelFactory.create_batch(size=2)
|
ChatGptModelFactory.create_batch(size=2)
|
||||||
model = ChatGptModelFactory(priority=42)
|
model = ChatGptModelFactory(priority=42)
|
||||||
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
access_token = AccessTokenFactory(user_id=user.id)
|
||||||
|
|
||||||
model_name = model.model
|
model_name = model.model
|
||||||
model_priority = faker.random_int(min=1, max=5)
|
model_priority = faker.random_int(min=1, max=5)
|
||||||
@ -141,6 +259,7 @@ async def test_add_existing_chatgpt_model(
|
|||||||
"model": model_name,
|
"model": model_name,
|
||||||
"priority": model_priority,
|
"priority": model_priority,
|
||||||
},
|
},
|
||||||
|
headers={"BOT-API-KEY": access_token.token},
|
||||||
)
|
)
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
|
|
||||||
@ -151,17 +270,48 @@ async def test_add_existing_chatgpt_model(
|
|||||||
async def test_delete_chatgpt_model(
|
async def test_delete_chatgpt_model(
|
||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
rest_client: AsyncClient,
|
rest_client: AsyncClient,
|
||||||
|
test_settings: AppSettings,
|
||||||
) -> None:
|
) -> None:
|
||||||
ChatGptModelFactory.create_batch(size=2)
|
ChatGptModelFactory.create_batch(size=2)
|
||||||
model = ChatGptModelFactory(priority=42)
|
model = ChatGptModelFactory(priority=42)
|
||||||
|
|
||||||
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
access_token = AccessTokenFactory(user_id=user.id)
|
||||||
|
|
||||||
models = dbsession.query(ChatGptModels).all()
|
models = dbsession.query(ChatGptModels).all()
|
||||||
assert len(models) == 3
|
assert len(models) == 3
|
||||||
|
|
||||||
response = await rest_client.delete(url=f"/api/chatgpt/models/{model.id}")
|
response = await rest_client.delete(
|
||||||
|
url=f"/api/chatgpt/models/{model.id}",
|
||||||
|
headers={"BOT-API-KEY": access_token.token},
|
||||||
|
)
|
||||||
assert response.status_code == 204
|
assert response.status_code == 204
|
||||||
|
|
||||||
models = dbsession.query(ChatGptModels).all()
|
models = dbsession.query(ChatGptModels).all()
|
||||||
assert len(models) == 2
|
assert len(models) == 2
|
||||||
|
|
||||||
assert model not in models
|
assert model not in models
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cant_delete_chatgpt_model_with_wrong_api_key(
|
||||||
|
dbsession: Session,
|
||||||
|
rest_client: AsyncClient,
|
||||||
|
test_settings: AppSettings,
|
||||||
|
) -> None:
|
||||||
|
ChatGptModelFactory.create_batch(size=2)
|
||||||
|
model = ChatGptModelFactory(priority=42)
|
||||||
|
|
||||||
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
access_token = AccessTokenFactory(user_id=user.id)
|
||||||
|
|
||||||
|
models = dbsession.query(ChatGptModels).all()
|
||||||
|
assert len(models) == 3
|
||||||
|
|
||||||
|
response = await rest_client.delete(
|
||||||
|
url=f"/api/chatgpt/models/{model.id}",
|
||||||
|
headers={"ROOT-ACCESS": access_token.token},
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
models = dbsession.query(ChatGptModels).all()
|
||||||
|
assert len(models) == 3
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import datetime
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -11,6 +12,7 @@ from sqlalchemy.orm import Session
|
|||||||
from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update
|
from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update
|
||||||
|
|
||||||
from constants import BotStagesEnum
|
from constants import BotStagesEnum
|
||||||
|
from core.auth.models.users import User, UserQuestionCount
|
||||||
from core.bot.app import BotApplication, BotQueue
|
from core.bot.app import BotApplication, BotQueue
|
||||||
from main import Application
|
from main import Application
|
||||||
from settings.config import AppSettings
|
from settings.config import AppSettings
|
||||||
@ -22,6 +24,7 @@ from tests.integration.factories.bot import (
|
|||||||
CallBackFactory,
|
CallBackFactory,
|
||||||
ChatGptModelFactory,
|
ChatGptModelFactory,
|
||||||
)
|
)
|
||||||
|
from tests.integration.factories.user import UserFactory, UserQuestionCountFactory
|
||||||
from tests.integration.utils import mocked_ask_question_api
|
from tests.integration.utils import mocked_ask_question_api
|
||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
@ -113,6 +116,33 @@ async def test_help_command(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_help_command_user_is_banned(
|
||||||
|
main_application: Application,
|
||||||
|
test_settings: AppSettings,
|
||||||
|
) -> None:
|
||||||
|
message = BotMessageFactory.create_instance(text="/help")
|
||||||
|
user = message["from"]
|
||||||
|
UserFactory(
|
||||||
|
id=user["id"],
|
||||||
|
first_name=user["first_name"],
|
||||||
|
last_name=user["last_name"],
|
||||||
|
username=user["username"],
|
||||||
|
is_active=False,
|
||||||
|
ban_reason="test reason",
|
||||||
|
)
|
||||||
|
with mock.patch.object(
|
||||||
|
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
|
||||||
|
) as mocked_send_message:
|
||||||
|
bot_update = BotUpdateFactory(message=message)
|
||||||
|
|
||||||
|
await main_application.bot_app.application.process_update(
|
||||||
|
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
|
||||||
|
)
|
||||||
|
assert mocked_send_message.call_args.kwargs["text"] == (
|
||||||
|
"You have banned for reason: *test reason*.\nPlease contact the /developer"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_start_entry(
|
async def test_start_entry(
|
||||||
main_application: Application,
|
main_application: Application,
|
||||||
test_settings: AppSettings,
|
test_settings: AppSettings,
|
||||||
@ -232,6 +262,35 @@ async def test_website_callback_action(
|
|||||||
assert mocked_reply_text.call_args.args == ("Веб версия: http://localhost/chat/",)
|
assert mocked_reply_text.call_args.args == ("Веб версия: http://localhost/chat/",)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_website_callback_action_user_is_banned(
|
||||||
|
main_application: Application,
|
||||||
|
test_settings: AppSettings,
|
||||||
|
) -> None:
|
||||||
|
message = BotMessageFactory.create_instance(text="Список основных команд:")
|
||||||
|
user = message["from"]
|
||||||
|
UserFactory(
|
||||||
|
id=user["id"],
|
||||||
|
first_name=user["first_name"],
|
||||||
|
last_name=user["last_name"],
|
||||||
|
username=user["username"],
|
||||||
|
is_active=False,
|
||||||
|
ban_reason="test reason",
|
||||||
|
)
|
||||||
|
with mock.patch.object(telegram._message.Message, "reply_text") as mocked_reply_text:
|
||||||
|
bot_update = BotCallBackQueryFactory(
|
||||||
|
message=message,
|
||||||
|
callback_query=CallBackFactory(data=BotStagesEnum.website),
|
||||||
|
)
|
||||||
|
|
||||||
|
await main_application.bot_app.application.process_update(
|
||||||
|
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mocked_reply_text.call_args.kwargs["text"] == (
|
||||||
|
"You have banned for reason: *test reason*.\nPlease contact the /developer"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_bug_report_action(
|
async def test_bug_report_action(
|
||||||
main_application: Application,
|
main_application: Application,
|
||||||
test_settings: AppSettings,
|
test_settings: AppSettings,
|
||||||
@ -261,6 +320,35 @@ async def test_bug_report_action(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_bug_report_action_user_is_banned(
|
||||||
|
main_application: Application,
|
||||||
|
test_settings: AppSettings,
|
||||||
|
) -> None:
|
||||||
|
message = BotMessageFactory.create_instance(text="/bug_report")
|
||||||
|
user = message["from"]
|
||||||
|
UserFactory(
|
||||||
|
id=user["id"],
|
||||||
|
first_name=user["first_name"],
|
||||||
|
last_name=user["last_name"],
|
||||||
|
username=user["username"],
|
||||||
|
is_active=False,
|
||||||
|
ban_reason="test reason",
|
||||||
|
)
|
||||||
|
with (
|
||||||
|
mock.patch.object(telegram._message.Message, "reply_text") as mocked_reply_text,
|
||||||
|
mock.patch.object(telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)),
|
||||||
|
):
|
||||||
|
bot_update = BotUpdateFactory(message=message)
|
||||||
|
|
||||||
|
await main_application.bot_app.application.process_update(
|
||||||
|
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mocked_reply_text.call_args.kwargs["text"] == (
|
||||||
|
"You have banned for reason: *test reason*.\nPlease contact the /developer"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_get_developer_action(
|
async def test_get_developer_action(
|
||||||
main_application: Application,
|
main_application: Application,
|
||||||
test_settings: AppSettings,
|
test_settings: AppSettings,
|
||||||
@ -278,19 +366,28 @@ async def test_get_developer_action(
|
|||||||
assert mocked_reply_text.call_args.args == ("Автор бота: *Дмитрий Афанасьев*\n\nTg nickname: *Balshtg*",)
|
assert mocked_reply_text.call_args.args == ("Автор бота: *Дмитрий Афанасьев*\n\nTg nickname: *Balshtg*",)
|
||||||
|
|
||||||
|
|
||||||
async def test_ask_question_action(
|
async def test_ask_question_action_bot_user_not_exists(
|
||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
main_application: Application,
|
main_application: Application,
|
||||||
test_settings: AppSettings,
|
test_settings: AppSettings,
|
||||||
) -> None:
|
) -> None:
|
||||||
ChatGptModelFactory.create_batch(size=3)
|
ChatGptModelFactory.create_batch(size=3)
|
||||||
|
users = dbsession.query(User).all()
|
||||||
|
users_question_count = dbsession.query(UserQuestionCount).all()
|
||||||
|
|
||||||
|
assert len(users) == 0
|
||||||
|
assert len(users_question_count) == 0
|
||||||
|
|
||||||
|
message = BotMessageFactory.create_instance(text="Привет!")
|
||||||
|
user = message["from"]
|
||||||
|
|
||||||
with mock.patch.object(
|
with mock.patch.object(
|
||||||
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
|
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
|
||||||
) as mocked_send_message, mocked_ask_question_api(
|
) as mocked_send_message, mocked_ask_question_api(
|
||||||
host=test_settings.GPT_BASE_HOST,
|
host=test_settings.GPT_BASE_HOST,
|
||||||
return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"),
|
return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"),
|
||||||
):
|
):
|
||||||
bot_update = BotUpdateFactory(message=BotMessageFactory.create_instance(text="Привет!"))
|
bot_update = BotUpdateFactory(message=message)
|
||||||
bot_update["message"].pop("entities")
|
bot_update["message"].pop("entities")
|
||||||
|
|
||||||
await main_application.bot_app.application.process_update(
|
await main_application.bot_app.application.process_update(
|
||||||
@ -313,6 +410,120 @@ async def test_ask_question_action(
|
|||||||
include=["text", "chat_id"],
|
include=["text", "chat_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
created_user = dbsession.query(User).filter_by(id=user["id"]).one()
|
||||||
|
assert created_user.username == user["username"]
|
||||||
|
|
||||||
|
created_user_question_count = dbsession.query(UserQuestionCount).filter_by(user_id=user["id"]).one()
|
||||||
|
assert created_user_question_count.question_count == 1
|
||||||
|
assert created_user_question_count.last_question_at - datetime.datetime.now() < datetime.timedelta( # noqa: DTZ005
|
||||||
|
seconds=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ask_question_action_bot_user_already_exists(
|
||||||
|
dbsession: Session,
|
||||||
|
main_application: Application,
|
||||||
|
test_settings: AppSettings,
|
||||||
|
) -> None:
|
||||||
|
ChatGptModelFactory.create_batch(size=3)
|
||||||
|
message = BotMessageFactory.create_instance(text="Привет!")
|
||||||
|
user = message["from"]
|
||||||
|
existing_user = UserFactory(
|
||||||
|
id=user["id"], first_name=user["first_name"], last_name=user["last_name"], username=user["username"]
|
||||||
|
)
|
||||||
|
existing_user_question_count = UserQuestionCountFactory(user_id=existing_user.id).question_count
|
||||||
|
|
||||||
|
users = dbsession.query(User).all()
|
||||||
|
assert len(users) == 1
|
||||||
|
|
||||||
|
with mock.patch.object(
|
||||||
|
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
|
||||||
|
) as mocked_send_message, mocked_ask_question_api(
|
||||||
|
host=test_settings.GPT_BASE_HOST,
|
||||||
|
return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"),
|
||||||
|
):
|
||||||
|
bot_update = BotUpdateFactory(message=message)
|
||||||
|
bot_update["message"].pop("entities")
|
||||||
|
|
||||||
|
await main_application.bot_app.application.process_update(
|
||||||
|
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
|
||||||
|
)
|
||||||
|
assert_that(mocked_send_message.call_args_list[0].kwargs).is_equal_to(
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
"Ответ в среднем занимает 10-15 секунд.\n- Список команд: /help\n- Сообщить об ошибке: /bug_report"
|
||||||
|
),
|
||||||
|
"chat_id": bot_update["message"]["chat"]["id"],
|
||||||
|
},
|
||||||
|
include=["text", "chat_id"],
|
||||||
|
)
|
||||||
|
assert_that(mocked_send_message.call_args_list[1].kwargs).is_equal_to(
|
||||||
|
{
|
||||||
|
"text": "Привет! Как я могу помочь вам сегодня?",
|
||||||
|
"chat_id": bot_update["message"]["chat"]["id"],
|
||||||
|
},
|
||||||
|
include=["text", "chat_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
users = dbsession.query(User).all()
|
||||||
|
assert len(users) == 1
|
||||||
|
|
||||||
|
updated_user_question_count = dbsession.query(UserQuestionCount).filter_by(user_id=user["id"]).one()
|
||||||
|
assert updated_user_question_count.question_count == existing_user_question_count + 1
|
||||||
|
assert updated_user_question_count.last_question_at - datetime.datetime.now() < datetime.timedelta( # noqa: DTZ005
|
||||||
|
seconds=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ask_question_action_user_is_banned(
|
||||||
|
dbsession: Session,
|
||||||
|
main_application: Application,
|
||||||
|
test_settings: AppSettings,
|
||||||
|
) -> None:
|
||||||
|
ChatGptModelFactory.create_batch(size=3)
|
||||||
|
|
||||||
|
users_question_count = dbsession.query(UserQuestionCount).all()
|
||||||
|
assert len(users_question_count) == 0
|
||||||
|
|
||||||
|
message = BotMessageFactory.create_instance(text="Привет!")
|
||||||
|
user = message["from"]
|
||||||
|
UserFactory(
|
||||||
|
id=user["id"],
|
||||||
|
first_name=user["first_name"],
|
||||||
|
last_name=user["last_name"],
|
||||||
|
username=user["username"],
|
||||||
|
is_active=False,
|
||||||
|
ban_reason="test reason",
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.object(
|
||||||
|
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
|
||||||
|
) as mocked_send_message, mocked_ask_question_api(
|
||||||
|
host=test_settings.GPT_BASE_HOST,
|
||||||
|
return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"),
|
||||||
|
assert_all_called=False,
|
||||||
|
):
|
||||||
|
bot_update = BotUpdateFactory(message=message)
|
||||||
|
bot_update["message"].pop("entities")
|
||||||
|
|
||||||
|
await main_application.bot_app.application.process_update(
|
||||||
|
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
|
||||||
|
)
|
||||||
|
assert_that(mocked_send_message.call_args_list[0].kwargs).is_equal_to(
|
||||||
|
{
|
||||||
|
"text": ("You have banned for reason: *test reason*.\nPlease contact the /developer"),
|
||||||
|
"chat_id": bot_update["message"]["chat"]["id"],
|
||||||
|
},
|
||||||
|
include=["text", "chat_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
created_user = dbsession.query(User).filter_by(id=user["id"]).one()
|
||||||
|
assert created_user.username == user["username"]
|
||||||
|
assert created_user.is_active is False
|
||||||
|
|
||||||
|
created_user_question_count = dbsession.query(UserQuestionCount).filter_by(user_id=user["id"]).scalar()
|
||||||
|
assert created_user_question_count is None
|
||||||
|
|
||||||
|
|
||||||
async def test_ask_question_action_not_success(
|
async def test_ask_question_action_not_success(
|
||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from constants import ChatGptModelsEnum
|
||||||
|
from core.bot.models.chatgpt import ChatGptModels
|
||||||
|
from core.bot.services import ChatGptService
|
||||||
|
|
||||||
|
pytestmark = [
|
||||||
|
pytest.mark.asyncio,
|
||||||
|
pytest.mark.enable_socket,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_models_update(dbsession: Session) -> None:
|
||||||
|
models = dbsession.query(ChatGptModels).all()
|
||||||
|
|
||||||
|
assert len(models) == 0
|
||||||
|
|
||||||
|
chatgpt_service = ChatGptService.build()
|
||||||
|
|
||||||
|
await chatgpt_service.update_chatgpt_models()
|
||||||
|
|
||||||
|
models = dbsession.query(ChatGptModels).all()
|
||||||
|
|
||||||
|
model_priorities = {model.model: model.priority for model in models}
|
||||||
|
|
||||||
|
assert len(models) == len(ChatGptModelsEnum.base_models_priority())
|
||||||
|
|
||||||
|
for model_priority in ChatGptModelsEnum.base_models_priority():
|
||||||
|
assert model_priorities[model_priority["model"]] == model_priority["priority"]
|
@ -1,13 +1,15 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
import factory
|
import factory
|
||||||
|
|
||||||
from core.auth.models.users import User
|
from core.auth.models.users import AccessToken, User, UserQuestionCount
|
||||||
from tests.integration.factories.utils import BaseModelFactory
|
from tests.integration.factories.utils import BaseModelFactory
|
||||||
|
|
||||||
|
|
||||||
class UserFactory(BaseModelFactory):
|
class UserFactory(BaseModelFactory):
|
||||||
id = factory.Sequence(lambda n: n + 1)
|
id = factory.Sequence(lambda n: n + 1)
|
||||||
email = factory.Faker("email")
|
email = factory.Faker("email")
|
||||||
username = factory.Faker("user_name", locale="en_EN")
|
username = factory.Faker("user_name", locale="en")
|
||||||
first_name = factory.Faker("word")
|
first_name = factory.Faker("word")
|
||||||
last_name = factory.Faker("word")
|
last_name = factory.Faker("word")
|
||||||
ban_reason = factory.Faker("text", max_nb_chars=100)
|
ban_reason = factory.Faker("text", max_nb_chars=100)
|
||||||
@ -18,3 +20,21 @@ class UserFactory(BaseModelFactory):
|
|||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = User
|
model = User
|
||||||
|
|
||||||
|
|
||||||
|
class AccessTokenFactory(BaseModelFactory):
|
||||||
|
user_id = factory.Sequence(lambda n: n + 1)
|
||||||
|
token = factory.LazyAttribute(lambda o: str(uuid.uuid4()))
|
||||||
|
created_at = factory.Faker("past_datetime")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = AccessToken
|
||||||
|
|
||||||
|
|
||||||
|
class UserQuestionCountFactory(BaseModelFactory):
|
||||||
|
user_id = factory.Sequence(lambda n: n + 1)
|
||||||
|
question_count = factory.Faker("random_int")
|
||||||
|
last_question_at = factory.Faker("past_datetime")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = UserQuestionCount
|
||||||
|
@ -9,11 +9,11 @@ from settings.config import settings
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def mocked_ask_question_api(
|
def mocked_ask_question_api(
|
||||||
host: str, return_value: Response | None = None, side_effect: Any | None = None
|
host: str, return_value: Response | None = None, side_effect: Any | None = None, assert_all_called: bool = True
|
||||||
) -> Iterator[respx.MockRouter]:
|
) -> Iterator[respx.MockRouter]:
|
||||||
with respx.mock(
|
with respx.mock(
|
||||||
assert_all_mocked=True,
|
assert_all_mocked=True,
|
||||||
assert_all_called=True,
|
assert_all_called=assert_all_called,
|
||||||
base_url=host,
|
base_url=host,
|
||||||
) as respx_mock:
|
) as respx_mock:
|
||||||
ask_question_route = respx_mock.post(url=settings.chatgpt_backend_url, name="ask_question")
|
ask_question_route = respx_mock.post(url=settings.chatgpt_backend_url, name="ask_question")
|
||||||
|
69
poetry.lock
generated
69
poetry.lock
generated
@ -607,25 +607,25 @@ toml = ["tomli"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cyclonedx-python-lib"
|
name = "cyclonedx-python-lib"
|
||||||
version = "5.2.0"
|
version = "6.3.0"
|
||||||
description = "Python library for CycloneDX"
|
description = "Python library for CycloneDX"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8,<4.0"
|
python-versions = ">=3.8,<4.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "cyclonedx_python_lib-5.2.0-py3-none-any.whl", hash = "sha256:1b43065205cdc53490c825fcfbda73142b758aa40ca169c968e342e88d06734f"},
|
{file = "cyclonedx_python_lib-6.3.0-py3-none-any.whl", hash = "sha256:0e73c1036c2f7fc67adc28aef807e6b44340ea70202aab197fb06b20ea165de8"},
|
||||||
{file = "cyclonedx_python_lib-5.2.0.tar.gz", hash = "sha256:b9ebf2c0520721d2f8ee16aadc2bbb9d4e015862c84ab1691a49b177f3014d99"},
|
{file = "cyclonedx_python_lib-6.3.0.tar.gz", hash = "sha256:82f2489de3c0cadad5af1ad7fa6b6a185f985746370245d38769699c734533c6"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
license-expression = ">=30,<31"
|
license-expression = ">=30,<31"
|
||||||
packageurl-python = ">=0.11"
|
packageurl-python = ">=0.11"
|
||||||
py-serializable = ">=0.15,<0.16"
|
py-serializable = ">=0.16,<0.18"
|
||||||
sortedcontainers = ">=2.4.0,<3.0.0"
|
sortedcontainers = ">=2.4.0,<3.0.0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
json-validation = ["jsonschema[format] (>=4.18,<5.0)"]
|
json-validation = ["jsonschema[format] (>=4.18,<5.0)"]
|
||||||
validation = ["jsonschema[format] (>=4.18,<5.0)", "lxml (>=4,<5)"]
|
validation = ["jsonschema[format] (>=4.18,<5.0)", "lxml (>=4,<6)"]
|
||||||
xml-validation = ["lxml (>=4,<5)"]
|
xml-validation = ["lxml (>=4,<6)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "decorator"
|
name = "decorator"
|
||||||
@ -761,19 +761,19 @@ typing = ["typing-extensions (>=4.8)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "flake8"
|
name = "flake8"
|
||||||
version = "6.1.0"
|
version = "7.0.0"
|
||||||
description = "the modular source code checker: pep8 pyflakes and co"
|
description = "the modular source code checker: pep8 pyflakes and co"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1"
|
python-versions = ">=3.8.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "flake8-6.1.0-py2.py3-none-any.whl", hash = "sha256:ffdfce58ea94c6580c77888a86506937f9a1a227dfcd15f245d694ae20a6b6e5"},
|
{file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"},
|
||||||
{file = "flake8-6.1.0.tar.gz", hash = "sha256:d5b3857f07c030bdb5bf41c7f53799571d75c4491748a3adcd47de929e34cd23"},
|
{file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
mccabe = ">=0.7.0,<0.8.0"
|
mccabe = ">=0.7.0,<0.8.0"
|
||||||
pycodestyle = ">=2.11.0,<2.12.0"
|
pycodestyle = ">=2.11.0,<2.12.0"
|
||||||
pyflakes = ">=3.1.0,<3.2.0"
|
pyflakes = ">=3.2.0,<3.3.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "flake8-aaa"
|
name = "flake8-aaa"
|
||||||
@ -954,25 +954,6 @@ files = [
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
flake8 = ">3.0.0"
|
flake8 = ">3.0.0"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "flake8-noqa"
|
|
||||||
version = "1.3.2"
|
|
||||||
description = "Flake8 noqa comment validation"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.7"
|
|
||||||
files = [
|
|
||||||
{file = "flake8-noqa-1.3.2.tar.gz", hash = "sha256:b12ddf7b02dedabaca0f807cb436ea7992cc0106cb6fa41e997ad45a0a3bf754"},
|
|
||||||
{file = "flake8_noqa-1.3.2-py3-none-any.whl", hash = "sha256:a2c139c4cc223f268fb262cd32a46fa72f509225d038058baa87c0ff8ac4d348"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
flake8 = ">=3.8.0,<7.0"
|
|
||||||
typing-extensions = ">=3.7.4.2"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
dev = ["flake8 (>=3.8.0,<6.0.0)", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-polyfill", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming"]
|
|
||||||
test = ["flake8-docstrings"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "flake8-plugin-utils"
|
name = "flake8-plugin-utils"
|
||||||
version = "1.3.3"
|
version = "1.3.3"
|
||||||
@ -1410,13 +1391,13 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ipython"
|
name = "ipython"
|
||||||
version = "8.19.0"
|
version = "8.20.0"
|
||||||
description = "IPython: Productive Interactive Computing"
|
description = "IPython: Productive Interactive Computing"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.10"
|
||||||
files = [
|
files = [
|
||||||
{file = "ipython-8.19.0-py3-none-any.whl", hash = "sha256:2f55d59370f59d0d2b2212109fe0e6035cfea436b1c0e6150ad2244746272ec5"},
|
{file = "ipython-8.20.0-py3-none-any.whl", hash = "sha256:bc9716aad6f29f36c449e30821c9dd0c1c1a7b59ddcc26931685b87b4c569619"},
|
||||||
{file = "ipython-8.19.0.tar.gz", hash = "sha256:ac4da4ecf0042fb4e0ce57c60430c2db3c719fa8bdf92f8631d6bd8a5785d1f0"},
|
{file = "ipython-8.20.0.tar.gz", hash = "sha256:2f21bd3fc1d51550c89ee3944ae04bbc7bc79e129ea0937da6e6c68bfdbf117a"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -2089,18 +2070,18 @@ pip = "*"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pip-audit"
|
name = "pip-audit"
|
||||||
version = "2.6.2"
|
version = "2.6.3"
|
||||||
description = "A tool for scanning Python environments for known vulnerabilities"
|
description = "A tool for scanning Python environments for known vulnerabilities"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "pip_audit-2.6.2-py3-none-any.whl", hash = "sha256:ac3a4b6e977ef2c574aa8d19a5d71d12201bdb65bba2d67d9df49f53f0be5e7d"},
|
{file = "pip_audit-2.6.3-py3-none-any.whl", hash = "sha256:216983210db4a15393f9e80e4d24a805f5767e4c8e0c31fc70c336acc629613b"},
|
||||||
{file = "pip_audit-2.6.2.tar.gz", hash = "sha256:0bbd023a199a104b29f949f063a872d41113b5a9048285666820fa35a76a7794"},
|
{file = "pip_audit-2.6.3.tar.gz", hash = "sha256:bd796066f69684b2f4fc2c2b6d222589e23190db0bbde069cea5c2b0be2cc57d"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
CacheControl = {version = ">=0.13.0", extras = ["filecache"]}
|
CacheControl = {version = ">=0.13.0", extras = ["filecache"]}
|
||||||
cyclonedx-python-lib = ">=4,<6"
|
cyclonedx-python-lib = ">=5,<7"
|
||||||
html5lib = ">=1.1"
|
html5lib = ">=1.1"
|
||||||
packaging = ">=23.0.0"
|
packaging = ">=23.0.0"
|
||||||
pip-api = ">=0.0.28"
|
pip-api = ">=0.0.28"
|
||||||
@ -2112,7 +2093,7 @@ toml = ">=0.10"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["build", "bump (>=1.3.2)", "pip-audit[doc,lint,test]"]
|
dev = ["build", "bump (>=1.3.2)", "pip-audit[doc,lint,test]"]
|
||||||
doc = ["pdoc"]
|
doc = ["pdoc"]
|
||||||
lint = ["interrogate", "mypy", "ruff (<0.1.9)", "types-html5lib", "types-requests", "types-toml"]
|
lint = ["interrogate", "mypy", "ruff (<0.1.12)", "types-html5lib", "types-requests", "types-toml"]
|
||||||
test = ["coverage[toml] (>=7.0,!=7.3.3,<8.0)", "pretend", "pytest", "pytest-cov"]
|
test = ["coverage[toml] (>=7.0,!=7.3.3,<8.0)", "pretend", "pytest", "pytest-cov"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2216,13 +2197,13 @@ tests = ["pytest"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "py-serializable"
|
name = "py-serializable"
|
||||||
version = "0.15.0"
|
version = "0.17.1"
|
||||||
description = "Library for serializing and deserializing Python Objects to and from JSON and XML."
|
description = "Library for serializing and deserializing Python Objects to and from JSON and XML."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7,<4.0"
|
python-versions = ">=3.7,<4.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "py-serializable-0.15.0.tar.gz", hash = "sha256:8fc41457d8ee5f5c5a12f41fd87bf1a4f2ecf9da39fee92059b728e78f320771"},
|
{file = "py-serializable-0.17.1.tar.gz", hash = "sha256:875bb9c01df77f563dfcd1e75bb4244b5596083d3aad4ccd3fb63e1f5a9d3e5f"},
|
||||||
{file = "py_serializable-0.15.0-py3-none-any.whl", hash = "sha256:d3f1201b33420c481aa83f7860c7bf2c2f036ba3ea82b6e15a96696457c36cd2"},
|
{file = "py_serializable-0.17.1-py3-none-any.whl", hash = "sha256:389c2254d912bec3a44acdac667c947d73c59325050d5ae66386e1ed7108a45a"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -2407,13 +2388,13 @@ resolved_reference = "996cec42e9621701edb83354232b2c0ca0121560"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyflakes"
|
name = "pyflakes"
|
||||||
version = "3.1.0"
|
version = "3.2.0"
|
||||||
description = "passive checker of Python programs"
|
description = "passive checker of Python programs"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "pyflakes-3.1.0-py2.py3-none-any.whl", hash = "sha256:4132f6d49cb4dae6819e5379898f2b8cce3c5f23994194c24b77d5da2e36f774"},
|
{file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"},
|
||||||
{file = "pyflakes-3.1.0.tar.gz", hash = "sha256:a0aae034c444db0071aa077972ba4768d40c830d9539fd45bf4cd3f8f6992efc"},
|
{file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -3732,4 +3713,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.12"
|
python-versions = "^3.12"
|
||||||
content-hash = "3a78cb3473202950a11ebbe28e5662f6a1a57508a0e2703761c3c448744b0bed"
|
content-hash = "30a4b709eed24d6b55f78015994cdf1901d0d7279fefa86b37e9490e189a2c52"
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "chat_gpt_bot"
|
name = "chat_gpt_bot"
|
||||||
version = "1.4.1"
|
version = "1.5.1"
|
||||||
description = "Bot to integrated with Chat gpt"
|
description = "Bot to integrated with Chat gpt"
|
||||||
authors = ["Dmitry Afanasyev <Balshbox@gmail.com>"]
|
authors = ["Dmitry Afanasyev <Balshbox@gmail.com>"]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=1.7.0"]
|
requires = ["poetry-core>=1.7.1"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
@ -84,7 +84,6 @@ autoflake = "^2.2"
|
|||||||
flake8-aaa = "^0.17.0"
|
flake8-aaa = "^0.17.0"
|
||||||
flake8-variables-names = "^0.0.6"
|
flake8-variables-names = "^0.0.6"
|
||||||
flake8-deprecated = "^2.2.1"
|
flake8-deprecated = "^2.2.1"
|
||||||
flake8-noqa = "^1.3.2"
|
|
||||||
flake8-annotations-complexity = "^0.0.8"
|
flake8-annotations-complexity = "^0.0.8"
|
||||||
flake8-useless-assert = "^0.4.4"
|
flake8-useless-assert = "^0.4.4"
|
||||||
flake8-newspaper-style = "^0.4.3"
|
flake8-newspaper-style = "^0.4.3"
|
||||||
|
@ -2,12 +2,8 @@
|
|||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
if [ -f shared/${DB_NAME:-chatgpt.db} ]
|
alembic upgrade "head" && \
|
||||||
then
|
python3 core/bot/managment/update_gpt_models.py
|
||||||
alembic downgrade -1 && alembic upgrade "head"
|
|
||||||
else
|
|
||||||
alembic upgrade "head"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "starting the bot"
|
echo "starting the bot"
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user