mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2026-02-03 11:40:39 +03:00
refactoring (#26)
This commit is contained in:
0
bot_microservice/core/bot/__init__.py
Normal file
0
bot_microservice/core/bot/__init__.py
Normal file
85
bot_microservice/core/bot/app.py
Normal file
85
bot_microservice/core/bot/app.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
import os
|
||||
from asyncio import Queue, sleep
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request, Response
|
||||
from loguru import logger
|
||||
from telegram import Bot, Update
|
||||
from telegram.ext import Application
|
||||
|
||||
from settings.config import AppSettings
|
||||
|
||||
|
||||
class BotApplication:
|
||||
def __init__(
|
||||
self,
|
||||
settings: AppSettings,
|
||||
handlers: list[Any],
|
||||
) -> None:
|
||||
self.application: Application = ( # type: ignore[type-arg]
|
||||
Application.builder().token(token=settings.TELEGRAM_API_TOKEN).build()
|
||||
)
|
||||
self.handlers = handlers
|
||||
self.settings = settings
|
||||
self.start_with_webhook = settings.START_WITH_WEBHOOK
|
||||
self._add_handlers()
|
||||
|
||||
@property
|
||||
def bot(self) -> Bot:
|
||||
return self.application.bot
|
||||
|
||||
async def set_webhook(self) -> None:
|
||||
_, webhook_info = await asyncio.gather(self.application.initialize(), self.application.bot.get_webhook_info())
|
||||
if not webhook_info.url:
|
||||
await self.application.bot.set_webhook(url=self.webhook_url)
|
||||
webhook_info = await self.application.bot.get_webhook_info()
|
||||
logger.info("webhook is set", ip_address=webhook_info.ip_address)
|
||||
|
||||
async def delete_webhook(self) -> None:
|
||||
if await self.application.bot.delete_webhook():
|
||||
logger.info("webhook has been deleted")
|
||||
|
||||
async def polling(self) -> None:
|
||||
if self.settings.STAGE == "runtests":
|
||||
return None
|
||||
await self.application.initialize()
|
||||
await self.application.start()
|
||||
await self.application.updater.start_polling() # type: ignore
|
||||
logger.info("bot started in polling mode")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self.application.updater.shutdown() # type: ignore
|
||||
|
||||
@cached_property
|
||||
def webhook_url(self) -> str:
|
||||
return os.path.join(self.settings.DOMAIN.strip("/"), self.settings.bot_webhook_url.strip("/"))
|
||||
|
||||
def _add_handlers(self) -> None:
|
||||
for handler in self.handlers:
|
||||
self.application.add_handler(handler)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotQueue:
|
||||
bot_app: BotApplication
|
||||
queue: Queue = asyncio.Queue() # type: ignore[type-arg]
|
||||
|
||||
async def put_updates_on_queue(self, request: Request) -> Response:
|
||||
"""
|
||||
Listen /{URL_PREFIX}/{API_PREFIX}/{TELEGRAM_WEB_TOKEN} path and proxy post request to bot
|
||||
"""
|
||||
data = await request.json()
|
||||
tg_update = Update.de_json(data=data, bot=self.bot_app.application.bot)
|
||||
self.queue.put_nowait(tg_update)
|
||||
|
||||
return Response(status_code=HTTPStatus.ACCEPTED)
|
||||
|
||||
async def get_updates_from_queue(self) -> None:
|
||||
while True:
|
||||
update = await self.queue.get()
|
||||
asyncio.create_task(self.bot_app.application.process_update(update))
|
||||
await sleep(0)
|
||||
100
bot_microservice/core/bot/commands.py
Normal file
100
bot_microservice/core/bot/commands.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import asyncio
|
||||
import tempfile
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from loguru import logger
|
||||
from telegram import InlineKeyboardMarkup, Update
|
||||
from telegram.ext import ContextTypes
|
||||
|
||||
from constants import BotEntryPoints
|
||||
from core.bot.keyboards import main_keyboard
|
||||
from core.bot.services import ChatGptService, SpeechToTextService
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
async def main_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> str:
|
||||
"""Send message on `/start`."""
|
||||
if not update.message:
|
||||
return BotEntryPoints.end
|
||||
reply_markup = InlineKeyboardMarkup(main_keyboard)
|
||||
await update.message.reply_text("Выберете команду:", reply_markup=reply_markup)
|
||||
return BotEntryPoints.start_routes
|
||||
|
||||
|
||||
async def about_me(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if not update.effective_message:
|
||||
return None
|
||||
await update.effective_message.reply_text(
|
||||
"Автор бота: *Дмитрий Афанасьев*\n\nTg nickname: *Balshtg*", parse_mode="MarkdownV2"
|
||||
)
|
||||
|
||||
|
||||
async def about_bot(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if not update.effective_message:
|
||||
return None
|
||||
await update.effective_message.reply_text(
|
||||
f"Бот использует бесплатную модель {settings.GPT_MODEL} для ответов на вопросы. "
|
||||
f"\nПринимает запросы на разных языках.\n\nБот так же умеет переводить русские голосовые сообщения в текст. "
|
||||
f"Просто пришлите голосовуху и получите поток сознания в виде текста, но без знаков препинания",
|
||||
parse_mode="Markdown",
|
||||
)
|
||||
|
||||
|
||||
async def website(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if not update.effective_message:
|
||||
return None
|
||||
website = urljoin(settings.DOMAIN, f"{settings.URL_PREFIX}/chat/")
|
||||
await update.effective_message.reply_text(f"Веб версия: {website}")
|
||||
|
||||
|
||||
async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Send a message when the command /help is issued."""
|
||||
|
||||
if not update.effective_message:
|
||||
return None
|
||||
reply_markup = InlineKeyboardMarkup(main_keyboard)
|
||||
await update.effective_message.reply_text(
|
||||
"Help!",
|
||||
disable_notification=True,
|
||||
api_kwargs={"text": "Список основных команд:"},
|
||||
reply_markup=reply_markup,
|
||||
)
|
||||
|
||||
|
||||
async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if not update.message:
|
||||
return None
|
||||
|
||||
await update.message.reply_text("Пожалуйста подождите, ответ в среднем занимает 10-15 секунд")
|
||||
|
||||
chat_gpt_service = ChatGptService(chat_gpt_model=settings.GPT_MODEL)
|
||||
logger.warning("question asked", user=update.message.from_user, question=update.message.text)
|
||||
answer = await chat_gpt_service.request_to_chatgpt(question=update.message.text)
|
||||
await update.message.reply_text(answer)
|
||||
|
||||
|
||||
async def voice_recognize(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if not update.message:
|
||||
return None
|
||||
await update.message.reply_text("Пожалуйста, ожидайте :)\nТрехминутная запись обрабатывается примерно 30 секунд")
|
||||
if not update.message.voice:
|
||||
return None
|
||||
|
||||
sound_file = await update.message.voice.get_file()
|
||||
sound_bytes = await sound_file.download_as_bytearray()
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
|
||||
tmpfile.write(sound_bytes)
|
||||
|
||||
logger.info("file has been saved", filename=tmpfile.name)
|
||||
|
||||
speech_to_text_service = SpeechToTextService(filename=tmpfile.name)
|
||||
|
||||
speech_to_text_service.get_text_from_audio()
|
||||
|
||||
part = 0
|
||||
while speech_to_text_service.text_parts or not speech_to_text_service.text_recognised:
|
||||
if text := speech_to_text_service.text_parts.get(part):
|
||||
speech_to_text_service.text_parts.pop(part)
|
||||
await update.message.reply_text(text)
|
||||
part += 1
|
||||
await asyncio.sleep(5)
|
||||
54
bot_microservice/core/bot/handlers.py
Normal file
54
bot_microservice/core/bot/handlers.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from telegram.ext import (
|
||||
CallbackQueryHandler,
|
||||
CommandHandler,
|
||||
ConversationHandler,
|
||||
MessageHandler,
|
||||
filters,
|
||||
)
|
||||
|
||||
from constants import BotEntryPoints, BotStagesEnum
|
||||
from core.bot.commands import (
|
||||
about_bot,
|
||||
about_me,
|
||||
ask_question,
|
||||
help_command,
|
||||
main_command,
|
||||
voice_recognize,
|
||||
website,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotEventHandlers:
|
||||
handlers: list[Any] = field(default_factory=list[Any])
|
||||
|
||||
def add_handler(self, handler: Any) -> None:
|
||||
self.handlers.append(handler)
|
||||
|
||||
|
||||
bot_event_handlers = BotEventHandlers()
|
||||
|
||||
bot_event_handlers.add_handler(CommandHandler("help", help_command))
|
||||
bot_event_handlers.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, ask_question))
|
||||
bot_event_handlers.add_handler(MessageHandler(filters.VOICE | filters.AUDIO, voice_recognize))
|
||||
bot_event_handlers.add_handler(
|
||||
ConversationHandler(
|
||||
entry_points=[CommandHandler("start", main_command)],
|
||||
states={
|
||||
BotEntryPoints.start_routes: [
|
||||
CallbackQueryHandler(about_me, pattern="^" + BotStagesEnum.about_me + "$"),
|
||||
CallbackQueryHandler(website, pattern="^" + BotStagesEnum.website + "$"),
|
||||
CallbackQueryHandler(help_command, pattern="^" + BotStagesEnum.help + "$"),
|
||||
CallbackQueryHandler(about_bot, pattern="^" + BotStagesEnum.about_bot + "$"),
|
||||
],
|
||||
},
|
||||
fallbacks=[CommandHandler("start", main_command)],
|
||||
)
|
||||
)
|
||||
bot_event_handlers.add_handler(CallbackQueryHandler(about_me, pattern="^" + BotStagesEnum.about_me + "$"))
|
||||
bot_event_handlers.add_handler(CallbackQueryHandler(website, pattern="^" + BotStagesEnum.website + "$"))
|
||||
bot_event_handlers.add_handler(CallbackQueryHandler(help_command, pattern="^" + BotStagesEnum.help + "$"))
|
||||
bot_event_handlers.add_handler(CallbackQueryHandler(about_bot, pattern="^" + BotStagesEnum.about_bot + "$"))
|
||||
14
bot_microservice/core/bot/keyboards.py
Normal file
14
bot_microservice/core/bot/keyboards.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from telegram import InlineKeyboardButton
|
||||
|
||||
from constants import BotStagesEnum
|
||||
|
||||
main_keyboard = (
|
||||
(
|
||||
InlineKeyboardButton("Обо мне", callback_data=str(BotStagesEnum.about_me)),
|
||||
InlineKeyboardButton("Веб версия", callback_data=str(BotStagesEnum.website)),
|
||||
),
|
||||
(
|
||||
InlineKeyboardButton("Помощь", callback_data=str(BotStagesEnum.help)),
|
||||
InlineKeyboardButton("О боте", callback_data=str(BotStagesEnum.about_bot)),
|
||||
),
|
||||
)
|
||||
138
bot_microservice/core/bot/services.py
Normal file
138
bot_microservice/core/bot/services.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import os
|
||||
import random
|
||||
import subprocess # noqa
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient, AsyncHTTPTransport, Response
|
||||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
from speech_recognition import (
|
||||
AudioFile,
|
||||
Recognizer,
|
||||
UnknownValueError as SpeechRecognizerError,
|
||||
)
|
||||
|
||||
from constants import (
|
||||
AUDIO_SEGMENT_DURATION,
|
||||
CHAT_GPT_BASE_URI,
|
||||
INVALID_GPT_REQUEST_MESSAGES,
|
||||
)
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
class SpeechToTextService:
|
||||
def __init__(self, filename: str) -> None:
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
self.filename = filename
|
||||
self.recognizer = Recognizer()
|
||||
self.recognizer.energy_threshold = 50
|
||||
self.text_parts: dict[int, str] = {}
|
||||
self.text_recognised = False
|
||||
|
||||
def get_text_from_audio(self) -> None:
|
||||
self.executor.submit(self.worker)
|
||||
|
||||
def worker(self) -> Any:
|
||||
self._convert_file_to_wav()
|
||||
self._convert_audio_to_text()
|
||||
|
||||
def _convert_audio_to_text(self) -> None:
|
||||
wav_filename = f"{self.filename}.wav"
|
||||
|
||||
speech = AudioSegment.from_wav(wav_filename)
|
||||
speech_duration = len(speech)
|
||||
pieces = speech_duration // AUDIO_SEGMENT_DURATION + 1
|
||||
ending = speech_duration % AUDIO_SEGMENT_DURATION
|
||||
for i in range(pieces):
|
||||
if i == 0 and pieces == 1:
|
||||
sound_segment = speech[0:ending]
|
||||
elif i == 0:
|
||||
sound_segment = speech[0 : (i + 1) * AUDIO_SEGMENT_DURATION]
|
||||
elif i == (pieces - 1):
|
||||
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : i * AUDIO_SEGMENT_DURATION + ending]
|
||||
else:
|
||||
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : (i + 1) * AUDIO_SEGMENT_DURATION]
|
||||
self.text_parts[i] = self._recognize_by_google(wav_filename, sound_segment)
|
||||
|
||||
self.text_recognised = True
|
||||
|
||||
# clean temp voice message main files
|
||||
try:
|
||||
os.remove(wav_filename)
|
||||
os.remove(self.filename)
|
||||
except FileNotFoundError as error:
|
||||
logger.error("error temps files not deleted", error=error, filenames=[self.filename, self.filename])
|
||||
|
||||
def _convert_file_to_wav(self) -> None:
|
||||
new_filename = self.filename + ".wav"
|
||||
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", self.filename, "-vn", new_filename]
|
||||
try:
|
||||
subprocess.run(args=cmd) # noqa: S603
|
||||
logger.info("file has been converted to wav", filename=new_filename)
|
||||
except Exception as error:
|
||||
logger.error("cant convert voice", error=error, filename=self.filename)
|
||||
|
||||
def _recognize_by_google(self, filename: str, sound_segment: AudioSegment) -> str:
|
||||
tmp_filename = f"{filename}_tmp_part"
|
||||
sound_segment.export(tmp_filename, format="wav")
|
||||
with AudioFile(tmp_filename) as source:
|
||||
audio_text = self.recognizer.listen(source)
|
||||
try:
|
||||
text = self.recognizer.recognize_google(audio_text, language="ru-RU")
|
||||
os.remove(tmp_filename)
|
||||
return text
|
||||
except SpeechRecognizerError as error:
|
||||
os.remove(tmp_filename)
|
||||
logger.error("error recognizing text with google", error=error)
|
||||
raise error
|
||||
|
||||
|
||||
class ChatGptService:
|
||||
def __init__(self, chat_gpt_model: str) -> None:
|
||||
self.chat_gpt_model = chat_gpt_model
|
||||
|
||||
async def request_to_chatgpt(self, question: str | None) -> str:
|
||||
question = question or "Привет!"
|
||||
chat_gpt_request = self.build_request_data(question)
|
||||
try:
|
||||
response = await self.do_request(chat_gpt_request)
|
||||
status = response.status_code
|
||||
for message in INVALID_GPT_REQUEST_MESSAGES:
|
||||
if message in response.text:
|
||||
message = f"{message}: {settings.GPT_MODEL}"
|
||||
logger.info(message, data=chat_gpt_request)
|
||||
return message
|
||||
if status != httpx.codes.OK:
|
||||
logger.info(f"got response status: {status} from chat api", data=chat_gpt_request)
|
||||
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
|
||||
return response.text
|
||||
except Exception as error:
|
||||
logger.error("error get data from chat api", error=error)
|
||||
return "Вообще всё сломалось :("
|
||||
|
||||
@staticmethod
|
||||
async def do_request(data: dict[str, Any]) -> Response:
|
||||
transport = AsyncHTTPTransport(retries=3)
|
||||
async with AsyncClient(base_url=settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
|
||||
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
|
||||
|
||||
def build_request_data(self, question: str) -> dict[str, Any]:
|
||||
return {
|
||||
"conversation_id": str(uuid4()),
|
||||
"action": "_ask",
|
||||
"model": self.chat_gpt_model,
|
||||
"jailbreak": "default",
|
||||
"meta": {
|
||||
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
|
||||
"content": {
|
||||
"conversation": [],
|
||||
"internet_access": False,
|
||||
"content_type": "text",
|
||||
"parts": [{"content": question, "role": "user"}],
|
||||
},
|
||||
},
|
||||
}
|
||||
Reference in New Issue
Block a user