mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2025-12-16 21:20:39 +03:00
add database and migration logic (#27)
* update chat_microservice * reformat logger_conf * add database * add service and repository logic * fix constants gpt base url * add models endpoints
This commit is contained in:
@@ -6,7 +6,7 @@ from functools import cached_property
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi import Response
|
||||
from loguru import logger
|
||||
from telegram import Bot, Update
|
||||
from telegram.ext import Application
|
||||
@@ -68,14 +68,11 @@ class BotQueue:
|
||||
bot_app: BotApplication
|
||||
queue: Queue = asyncio.Queue() # type: ignore[type-arg]
|
||||
|
||||
async def put_updates_on_queue(self, request: Request) -> Response:
|
||||
async def put_updates_on_queue(self, tg_update: Update) -> Response:
|
||||
"""
|
||||
Listen /{URL_PREFIX}/{API_PREFIX}/{TELEGRAM_WEB_TOKEN} path and proxy post request to bot
|
||||
"""
|
||||
data = await request.json()
|
||||
tg_update = Update.de_json(data=data, bot=self.bot_app.application.bot)
|
||||
self.queue.put_nowait(tg_update)
|
||||
|
||||
return Response(status_code=HTTPStatus.ACCEPTED)
|
||||
|
||||
async def get_updates_from_queue(self) -> None:
|
||||
|
||||
@@ -67,7 +67,7 @@ async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No
|
||||
|
||||
await update.message.reply_text("Пожалуйста подождите, ответ в среднем занимает 10-15 секунд")
|
||||
|
||||
chat_gpt_service = ChatGptService(chat_gpt_model=settings.GPT_MODEL)
|
||||
chat_gpt_service = ChatGptService.build()
|
||||
logger.warning("question asked", user=update.message.from_user, question=update.message.text)
|
||||
answer = await chat_gpt_service.request_to_chatgpt(question=update.message.text)
|
||||
await update.message.reply_text(answer)
|
||||
@@ -87,9 +87,9 @@ async def voice_recognize(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
|
||||
|
||||
logger.info("file has been saved", filename=tmpfile.name)
|
||||
|
||||
speech_to_text_service = SpeechToTextService(filename=tmpfile.name)
|
||||
speech_to_text_service = SpeechToTextService()
|
||||
|
||||
speech_to_text_service.get_text_from_audio()
|
||||
speech_to_text_service.get_text_from_audio(filename=tmpfile.name)
|
||||
|
||||
part = 0
|
||||
while speech_to_text_service.text_parts or not speech_to_text_service.text_recognised:
|
||||
|
||||
0
bot_microservice/core/bot/models/__init__.py
Normal file
0
bot_microservice/core/bot/models/__init__.py
Normal file
12
bot_microservice/core/bot/models/chat_gpt.py
Normal file
12
bot_microservice/core/bot/models/chat_gpt.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from sqlalchemy import INTEGER, SMALLINT, VARCHAR
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from infra.database.base import Base
|
||||
|
||||
__slots__ = ("ChatGpt",)
|
||||
|
||||
|
||||
class ChatGpt(Base):
|
||||
id: Mapped[int] = mapped_column("id", INTEGER(), primary_key=True, autoincrement=True)
|
||||
model: Mapped[str] = mapped_column("model", VARCHAR(length=256), nullable=False, unique=True)
|
||||
priority: Mapped[int] = mapped_column("priority", SMALLINT(), default=0)
|
||||
106
bot_microservice/core/bot/repository.py
Normal file
106
bot_microservice/core/bot/repository.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Sequence
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient, AsyncHTTPTransport, Response
|
||||
from loguru import logger
|
||||
from sqlalchemy import delete, desc, select, update
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
|
||||
from constants import CHAT_GPT_BASE_URI, INVALID_GPT_REQUEST_MESSAGES
|
||||
from core.bot.models.chat_gpt import ChatGpt
|
||||
from infra.database.db_adapter import Database
|
||||
from settings.config import AppSettings
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatGPTRepository:
|
||||
settings: AppSettings
|
||||
db: Database
|
||||
|
||||
async def get_chatgpt_models(self) -> Sequence[ChatGpt]:
|
||||
query = select(ChatGpt).order_by(desc(ChatGpt.priority))
|
||||
|
||||
async with self.db.session() as session:
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
|
||||
current_model = await self.get_current_chatgpt_model()
|
||||
|
||||
reset_priority_query = update(ChatGpt).values(priority=0).filter(ChatGpt.model == current_model)
|
||||
set_new_priority_query = update(ChatGpt).values(priority=priority).filter(ChatGpt.model == model_id)
|
||||
|
||||
async with self.db.get_transaction_session() as session:
|
||||
await session.execute(reset_priority_query)
|
||||
await session.execute(set_new_priority_query)
|
||||
|
||||
async def add_chatgpt_model(self, model: str, priority: int) -> dict[str, str | int]:
|
||||
query = (
|
||||
insert(ChatGpt)
|
||||
.values(
|
||||
{ChatGpt.model: model, ChatGpt.priority: priority},
|
||||
)
|
||||
.prefix_with("OR IGNORE")
|
||||
)
|
||||
async with self.db.session() as session:
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
return {"model": model, "priority": priority}
|
||||
|
||||
async def delete_chatgpt_model(self, model_id: int) -> None:
|
||||
query = delete(ChatGpt).filter_by(id=model_id)
|
||||
|
||||
async with self.db.session() as session:
|
||||
await session.execute(query)
|
||||
|
||||
async def get_current_chatgpt_model(self) -> str:
|
||||
query = select(ChatGpt.model).order_by(desc(ChatGpt.priority)).limit(1)
|
||||
|
||||
async with self.db.session() as session:
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one()
|
||||
|
||||
async def ask_question(self, question: str, chat_gpt_model: str) -> str:
|
||||
try:
|
||||
response = await self.request_to_chatgpt_microservice(question=question, chat_gpt_model=chat_gpt_model)
|
||||
status = response.status_code
|
||||
for message in INVALID_GPT_REQUEST_MESSAGES:
|
||||
if message in response.text:
|
||||
message = f"{message}: {chat_gpt_model}"
|
||||
logger.info(message, question=question, chat_gpt_model=chat_gpt_model)
|
||||
return message
|
||||
if status != httpx.codes.OK:
|
||||
logger.info(f"got response status: {status} from chat api", response.text)
|
||||
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
|
||||
return response.text
|
||||
except Exception as error:
|
||||
logger.error("error get data from chat api", error=error)
|
||||
return "Вообще всё сломалось :("
|
||||
|
||||
async def request_to_chatgpt_microservice(self, question: str, chat_gpt_model: str) -> Response:
|
||||
data = self._build_request_data(question=question, chat_gpt_model=chat_gpt_model)
|
||||
|
||||
transport = AsyncHTTPTransport(retries=3)
|
||||
async with AsyncClient(base_url=self.settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
|
||||
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
|
||||
|
||||
@staticmethod
|
||||
def _build_request_data(*, question: str, chat_gpt_model: str) -> dict[str, Any]:
|
||||
return {
|
||||
"conversation_id": str(uuid4()),
|
||||
"action": "_ask",
|
||||
"model": chat_gpt_model,
|
||||
"jailbreak": "default",
|
||||
"meta": {
|
||||
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
|
||||
"content": {
|
||||
"conversation": [],
|
||||
"internet_access": False,
|
||||
"content_type": "text",
|
||||
"parts": [{"content": question, "role": "user"}],
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1,12 +1,10 @@
|
||||
import os
|
||||
import random
|
||||
import subprocess # noqa
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Sequence
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient, AsyncHTTPTransport, Response
|
||||
from httpx import Response
|
||||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
from speech_recognition import (
|
||||
@@ -15,33 +13,30 @@ from speech_recognition import (
|
||||
UnknownValueError as SpeechRecognizerError,
|
||||
)
|
||||
|
||||
from constants import (
|
||||
AUDIO_SEGMENT_DURATION,
|
||||
CHAT_GPT_BASE_URI,
|
||||
INVALID_GPT_REQUEST_MESSAGES,
|
||||
)
|
||||
from constants import AUDIO_SEGMENT_DURATION
|
||||
from core.bot.models.chat_gpt import ChatGpt
|
||||
from core.bot.repository import ChatGPTRepository
|
||||
from infra.database.db_adapter import Database
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
class SpeechToTextService:
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self) -> None:
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
self.filename = filename
|
||||
self.recognizer = Recognizer()
|
||||
self.recognizer.energy_threshold = 50
|
||||
self.text_parts: dict[int, str] = {}
|
||||
self.text_recognised = False
|
||||
|
||||
def get_text_from_audio(self) -> None:
|
||||
self.executor.submit(self.worker)
|
||||
def get_text_from_audio(self, filename: str) -> None:
|
||||
self.executor.submit(self.worker, filename=filename)
|
||||
|
||||
def worker(self) -> Any:
|
||||
self._convert_file_to_wav()
|
||||
self._convert_audio_to_text()
|
||||
def worker(self, filename: str) -> Any:
|
||||
self._convert_file_to_wav(filename)
|
||||
self._convert_audio_to_text(filename)
|
||||
|
||||
def _convert_audio_to_text(self) -> None:
|
||||
wav_filename = f"{self.filename}.wav"
|
||||
def _convert_audio_to_text(self, filename: str) -> None:
|
||||
wav_filename = f"{filename}.wav"
|
||||
|
||||
speech = AudioSegment.from_wav(wav_filename)
|
||||
speech_duration = len(speech)
|
||||
@@ -63,18 +58,19 @@ class SpeechToTextService:
|
||||
# clean temp voice message main files
|
||||
try:
|
||||
os.remove(wav_filename)
|
||||
os.remove(self.filename)
|
||||
os.remove(filename)
|
||||
except FileNotFoundError as error:
|
||||
logger.error("error temps files not deleted", error=error, filenames=[self.filename, self.filename])
|
||||
logger.error("error temps files not deleted", error=error, filenames=[filename, wav_filename])
|
||||
|
||||
def _convert_file_to_wav(self) -> None:
|
||||
new_filename = self.filename + ".wav"
|
||||
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", self.filename, "-vn", new_filename]
|
||||
@staticmethod
|
||||
def _convert_file_to_wav(filename: str) -> None:
|
||||
new_filename = filename + ".wav"
|
||||
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", filename, "-vn", new_filename]
|
||||
try:
|
||||
subprocess.run(args=cmd) # noqa: S603
|
||||
logger.info("file has been converted to wav", filename=new_filename)
|
||||
except Exception as error:
|
||||
logger.error("cant convert voice", error=error, filename=self.filename)
|
||||
logger.error("cant convert voice", error=error, filename=filename)
|
||||
|
||||
def _recognize_by_google(self, filename: str, sound_segment: AudioSegment) -> str:
|
||||
tmp_filename = f"{filename}_tmp_part"
|
||||
@@ -91,48 +87,36 @@ class SpeechToTextService:
|
||||
raise error
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatGptService:
|
||||
def __init__(self, chat_gpt_model: str) -> None:
|
||||
self.chat_gpt_model = chat_gpt_model
|
||||
repository: ChatGPTRepository
|
||||
|
||||
async def get_chatgpt_models(self) -> Sequence[ChatGpt]:
|
||||
return await self.repository.get_chatgpt_models()
|
||||
|
||||
async def request_to_chatgpt(self, question: str | None) -> str:
|
||||
question = question or "Привет!"
|
||||
chat_gpt_request = self.build_request_data(question)
|
||||
try:
|
||||
response = await self.do_request(chat_gpt_request)
|
||||
status = response.status_code
|
||||
for message in INVALID_GPT_REQUEST_MESSAGES:
|
||||
if message in response.text:
|
||||
message = f"{message}: {settings.GPT_MODEL}"
|
||||
logger.info(message, data=chat_gpt_request)
|
||||
return message
|
||||
if status != httpx.codes.OK:
|
||||
logger.info(f"got response status: {status} from chat api", data=chat_gpt_request)
|
||||
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
|
||||
return response.text
|
||||
except Exception as error:
|
||||
logger.error("error get data from chat api", error=error)
|
||||
return "Вообще всё сломалось :("
|
||||
chat_gpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.ask_question(question=question, chat_gpt_model=chat_gpt_model)
|
||||
|
||||
@staticmethod
|
||||
async def do_request(data: dict[str, Any]) -> Response:
|
||||
transport = AsyncHTTPTransport(retries=3)
|
||||
async with AsyncClient(base_url=settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
|
||||
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
|
||||
async def request_to_chatgpt_microservice(self, question: str) -> Response:
|
||||
chat_gpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.request_to_chatgpt_microservice(question=question, chat_gpt_model=chat_gpt_model)
|
||||
|
||||
def build_request_data(self, question: str) -> dict[str, Any]:
|
||||
return {
|
||||
"conversation_id": str(uuid4()),
|
||||
"action": "_ask",
|
||||
"model": self.chat_gpt_model,
|
||||
"jailbreak": "default",
|
||||
"meta": {
|
||||
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
|
||||
"content": {
|
||||
"conversation": [],
|
||||
"internet_access": False,
|
||||
"content_type": "text",
|
||||
"parts": [{"content": question, "role": "user"}],
|
||||
},
|
||||
},
|
||||
}
|
||||
async def get_current_chatgpt_model(self) -> str:
|
||||
return await self.repository.get_current_chatgpt_model()
|
||||
|
||||
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
|
||||
return await self.repository.change_chatgpt_model_priority(model_id=model_id, priority=priority)
|
||||
|
||||
async def add_chatgpt_model(self, gpt_model: str, priority: int) -> dict[str, str | int]:
|
||||
return await self.repository.add_chatgpt_model(model=gpt_model, priority=priority)
|
||||
|
||||
async def delete_chatgpt_model(self, model_id: int) -> None:
|
||||
return await self.repository.delete_chatgpt_model(model_id=model_id)
|
||||
|
||||
@classmethod
|
||||
def build(cls) -> "ChatGptService":
|
||||
db = Database(settings=settings)
|
||||
repository = ChatGPTRepository(settings=settings, db=db)
|
||||
return ChatGptService(repository=repository)
|
||||
|
||||
73
bot_microservice/core/lifetime.py
Normal file
73
bot_microservice/core/lifetime.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from asyncio import current_task
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_scoped_session,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from settings.config import AppSettings
|
||||
|
||||
|
||||
def startup(app: FastAPI, settings: AppSettings) -> Callable[[], Awaitable[None]]:
|
||||
"""
|
||||
Actions to run on application startup.
|
||||
|
||||
This function use fastAPI app to store data,
|
||||
such as db_engine.
|
||||
|
||||
:param app: the fastAPI application.
|
||||
:param settings: app settings
|
||||
:return: function that actually performs actions.
|
||||
|
||||
"""
|
||||
|
||||
async def _startup() -> None:
|
||||
_setup_db(app, settings)
|
||||
|
||||
return _startup
|
||||
|
||||
|
||||
def shutdown(app: FastAPI) -> Callable[[], Awaitable[None]]:
|
||||
"""
|
||||
Actions to run on application's shutdown.
|
||||
|
||||
:param app: fastAPI application.
|
||||
:return: function that actually performs actions.
|
||||
|
||||
"""
|
||||
|
||||
async def _shutdown() -> None:
|
||||
await app.state.db_engine.dispose()
|
||||
|
||||
return _shutdown
|
||||
|
||||
|
||||
def _setup_db(app: FastAPI, settings: AppSettings) -> None:
|
||||
"""
|
||||
Create connection to the database.
|
||||
|
||||
This function creates SQLAlchemy engine instance,
|
||||
session_factory for creating sessions
|
||||
and stores them in the application's state property.
|
||||
|
||||
:param app: fastAPI application.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
str(settings.db_url),
|
||||
echo=settings.DB_ECHO,
|
||||
execution_options={"isolation_level": "AUTOCOMMIT"},
|
||||
)
|
||||
session_factory = async_scoped_session(
|
||||
async_sessionmaker(
|
||||
engine,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession,
|
||||
),
|
||||
scopefunc=current_task,
|
||||
)
|
||||
app.state.db_engine = engine
|
||||
app.state.db_session_factory = session_factory
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache, wraps
|
||||
from inspect import cleandoc
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -22,3 +23,9 @@ def timed_cache(**timedelta_kwargs: Any) -> Any:
|
||||
return _wrapped
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def clean_doc(cls: Any) -> str | None:
|
||||
if cls.__doc__ is None:
|
||||
return None
|
||||
return cleandoc(cls.__doc__)
|
||||
|
||||
Reference in New Issue
Block a user