refactoring (#26)

This commit is contained in:
Dmitry Afanasyev
2023-10-03 23:30:19 +03:00
committed by GitHub
parent 482e1fdda1
commit c401e1006c
22 changed files with 423 additions and 395 deletions

View File

View File

@@ -0,0 +1,85 @@
import asyncio
import os
from asyncio import Queue, sleep
from dataclasses import dataclass
from functools import cached_property
from http import HTTPStatus
from typing import Any
from fastapi import Request, Response
from loguru import logger
from telegram import Bot, Update
from telegram.ext import Application
from settings.config import AppSettings
class BotApplication:
def __init__(
self,
settings: AppSettings,
handlers: list[Any],
) -> None:
self.application: Application = ( # type: ignore[type-arg]
Application.builder().token(token=settings.TELEGRAM_API_TOKEN).build()
)
self.handlers = handlers
self.settings = settings
self.start_with_webhook = settings.START_WITH_WEBHOOK
self._add_handlers()
@property
def bot(self) -> Bot:
return self.application.bot
async def set_webhook(self) -> None:
_, webhook_info = await asyncio.gather(self.application.initialize(), self.application.bot.get_webhook_info())
if not webhook_info.url:
await self.application.bot.set_webhook(url=self.webhook_url)
webhook_info = await self.application.bot.get_webhook_info()
logger.info("webhook is set", ip_address=webhook_info.ip_address)
async def delete_webhook(self) -> None:
if await self.application.bot.delete_webhook():
logger.info("webhook has been deleted")
async def polling(self) -> None:
if self.settings.STAGE == "runtests":
return None
await self.application.initialize()
await self.application.start()
await self.application.updater.start_polling() # type: ignore
logger.info("bot started in polling mode")
async def shutdown(self) -> None:
await self.application.updater.shutdown() # type: ignore
@cached_property
def webhook_url(self) -> str:
return os.path.join(self.settings.DOMAIN.strip("/"), self.settings.bot_webhook_url.strip("/"))
def _add_handlers(self) -> None:
for handler in self.handlers:
self.application.add_handler(handler)
@dataclass
class BotQueue:
bot_app: BotApplication
queue: Queue = asyncio.Queue() # type: ignore[type-arg]
async def put_updates_on_queue(self, request: Request) -> Response:
"""
Listen /{URL_PREFIX}/{API_PREFIX}/{TELEGRAM_WEB_TOKEN} path and proxy post request to bot
"""
data = await request.json()
tg_update = Update.de_json(data=data, bot=self.bot_app.application.bot)
self.queue.put_nowait(tg_update)
return Response(status_code=HTTPStatus.ACCEPTED)
async def get_updates_from_queue(self) -> None:
while True:
update = await self.queue.get()
asyncio.create_task(self.bot_app.application.process_update(update))
await sleep(0)

View File

@@ -0,0 +1,100 @@
import asyncio
import tempfile
from urllib.parse import urljoin
from loguru import logger
from telegram import InlineKeyboardMarkup, Update
from telegram.ext import ContextTypes
from constants import BotEntryPoints
from core.bot.keyboards import main_keyboard
from core.bot.services import ChatGptService, SpeechToTextService
from settings.config import settings
async def main_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> str:
"""Send message on `/start`."""
if not update.message:
return BotEntryPoints.end
reply_markup = InlineKeyboardMarkup(main_keyboard)
await update.message.reply_text("Выберете команду:", reply_markup=reply_markup)
return BotEntryPoints.start_routes
async def about_me(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if not update.effective_message:
return None
await update.effective_message.reply_text(
"Автор бота: *Дмитрий Афанасьев*\n\nTg nickname: *Balshtg*", parse_mode="MarkdownV2"
)
async def about_bot(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if not update.effective_message:
return None
await update.effective_message.reply_text(
f"Бот использует бесплатную модель {settings.GPT_MODEL} для ответов на вопросы. "
f"\nПринимает запросы на разных языках.\n\nБот так же умеет переводить русские голосовые сообщения в текст. "
f"Просто пришлите голосовуху и получите поток сознания в виде текста, но без знаков препинания",
parse_mode="Markdown",
)
async def website(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if not update.effective_message:
return None
website = urljoin(settings.DOMAIN, f"{settings.URL_PREFIX}/chat/")
await update.effective_message.reply_text(f"Веб версия: {website}")
async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Send a message when the command /help is issued."""
if not update.effective_message:
return None
reply_markup = InlineKeyboardMarkup(main_keyboard)
await update.effective_message.reply_text(
"Help!",
disable_notification=True,
api_kwargs={"text": "Список основных команд:"},
reply_markup=reply_markup,
)
async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if not update.message:
return None
await update.message.reply_text("Пожалуйста подождите, ответ в среднем занимает 10-15 секунд")
chat_gpt_service = ChatGptService(chat_gpt_model=settings.GPT_MODEL)
logger.warning("question asked", user=update.message.from_user, question=update.message.text)
answer = await chat_gpt_service.request_to_chatgpt(question=update.message.text)
await update.message.reply_text(answer)
async def voice_recognize(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if not update.message:
return None
await update.message.reply_text("Пожалуйста, ожидайте :)\nТрехминутная запись обрабатывается примерно 30 секунд")
if not update.message.voice:
return None
sound_file = await update.message.voice.get_file()
sound_bytes = await sound_file.download_as_bytearray()
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
tmpfile.write(sound_bytes)
logger.info("file has been saved", filename=tmpfile.name)
speech_to_text_service = SpeechToTextService(filename=tmpfile.name)
speech_to_text_service.get_text_from_audio()
part = 0
while speech_to_text_service.text_parts or not speech_to_text_service.text_recognised:
if text := speech_to_text_service.text_parts.get(part):
speech_to_text_service.text_parts.pop(part)
await update.message.reply_text(text)
part += 1
await asyncio.sleep(5)

View File

@@ -0,0 +1,54 @@
from dataclasses import dataclass, field
from typing import Any
from telegram.ext import (
CallbackQueryHandler,
CommandHandler,
ConversationHandler,
MessageHandler,
filters,
)
from constants import BotEntryPoints, BotStagesEnum
from core.bot.commands import (
about_bot,
about_me,
ask_question,
help_command,
main_command,
voice_recognize,
website,
)
@dataclass
class BotEventHandlers:
handlers: list[Any] = field(default_factory=list[Any])
def add_handler(self, handler: Any) -> None:
self.handlers.append(handler)
bot_event_handlers = BotEventHandlers()
bot_event_handlers.add_handler(CommandHandler("help", help_command))
bot_event_handlers.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, ask_question))
bot_event_handlers.add_handler(MessageHandler(filters.VOICE | filters.AUDIO, voice_recognize))
bot_event_handlers.add_handler(
ConversationHandler(
entry_points=[CommandHandler("start", main_command)],
states={
BotEntryPoints.start_routes: [
CallbackQueryHandler(about_me, pattern="^" + BotStagesEnum.about_me + "$"),
CallbackQueryHandler(website, pattern="^" + BotStagesEnum.website + "$"),
CallbackQueryHandler(help_command, pattern="^" + BotStagesEnum.help + "$"),
CallbackQueryHandler(about_bot, pattern="^" + BotStagesEnum.about_bot + "$"),
],
},
fallbacks=[CommandHandler("start", main_command)],
)
)
bot_event_handlers.add_handler(CallbackQueryHandler(about_me, pattern="^" + BotStagesEnum.about_me + "$"))
bot_event_handlers.add_handler(CallbackQueryHandler(website, pattern="^" + BotStagesEnum.website + "$"))
bot_event_handlers.add_handler(CallbackQueryHandler(help_command, pattern="^" + BotStagesEnum.help + "$"))
bot_event_handlers.add_handler(CallbackQueryHandler(about_bot, pattern="^" + BotStagesEnum.about_bot + "$"))

View File

@@ -0,0 +1,14 @@
from telegram import InlineKeyboardButton
from constants import BotStagesEnum
main_keyboard = (
(
InlineKeyboardButton("Обо мне", callback_data=str(BotStagesEnum.about_me)),
InlineKeyboardButton("Веб версия", callback_data=str(BotStagesEnum.website)),
),
(
InlineKeyboardButton("Помощь", callback_data=str(BotStagesEnum.help)),
InlineKeyboardButton("О боте", callback_data=str(BotStagesEnum.about_bot)),
),
)

View File

@@ -0,0 +1,138 @@
import os
import random
import subprocess # noqa
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Any
from uuid import uuid4
import httpx
from httpx import AsyncClient, AsyncHTTPTransport, Response
from loguru import logger
from pydub import AudioSegment
from speech_recognition import (
AudioFile,
Recognizer,
UnknownValueError as SpeechRecognizerError,
)
from constants import (
AUDIO_SEGMENT_DURATION,
CHAT_GPT_BASE_URI,
INVALID_GPT_REQUEST_MESSAGES,
)
from settings.config import settings
class SpeechToTextService:
def __init__(self, filename: str) -> None:
self.executor = ThreadPoolExecutor()
self.filename = filename
self.recognizer = Recognizer()
self.recognizer.energy_threshold = 50
self.text_parts: dict[int, str] = {}
self.text_recognised = False
def get_text_from_audio(self) -> None:
self.executor.submit(self.worker)
def worker(self) -> Any:
self._convert_file_to_wav()
self._convert_audio_to_text()
def _convert_audio_to_text(self) -> None:
wav_filename = f"{self.filename}.wav"
speech = AudioSegment.from_wav(wav_filename)
speech_duration = len(speech)
pieces = speech_duration // AUDIO_SEGMENT_DURATION + 1
ending = speech_duration % AUDIO_SEGMENT_DURATION
for i in range(pieces):
if i == 0 and pieces == 1:
sound_segment = speech[0:ending]
elif i == 0:
sound_segment = speech[0 : (i + 1) * AUDIO_SEGMENT_DURATION]
elif i == (pieces - 1):
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : i * AUDIO_SEGMENT_DURATION + ending]
else:
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : (i + 1) * AUDIO_SEGMENT_DURATION]
self.text_parts[i] = self._recognize_by_google(wav_filename, sound_segment)
self.text_recognised = True
# clean temp voice message main files
try:
os.remove(wav_filename)
os.remove(self.filename)
except FileNotFoundError as error:
logger.error("error temps files not deleted", error=error, filenames=[self.filename, self.filename])
def _convert_file_to_wav(self) -> None:
new_filename = self.filename + ".wav"
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", self.filename, "-vn", new_filename]
try:
subprocess.run(args=cmd) # noqa: S603
logger.info("file has been converted to wav", filename=new_filename)
except Exception as error:
logger.error("cant convert voice", error=error, filename=self.filename)
def _recognize_by_google(self, filename: str, sound_segment: AudioSegment) -> str:
tmp_filename = f"{filename}_tmp_part"
sound_segment.export(tmp_filename, format="wav")
with AudioFile(tmp_filename) as source:
audio_text = self.recognizer.listen(source)
try:
text = self.recognizer.recognize_google(audio_text, language="ru-RU")
os.remove(tmp_filename)
return text
except SpeechRecognizerError as error:
os.remove(tmp_filename)
logger.error("error recognizing text with google", error=error)
raise error
class ChatGptService:
def __init__(self, chat_gpt_model: str) -> None:
self.chat_gpt_model = chat_gpt_model
async def request_to_chatgpt(self, question: str | None) -> str:
question = question or "Привет!"
chat_gpt_request = self.build_request_data(question)
try:
response = await self.do_request(chat_gpt_request)
status = response.status_code
for message in INVALID_GPT_REQUEST_MESSAGES:
if message in response.text:
message = f"{message}: {settings.GPT_MODEL}"
logger.info(message, data=chat_gpt_request)
return message
if status != httpx.codes.OK:
logger.info(f"got response status: {status} from chat api", data=chat_gpt_request)
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
return response.text
except Exception as error:
logger.error("error get data from chat api", error=error)
return "Вообще всё сломалось :("
@staticmethod
async def do_request(data: dict[str, Any]) -> Response:
transport = AsyncHTTPTransport(retries=3)
async with AsyncClient(base_url=settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
def build_request_data(self, question: str) -> dict[str, Any]:
return {
"conversation_id": str(uuid4()),
"action": "_ask",
"model": self.chat_gpt_model,
"jailbreak": "default",
"meta": {
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
"content": {
"conversation": [],
"internet_access": False,
"content_type": "text",
"parts": [{"content": question, "role": "user"}],
},
},
}