refactoring (#26)

This commit is contained in:
Dmitry Afanasyev
2023-10-03 23:30:19 +03:00
committed by GitHub
parent 482e1fdda1
commit c401e1006c
22 changed files with 423 additions and 395 deletions

View File

View File

@@ -7,8 +7,8 @@ from telegram import InlineKeyboardMarkup, Update
from telegram.ext import ContextTypes
from constants import BotEntryPoints
from core.keyboards import main_keyboard
from core.utils import ChatGptService, SpeechToTextService
from core.bot.keyboards import main_keyboard
from core.bot.services import ChatGptService, SpeechToTextService
from settings.config import settings

View File

@@ -10,7 +10,7 @@ from telegram.ext import (
)
from constants import BotEntryPoints, BotStagesEnum
from core.commands import (
from core.bot.commands import (
about_bot,
about_me,
ask_question,

View File

@@ -0,0 +1,138 @@
import os
import random
import subprocess # noqa
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Any
from uuid import uuid4
import httpx
from httpx import AsyncClient, AsyncHTTPTransport, Response
from loguru import logger
from pydub import AudioSegment
from speech_recognition import (
AudioFile,
Recognizer,
UnknownValueError as SpeechRecognizerError,
)
from constants import (
AUDIO_SEGMENT_DURATION,
CHAT_GPT_BASE_URI,
INVALID_GPT_REQUEST_MESSAGES,
)
from settings.config import settings
class SpeechToTextService:
def __init__(self, filename: str) -> None:
self.executor = ThreadPoolExecutor()
self.filename = filename
self.recognizer = Recognizer()
self.recognizer.energy_threshold = 50
self.text_parts: dict[int, str] = {}
self.text_recognised = False
def get_text_from_audio(self) -> None:
self.executor.submit(self.worker)
def worker(self) -> Any:
self._convert_file_to_wav()
self._convert_audio_to_text()
def _convert_audio_to_text(self) -> None:
wav_filename = f"{self.filename}.wav"
speech = AudioSegment.from_wav(wav_filename)
speech_duration = len(speech)
pieces = speech_duration // AUDIO_SEGMENT_DURATION + 1
ending = speech_duration % AUDIO_SEGMENT_DURATION
for i in range(pieces):
if i == 0 and pieces == 1:
sound_segment = speech[0:ending]
elif i == 0:
sound_segment = speech[0 : (i + 1) * AUDIO_SEGMENT_DURATION]
elif i == (pieces - 1):
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : i * AUDIO_SEGMENT_DURATION + ending]
else:
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : (i + 1) * AUDIO_SEGMENT_DURATION]
self.text_parts[i] = self._recognize_by_google(wav_filename, sound_segment)
self.text_recognised = True
# clean temp voice message main files
try:
os.remove(wav_filename)
os.remove(self.filename)
except FileNotFoundError as error:
logger.error("error temps files not deleted", error=error, filenames=[self.filename, self.filename])
def _convert_file_to_wav(self) -> None:
new_filename = self.filename + ".wav"
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", self.filename, "-vn", new_filename]
try:
subprocess.run(args=cmd) # noqa: S603
logger.info("file has been converted to wav", filename=new_filename)
except Exception as error:
logger.error("cant convert voice", error=error, filename=self.filename)
def _recognize_by_google(self, filename: str, sound_segment: AudioSegment) -> str:
tmp_filename = f"{filename}_tmp_part"
sound_segment.export(tmp_filename, format="wav")
with AudioFile(tmp_filename) as source:
audio_text = self.recognizer.listen(source)
try:
text = self.recognizer.recognize_google(audio_text, language="ru-RU")
os.remove(tmp_filename)
return text
except SpeechRecognizerError as error:
os.remove(tmp_filename)
logger.error("error recognizing text with google", error=error)
raise error
class ChatGptService:
def __init__(self, chat_gpt_model: str) -> None:
self.chat_gpt_model = chat_gpt_model
async def request_to_chatgpt(self, question: str | None) -> str:
question = question or "Привет!"
chat_gpt_request = self.build_request_data(question)
try:
response = await self.do_request(chat_gpt_request)
status = response.status_code
for message in INVALID_GPT_REQUEST_MESSAGES:
if message in response.text:
message = f"{message}: {settings.GPT_MODEL}"
logger.info(message, data=chat_gpt_request)
return message
if status != httpx.codes.OK:
logger.info(f"got response status: {status} from chat api", data=chat_gpt_request)
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
return response.text
except Exception as error:
logger.error("error get data from chat api", error=error)
return "Вообще всё сломалось :("
@staticmethod
async def do_request(data: dict[str, Any]) -> Response:
transport = AsyncHTTPTransport(retries=3)
async with AsyncClient(base_url=settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
def build_request_data(self, question: str) -> dict[str, Any]:
return {
"conversation_id": str(uuid4()),
"action": "_ask",
"model": self.chat_gpt_model,
"jailbreak": "default",
"meta": {
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
"content": {
"conversation": [],
"internet_access": False,
"content_type": "text",
"parts": [{"content": question, "role": "user"}],
},
},
}

View File

@@ -1,116 +0,0 @@
import logging
import os
import sys
from types import FrameType
from typing import TYPE_CHECKING, Any, cast
import graypy
from loguru import logger
from sentry_sdk.integrations.logging import EventHandler
from constants import LogLevelEnum
from settings.config import DIR_LOGS, settings
if TYPE_CHECKING:
from loguru import Record
else:
Record = dict[str, Any]
class InterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists
try:
level = logger.level(record.levelname).name
except ValueError:
level = str(record.levelno)
# Find caller from where originated the logged message
frame, depth = logging.currentframe(), 2
while frame.f_code.co_filename == logging.__file__:
frame = cast(FrameType, frame.f_back)
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(
level,
record.getMessage().replace(settings.TELEGRAM_API_TOKEN, "TELEGRAM_API_TOKEN".center(24, "*")),
)
def configure_logging(
*, level: LogLevelEnum, enable_json_logs: bool, enable_sentry_logs: bool, log_to_file: str | None = None
) -> None:
intercept_handler = InterceptHandler()
formatter = _json_formatter if enable_json_logs else _text_formatter
base_config_handlers = [intercept_handler]
base_loguru_handler = {
"level": level.name,
"serialize": enable_json_logs,
"format": formatter,
"colorize": False,
}
loguru_handlers = [
{**base_loguru_handler, "colorize": True, "sink": sys.stdout},
]
if settings.GRAYLOG_HOST and settings.GRAYLOG_PORT:
graylog_handler = graypy.GELFUDPHandler(settings.GRAYLOG_HOST, settings.GRAYLOG_PORT)
base_config_handlers.append(graylog_handler)
loguru_handlers.append({**base_loguru_handler, "sink": graylog_handler})
if log_to_file:
file_path = os.path.join(DIR_LOGS, log_to_file)
if not os.path.exists(log_to_file):
with open(file_path, 'w') as f:
f.write('')
loguru_handlers.append({**base_loguru_handler, "sink": file_path})
logging.basicConfig(handlers=base_config_handlers, level=level.name)
logger.configure(handlers=loguru_handlers)
# sentry sdk не умеет из коробки работать с loguru, нужно добавлять хандлер
# https://github.com/getsentry/sentry-python/issues/653#issuecomment-788854865
# https://forum.sentry.io/t/changing-issue-title-when-logging-with-traceback/446
if enable_sentry_logs:
handler = EventHandler(level=logging.WARNING)
logger.add(handler, diagnose=True, level=logging.WARNING, format=_sentry_formatter)
def _json_formatter(record: Record) -> str:
# Обрезаем `\n` в конце логов, т.к. в json формате переносы не нужны
return record.get("message", "").strip()
def _sentry_formatter(record: Record) -> str:
return "{name}:{function} {message}"
def _text_formatter(record: Record) -> str:
# WARNING !!!
# Функция должна возвращать строку, которая содержит только шаблоны для форматирования.
# Если в строку прокидывать значения из record (или еще откуда-либо),
# то loguru может принять их за f-строки и попытается обработать, что приведет к ошибке.
# Например, если нужно достать какое-то значение из поля extra, вместо того чтобы прокидывать его в строку формата,
# нужно прокидывать подстроку вида {extra[тут_ключ]}
# Стандартный формат loguru. Задается через env LOGURU_FORMAT
format_ = (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
)
# Добавляем мета параметры по типу user_id, art_id, которые передаются через logger.bind(...)
extra = record["extra"]
if extra:
formatted = ", ".join(f"{key}" + "={extra[" + str(key) + "]}" for key, value in extra.items())
format_ += f" - <cyan>{formatted}</cyan>"
format_ += "\n"
if record["exception"] is not None:
format_ += "{exception}\n"
return format_

View File

@@ -1,28 +1,6 @@
import os
import random
import subprocess # noqa
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime, timedelta
from functools import lru_cache, wraps
from typing import Any
from uuid import uuid4
import httpx
from httpx import AsyncClient, AsyncHTTPTransport, Response
from loguru import logger
from pydub import AudioSegment
from speech_recognition import (
AudioFile,
Recognizer,
UnknownValueError as SpeechRecognizerError,
)
from constants import (
AUDIO_SEGMENT_DURATION,
CHAT_GPT_BASE_URI,
INVALID_GPT_REQUEST_MESSAGES,
)
from settings.config import settings
def timed_cache(**timedelta_kwargs: Any) -> Any:
@@ -44,118 +22,3 @@ def timed_cache(**timedelta_kwargs: Any) -> Any:
return _wrapped
return _wrapper
class SpeechToTextService:
def __init__(self, filename: str) -> None:
self.executor = ThreadPoolExecutor()
self.filename = filename
self.recognizer = Recognizer()
self.recognizer.energy_threshold = 50
self.text_parts: dict[int, str] = {}
self.text_recognised = False
def get_text_from_audio(self) -> None:
self.executor.submit(self.worker)
def worker(self) -> Any:
self._convert_file_to_wav()
self._convert_audio_to_text()
def _convert_audio_to_text(self) -> None:
wav_filename = f"{self.filename}.wav"
speech = AudioSegment.from_wav(wav_filename)
speech_duration = len(speech)
pieces = speech_duration // AUDIO_SEGMENT_DURATION + 1
ending = speech_duration % AUDIO_SEGMENT_DURATION
for i in range(pieces):
if i == 0 and pieces == 1:
sound_segment = speech[0:ending]
elif i == 0:
sound_segment = speech[0 : (i + 1) * AUDIO_SEGMENT_DURATION]
elif i == (pieces - 1):
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : i * AUDIO_SEGMENT_DURATION + ending]
else:
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : (i + 1) * AUDIO_SEGMENT_DURATION]
self.text_parts[i] = self._recognize_by_google(wav_filename, sound_segment)
self.text_recognised = True
# clean temp voice message main files
try:
os.remove(wav_filename)
os.remove(self.filename)
except FileNotFoundError as error:
logger.error("error temps files not deleted", error=error, filenames=[self.filename, self.filename])
def _convert_file_to_wav(self) -> None:
new_filename = self.filename + ".wav"
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", self.filename, "-vn", new_filename]
try:
subprocess.run(args=cmd) # noqa: S603
logger.info("file has been converted to wav", filename=new_filename)
except Exception as error:
logger.error("cant convert voice", error=error, filename=self.filename)
def _recognize_by_google(self, filename: str, sound_segment: AudioSegment) -> str:
tmp_filename = f"{filename}_tmp_part"
sound_segment.export(tmp_filename, format="wav")
with AudioFile(tmp_filename) as source:
audio_text = self.recognizer.listen(source)
try:
text = self.recognizer.recognize_google(audio_text, language="ru-RU")
os.remove(tmp_filename)
return text
except SpeechRecognizerError as error:
os.remove(tmp_filename)
logger.error("error recognizing text with google", error=error)
raise error
class ChatGptService:
def __init__(self, chat_gpt_model: str) -> None:
self.chat_gpt_model = chat_gpt_model
async def request_to_chatgpt(self, question: str | None) -> str:
question = question or "Привет!"
chat_gpt_request = self.build_request_data(question)
try:
response = await self.do_request(chat_gpt_request)
status = response.status_code
for message in INVALID_GPT_REQUEST_MESSAGES:
if message in response.text:
message = f"{message}: {settings.GPT_MODEL}"
logger.info(message, data=chat_gpt_request)
return message
if status != httpx.codes.OK:
logger.info(f"got response status: {status} from chat api", data=chat_gpt_request)
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
return response.text
except Exception as error:
logger.error("error get data from chat api", error=error)
return "Вообще всё сломалось :("
@staticmethod
async def do_request(data: dict[str, Any]) -> Response:
transport = AsyncHTTPTransport(retries=3)
async with AsyncClient(base_url=settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
def build_request_data(self, question: str) -> dict[str, Any]:
return {
"conversation_id": str(uuid4()),
"action": "_ask",
"model": self.chat_gpt_model,
"jailbreak": "default",
"meta": {
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
"content": {
"conversation": [],
"internet_access": False,
"content_type": "text",
"parts": [{"content": question, "role": "user"}],
},
},
}