Compare commits

..

No commits in common. "28d01f551b2bbf38b48ca20fca62b282ae9633d6" and "826634221431a39df0a62e71be3ae4133c162354" have entirely different histories.

26 changed files with 194 additions and 760 deletions

View File

@ -110,7 +110,6 @@ Init alembic
cd bot_microservice cd bot_microservice
alembic revision --autogenerate -m 'create_chatgpt_table' alembic revision --autogenerate -m 'create_chatgpt_table'
alembic upgrade head alembic upgrade head
python3 core/bot/managment/update_gpt_models.py
``` ```
@ -131,7 +130,6 @@ Run migrations
```bash ```bash
cd ./bot_microservice # alembic root cd ./bot_microservice # alembic root
alembic --config ./alembic.ini upgrade head alembic --config ./alembic.ini upgrade head
python3 core/bot/managment/update_gpt_models.py
``` ```
**downgrade to 0001_create_chatgpt_table revision:** **downgrade to 0001_create_chatgpt_table revision:**
@ -168,4 +166,4 @@ alembic downgrade base
- [x] add sentry - [x] add sentry
- [x] add graylog integration and availability to log to file - [x] add graylog integration and availability to log to file
- [x] add user model - [x] add user model
- [x] add messages statistic - [ ] add messages statistic

View File

@ -1 +0,0 @@
BOT_ACCESS_API_HEADER = "BOT-API-KEY"

View File

@ -3,19 +3,13 @@ from starlette import status
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from telegram import Update from telegram import Update
from api.bot.deps import ( from api.bot.deps import get_bot_queue, get_chatgpt_service, get_update_from_request
get_access_to_bot_api_or_403,
get_bot_queue,
get_chatgpt_service,
get_update_from_request,
)
from api.bot.serializers import ( from api.bot.serializers import (
ChatGptModelSerializer, ChatGptModelSerializer,
ChatGptModelsPrioritySerializer, ChatGptModelsPrioritySerializer,
GETChatGptModelsSerializer, GETChatGptModelsSerializer,
LightChatGptModel, LightChatGptModel,
) )
from api.exceptions import PermissionMissingResponse
from core.bot.app import BotQueue from core.bot.app import BotQueue
from core.bot.services import ChatGptService from core.bot.services import ChatGptService
from settings.config import settings from settings.config import settings
@ -59,10 +53,6 @@ async def models_list(
@router.put( @router.put(
"/chatgpt/models/{model_id}/priority", "/chatgpt/models/{model_id}/priority",
name="bot:change_model_priority", name="bot:change_model_priority",
dependencies=[Depends(get_access_to_bot_api_or_403)],
responses={
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
},
response_class=Response, response_class=Response,
status_code=status.HTTP_202_ACCEPTED, status_code=status.HTTP_202_ACCEPTED,
summary="change gpt model priority", summary="change gpt model priority",
@ -79,10 +69,6 @@ async def change_model_priority(
@router.put( @router.put(
"/chatgpt/models/priority/reset", "/chatgpt/models/priority/reset",
name="bot:reset_models_priority", name="bot:reset_models_priority",
dependencies=[Depends(get_access_to_bot_api_or_403)],
responses={
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
},
response_class=Response, response_class=Response,
status_code=status.HTTP_202_ACCEPTED, status_code=status.HTTP_202_ACCEPTED,
summary="reset all model priority to default", summary="reset all model priority to default",
@ -97,11 +83,6 @@ async def reset_models_priority(
@router.post( @router.post(
"/chatgpt/models", "/chatgpt/models",
name="bot:add_new_model", name="bot:add_new_model",
dependencies=[Depends(get_access_to_bot_api_or_403)],
responses={
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
status.HTTP_201_CREATED: {"model": ChatGptModelSerializer},
},
response_model=ChatGptModelSerializer, response_model=ChatGptModelSerializer,
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
summary="add new model", summary="add new model",
@ -119,10 +100,6 @@ async def add_new_model(
@router.delete( @router.delete(
"/chatgpt/models/{model_id}", "/chatgpt/models/{model_id}",
name="bot:delete_gpt_model", name="bot:delete_gpt_model",
dependencies=[Depends(get_access_to_bot_api_or_403)],
responses={
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
},
response_class=Response, response_class=Response,
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
summary="delete gpt model", summary="delete gpt model",

View File

@ -1,17 +1,15 @@
from fastapi import Depends, Header, HTTPException from fastapi import Depends
from starlette import status
from starlette.requests import Request from starlette.requests import Request
from telegram import Update from telegram import Update
from api.auth.deps import get_user_service from api.auth.deps import get_user_service
from api.bot.constants import BOT_ACCESS_API_HEADER
from api.deps import get_database from api.deps import get_database
from core.auth.services import UserService from core.auth.services import UserService
from core.bot.app import BotApplication, BotQueue from core.bot.app import BotApplication, BotQueue
from core.bot.repository import ChatGPTRepository from core.bot.repository import ChatGPTRepository
from core.bot.services import ChatGptService from core.bot.services import ChatGptService
from infra.database.db_adapter import Database from infra.database.db_adapter import Database
from settings.config import AppSettings, get_settings, settings from settings.config import AppSettings, get_settings
def get_bot_app(request: Request) -> BotApplication: def get_bot_app(request: Request) -> BotApplication:
@ -42,13 +40,3 @@ def get_chatgpt_service(
user_service: UserService = Depends(get_user_service), user_service: UserService = Depends(get_user_service),
) -> ChatGptService: ) -> ChatGptService:
return ChatGptService(repository=chatgpt_repository, user_service=user_service) return ChatGptService(repository=chatgpt_repository, user_service=user_service)
async def get_access_to_bot_api_or_403(
bot_api_key: str | None = Header(None, alias=BOT_ACCESS_API_HEADER, description="Ключ доступа до API бота"),
user_service: UserService = Depends(get_user_service),
) -> None:
access_token = await user_service.get_user_access_token_by_username(settings.SUPERUSER)
if not access_token or access_token != bot_api_key:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate api header")

View File

@ -1,39 +1,11 @@
from typing import Any
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from starlette import status
from starlette.requests import Request from starlette.requests import Request
from api.base_schemas import BaseError, BaseResponse from api.base_schemas import BaseError, BaseResponse
class BaseAPIException(Exception): class BaseAPIException(Exception):
_content_type: str = "application/json" pass
model: type[BaseResponse] = BaseResponse
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR
title: str | None = None
type: str | None = None
detail: str | None = None
instance: str | None = None
headers: dict[str, str] | None = None
def __init__(self, **ctx: Any) -> None:
self.__dict__ = ctx
@classmethod
def example(cls) -> dict[str, Any] | None:
if isinstance(cls.model.Config.schema_extra, dict): # type: ignore[attr-defined]
return cls.model.Config.schema_extra.get("example") # type: ignore[attr-defined]
return None
@classmethod
def response(cls) -> dict[str, Any]:
return {
"model": cls.model,
"content": {
cls._content_type: cls.model.Config.schema_extra, # type: ignore[attr-defined]
},
}
class InternalServerError(BaseError): class InternalServerError(BaseError):
@ -56,31 +28,6 @@ class InternalServerErrorResponse(BaseResponse):
} }
class PermissionMissing(BaseError):
pass
class PermissionMissingResponse(BaseResponse):
error: PermissionMissing
class Config:
json_schema_extra = {
"example": {
"status": 403,
"error": {
"type": "PermissionMissing",
"title": "Permission required for this endpoint is missing",
},
},
}
class PermissionMissingError(BaseAPIException):
model = PermissionMissingResponse
status_code = status.HTTP_403_FORBIDDEN
title: str = "Permission required for this endpoint is missing"
async def internal_server_error_handler(_request: Request, _exception: Exception) -> ORJSONResponse: async def internal_server_error_handler(_request: Request, _exception: Exception) -> ORJSONResponse:
error = InternalServerError(title="Something went wrong!", type="InternalServerError") error = InternalServerError(title="Something went wrong!", type="InternalServerError")
response = InternalServerErrorResponse(status=500, error=error).model_dump(exclude_unset=True) response = InternalServerErrorResponse(status=500, error=error).model_dump(exclude_unset=True)

View File

@ -1,4 +1,3 @@
import uuid
from datetime import datetime from datetime import datetime
from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String
@ -19,13 +18,15 @@ class User(Base):
hashed_password: Mapped[str] = mapped_column(String(length=1024), nullable=False) hashed_password: Mapped[str] = mapped_column(String(length=1024), nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, default=datetime.now) created_at: Mapped[datetime] = mapped_column(
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
)
user_question_count: Mapped["UserQuestionCount"] = relationship( user_question_count: Mapped["UserQuestionCount"] = relationship(
"UserQuestionCount", "UserQuestionCount",
primaryjoin="UserQuestionCount.user_id == User.id", primaryjoin="UserQuestionCount.user_id == User.id",
backref="user", backref="user",
lazy="noload", lazy="selectin",
uselist=False, uselist=False,
cascade="delete", cascade="delete",
) )
@ -36,12 +37,6 @@ class User(Base):
return self.user_question_count.question_count return self.user_question_count.question_count
return 0 return 0
@property
def last_question_at(self) -> str | None:
if self.user_question_count:
return self.user_question_count.last_question_at.strftime("%Y-%m-%d %H:%M:%S")
return None
@classmethod @classmethod
def build( def build(
cls, cls,
@ -73,27 +68,14 @@ class AccessToken(Base):
__tablename__ = "access_token" # type: ignore[assignment] __tablename__ = "access_token" # type: ignore[assignment]
user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False) user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
token: Mapped[str] = mapped_column(String(length=42), primary_key=True, default=lambda: str(uuid.uuid4())) token: Mapped[str] = mapped_column(String(length=42), primary_key=True)
created_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, default=datetime.now) created_at: Mapped[datetime] = mapped_column(
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
user: Mapped["User"] = relationship(
"User",
backref="access_token",
lazy="selectin",
uselist=False,
cascade="expunge",
) )
@property
def username(self) -> str:
if self.user:
return self.user.username
return ""
class UserQuestionCount(Base): class UserQuestionCount(Base):
__tablename__ = "user_question_count" # type: ignore[assignment] __tablename__ = "user_question_count" # type: ignore[assignment]
user_id: Mapped[int] = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), primary_key=True) user_id: Mapped[int] = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), primary_key=True)
question_count: Mapped[int] = mapped_column(INTEGER, default=0, nullable=False) question_count: Mapped[int] = mapped_column(INTEGER, default=0, nullable=False)
last_question_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True), nullable=False, default=datetime.now)

View File

@ -1,11 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from sqlalchemy import func, select from sqlalchemy import select
from sqlalchemy.dialects.sqlite import insert from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.orm import load_only from sqlalchemy.orm import load_only
from core.auth.dto import UserIsBannedDTO from core.auth.dto import UserIsBannedDTO
from core.auth.models.users import AccessToken, User, UserQuestionCount from core.auth.models.users import User, UserQuestionCount
from infra.database.db_adapter import Database from infra.database.db_adapter import Database
@ -69,18 +69,10 @@ class UserRepository:
UserQuestionCount.get_real_column_name( UserQuestionCount.get_real_column_name(
UserQuestionCount.question_count.key UserQuestionCount.question_count.key
): UserQuestionCount.question_count ): UserQuestionCount.question_count
+ 1, + 1
UserQuestionCount.get_real_column_name(UserQuestionCount.last_question_at.key): func.now(),
}, },
) )
) )
async with self.db.session() as session: async with self.db.session() as session:
await session.execute(query) await session.execute(query)
async def get_user_access_token(self, username: str | None) -> str | None:
query = select(AccessToken.token).join(AccessToken.user).where(User.username == username)
async with self.db.session() as session:
result = await session.execute(query)
return result.scalar()

View File

@ -62,9 +62,6 @@ class UserService:
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO: async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
return await self.repository.check_user_is_banned(user_id) return await self.repository.check_user_is_banned(user_id)
async def get_user_access_token_by_username(self, username: str | None) -> str | None:
return await self.repository.get_user_access_token(username)
def check_user_is_banned(func: Any) -> Any: def check_user_is_banned(func: Any) -> Any:
@wraps(func) @wraps(func)

View File

@ -1,13 +0,0 @@
import asyncio
import sys
from pathlib import Path
from loguru import logger
if __name__ == "__main__":
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
from core.bot.services import ChatGptService
chatgpt_service = ChatGptService.build()
asyncio.run(chatgpt_service.update_chatgpt_models())
logger.info("chatgpt models has been updated")

View File

@ -38,18 +38,6 @@ class ChatGPTRepository:
async with self.db.session() as session: async with self.db.session() as session:
await session.execute(query) await session.execute(query)
async def delete_all_chatgpt_models(self) -> None:
query = delete(ChatGptModels)
async with self.db.session() as session:
await session.execute(query)
async def bulk_insert_chatgpt_models(self, models_priority: list[dict[str, Any]]) -> None:
models = [ChatGptModels(**model_priority) for model_priority in models_priority]
async with self.db.session() as session:
session.add_all(models)
await session.commit()
async def add_chatgpt_model(self, model: str, priority: int) -> dict[str, str | int]: async def add_chatgpt_model(self, model: str, priority: int) -> dict[str, str | int]:
query = ( query = (
insert(ChatGptModels) insert(ChatGptModels)

View File

@ -14,7 +14,7 @@ from speech_recognition import (
UnknownValueError as SpeechRecognizerError, UnknownValueError as SpeechRecognizerError,
) )
from constants import AUDIO_SEGMENT_DURATION, ChatGptModelsEnum from constants import AUDIO_SEGMENT_DURATION
from core.auth.models.users import User from core.auth.models.users import User
from core.auth.repository import UserRepository from core.auth.repository import UserRepository
from core.auth.services import UserService from core.auth.services import UserService
@ -24,77 +24,6 @@ from infra.database.db_adapter import Database
from settings.config import settings from settings.config import settings
@dataclass
class ChatGptService:
repository: ChatGPTRepository
user_service: UserService
@classmethod
def build(cls) -> "ChatGptService":
db = Database(settings=settings)
repository = ChatGPTRepository(settings=settings, db=db)
user_repository = UserRepository(db=db)
user_service = UserService(repository=user_repository)
return ChatGptService(repository=repository, user_service=user_service)
async def get_chatgpt_models(self) -> Sequence[ChatGptModels]:
return await self.repository.get_chatgpt_models()
async def request_to_chatgpt(self, question: str | None) -> str:
question = question or "Привет!"
chatgpt_model = await self.get_current_chatgpt_model()
return await self.repository.ask_question(question=question, chatgpt_model=chatgpt_model)
async def request_to_chatgpt_microservice(self, question: str) -> Response:
chatgpt_model = await self.get_current_chatgpt_model()
return await self.repository.request_to_chatgpt_microservice(question=question, chatgpt_model=chatgpt_model)
async def get_current_chatgpt_model(self) -> str:
return await self.repository.get_current_chatgpt_model()
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
return await self.repository.change_chatgpt_model_priority(model_id=model_id, priority=priority)
async def update_chatgpt_models(self) -> None:
await self.repository.delete_all_chatgpt_models()
models = ChatGptModelsEnum.base_models_priority()
await self.repository.bulk_insert_chatgpt_models(models)
async def reset_all_chatgpt_models_priority(self) -> None:
return await self.repository.reset_all_chatgpt_models_priority()
async def add_chatgpt_model(self, gpt_model: str, priority: int) -> dict[str, str | int]:
return await self.repository.add_chatgpt_model(model=gpt_model, priority=priority)
async def delete_chatgpt_model(self, model_id: int) -> None:
return await self.repository.delete_chatgpt_model(model_id=model_id)
async def get_or_create_bot_user(
self,
user_id: int,
email: str | None = None,
username: str | None = None,
first_name: str | None = None,
last_name: str | None = None,
ban_reason: str | None = None,
is_active: bool = True,
is_superuser: bool = False,
) -> User:
return await self.user_service.get_or_create_user_by_id(
user_id=user_id,
email=email,
username=username,
first_name=first_name,
last_name=last_name,
ban_reason=ban_reason,
is_active=is_active,
is_superuser=is_superuser,
)
async def update_bot_user_message_count(self, user_id: int) -> None:
await self.user_service.update_user_message_count(user_id)
class SpeechToTextService: class SpeechToTextService:
def __init__(self, filename: str) -> None: def __init__(self, filename: str) -> None:
self.filename = filename self.filename = filename
@ -158,3 +87,69 @@ class SpeechToTextService:
except SpeechRecognizerError as error: except SpeechRecognizerError as error:
logger.error("error recognizing text with google", error=error) logger.error("error recognizing text with google", error=error)
raise raise
@dataclass
class ChatGptService:
repository: ChatGPTRepository
user_service: UserService
@classmethod
def build(cls) -> "ChatGptService":
db = Database(settings=settings)
repository = ChatGPTRepository(settings=settings, db=db)
user_repository = UserRepository(db=db)
user_service = UserService(repository=user_repository)
return ChatGptService(repository=repository, user_service=user_service)
async def get_chatgpt_models(self) -> Sequence[ChatGptModels]:
return await self.repository.get_chatgpt_models()
async def request_to_chatgpt(self, question: str | None) -> str:
question = question or "Привет!"
chatgpt_model = await self.get_current_chatgpt_model()
return await self.repository.ask_question(question=question, chatgpt_model=chatgpt_model)
async def request_to_chatgpt_microservice(self, question: str) -> Response:
chatgpt_model = await self.get_current_chatgpt_model()
return await self.repository.request_to_chatgpt_microservice(question=question, chatgpt_model=chatgpt_model)
async def get_current_chatgpt_model(self) -> str:
return await self.repository.get_current_chatgpt_model()
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
return await self.repository.change_chatgpt_model_priority(model_id=model_id, priority=priority)
async def reset_all_chatgpt_models_priority(self) -> None:
return await self.repository.reset_all_chatgpt_models_priority()
async def add_chatgpt_model(self, gpt_model: str, priority: int) -> dict[str, str | int]:
return await self.repository.add_chatgpt_model(model=gpt_model, priority=priority)
async def delete_chatgpt_model(self, model_id: int) -> None:
return await self.repository.delete_chatgpt_model(model_id=model_id)
async def get_or_create_bot_user(
self,
user_id: int,
email: str | None = None,
username: str | None = None,
first_name: str | None = None,
last_name: str | None = None,
ban_reason: str | None = None,
is_active: bool = True,
is_superuser: bool = False,
) -> User:
return await self.user_service.get_or_create_user_by_id(
user_id=user_id,
email=email,
username=username,
first_name=first_name,
last_name=last_name,
ban_reason=ban_reason,
is_active=is_active,
is_superuser=is_superuser,
)
async def update_bot_user_message_count(self, user_id: int) -> None:
await self.user_service.update_user_message_count(user_id)

View File

@ -1,11 +1,8 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqladmin import Admin, ModelView from sqladmin import Admin, ModelView
from sqlalchemy import Select, desc, select
from sqlalchemy.orm import contains_eager, load_only
from starlette.requests import Request
from core.auth.models.users import AccessToken, User, UserQuestionCount from core.auth.models.users import User
from core.bot.models.chatgpt import ChatGptModels from core.bot.models.chatgpt import ChatGptModels
from core.utils import build_uri from core.utils import build_uri
from settings.config import settings from settings.config import settings
@ -37,44 +34,12 @@ class UserAdmin(ModelView, model=User):
User.is_active, User.is_active,
User.ban_reason, User.ban_reason,
"question_count", "question_count",
"last_question_at",
User.created_at, User.created_at,
] ]
column_sortable_list = [User.created_at]
column_default_sort = ("created_at", True) column_default_sort = ("created_at", True)
form_widget_args = {"created_at": {"readonly": True}} form_widget_args = {"created_at": {"readonly": True}}
def list_query(self, request: Request) -> Select[tuple[User]]:
return (
select(User)
.options(
load_only(
User.id,
User.username,
User.first_name,
User.last_name,
User.is_active,
User.created_at,
)
)
.outerjoin(User.user_question_count)
.options(
contains_eager(User.user_question_count).options(
load_only(
UserQuestionCount.question_count,
UserQuestionCount.last_question_at,
)
)
)
).order_by(desc(UserQuestionCount.question_count))
class AccessTokenAdmin(ModelView, model=AccessToken):
name = "API access token"
name_plural = "API access tokens"
column_list = [AccessToken.user_id, "username", AccessToken.token, AccessToken.created_at]
form_widget_args = {"created_at": {"readonly": True}}
def create_admin(application: "Application") -> Admin: def create_admin(application: "Application") -> Admin:
admin = Admin( admin = Admin(
@ -86,5 +51,4 @@ def create_admin(application: "Application") -> Admin:
) )
admin.add_view(ChatGptAdmin) admin.add_view(ChatGptAdmin)
admin.add_view(UserAdmin) admin.add_view(UserAdmin)
admin.add_view(AccessTokenAdmin)
return admin return admin

View File

@ -10,8 +10,9 @@ from datetime import datetime
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
from sqlalchemy import TIMESTAMP from sqlalchemy import TIMESTAMP
from sqlalchemy.dialects.sqlite import insert
from core.auth.models.users import AccessToken, User from core.auth.models.users import User
from core.auth.utils import create_password_hash from core.auth.utils import create_password_hash
from infra.database.deps import get_sync_session from infra.database.deps import get_sync_session
from settings.config import settings from settings.config import settings
@ -57,14 +58,8 @@ def upgrade() -> None:
return return
with get_sync_session() as session: with get_sync_session() as session:
hashed_password = create_password_hash(password.get_secret_value()) hashed_password = create_password_hash(password.get_secret_value())
user = User(username=username, hashed_password=hashed_password) query = insert(User).values({"username": username, "hashed_password": hashed_password})
session.add(user) session.execute(query)
session.flush()
session.refresh(user)
access_token = AccessToken(user_id=user.id)
session.add(access_token)
session.commit() session.commit()

View File

@ -0,0 +1,43 @@
"""create chatgpt models
Revision ID: 0004_add_chatgpt_models
Revises: 0003_create_user_question_count_table
Create Date: 2025-10-05 20:44:05.414977
"""
from loguru import logger
from sqlalchemy import select, text
from constants import ChatGptModelsEnum
from core.bot.models.chatgpt import ChatGptModels
from infra.database.deps import get_sync_session
# revision identifiers, used by Alembic.
revision = "0004_add_chatgpt_models"
down_revision = "0003_create_user_question_count_table"
branch_labels: str | None = None
depends_on: str | None = None
def upgrade() -> None:
with get_sync_session() as session:
query = select(ChatGptModels)
results = session.execute(query)
models = results.scalars().all()
if models:
return
models = []
for data in ChatGptModelsEnum.base_models_priority():
models.append(ChatGptModels(**data))
session.add_all(models)
session.commit()
def downgrade() -> None:
chatgpt_table_name = ChatGptModels.__tablename__
with get_sync_session() as session:
# Truncate doesn't exists for SQLite
session.execute(text(f"""DELETE FROM {chatgpt_table_name}""")) # noqa: S608
session.commit()
logger.info("chatgpt models table has been truncated", table=chatgpt_table_name)

View File

@ -1,29 +0,0 @@
"""add_last_question_at
Revision ID: 0004_add_last_question_at
Revises: 0003_create_user_question_count_table
Create Date: 2024-01-08 20:56:34.815976
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "0004_add_last_question_at"
down_revision = "0003_create_user_question_count_table"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_access_token_created_at", table_name="access_token")
op.add_column("user_question_count", sa.Column("last_question_at", sa.TIMESTAMP(timezone=True), nullable=False))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("user_question_count", "last_question_at")
op.create_index("ix_access_token_created_at", "access_token", ["created_at"], unique=False)
# ### end Alembic commands ###

View File

@ -22,7 +22,7 @@ class Application:
self.app = FastAPI( self.app = FastAPI(
title="Chat gpt bot", title="Chat gpt bot",
description="Bot for proxy to chat gpt in telegram", description="Bot for proxy to chat gpt in telegram",
version="1.1.12", version="0.0.3",
docs_url=build_uri([settings.api_prefix, "docs"]), docs_url=build_uri([settings.api_prefix, "docs"]),
redoc_url=build_uri([settings.api_prefix, "redocs"]), redoc_url=build_uri([settings.api_prefix, "redocs"]),
openapi_url=build_uri([settings.api_prefix, "openapi.json"]), openapi_url=build_uri([settings.api_prefix, "openapi.json"]),

View File

@ -11,8 +11,6 @@ TELEGRAM_API_TOKEN="123456789:AABBCCDDEEFFaabbccddeeff-1234567890"
# set to true to start with webhook. Else bot will start on polling method # set to true to start with webhook. Else bot will start on polling method
START_WITH_WEBHOOK="false" START_WITH_WEBHOOK="false"
SUPERUSER="Superuser"
# ==== domain settings ==== # ==== domain settings ====
DOMAIN="http://localhost" DOMAIN="http://localhost"
URL_PREFIX= URL_PREFIX=

View File

@ -6,9 +6,7 @@ from sqlalchemy import desc
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.bot.models.chatgpt import ChatGptModels from core.bot.models.chatgpt import ChatGptModels
from settings.config import AppSettings
from tests.integration.factories.bot import ChatGptModelFactory from tests.integration.factories.bot import ChatGptModelFactory
from tests.integration.factories.user import AccessTokenFactory, UserFactory
pytestmark = [ pytestmark = [
pytest.mark.asyncio, pytest.mark.asyncio,
@ -53,18 +51,11 @@ async def test_change_chatgpt_model_priority(
dbsession: Session, dbsession: Session,
rest_client: AsyncClient, rest_client: AsyncClient,
faker: Faker, faker: Faker,
test_settings: AppSettings,
) -> None: ) -> None:
model1 = ChatGptModelFactory(priority=0) model1 = ChatGptModelFactory(priority=0)
model2 = ChatGptModelFactory(priority=1) model2 = ChatGptModelFactory(priority=1)
priority = faker.random_int(min=2, max=7) priority = faker.random_int(min=2, max=7)
user = UserFactory(username=test_settings.SUPERUSER) response = await rest_client.put(url=f"/api/chatgpt/models/{model2.id}/priority", json={"priority": priority})
access_token = AccessTokenFactory(user_id=user.id)
response = await rest_client.put(
url=f"/api/chatgpt/models/{model2.id}/priority",
json={"priority": priority},
headers={"BOT-API-KEY": access_token.token},
)
assert response.status_code == 202 assert response.status_code == 202
upd_model1, upd_model2 = dbsession.query(ChatGptModels).order_by(ChatGptModels.priority).all() upd_model1, upd_model2 = dbsession.query(ChatGptModels).order_by(ChatGptModels.priority).all()
@ -75,53 +66,14 @@ async def test_change_chatgpt_model_priority(
assert upd_model2.priority == priority assert upd_model2.priority == priority
@pytest.mark.parametrize(
"bot_api_key_header",
[
pytest.param({"BOT-API-KEY": ""}, id="empty key in header"),
pytest.param({"BOT-API-KEY": "Hello World"}, id="incorrect token"),
pytest.param({"Hello-World": "lasdkj3-wer-weqwe_34"}, id="correct token but wrong api key"),
pytest.param({}, id="no api key header"),
],
)
async def test_cant_change_chatgpt_model_priority_with_wrong_api_key(
dbsession: Session,
rest_client: AsyncClient,
faker: Faker,
test_settings: AppSettings,
bot_api_key_header: dict[str, str],
) -> None:
ChatGptModelFactory(priority=0)
model2 = ChatGptModelFactory(priority=1)
priority = faker.random_int(min=2, max=7)
user = UserFactory(username=test_settings.SUPERUSER)
AccessTokenFactory(user_id=user.id, token="lasdkj3-wer-weqwe_34") # noqa: S106
response = await rest_client.put(
url=f"/api/chatgpt/models/{model2.id}/priority",
json={"priority": priority},
headers=bot_api_key_header,
)
assert response.status_code == 403
changed_model = dbsession.query(ChatGptModels).filter_by(id=model2.id).one()
assert changed_model.priority == model2.priority
async def test_reset_chatgpt_models_priority( async def test_reset_chatgpt_models_priority(
dbsession: Session, dbsession: Session,
rest_client: AsyncClient, rest_client: AsyncClient,
test_settings: AppSettings,
) -> None: ) -> None:
ChatGptModelFactory.create_batch(size=4) ChatGptModelFactory.create_batch(size=4)
ChatGptModelFactory(priority=42) ChatGptModelFactory(priority=42)
user = UserFactory(username=test_settings.SUPERUSER) response = await rest_client.put(url="/api/chatgpt/models/priority/reset")
access_token = AccessTokenFactory(user_id=user.id)
response = await rest_client.put(
url="/api/chatgpt/models/priority/reset",
headers={"BOT-API-KEY": access_token.token},
)
assert response.status_code == 202 assert response.status_code == 202
models = dbsession.query(ChatGptModels).all() models = dbsession.query(ChatGptModels).all()
@ -133,49 +85,14 @@ async def test_reset_chatgpt_models_priority(
assert model.priority == 0 assert model.priority == 0
@pytest.mark.parametrize(
"bot_api_key_header",
[
pytest.param({"BOT-API-KEY": ""}, id="empty key in header"),
pytest.param({"BOT-API-KEY": "Hello World"}, id="incorrect token"),
],
)
async def test_cant_reset_chatgpt_models_priority_with_wrong_api_key(
dbsession: Session, rest_client: AsyncClient, test_settings: AppSettings, bot_api_key_header: dict[str, str]
) -> None:
chat_gpt_models = ChatGptModelFactory.create_batch(size=2)
model_with_highest_priority = ChatGptModelFactory(priority=42)
priorities = sorted([model.priority for model in chat_gpt_models] + [model_with_highest_priority.priority])
user = UserFactory(username=test_settings.SUPERUSER)
AccessTokenFactory(user_id=user.id)
response = await rest_client.put(
url="/api/chatgpt/models/priority/reset",
headers=bot_api_key_header,
)
assert response.status_code == 403
models = dbsession.query(ChatGptModels).all()
changed_priorities = sorted([model.priority for model in models])
assert changed_priorities == priorities
async def test_create_new_chatgpt_model( async def test_create_new_chatgpt_model(
dbsession: Session, dbsession: Session,
rest_client: AsyncClient, rest_client: AsyncClient,
faker: Faker, faker: Faker,
test_settings: AppSettings,
) -> None: ) -> None:
ChatGptModelFactory.create_batch(size=2) ChatGptModelFactory.create_batch(size=2)
ChatGptModelFactory(priority=42) ChatGptModelFactory(priority=42)
user = UserFactory(username=test_settings.SUPERUSER)
access_token = AccessTokenFactory(user_id=user.id)
model_name = "new-gpt-model" model_name = "new-gpt-model"
model_priority = faker.random_int(min=1, max=5) model_priority = faker.random_int(min=1, max=5)
@ -188,7 +105,6 @@ async def test_create_new_chatgpt_model(
"model": model_name, "model": model_name,
"priority": model_priority, "priority": model_priority,
}, },
headers={"BOT-API-KEY": access_token.token},
) )
assert response.status_code == 201 assert response.status_code == 201
@ -205,47 +121,13 @@ async def test_create_new_chatgpt_model(
} }
async def test_cant_create_new_chatgpt_model_with_wrong_api_key(
dbsession: Session,
rest_client: AsyncClient,
faker: Faker,
test_settings: AppSettings,
) -> None:
ChatGptModelFactory.create_batch(size=2)
ChatGptModelFactory(priority=42)
user = UserFactory(username=test_settings.SUPERUSER)
AccessTokenFactory(user_id=user.id)
model_name = "new-gpt-model"
model_priority = faker.random_int(min=1, max=5)
models = dbsession.query(ChatGptModels).all()
assert len(models) == 3
response = await rest_client.post(
url="/api/chatgpt/models",
json={
"model": model_name,
"priority": model_priority,
},
)
assert response.status_code == 403
models = dbsession.query(ChatGptModels).all()
assert len(models) == 3
async def test_add_existing_chatgpt_model( async def test_add_existing_chatgpt_model(
dbsession: Session, dbsession: Session,
rest_client: AsyncClient, rest_client: AsyncClient,
faker: Faker, faker: Faker,
test_settings: AppSettings,
) -> None: ) -> None:
ChatGptModelFactory.create_batch(size=2) ChatGptModelFactory.create_batch(size=2)
model = ChatGptModelFactory(priority=42) model = ChatGptModelFactory(priority=42)
user = UserFactory(username=test_settings.SUPERUSER)
access_token = AccessTokenFactory(user_id=user.id)
model_name = model.model model_name = model.model
model_priority = faker.random_int(min=1, max=5) model_priority = faker.random_int(min=1, max=5)
@ -259,7 +141,6 @@ async def test_add_existing_chatgpt_model(
"model": model_name, "model": model_name,
"priority": model_priority, "priority": model_priority,
}, },
headers={"BOT-API-KEY": access_token.token},
) )
assert response.status_code == 201 assert response.status_code == 201
@ -270,48 +151,17 @@ async def test_add_existing_chatgpt_model(
async def test_delete_chatgpt_model( async def test_delete_chatgpt_model(
dbsession: Session, dbsession: Session,
rest_client: AsyncClient, rest_client: AsyncClient,
test_settings: AppSettings,
) -> None: ) -> None:
ChatGptModelFactory.create_batch(size=2) ChatGptModelFactory.create_batch(size=2)
model = ChatGptModelFactory(priority=42) model = ChatGptModelFactory(priority=42)
user = UserFactory(username=test_settings.SUPERUSER)
access_token = AccessTokenFactory(user_id=user.id)
models = dbsession.query(ChatGptModels).all() models = dbsession.query(ChatGptModels).all()
assert len(models) == 3 assert len(models) == 3
response = await rest_client.delete( response = await rest_client.delete(url=f"/api/chatgpt/models/{model.id}")
url=f"/api/chatgpt/models/{model.id}",
headers={"BOT-API-KEY": access_token.token},
)
assert response.status_code == 204 assert response.status_code == 204
models = dbsession.query(ChatGptModels).all() models = dbsession.query(ChatGptModels).all()
assert len(models) == 2 assert len(models) == 2
assert model not in models assert model not in models
async def test_cant_delete_chatgpt_model_with_wrong_api_key(
dbsession: Session,
rest_client: AsyncClient,
test_settings: AppSettings,
) -> None:
ChatGptModelFactory.create_batch(size=2)
model = ChatGptModelFactory(priority=42)
user = UserFactory(username=test_settings.SUPERUSER)
access_token = AccessTokenFactory(user_id=user.id)
models = dbsession.query(ChatGptModels).all()
assert len(models) == 3
response = await rest_client.delete(
url=f"/api/chatgpt/models/{model.id}",
headers={"ROOT-ACCESS": access_token.token},
)
assert response.status_code == 403
models = dbsession.query(ChatGptModels).all()
assert len(models) == 3

View File

@ -1,5 +1,4 @@
import asyncio import asyncio
import datetime
from unittest import mock from unittest import mock
import httpx import httpx
@ -12,7 +11,6 @@ from sqlalchemy.orm import Session
from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update
from constants import BotStagesEnum from constants import BotStagesEnum
from core.auth.models.users import User, UserQuestionCount
from core.bot.app import BotApplication, BotQueue from core.bot.app import BotApplication, BotQueue
from main import Application from main import Application
from settings.config import AppSettings from settings.config import AppSettings
@ -24,7 +22,6 @@ from tests.integration.factories.bot import (
CallBackFactory, CallBackFactory,
ChatGptModelFactory, ChatGptModelFactory,
) )
from tests.integration.factories.user import UserFactory, UserQuestionCountFactory
from tests.integration.utils import mocked_ask_question_api from tests.integration.utils import mocked_ask_question_api
pytestmark = [ pytestmark = [
@ -116,33 +113,6 @@ async def test_help_command(
) )
async def test_help_command_user_is_banned(
main_application: Application,
test_settings: AppSettings,
) -> None:
message = BotMessageFactory.create_instance(text="/help")
user = message["from"]
UserFactory(
id=user["id"],
first_name=user["first_name"],
last_name=user["last_name"],
username=user["username"],
is_active=False,
ban_reason="test reason",
)
with mock.patch.object(
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
) as mocked_send_message:
bot_update = BotUpdateFactory(message=message)
await main_application.bot_app.application.process_update(
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
)
assert mocked_send_message.call_args.kwargs["text"] == (
"You have banned for reason: *test reason*.\nPlease contact the /developer"
)
async def test_start_entry( async def test_start_entry(
main_application: Application, main_application: Application,
test_settings: AppSettings, test_settings: AppSettings,
@ -262,35 +232,6 @@ async def test_website_callback_action(
assert mocked_reply_text.call_args.args == ("Веб версия: http://localhost/chat/",) assert mocked_reply_text.call_args.args == ("Веб версия: http://localhost/chat/",)
async def test_website_callback_action_user_is_banned(
main_application: Application,
test_settings: AppSettings,
) -> None:
message = BotMessageFactory.create_instance(text="Список основных команд:")
user = message["from"]
UserFactory(
id=user["id"],
first_name=user["first_name"],
last_name=user["last_name"],
username=user["username"],
is_active=False,
ban_reason="test reason",
)
with mock.patch.object(telegram._message.Message, "reply_text") as mocked_reply_text:
bot_update = BotCallBackQueryFactory(
message=message,
callback_query=CallBackFactory(data=BotStagesEnum.website),
)
await main_application.bot_app.application.process_update(
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
)
assert mocked_reply_text.call_args.kwargs["text"] == (
"You have banned for reason: *test reason*.\nPlease contact the /developer"
)
async def test_bug_report_action( async def test_bug_report_action(
main_application: Application, main_application: Application,
test_settings: AppSettings, test_settings: AppSettings,
@ -320,35 +261,6 @@ async def test_bug_report_action(
) )
async def test_bug_report_action_user_is_banned(
main_application: Application,
test_settings: AppSettings,
) -> None:
message = BotMessageFactory.create_instance(text="/bug_report")
user = message["from"]
UserFactory(
id=user["id"],
first_name=user["first_name"],
last_name=user["last_name"],
username=user["username"],
is_active=False,
ban_reason="test reason",
)
with (
mock.patch.object(telegram._message.Message, "reply_text") as mocked_reply_text,
mock.patch.object(telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)),
):
bot_update = BotUpdateFactory(message=message)
await main_application.bot_app.application.process_update(
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
)
assert mocked_reply_text.call_args.kwargs["text"] == (
"You have banned for reason: *test reason*.\nPlease contact the /developer"
)
async def test_get_developer_action( async def test_get_developer_action(
main_application: Application, main_application: Application,
test_settings: AppSettings, test_settings: AppSettings,
@ -366,28 +278,19 @@ async def test_get_developer_action(
assert mocked_reply_text.call_args.args == ("Автор бота: *Дмитрий Афанасьев*\n\nTg nickname: *Balshtg*",) assert mocked_reply_text.call_args.args == ("Автор бота: *Дмитрий Афанасьев*\n\nTg nickname: *Balshtg*",)
async def test_ask_question_action_bot_user_not_exists( async def test_ask_question_action(
dbsession: Session, dbsession: Session,
main_application: Application, main_application: Application,
test_settings: AppSettings, test_settings: AppSettings,
) -> None: ) -> None:
ChatGptModelFactory.create_batch(size=3) ChatGptModelFactory.create_batch(size=3)
users = dbsession.query(User).all()
users_question_count = dbsession.query(UserQuestionCount).all()
assert len(users) == 0
assert len(users_question_count) == 0
message = BotMessageFactory.create_instance(text="Привет!")
user = message["from"]
with mock.patch.object( with mock.patch.object(
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs) telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
) as mocked_send_message, mocked_ask_question_api( ) as mocked_send_message, mocked_ask_question_api(
host=test_settings.GPT_BASE_HOST, host=test_settings.GPT_BASE_HOST,
return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"), return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"),
): ):
bot_update = BotUpdateFactory(message=message) bot_update = BotUpdateFactory(message=BotMessageFactory.create_instance(text="Привет!"))
bot_update["message"].pop("entities") bot_update["message"].pop("entities")
await main_application.bot_app.application.process_update( await main_application.bot_app.application.process_update(
@ -410,120 +313,6 @@ async def test_ask_question_action_bot_user_not_exists(
include=["text", "chat_id"], include=["text", "chat_id"],
) )
created_user = dbsession.query(User).filter_by(id=user["id"]).one()
assert created_user.username == user["username"]
created_user_question_count = dbsession.query(UserQuestionCount).filter_by(user_id=user["id"]).one()
assert created_user_question_count.question_count == 1
assert created_user_question_count.last_question_at - datetime.datetime.now() < datetime.timedelta( # noqa: DTZ005
seconds=2
)
async def test_ask_question_action_bot_user_already_exists(
dbsession: Session,
main_application: Application,
test_settings: AppSettings,
) -> None:
ChatGptModelFactory.create_batch(size=3)
message = BotMessageFactory.create_instance(text="Привет!")
user = message["from"]
existing_user = UserFactory(
id=user["id"], first_name=user["first_name"], last_name=user["last_name"], username=user["username"]
)
existing_user_question_count = UserQuestionCountFactory(user_id=existing_user.id).question_count
users = dbsession.query(User).all()
assert len(users) == 1
with mock.patch.object(
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
) as mocked_send_message, mocked_ask_question_api(
host=test_settings.GPT_BASE_HOST,
return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"),
):
bot_update = BotUpdateFactory(message=message)
bot_update["message"].pop("entities")
await main_application.bot_app.application.process_update(
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
)
assert_that(mocked_send_message.call_args_list[0].kwargs).is_equal_to(
{
"text": (
"Ответ в среднем занимает 10-15 секунд.\n- Список команд: /help\n- Сообщить об ошибке: /bug_report"
),
"chat_id": bot_update["message"]["chat"]["id"],
},
include=["text", "chat_id"],
)
assert_that(mocked_send_message.call_args_list[1].kwargs).is_equal_to(
{
"text": "Привет! Как я могу помочь вам сегодня?",
"chat_id": bot_update["message"]["chat"]["id"],
},
include=["text", "chat_id"],
)
users = dbsession.query(User).all()
assert len(users) == 1
updated_user_question_count = dbsession.query(UserQuestionCount).filter_by(user_id=user["id"]).one()
assert updated_user_question_count.question_count == existing_user_question_count + 1
assert updated_user_question_count.last_question_at - datetime.datetime.now() < datetime.timedelta( # noqa: DTZ005
seconds=2
)
async def test_ask_question_action_user_is_banned(
dbsession: Session,
main_application: Application,
test_settings: AppSettings,
) -> None:
ChatGptModelFactory.create_batch(size=3)
users_question_count = dbsession.query(UserQuestionCount).all()
assert len(users_question_count) == 0
message = BotMessageFactory.create_instance(text="Привет!")
user = message["from"]
UserFactory(
id=user["id"],
first_name=user["first_name"],
last_name=user["last_name"],
username=user["username"],
is_active=False,
ban_reason="test reason",
)
with mock.patch.object(
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
) as mocked_send_message, mocked_ask_question_api(
host=test_settings.GPT_BASE_HOST,
return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"),
assert_all_called=False,
):
bot_update = BotUpdateFactory(message=message)
bot_update["message"].pop("entities")
await main_application.bot_app.application.process_update(
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
)
assert_that(mocked_send_message.call_args_list[0].kwargs).is_equal_to(
{
"text": ("You have banned for reason: *test reason*.\nPlease contact the /developer"),
"chat_id": bot_update["message"]["chat"]["id"],
},
include=["text", "chat_id"],
)
created_user = dbsession.query(User).filter_by(id=user["id"]).one()
assert created_user.username == user["username"]
assert created_user.is_active is False
created_user_question_count = dbsession.query(UserQuestionCount).filter_by(user_id=user["id"]).scalar()
assert created_user_question_count is None
async def test_ask_question_action_not_success( async def test_ask_question_action_not_success(
dbsession: Session, dbsession: Session,

View File

@ -1,30 +0,0 @@
import pytest
from sqlalchemy.orm import Session
from constants import ChatGptModelsEnum
from core.bot.models.chatgpt import ChatGptModels
from core.bot.services import ChatGptService
pytestmark = [
pytest.mark.asyncio,
pytest.mark.enable_socket,
]
async def test_models_update(dbsession: Session) -> None:
models = dbsession.query(ChatGptModels).all()
assert len(models) == 0
chatgpt_service = ChatGptService.build()
await chatgpt_service.update_chatgpt_models()
models = dbsession.query(ChatGptModels).all()
model_priorities = {model.model: model.priority for model in models}
assert len(models) == len(ChatGptModelsEnum.base_models_priority())
for model_priority in ChatGptModelsEnum.base_models_priority():
assert model_priorities[model_priority["model"]] == model_priority["priority"]

View File

@ -1,15 +1,13 @@
import uuid
import factory import factory
from core.auth.models.users import AccessToken, User, UserQuestionCount from core.auth.models.users import User
from tests.integration.factories.utils import BaseModelFactory from tests.integration.factories.utils import BaseModelFactory
class UserFactory(BaseModelFactory): class UserFactory(BaseModelFactory):
id = factory.Sequence(lambda n: n + 1) id = factory.Sequence(lambda n: n + 1)
email = factory.Faker("email") email = factory.Faker("email")
username = factory.Faker("user_name", locale="en") username = factory.Faker("user_name", locale="en_EN")
first_name = factory.Faker("word") first_name = factory.Faker("word")
last_name = factory.Faker("word") last_name = factory.Faker("word")
ban_reason = factory.Faker("text", max_nb_chars=100) ban_reason = factory.Faker("text", max_nb_chars=100)
@ -20,21 +18,3 @@ class UserFactory(BaseModelFactory):
class Meta: class Meta:
model = User model = User
class AccessTokenFactory(BaseModelFactory):
user_id = factory.Sequence(lambda n: n + 1)
token = factory.LazyAttribute(lambda o: str(uuid.uuid4()))
created_at = factory.Faker("past_datetime")
class Meta:
model = AccessToken
class UserQuestionCountFactory(BaseModelFactory):
user_id = factory.Sequence(lambda n: n + 1)
question_count = factory.Faker("random_int")
last_question_at = factory.Faker("past_datetime")
class Meta:
model = UserQuestionCount

View File

@ -9,11 +9,11 @@ from settings.config import settings
@contextmanager @contextmanager
def mocked_ask_question_api( def mocked_ask_question_api(
host: str, return_value: Response | None = None, side_effect: Any | None = None, assert_all_called: bool = True host: str, return_value: Response | None = None, side_effect: Any | None = None
) -> Iterator[respx.MockRouter]: ) -> Iterator[respx.MockRouter]:
with respx.mock( with respx.mock(
assert_all_mocked=True, assert_all_mocked=True,
assert_all_called=assert_all_called, assert_all_called=True,
base_url=host, base_url=host,
) as respx_mock: ) as respx_mock:
ask_question_route = respx_mock.post(url=settings.chatgpt_backend_url, name="ask_question") ask_question_route = respx_mock.post(url=settings.chatgpt_backend_url, name="ask_question")

69
poetry.lock generated
View File

@ -607,25 +607,25 @@ toml = ["tomli"]
[[package]] [[package]]
name = "cyclonedx-python-lib" name = "cyclonedx-python-lib"
version = "6.3.0" version = "5.2.0"
description = "Python library for CycloneDX" description = "Python library for CycloneDX"
optional = false optional = false
python-versions = ">=3.8,<4.0" python-versions = ">=3.8,<4.0"
files = [ files = [
{file = "cyclonedx_python_lib-6.3.0-py3-none-any.whl", hash = "sha256:0e73c1036c2f7fc67adc28aef807e6b44340ea70202aab197fb06b20ea165de8"}, {file = "cyclonedx_python_lib-5.2.0-py3-none-any.whl", hash = "sha256:1b43065205cdc53490c825fcfbda73142b758aa40ca169c968e342e88d06734f"},
{file = "cyclonedx_python_lib-6.3.0.tar.gz", hash = "sha256:82f2489de3c0cadad5af1ad7fa6b6a185f985746370245d38769699c734533c6"}, {file = "cyclonedx_python_lib-5.2.0.tar.gz", hash = "sha256:b9ebf2c0520721d2f8ee16aadc2bbb9d4e015862c84ab1691a49b177f3014d99"},
] ]
[package.dependencies] [package.dependencies]
license-expression = ">=30,<31" license-expression = ">=30,<31"
packageurl-python = ">=0.11" packageurl-python = ">=0.11"
py-serializable = ">=0.16,<0.18" py-serializable = ">=0.15,<0.16"
sortedcontainers = ">=2.4.0,<3.0.0" sortedcontainers = ">=2.4.0,<3.0.0"
[package.extras] [package.extras]
json-validation = ["jsonschema[format] (>=4.18,<5.0)"] json-validation = ["jsonschema[format] (>=4.18,<5.0)"]
validation = ["jsonschema[format] (>=4.18,<5.0)", "lxml (>=4,<6)"] validation = ["jsonschema[format] (>=4.18,<5.0)", "lxml (>=4,<5)"]
xml-validation = ["lxml (>=4,<6)"] xml-validation = ["lxml (>=4,<5)"]
[[package]] [[package]]
name = "decorator" name = "decorator"
@ -761,19 +761,19 @@ typing = ["typing-extensions (>=4.8)"]
[[package]] [[package]]
name = "flake8" name = "flake8"
version = "7.0.0" version = "6.1.0"
description = "the modular source code checker: pep8 pyflakes and co" description = "the modular source code checker: pep8 pyflakes and co"
optional = false optional = false
python-versions = ">=3.8.1" python-versions = ">=3.8.1"
files = [ files = [
{file = "flake8-7.0.0-py2.py3-none-any.whl", hash = "sha256:a6dfbb75e03252917f2473ea9653f7cd799c3064e54d4c8140044c5c065f53c3"}, {file = "flake8-6.1.0-py2.py3-none-any.whl", hash = "sha256:ffdfce58ea94c6580c77888a86506937f9a1a227dfcd15f245d694ae20a6b6e5"},
{file = "flake8-7.0.0.tar.gz", hash = "sha256:33f96621059e65eec474169085dc92bf26e7b2d47366b70be2f67ab80dc25132"}, {file = "flake8-6.1.0.tar.gz", hash = "sha256:d5b3857f07c030bdb5bf41c7f53799571d75c4491748a3adcd47de929e34cd23"},
] ]
[package.dependencies] [package.dependencies]
mccabe = ">=0.7.0,<0.8.0" mccabe = ">=0.7.0,<0.8.0"
pycodestyle = ">=2.11.0,<2.12.0" pycodestyle = ">=2.11.0,<2.12.0"
pyflakes = ">=3.2.0,<3.3.0" pyflakes = ">=3.1.0,<3.2.0"
[[package]] [[package]]
name = "flake8-aaa" name = "flake8-aaa"
@ -954,6 +954,25 @@ files = [
[package.dependencies] [package.dependencies]
flake8 = ">3.0.0" flake8 = ">3.0.0"
[[package]]
name = "flake8-noqa"
version = "1.3.2"
description = "Flake8 noqa comment validation"
optional = false
python-versions = ">=3.7"
files = [
{file = "flake8-noqa-1.3.2.tar.gz", hash = "sha256:b12ddf7b02dedabaca0f807cb436ea7992cc0106cb6fa41e997ad45a0a3bf754"},
{file = "flake8_noqa-1.3.2-py3-none-any.whl", hash = "sha256:a2c139c4cc223f268fb262cd32a46fa72f509225d038058baa87c0ff8ac4d348"},
]
[package.dependencies]
flake8 = ">=3.8.0,<7.0"
typing-extensions = ">=3.7.4.2"
[package.extras]
dev = ["flake8 (>=3.8.0,<6.0.0)", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-polyfill", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming"]
test = ["flake8-docstrings"]
[[package]] [[package]]
name = "flake8-plugin-utils" name = "flake8-plugin-utils"
version = "1.3.3" version = "1.3.3"
@ -1391,13 +1410,13 @@ files = [
[[package]] [[package]]
name = "ipython" name = "ipython"
version = "8.20.0" version = "8.19.0"
description = "IPython: Productive Interactive Computing" description = "IPython: Productive Interactive Computing"
optional = false optional = false
python-versions = ">=3.10" python-versions = ">=3.10"
files = [ files = [
{file = "ipython-8.20.0-py3-none-any.whl", hash = "sha256:bc9716aad6f29f36c449e30821c9dd0c1c1a7b59ddcc26931685b87b4c569619"}, {file = "ipython-8.19.0-py3-none-any.whl", hash = "sha256:2f55d59370f59d0d2b2212109fe0e6035cfea436b1c0e6150ad2244746272ec5"},
{file = "ipython-8.20.0.tar.gz", hash = "sha256:2f21bd3fc1d51550c89ee3944ae04bbc7bc79e129ea0937da6e6c68bfdbf117a"}, {file = "ipython-8.19.0.tar.gz", hash = "sha256:ac4da4ecf0042fb4e0ce57c60430c2db3c719fa8bdf92f8631d6bd8a5785d1f0"},
] ]
[package.dependencies] [package.dependencies]
@ -2070,18 +2089,18 @@ pip = "*"
[[package]] [[package]]
name = "pip-audit" name = "pip-audit"
version = "2.6.3" version = "2.6.2"
description = "A tool for scanning Python environments for known vulnerabilities" description = "A tool for scanning Python environments for known vulnerabilities"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pip_audit-2.6.3-py3-none-any.whl", hash = "sha256:216983210db4a15393f9e80e4d24a805f5767e4c8e0c31fc70c336acc629613b"}, {file = "pip_audit-2.6.2-py3-none-any.whl", hash = "sha256:ac3a4b6e977ef2c574aa8d19a5d71d12201bdb65bba2d67d9df49f53f0be5e7d"},
{file = "pip_audit-2.6.3.tar.gz", hash = "sha256:bd796066f69684b2f4fc2c2b6d222589e23190db0bbde069cea5c2b0be2cc57d"}, {file = "pip_audit-2.6.2.tar.gz", hash = "sha256:0bbd023a199a104b29f949f063a872d41113b5a9048285666820fa35a76a7794"},
] ]
[package.dependencies] [package.dependencies]
CacheControl = {version = ">=0.13.0", extras = ["filecache"]} CacheControl = {version = ">=0.13.0", extras = ["filecache"]}
cyclonedx-python-lib = ">=5,<7" cyclonedx-python-lib = ">=4,<6"
html5lib = ">=1.1" html5lib = ">=1.1"
packaging = ">=23.0.0" packaging = ">=23.0.0"
pip-api = ">=0.0.28" pip-api = ">=0.0.28"
@ -2093,7 +2112,7 @@ toml = ">=0.10"
[package.extras] [package.extras]
dev = ["build", "bump (>=1.3.2)", "pip-audit[doc,lint,test]"] dev = ["build", "bump (>=1.3.2)", "pip-audit[doc,lint,test]"]
doc = ["pdoc"] doc = ["pdoc"]
lint = ["interrogate", "mypy", "ruff (<0.1.12)", "types-html5lib", "types-requests", "types-toml"] lint = ["interrogate", "mypy", "ruff (<0.1.9)", "types-html5lib", "types-requests", "types-toml"]
test = ["coverage[toml] (>=7.0,!=7.3.3,<8.0)", "pretend", "pytest", "pytest-cov"] test = ["coverage[toml] (>=7.0,!=7.3.3,<8.0)", "pretend", "pytest", "pytest-cov"]
[[package]] [[package]]
@ -2197,13 +2216,13 @@ tests = ["pytest"]
[[package]] [[package]]
name = "py-serializable" name = "py-serializable"
version = "0.17.1" version = "0.15.0"
description = "Library for serializing and deserializing Python Objects to and from JSON and XML." description = "Library for serializing and deserializing Python Objects to and from JSON and XML."
optional = false optional = false
python-versions = ">=3.7,<4.0" python-versions = ">=3.7,<4.0"
files = [ files = [
{file = "py-serializable-0.17.1.tar.gz", hash = "sha256:875bb9c01df77f563dfcd1e75bb4244b5596083d3aad4ccd3fb63e1f5a9d3e5f"}, {file = "py-serializable-0.15.0.tar.gz", hash = "sha256:8fc41457d8ee5f5c5a12f41fd87bf1a4f2ecf9da39fee92059b728e78f320771"},
{file = "py_serializable-0.17.1-py3-none-any.whl", hash = "sha256:389c2254d912bec3a44acdac667c947d73c59325050d5ae66386e1ed7108a45a"}, {file = "py_serializable-0.15.0-py3-none-any.whl", hash = "sha256:d3f1201b33420c481aa83f7860c7bf2c2f036ba3ea82b6e15a96696457c36cd2"},
] ]
[package.dependencies] [package.dependencies]
@ -2388,13 +2407,13 @@ resolved_reference = "996cec42e9621701edb83354232b2c0ca0121560"
[[package]] [[package]]
name = "pyflakes" name = "pyflakes"
version = "3.2.0" version = "3.1.0"
description = "passive checker of Python programs" description = "passive checker of Python programs"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"}, {file = "pyflakes-3.1.0-py2.py3-none-any.whl", hash = "sha256:4132f6d49cb4dae6819e5379898f2b8cce3c5f23994194c24b77d5da2e36f774"},
{file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, {file = "pyflakes-3.1.0.tar.gz", hash = "sha256:a0aae034c444db0071aa077972ba4768d40c830d9539fd45bf4cd3f8f6992efc"},
] ]
[[package]] [[package]]
@ -3713,4 +3732,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.12" python-versions = "^3.12"
content-hash = "30a4b709eed24d6b55f78015994cdf1901d0d7279fefa86b37e9490e189a2c52" content-hash = "3a78cb3473202950a11ebbe28e5662f6a1a57508a0e2703761c3c448744b0bed"

View File

@ -1,11 +1,11 @@
[tool.poetry] [tool.poetry]
name = "chat_gpt_bot" name = "chat_gpt_bot"
version = "1.5.1" version = "1.4.1"
description = "Bot to integrated with Chat gpt" description = "Bot to integrated with Chat gpt"
authors = ["Dmitry Afanasyev <Balshbox@gmail.com>"] authors = ["Dmitry Afanasyev <Balshbox@gmail.com>"]
[build-system] [build-system]
requires = ["poetry-core>=1.7.1"] requires = ["poetry-core>=1.7.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.poetry.dependencies] [tool.poetry.dependencies]
@ -84,6 +84,7 @@ autoflake = "^2.2"
flake8-aaa = "^0.17.0" flake8-aaa = "^0.17.0"
flake8-variables-names = "^0.0.6" flake8-variables-names = "^0.0.6"
flake8-deprecated = "^2.2.1" flake8-deprecated = "^2.2.1"
flake8-noqa = "^1.3.2"
flake8-annotations-complexity = "^0.0.8" flake8-annotations-complexity = "^0.0.8"
flake8-useless-assert = "^0.4.4" flake8-useless-assert = "^0.4.4"
flake8-newspaper-style = "^0.4.3" flake8-newspaper-style = "^0.4.3"

View File

@ -2,8 +2,12 @@
set -e set -e
alembic upgrade "head" && \ if [ -f shared/${DB_NAME:-chatgpt.db} ]
python3 core/bot/managment/update_gpt_models.py then
alembic downgrade -1 && alembic upgrade "head"
else
alembic upgrade "head"
fi
echo "starting the bot" echo "starting the bot"