mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2026-02-03 11:40:39 +03:00
add speech to text command (#6)
This commit is contained in:
@@ -29,13 +29,15 @@ class BotApplication:
|
||||
self._add_handlers()
|
||||
|
||||
async def set_webhook(self) -> None:
|
||||
await self.application.initialize()
|
||||
await self.application.bot.set_webhook(url=self.webhook_url)
|
||||
logger.info('webhook is set')
|
||||
_, webhook_info = await asyncio.gather(self.application.initialize(), self.application.bot.get_webhook_info())
|
||||
if not webhook_info.url:
|
||||
await self.application.bot.set_webhook(url=self.webhook_url)
|
||||
webhook_info = await self.application.bot.get_webhook_info()
|
||||
logger.info("webhook is set", ip_address=webhook_info.ip_address)
|
||||
|
||||
async def delete_webhook(self) -> None:
|
||||
await self.application.bot.delete_webhook()
|
||||
logger.info('webhook has been deleted')
|
||||
if await self.application.bot.delete_webhook():
|
||||
logger.info("webhook has been deleted")
|
||||
|
||||
async def polling(self) -> None:
|
||||
await self.application.initialize()
|
||||
@@ -73,5 +75,5 @@ class BotQueue:
|
||||
async def get_updates_from_queue(self) -> None:
|
||||
while True:
|
||||
update = await self.queue.get()
|
||||
await self.bot_app.application.process_update(update)
|
||||
asyncio.create_task(self.bot_app.application.process_update(update))
|
||||
await sleep(0)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
import random
|
||||
import tempfile
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from constants import CHAT_GPT_BASE_URL
|
||||
from core.utils import convert_file_to_wav
|
||||
from core.utils import SpeechToTextService
|
||||
from httpx import AsyncClient, AsyncHTTPTransport
|
||||
from loguru import logger
|
||||
from telegram import Update
|
||||
@@ -14,19 +15,20 @@ from telegram.ext import ContextTypes
|
||||
async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Send a message when the command /help is issued."""
|
||||
|
||||
if update.message:
|
||||
await update.message.reply_text(
|
||||
"Help!",
|
||||
disable_notification=True,
|
||||
api_kwargs={"text": "Hello World"},
|
||||
)
|
||||
return None
|
||||
if not update.message:
|
||||
return None
|
||||
await update.message.reply_text(
|
||||
"Help!",
|
||||
disable_notification=True,
|
||||
api_kwargs={"text": "Hello World"},
|
||||
)
|
||||
|
||||
|
||||
async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
await update.message.reply_text( # type: ignore[union-attr]
|
||||
"Пожалуйста подождите, ответ в среднем занимает 10-15 секунд"
|
||||
)
|
||||
if not update.message:
|
||||
return None
|
||||
|
||||
await update.message.reply_text("Пожалуйста подождите, ответ в среднем занимает 10-15 секунд")
|
||||
|
||||
chat_gpt_request = {
|
||||
"conversation_id": str(uuid4()),
|
||||
@@ -39,36 +41,51 @@ async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No
|
||||
"conversation": [],
|
||||
"internet_access": False,
|
||||
"content_type": "text",
|
||||
"parts": [{"content": update.message.text, "role": "user"}], # type: ignore[union-attr]
|
||||
"parts": [{"content": update.message.text, "role": "user"}],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
transport = AsyncHTTPTransport(retries=1)
|
||||
async with AsyncClient(transport=transport) as client:
|
||||
transport = AsyncHTTPTransport(retries=3)
|
||||
async with AsyncClient(transport=transport, timeout=50) as client:
|
||||
try:
|
||||
response = await client.post(CHAT_GPT_BASE_URL, json=chat_gpt_request)
|
||||
response = await client.post(CHAT_GPT_BASE_URL, json=chat_gpt_request, timeout=50)
|
||||
status = response.status_code
|
||||
if status != httpx.codes.OK:
|
||||
logger.info(f'got response status: {status} from chat api', data=chat_gpt_request)
|
||||
await update.message.reply_text( # type: ignore[union-attr]
|
||||
await update.message.reply_text(
|
||||
"Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
|
||||
)
|
||||
return
|
||||
|
||||
data = response.json()
|
||||
await update.message.reply_text(data) # type: ignore[union-attr]
|
||||
await update.message.reply_text(response.text)
|
||||
except Exception as error:
|
||||
logger.error("error get data from chat api", error=error)
|
||||
await update.message.reply_text("Вообще всё сломалось :(") # type: ignore[union-attr]
|
||||
await update.message.reply_text("Вообще всё сломалось :(")
|
||||
|
||||
|
||||
async def voice_recognize(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
await update.message.reply_text( # type: ignore[union-attr]
|
||||
"Пожалуйста, ожидайте :)\nТрехминутная запись обрабатывается примерно 30 секунд"
|
||||
)
|
||||
sound_bytes = await update.message.voice.get_file() # type: ignore[union-attr]
|
||||
sound_bytes = await sound_bytes.download_as_bytearray()
|
||||
if not update.message:
|
||||
return None
|
||||
await update.message.reply_text("Пожалуйста, ожидайте :)\nТрехминутная запись обрабатывается примерно 30 секунд")
|
||||
if not update.message.voice:
|
||||
return None
|
||||
|
||||
sound_file = await update.message.voice.get_file()
|
||||
sound_bytes = await sound_file.download_as_bytearray()
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
|
||||
tmpfile.write(sound_bytes)
|
||||
convert_file_to_wav(tmpfile.name)
|
||||
|
||||
logger.info('file has been saved', filename=tmpfile.name)
|
||||
|
||||
speech_to_text_service = SpeechToTextService(filename=tmpfile.name)
|
||||
|
||||
speech_to_text_service.get_text_from_audio()
|
||||
|
||||
part = 0
|
||||
while speech_to_text_service.text_parts or not speech_to_text_service.text_recognised:
|
||||
if text := speech_to_text_service.text_parts.get(part):
|
||||
speech_to_text_service.text_parts.pop(part)
|
||||
await update.message.reply_text(text)
|
||||
part += 1
|
||||
await asyncio.sleep(5)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, cast
|
||||
from constants import LogLevelEnum
|
||||
from loguru import logger
|
||||
from sentry_sdk.integrations.logging import EventHandler
|
||||
from settings.config import get_settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from loguru import Record
|
||||
@@ -13,6 +14,9 @@ else:
|
||||
Record = dict[str, Any]
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# Get corresponding Loguru level if it exists
|
||||
@@ -29,7 +33,7 @@ class InterceptHandler(logging.Handler):
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level,
|
||||
record.getMessage(),
|
||||
record.getMessage().replace(settings.TELEGRAM_API_TOKEN, "TELEGRAM_API_TOKEN".center(24, '*')),
|
||||
)
|
||||
|
||||
|
||||
@@ -97,6 +101,3 @@ def _text_formatter(record: Record) -> str:
|
||||
format_ += "{exception}\n"
|
||||
|
||||
return format_
|
||||
|
||||
|
||||
configure_logging(level=LogLevelEnum.DEBUG, enable_json_logs=True, enable_sentry_logs=True)
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
import os
|
||||
import subprocess # noqa
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Any
|
||||
|
||||
from constants import AUDIO_SEGMENT_DURATION
|
||||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
from speech_recognition import (
|
||||
AudioFile,
|
||||
Recognizer,
|
||||
UnknownValueError as SpeechRecognizerError,
|
||||
)
|
||||
|
||||
|
||||
def timed_cache(**timedelta_kwargs: Any) -> Any:
|
||||
@@ -27,13 +36,69 @@ def timed_cache(**timedelta_kwargs: Any) -> Any:
|
||||
return _wrapper
|
||||
|
||||
|
||||
def convert_file_to_wav(filename: str) -> str:
|
||||
new_filename = filename + '.wav'
|
||||
class SpeechToTextService:
|
||||
def __init__(self, filename: str) -> None:
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
cmd = ['ffmpeg', '-loglevel', 'quiet', '-i', filename, '-vn', new_filename]
|
||||
self.filename = filename
|
||||
self.recognizer = Recognizer()
|
||||
self.recognizer.energy_threshold = 50
|
||||
self.text_parts: dict[int, str] = {}
|
||||
self.text_recognised = False
|
||||
|
||||
try:
|
||||
subprocess.run(args=cmd) # noqa: S603
|
||||
except Exception as error:
|
||||
logger.error("cant convert voice: reason", error=error)
|
||||
return new_filename
|
||||
def get_text_from_audio(self) -> None:
|
||||
self.executor.submit(self.worker)
|
||||
|
||||
def worker(self) -> Any:
|
||||
self._convert_file_to_wav()
|
||||
self._convert_audio_to_text()
|
||||
|
||||
def _convert_audio_to_text(self) -> None:
|
||||
wav_filename = f'{self.filename}.wav'
|
||||
|
||||
speech = AudioSegment.from_wav(wav_filename)
|
||||
speech_duration = len(speech)
|
||||
pieces = speech_duration // AUDIO_SEGMENT_DURATION + 1
|
||||
ending = speech_duration % AUDIO_SEGMENT_DURATION
|
||||
for i in range(pieces):
|
||||
if i == 0 and pieces == 1:
|
||||
sound_segment = speech[0:ending]
|
||||
elif i == 0:
|
||||
sound_segment = speech[0 : (i + 1) * AUDIO_SEGMENT_DURATION]
|
||||
elif i == (pieces - 1):
|
||||
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : i * AUDIO_SEGMENT_DURATION + ending]
|
||||
else:
|
||||
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : (i + 1) * AUDIO_SEGMENT_DURATION]
|
||||
self.text_parts[i] = self._recognize_by_google(wav_filename, sound_segment)
|
||||
|
||||
self.text_recognised = True
|
||||
|
||||
# clean temp voice message main files
|
||||
try:
|
||||
os.remove(wav_filename)
|
||||
os.remove(self.filename)
|
||||
except FileNotFoundError as error:
|
||||
logger.error("error temps files not deleted", error=error, filenames=[self.filename, self.filename])
|
||||
|
||||
def _convert_file_to_wav(self) -> None:
|
||||
new_filename = self.filename + '.wav'
|
||||
cmd = ['ffmpeg', '-loglevel', 'quiet', '-i', self.filename, '-vn', new_filename]
|
||||
try:
|
||||
subprocess.run(args=cmd) # noqa: S603
|
||||
logger.info("file has been converted to wav", filename=new_filename)
|
||||
except Exception as error:
|
||||
logger.error("cant convert voice", error=error, filename=self.filename)
|
||||
|
||||
def _recognize_by_google(self, filename: str, sound_segment: AudioSegment) -> str:
|
||||
tmp_filename = f"{filename}_tmp_part"
|
||||
sound_segment.export(tmp_filename, format="wav")
|
||||
with AudioFile(tmp_filename) as source:
|
||||
audio_text = self.recognizer.listen(source)
|
||||
try:
|
||||
text = self.recognizer.recognize_google(audio_text, language='ru-RU')
|
||||
os.remove(tmp_filename)
|
||||
return text
|
||||
except SpeechRecognizerError as error:
|
||||
os.remove(tmp_filename)
|
||||
logger.error("error recognizing text with google", error=error)
|
||||
raise error
|
||||
|
||||
Reference in New Issue
Block a user