mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2026-02-03 11:40:39 +03:00
refactoring (#26)
This commit is contained in:
@@ -8,7 +8,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
f"/{settings.TELEGRAM_API_TOKEN}",
|
||||
f"/{settings.token_part}",
|
||||
name="bot:process_bot_updates",
|
||||
response_class=Response,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
|
||||
13
bot_microservice/api/bot/deps.py
Normal file
13
bot_microservice/api/bot/deps.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from fastapi import Depends
|
||||
from starlette.requests import Request
|
||||
|
||||
from core.bot.services import ChatGptService
|
||||
from settings.config import AppSettings
|
||||
|
||||
|
||||
def get_settings(request: Request) -> AppSettings:
|
||||
return request.app.state.settings
|
||||
|
||||
|
||||
def get_chat_gpt_service(settings: AppSettings = Depends(get_settings)) -> ChatGptService:
|
||||
return ChatGptService(settings.GPT_MODEL)
|
||||
2
bot_microservice/api/exceptions.py
Normal file
2
bot_microservice/api/exceptions.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class BaseAPIException(Exception):
|
||||
pass
|
||||
@@ -1,11 +1,12 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from api.bot.deps import get_chat_gpt_service
|
||||
from api.exceptions import BaseAPIException
|
||||
from constants import INVALID_GPT_REQUEST_MESSAGES
|
||||
from core.utils import ChatGptService
|
||||
from settings.config import settings
|
||||
from core.bot.services import ChatGptService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -30,8 +31,10 @@ async def healthcheck() -> ORJSONResponse:
|
||||
status.HTTP_200_OK: {"description": "Successful Response"},
|
||||
},
|
||||
)
|
||||
async def gpt_healthcheck(response: Response) -> Response:
|
||||
chatgpt_service = ChatGptService(chat_gpt_model=settings.GPT_MODEL)
|
||||
async def gpt_healthcheck(
|
||||
response: Response,
|
||||
chatgpt_service: ChatGptService = Depends(get_chat_gpt_service),
|
||||
) -> Response:
|
||||
data = chatgpt_service.build_request_data("Привет!")
|
||||
response.status_code = status.HTTP_200_OK
|
||||
try:
|
||||
@@ -41,7 +44,7 @@ async def gpt_healthcheck(response: Response) -> Response:
|
||||
for message in INVALID_GPT_REQUEST_MESSAGES:
|
||||
if message in chatgpt_response.text:
|
||||
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
except Exception:
|
||||
except BaseAPIException:
|
||||
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
return Response(status_code=response.status_code, content=None)
|
||||
|
||||
0
bot_microservice/core/bot/__init__.py
Normal file
0
bot_microservice/core/bot/__init__.py
Normal file
@@ -7,8 +7,8 @@ from telegram import InlineKeyboardMarkup, Update
|
||||
from telegram.ext import ContextTypes
|
||||
|
||||
from constants import BotEntryPoints
|
||||
from core.keyboards import main_keyboard
|
||||
from core.utils import ChatGptService, SpeechToTextService
|
||||
from core.bot.keyboards import main_keyboard
|
||||
from core.bot.services import ChatGptService, SpeechToTextService
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from telegram.ext import (
|
||||
)
|
||||
|
||||
from constants import BotEntryPoints, BotStagesEnum
|
||||
from core.commands import (
|
||||
from core.bot.commands import (
|
||||
about_bot,
|
||||
about_me,
|
||||
ask_question,
|
||||
138
bot_microservice/core/bot/services.py
Normal file
138
bot_microservice/core/bot/services.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import os
|
||||
import random
|
||||
import subprocess # noqa
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient, AsyncHTTPTransport, Response
|
||||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
from speech_recognition import (
|
||||
AudioFile,
|
||||
Recognizer,
|
||||
UnknownValueError as SpeechRecognizerError,
|
||||
)
|
||||
|
||||
from constants import (
|
||||
AUDIO_SEGMENT_DURATION,
|
||||
CHAT_GPT_BASE_URI,
|
||||
INVALID_GPT_REQUEST_MESSAGES,
|
||||
)
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
class SpeechToTextService:
|
||||
def __init__(self, filename: str) -> None:
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
self.filename = filename
|
||||
self.recognizer = Recognizer()
|
||||
self.recognizer.energy_threshold = 50
|
||||
self.text_parts: dict[int, str] = {}
|
||||
self.text_recognised = False
|
||||
|
||||
def get_text_from_audio(self) -> None:
|
||||
self.executor.submit(self.worker)
|
||||
|
||||
def worker(self) -> Any:
|
||||
self._convert_file_to_wav()
|
||||
self._convert_audio_to_text()
|
||||
|
||||
def _convert_audio_to_text(self) -> None:
|
||||
wav_filename = f"{self.filename}.wav"
|
||||
|
||||
speech = AudioSegment.from_wav(wav_filename)
|
||||
speech_duration = len(speech)
|
||||
pieces = speech_duration // AUDIO_SEGMENT_DURATION + 1
|
||||
ending = speech_duration % AUDIO_SEGMENT_DURATION
|
||||
for i in range(pieces):
|
||||
if i == 0 and pieces == 1:
|
||||
sound_segment = speech[0:ending]
|
||||
elif i == 0:
|
||||
sound_segment = speech[0 : (i + 1) * AUDIO_SEGMENT_DURATION]
|
||||
elif i == (pieces - 1):
|
||||
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : i * AUDIO_SEGMENT_DURATION + ending]
|
||||
else:
|
||||
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : (i + 1) * AUDIO_SEGMENT_DURATION]
|
||||
self.text_parts[i] = self._recognize_by_google(wav_filename, sound_segment)
|
||||
|
||||
self.text_recognised = True
|
||||
|
||||
# clean temp voice message main files
|
||||
try:
|
||||
os.remove(wav_filename)
|
||||
os.remove(self.filename)
|
||||
except FileNotFoundError as error:
|
||||
logger.error("error temps files not deleted", error=error, filenames=[self.filename, self.filename])
|
||||
|
||||
def _convert_file_to_wav(self) -> None:
|
||||
new_filename = self.filename + ".wav"
|
||||
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", self.filename, "-vn", new_filename]
|
||||
try:
|
||||
subprocess.run(args=cmd) # noqa: S603
|
||||
logger.info("file has been converted to wav", filename=new_filename)
|
||||
except Exception as error:
|
||||
logger.error("cant convert voice", error=error, filename=self.filename)
|
||||
|
||||
def _recognize_by_google(self, filename: str, sound_segment: AudioSegment) -> str:
|
||||
tmp_filename = f"{filename}_tmp_part"
|
||||
sound_segment.export(tmp_filename, format="wav")
|
||||
with AudioFile(tmp_filename) as source:
|
||||
audio_text = self.recognizer.listen(source)
|
||||
try:
|
||||
text = self.recognizer.recognize_google(audio_text, language="ru-RU")
|
||||
os.remove(tmp_filename)
|
||||
return text
|
||||
except SpeechRecognizerError as error:
|
||||
os.remove(tmp_filename)
|
||||
logger.error("error recognizing text with google", error=error)
|
||||
raise error
|
||||
|
||||
|
||||
class ChatGptService:
|
||||
def __init__(self, chat_gpt_model: str) -> None:
|
||||
self.chat_gpt_model = chat_gpt_model
|
||||
|
||||
async def request_to_chatgpt(self, question: str | None) -> str:
|
||||
question = question or "Привет!"
|
||||
chat_gpt_request = self.build_request_data(question)
|
||||
try:
|
||||
response = await self.do_request(chat_gpt_request)
|
||||
status = response.status_code
|
||||
for message in INVALID_GPT_REQUEST_MESSAGES:
|
||||
if message in response.text:
|
||||
message = f"{message}: {settings.GPT_MODEL}"
|
||||
logger.info(message, data=chat_gpt_request)
|
||||
return message
|
||||
if status != httpx.codes.OK:
|
||||
logger.info(f"got response status: {status} from chat api", data=chat_gpt_request)
|
||||
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
|
||||
return response.text
|
||||
except Exception as error:
|
||||
logger.error("error get data from chat api", error=error)
|
||||
return "Вообще всё сломалось :("
|
||||
|
||||
@staticmethod
|
||||
async def do_request(data: dict[str, Any]) -> Response:
|
||||
transport = AsyncHTTPTransport(retries=3)
|
||||
async with AsyncClient(base_url=settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
|
||||
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
|
||||
|
||||
def build_request_data(self, question: str) -> dict[str, Any]:
|
||||
return {
|
||||
"conversation_id": str(uuid4()),
|
||||
"action": "_ask",
|
||||
"model": self.chat_gpt_model,
|
||||
"jailbreak": "default",
|
||||
"meta": {
|
||||
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
|
||||
"content": {
|
||||
"conversation": [],
|
||||
"internet_access": False,
|
||||
"content_type": "text",
|
||||
"parts": [{"content": question, "role": "user"}],
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1,28 +1,6 @@
|
||||
import os
|
||||
import random
|
||||
import subprocess # noqa
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient, AsyncHTTPTransport, Response
|
||||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
from speech_recognition import (
|
||||
AudioFile,
|
||||
Recognizer,
|
||||
UnknownValueError as SpeechRecognizerError,
|
||||
)
|
||||
|
||||
from constants import (
|
||||
AUDIO_SEGMENT_DURATION,
|
||||
CHAT_GPT_BASE_URI,
|
||||
INVALID_GPT_REQUEST_MESSAGES,
|
||||
)
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
def timed_cache(**timedelta_kwargs: Any) -> Any:
|
||||
@@ -44,118 +22,3 @@ def timed_cache(**timedelta_kwargs: Any) -> Any:
|
||||
return _wrapped
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
class SpeechToTextService:
|
||||
def __init__(self, filename: str) -> None:
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
self.filename = filename
|
||||
self.recognizer = Recognizer()
|
||||
self.recognizer.energy_threshold = 50
|
||||
self.text_parts: dict[int, str] = {}
|
||||
self.text_recognised = False
|
||||
|
||||
def get_text_from_audio(self) -> None:
|
||||
self.executor.submit(self.worker)
|
||||
|
||||
def worker(self) -> Any:
|
||||
self._convert_file_to_wav()
|
||||
self._convert_audio_to_text()
|
||||
|
||||
def _convert_audio_to_text(self) -> None:
|
||||
wav_filename = f"{self.filename}.wav"
|
||||
|
||||
speech = AudioSegment.from_wav(wav_filename)
|
||||
speech_duration = len(speech)
|
||||
pieces = speech_duration // AUDIO_SEGMENT_DURATION + 1
|
||||
ending = speech_duration % AUDIO_SEGMENT_DURATION
|
||||
for i in range(pieces):
|
||||
if i == 0 and pieces == 1:
|
||||
sound_segment = speech[0:ending]
|
||||
elif i == 0:
|
||||
sound_segment = speech[0 : (i + 1) * AUDIO_SEGMENT_DURATION]
|
||||
elif i == (pieces - 1):
|
||||
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : i * AUDIO_SEGMENT_DURATION + ending]
|
||||
else:
|
||||
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : (i + 1) * AUDIO_SEGMENT_DURATION]
|
||||
self.text_parts[i] = self._recognize_by_google(wav_filename, sound_segment)
|
||||
|
||||
self.text_recognised = True
|
||||
|
||||
# clean temp voice message main files
|
||||
try:
|
||||
os.remove(wav_filename)
|
||||
os.remove(self.filename)
|
||||
except FileNotFoundError as error:
|
||||
logger.error("error temps files not deleted", error=error, filenames=[self.filename, self.filename])
|
||||
|
||||
def _convert_file_to_wav(self) -> None:
|
||||
new_filename = self.filename + ".wav"
|
||||
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", self.filename, "-vn", new_filename]
|
||||
try:
|
||||
subprocess.run(args=cmd) # noqa: S603
|
||||
logger.info("file has been converted to wav", filename=new_filename)
|
||||
except Exception as error:
|
||||
logger.error("cant convert voice", error=error, filename=self.filename)
|
||||
|
||||
def _recognize_by_google(self, filename: str, sound_segment: AudioSegment) -> str:
|
||||
tmp_filename = f"{filename}_tmp_part"
|
||||
sound_segment.export(tmp_filename, format="wav")
|
||||
with AudioFile(tmp_filename) as source:
|
||||
audio_text = self.recognizer.listen(source)
|
||||
try:
|
||||
text = self.recognizer.recognize_google(audio_text, language="ru-RU")
|
||||
os.remove(tmp_filename)
|
||||
return text
|
||||
except SpeechRecognizerError as error:
|
||||
os.remove(tmp_filename)
|
||||
logger.error("error recognizing text with google", error=error)
|
||||
raise error
|
||||
|
||||
|
||||
class ChatGptService:
|
||||
def __init__(self, chat_gpt_model: str) -> None:
|
||||
self.chat_gpt_model = chat_gpt_model
|
||||
|
||||
async def request_to_chatgpt(self, question: str | None) -> str:
|
||||
question = question or "Привет!"
|
||||
chat_gpt_request = self.build_request_data(question)
|
||||
try:
|
||||
response = await self.do_request(chat_gpt_request)
|
||||
status = response.status_code
|
||||
for message in INVALID_GPT_REQUEST_MESSAGES:
|
||||
if message in response.text:
|
||||
message = f"{message}: {settings.GPT_MODEL}"
|
||||
logger.info(message, data=chat_gpt_request)
|
||||
return message
|
||||
if status != httpx.codes.OK:
|
||||
logger.info(f"got response status: {status} from chat api", data=chat_gpt_request)
|
||||
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
|
||||
return response.text
|
||||
except Exception as error:
|
||||
logger.error("error get data from chat api", error=error)
|
||||
return "Вообще всё сломалось :("
|
||||
|
||||
@staticmethod
|
||||
async def do_request(data: dict[str, Any]) -> Response:
|
||||
transport = AsyncHTTPTransport(retries=3)
|
||||
async with AsyncClient(base_url=settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
|
||||
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
|
||||
|
||||
def build_request_data(self, question: str) -> dict[str, Any]:
|
||||
return {
|
||||
"conversation_id": str(uuid4()),
|
||||
"action": "_ask",
|
||||
"model": self.chat_gpt_model,
|
||||
"jailbreak": "default",
|
||||
"meta": {
|
||||
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
|
||||
"content": {
|
||||
"conversation": [],
|
||||
"internet_access": False,
|
||||
"content_type": "text",
|
||||
"parts": [{"content": question, "role": "user"}],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
0
bot_microservice/infra/__init__.py
Normal file
0
bot_microservice/infra/__init__.py
Normal file
@@ -31,10 +31,13 @@ class InterceptHandler(logging.Handler):
|
||||
frame = cast(FrameType, frame.f_back)
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level,
|
||||
record.getMessage().replace(settings.TELEGRAM_API_TOKEN, "TELEGRAM_API_TOKEN".center(24, "*")),
|
||||
)
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(level, self._scrap_sensitive_info(record))
|
||||
|
||||
@staticmethod
|
||||
def _scrap_sensitive_info(record: logging.LogRecord) -> str:
|
||||
message = record.getMessage()
|
||||
message.replace(settings.TELEGRAM_API_TOKEN, "TELEGRAM_API_TOKEN".center(24, "*"))
|
||||
return message
|
||||
|
||||
|
||||
def configure_logging(
|
||||
@@ -6,9 +6,9 @@ from fastapi import FastAPI
|
||||
from fastapi.responses import UJSONResponse
|
||||
|
||||
from constants import LogLevelEnum
|
||||
from core.bot import BotApplication, BotQueue
|
||||
from core.handlers import bot_event_handlers
|
||||
from core.logging import configure_logging
|
||||
from core.bot.app import BotApplication, BotQueue
|
||||
from core.bot.handlers import bot_event_handlers
|
||||
from infra.logging_conf import configure_logging
|
||||
from routers import api_router
|
||||
from settings.config import AppSettings, get_settings
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from functools import cached_property
|
||||
from functools import cached_property, lru_cache
|
||||
from os import environ
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -83,14 +83,19 @@ class AppSettings(SentrySettings, BaseSettings):
|
||||
return "/" + "/".join([self.URL_PREFIX.strip("/"), API_PREFIX.strip("/")])
|
||||
return API_PREFIX
|
||||
|
||||
@cached_property
|
||||
def token_part(self) -> str:
|
||||
return self.TELEGRAM_API_TOKEN[15:30]
|
||||
|
||||
@cached_property
|
||||
def bot_webhook_url(self) -> str:
|
||||
return "/".join([self.api_prefix, self.TELEGRAM_API_TOKEN])
|
||||
return "/".join([self.api_prefix, self.token_part])
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_settings() -> AppSettings:
|
||||
return AppSettings()
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from httpx import AsyncClient, Response
|
||||
from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update
|
||||
|
||||
from constants import BotStagesEnum
|
||||
from core.bot import BotApplication, BotQueue
|
||||
from core.bot.app import BotApplication, BotQueue
|
||||
from main import Application
|
||||
from settings.config import AppSettings, settings
|
||||
from tests.integration.bot.networking import MockedRequest
|
||||
@@ -37,7 +37,7 @@ async def test_bot_webhook_endpoint(
|
||||
main_application: Application,
|
||||
) -> None:
|
||||
bot_update = BotUpdateFactory(message=BotMessageFactory.create_instance(text="/help"))
|
||||
response = await rest_client.post(url="/api/123456789:AABBCCDDEEFFaabbccddeeff-1234567890", json=bot_update)
|
||||
response = await rest_client.post(url="/api/CDDEEFFaabbccdd", json=bot_update)
|
||||
assert response.status_code == 202
|
||||
update = await main_application.fastapi_app.state._state["queue"].queue.get()
|
||||
update = update.to_dict()
|
||||
|
||||
@@ -14,8 +14,8 @@ from pytest_asyncio.plugin import SubRequest
|
||||
from telegram import Bot, User
|
||||
from telegram.ext import Application, ApplicationBuilder, Defaults, ExtBot
|
||||
|
||||
from core.bot import BotApplication
|
||||
from core.handlers import bot_event_handlers
|
||||
from core.bot.app import BotApplication
|
||||
from core.bot.handlers import bot_event_handlers
|
||||
from main import Application as AppApplication
|
||||
from settings.config import AppSettings, get_settings
|
||||
from tests.integration.bot.networking import NonchalantHttpxRequest
|
||||
|
||||
@@ -3,6 +3,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from httpx import AsyncClient, Response
|
||||
|
||||
from api.exceptions import BaseAPIException
|
||||
from settings.config import AppSettings
|
||||
from tests.integration.utils import mocked_ask_question_api
|
||||
|
||||
@@ -50,7 +51,7 @@ async def test_bot_healthcheck_not_ok(
|
||||
) -> None:
|
||||
with mocked_ask_question_api(
|
||||
host=test_settings.GPT_BASE_HOST,
|
||||
side_effect=Exception(),
|
||||
side_effect=BaseAPIException(),
|
||||
):
|
||||
response = await rest_client.get("/api/bot-healthcheck")
|
||||
assert response.status_code == httpx.codes.INTERNAL_SERVER_ERROR
|
||||
Reference in New Issue
Block a user