mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2026-02-03 11:40:39 +03:00
add database and migration logic (#27)
* update chat_microservice * reformat logger_conf * add database * add service and repository logic * fix constants gpt base url * add models endpoints
This commit is contained in:
52
bot_microservice/alembic.ini
Normal file
52
bot_microservice/alembic.ini
Normal file
@@ -0,0 +1,52 @@
|
||||
[alembic]
|
||||
script_location = infra/database/migrations
|
||||
file_template = %%(year)d-%%(month).2d-%%(day).2d-%%(hour).2d-%%(minute).2d_%%(rev)s
|
||||
prepend_sys_path = .
|
||||
output_encoding = utf-8
|
||||
|
||||
[post_write_hooks]
|
||||
hooks = black,autoflake,isort
|
||||
|
||||
black.type = console_scripts
|
||||
black.entrypoint = black
|
||||
|
||||
autoflake.type = console_scripts
|
||||
autoflake.entrypoint = autoflake
|
||||
|
||||
isort.type = console_scripts
|
||||
isort.entrypoint = isort
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -1,7 +1,17 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Body, Depends, Path
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from telegram import Update
|
||||
|
||||
from api.bot.serializers import (
|
||||
ChatGptModelSerializer,
|
||||
ChatGptModelsPrioritySerializer,
|
||||
GETChatGptModelsSerializer,
|
||||
LightChatGptModel,
|
||||
)
|
||||
from api.deps import get_bot_queue, get_chatgpt_service, get_update_from_request
|
||||
from core.bot.app import BotQueue
|
||||
from core.bot.services import ChatGptService
|
||||
from settings.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
@@ -15,5 +25,74 @@ router = APIRouter()
|
||||
summary="process bot updates",
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def process_bot_updates(request: Request) -> None:
|
||||
await request.app.state.queue.put_updates_on_queue(request)
|
||||
async def process_bot_updates(
|
||||
tg_update: Update = Depends(get_update_from_request),
|
||||
queue: BotQueue = Depends(get_bot_queue),
|
||||
) -> None:
|
||||
await queue.put_updates_on_queue(tg_update)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
name="bot:models_list",
|
||||
response_class=JSONResponse,
|
||||
response_model=list[ChatGptModelSerializer],
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="list of models",
|
||||
)
|
||||
async def models_list(
|
||||
chatgpt_service: ChatGptService = Depends(get_chatgpt_service),
|
||||
) -> JSONResponse:
|
||||
"""Получить список всех моделей"""
|
||||
models = await chatgpt_service.get_chatgpt_models()
|
||||
return JSONResponse(
|
||||
content=GETChatGptModelsSerializer(data=models).model_dump(), status_code=status.HTTP_200_OK # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/models/{model_id}/priority",
|
||||
name="bot:change_model_priority",
|
||||
response_class=Response,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="change gpt model priority",
|
||||
)
|
||||
async def change_model_priority(
|
||||
model_id: int = Path(..., gt=0, description="Id модели для обновления приореитета"),
|
||||
chatgpt_service: ChatGptService = Depends(get_chatgpt_service),
|
||||
gpt_model: ChatGptModelsPrioritySerializer = Body(...),
|
||||
) -> None:
|
||||
"""Изменить приоритет модели в выдаче"""
|
||||
await chatgpt_service.change_chatgpt_model_priority(model_id=model_id, priority=gpt_model.priority)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/models",
|
||||
name="bot:add_new_model",
|
||||
response_model=ChatGptModelSerializer,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="add new model",
|
||||
)
|
||||
async def add_new_model(
|
||||
chatgpt_service: ChatGptService = Depends(get_chatgpt_service),
|
||||
gpt_model: LightChatGptModel = Body(...),
|
||||
) -> JSONResponse:
|
||||
"""Добавить новую модель"""
|
||||
model = await chatgpt_service.add_chatgpt_model(gpt_model=gpt_model.model, priority=gpt_model.priority)
|
||||
|
||||
return JSONResponse(content=model, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/models/{model_id}",
|
||||
name="bot:delete_gpt_model",
|
||||
response_class=Response,
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="delete gpt model",
|
||||
)
|
||||
async def delete_model(
|
||||
model_id: int = Path(..., gt=0, description="Id модели для удаления"),
|
||||
chatgpt_service: ChatGptService = Depends(get_chatgpt_service),
|
||||
) -> None:
|
||||
"""Удалить gpt модель"""
|
||||
await chatgpt_service.delete_chatgpt_model(model_id=model_id)
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from fastapi import Depends
|
||||
from starlette.requests import Request
|
||||
|
||||
from core.bot.services import ChatGptService
|
||||
from settings.config import AppSettings
|
||||
|
||||
|
||||
def get_settings(request: Request) -> AppSettings:
|
||||
return request.app.state.settings
|
||||
|
||||
|
||||
def get_chat_gpt_service(settings: AppSettings = Depends(get_settings)) -> ChatGptService:
|
||||
return ChatGptService(settings.GPT_MODEL)
|
||||
24
bot_microservice/api/bot/serializers.py
Normal file
24
bot_microservice/api/bot/serializers.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class LightChatGptModel(BaseModel):
|
||||
model: str = Field(..., title="Chat Gpt model")
|
||||
priority: int = Field(default=0, ge=0, title="Приоритет модели")
|
||||
|
||||
|
||||
class ChatGptModelsPrioritySerializer(BaseModel):
|
||||
priority: int = Field(default=0, ge=0, title="Приоритет модели")
|
||||
|
||||
|
||||
class ChatGptModelSerializer(BaseModel):
|
||||
id: int = Field(..., gt=0, title="Id модели")
|
||||
model: str = Field(..., title="Chat Gpt model")
|
||||
priority: int = Field(..., ge=0, title="Приоритет модели")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class GETChatGptModelsSerializer(BaseModel):
|
||||
data: list[ChatGptModelSerializer] = Field(..., title="Список всех моделей")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
50
bot_microservice/api/deps.py
Normal file
50
bot_microservice/api/deps.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from fastapi import Depends
|
||||
from starlette.requests import Request
|
||||
from telegram import Update
|
||||
|
||||
from core.bot.app import BotApplication, BotQueue
|
||||
from core.bot.repository import ChatGPTRepository
|
||||
from core.bot.services import ChatGptService, SpeechToTextService
|
||||
from infra.database.db_adapter import Database
|
||||
from settings.config import AppSettings
|
||||
|
||||
|
||||
def get_settings(request: Request) -> AppSettings:
|
||||
return request.app.state.settings
|
||||
|
||||
|
||||
def get_bot_app(request: Request) -> BotApplication:
|
||||
return request.app.state.bot_app
|
||||
|
||||
|
||||
def get_bot_queue(request: Request) -> BotQueue:
|
||||
return request.app.state.queue
|
||||
|
||||
|
||||
async def get_update_from_request(request: Request, bot_app: BotApplication = Depends(get_bot_app)) -> Update | None:
|
||||
data = await request.json()
|
||||
return Update.de_json(data, bot_app.bot)
|
||||
|
||||
|
||||
def get_database(settings: AppSettings = Depends(get_settings)) -> Database:
|
||||
return Database(settings=settings)
|
||||
|
||||
|
||||
def get_chat_gpt_repository(
|
||||
db: Database = Depends(get_database), settings: AppSettings = Depends(get_settings)
|
||||
) -> ChatGPTRepository:
|
||||
return ChatGPTRepository(settings=settings, db=db)
|
||||
|
||||
|
||||
def get_speech_to_text_service() -> SpeechToTextService:
|
||||
return SpeechToTextService()
|
||||
|
||||
|
||||
def new_bot_queue(bot_app: BotApplication = Depends(get_bot_app)) -> BotQueue:
|
||||
return BotQueue(bot_app=bot_app)
|
||||
|
||||
|
||||
def get_chatgpt_service(
|
||||
chat_gpt_repository: ChatGPTRepository = Depends(get_chat_gpt_repository),
|
||||
) -> ChatGptService:
|
||||
return ChatGptService(repository=chat_gpt_repository)
|
||||
@@ -3,7 +3,7 @@ from fastapi.responses import ORJSONResponse
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from api.bot.deps import get_chat_gpt_service
|
||||
from api.deps import get_chatgpt_service
|
||||
from api.exceptions import BaseAPIException
|
||||
from constants import INVALID_GPT_REQUEST_MESSAGES
|
||||
from core.bot.services import ChatGptService
|
||||
@@ -33,12 +33,11 @@ async def healthcheck() -> ORJSONResponse:
|
||||
)
|
||||
async def gpt_healthcheck(
|
||||
response: Response,
|
||||
chatgpt_service: ChatGptService = Depends(get_chat_gpt_service),
|
||||
chatgpt_service: ChatGptService = Depends(get_chatgpt_service),
|
||||
) -> Response:
|
||||
data = chatgpt_service.build_request_data("Привет!")
|
||||
response.status_code = status.HTTP_200_OK
|
||||
try:
|
||||
chatgpt_response = await chatgpt_service.do_request(data)
|
||||
chatgpt_response = await chatgpt_service.request_to_chatgpt_microservice(question="Привет!")
|
||||
if chatgpt_response.status_code != status.HTTP_200_OK:
|
||||
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
for message in INVALID_GPT_REQUEST_MESSAGES:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from enum import StrEnum
|
||||
from enum import StrEnum, unique
|
||||
|
||||
AUDIO_SEGMENT_DURATION = 120 * 1000
|
||||
|
||||
@@ -26,3 +26,35 @@ class LogLevelEnum(StrEnum):
|
||||
INFO = "info"
|
||||
DEBUG = "debug"
|
||||
NOTSET = ""
|
||||
|
||||
|
||||
@unique
|
||||
class ChatGptModelsEnum(StrEnum):
|
||||
gpt_3_5_turbo_stream_openai = "gpt-3.5-turbo-stream-openai"
|
||||
gpt_3_5_turbo_Aichat = "gpt-3.5-turbo-Aichat"
|
||||
gpt_4_ChatgptAi = "gpt-4-ChatgptAi"
|
||||
gpt_3_5_turbo_weWordle = "gpt-3.5-turbo-weWordle"
|
||||
gpt_3_5_turbo_acytoo = "gpt-3.5-turbo-acytoo"
|
||||
gpt_3_5_turbo_stream_DeepAi = "gpt-3.5-turbo-stream-DeepAi"
|
||||
gpt_3_5_turbo_stream_H2o = "gpt-3.5-turbo-stream-H2o"
|
||||
gpt_3_5_turbo_stream_yqcloud = "gpt-3.5-turbo-stream-yqcloud"
|
||||
gpt_OpenAssistant_stream_HuggingChat = "gpt-OpenAssistant-stream-HuggingChat"
|
||||
gpt_4_turbo_stream_you = "gpt-4-turbo-stream-you"
|
||||
gpt_3_5_turbo_AItianhu = "gpt-3.5-turbo-AItianhu"
|
||||
gpt_3_stream_binjie = "gpt-3-stream-binjie"
|
||||
gpt_3_5_turbo_stream_CodeLinkAva = "gpt-3.5-turbo-stream-CodeLinkAva"
|
||||
gpt_4_stream_ChatBase = "gpt-4-stream-ChatBase"
|
||||
gpt_3_5_turbo_stream_aivvm = "gpt-3.5-turbo-stream-aivvm"
|
||||
gpt_3_5_turbo_16k_stream_Ylokh = "gpt-3.5-turbo-16k-stream-Ylokh"
|
||||
gpt_3_5_turbo_stream_Vitalentum = "gpt-3.5-turbo-stream-Vitalentum"
|
||||
gpt_3_5_turbo_stream_GptGo = "gpt-3.5-turbo-stream-GptGo"
|
||||
gpt_3_5_turbo_stream_AItianhuSpace = "gpt-3.5-turbo-stream-AItianhuSpace"
|
||||
gpt_3_5_turbo_stream_Aibn = "gpt-3.5-turbo-stream-Aibn"
|
||||
gpt_3_5_turbo_ChatgptDuo = "gpt-3.5-turbo-ChatgptDuo"
|
||||
gpt_3_5_turbo_stream_FreeGpt = "gpt-3.5-turbo-stream-FreeGpt"
|
||||
gpt_3_5_turbo_stream_ChatForAi = "gpt-3.5-turbo-stream-ChatForAi"
|
||||
gpt_3_5_turbo_stream_Cromicle = "gpt-3.5-turbo-stream-Cromicle"
|
||||
|
||||
@classmethod
|
||||
def values(cls) -> set[str]:
|
||||
return set(map(str, set(ChatGptModelsEnum)))
|
||||
|
||||
@@ -6,7 +6,7 @@ from functools import cached_property
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi import Response
|
||||
from loguru import logger
|
||||
from telegram import Bot, Update
|
||||
from telegram.ext import Application
|
||||
@@ -68,14 +68,11 @@ class BotQueue:
|
||||
bot_app: BotApplication
|
||||
queue: Queue = asyncio.Queue() # type: ignore[type-arg]
|
||||
|
||||
async def put_updates_on_queue(self, request: Request) -> Response:
|
||||
async def put_updates_on_queue(self, tg_update: Update) -> Response:
|
||||
"""
|
||||
Listen /{URL_PREFIX}/{API_PREFIX}/{TELEGRAM_WEB_TOKEN} path and proxy post request to bot
|
||||
"""
|
||||
data = await request.json()
|
||||
tg_update = Update.de_json(data=data, bot=self.bot_app.application.bot)
|
||||
self.queue.put_nowait(tg_update)
|
||||
|
||||
return Response(status_code=HTTPStatus.ACCEPTED)
|
||||
|
||||
async def get_updates_from_queue(self) -> None:
|
||||
|
||||
@@ -67,7 +67,7 @@ async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No
|
||||
|
||||
await update.message.reply_text("Пожалуйста подождите, ответ в среднем занимает 10-15 секунд")
|
||||
|
||||
chat_gpt_service = ChatGptService(chat_gpt_model=settings.GPT_MODEL)
|
||||
chat_gpt_service = ChatGptService.build()
|
||||
logger.warning("question asked", user=update.message.from_user, question=update.message.text)
|
||||
answer = await chat_gpt_service.request_to_chatgpt(question=update.message.text)
|
||||
await update.message.reply_text(answer)
|
||||
@@ -87,9 +87,9 @@ async def voice_recognize(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
|
||||
|
||||
logger.info("file has been saved", filename=tmpfile.name)
|
||||
|
||||
speech_to_text_service = SpeechToTextService(filename=tmpfile.name)
|
||||
speech_to_text_service = SpeechToTextService()
|
||||
|
||||
speech_to_text_service.get_text_from_audio()
|
||||
speech_to_text_service.get_text_from_audio(filename=tmpfile.name)
|
||||
|
||||
part = 0
|
||||
while speech_to_text_service.text_parts or not speech_to_text_service.text_recognised:
|
||||
|
||||
0
bot_microservice/core/bot/models/__init__.py
Normal file
0
bot_microservice/core/bot/models/__init__.py
Normal file
12
bot_microservice/core/bot/models/chat_gpt.py
Normal file
12
bot_microservice/core/bot/models/chat_gpt.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from sqlalchemy import INTEGER, SMALLINT, VARCHAR
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from infra.database.base import Base
|
||||
|
||||
__slots__ = ("ChatGpt",)
|
||||
|
||||
|
||||
class ChatGpt(Base):
|
||||
id: Mapped[int] = mapped_column("id", INTEGER(), primary_key=True, autoincrement=True)
|
||||
model: Mapped[str] = mapped_column("model", VARCHAR(length=256), nullable=False, unique=True)
|
||||
priority: Mapped[int] = mapped_column("priority", SMALLINT(), default=0)
|
||||
106
bot_microservice/core/bot/repository.py
Normal file
106
bot_microservice/core/bot/repository.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Sequence
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient, AsyncHTTPTransport, Response
|
||||
from loguru import logger
|
||||
from sqlalchemy import delete, desc, select, update
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
|
||||
from constants import CHAT_GPT_BASE_URI, INVALID_GPT_REQUEST_MESSAGES
|
||||
from core.bot.models.chat_gpt import ChatGpt
|
||||
from infra.database.db_adapter import Database
|
||||
from settings.config import AppSettings
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatGPTRepository:
|
||||
settings: AppSettings
|
||||
db: Database
|
||||
|
||||
async def get_chatgpt_models(self) -> Sequence[ChatGpt]:
|
||||
query = select(ChatGpt).order_by(desc(ChatGpt.priority))
|
||||
|
||||
async with self.db.session() as session:
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
|
||||
current_model = await self.get_current_chatgpt_model()
|
||||
|
||||
reset_priority_query = update(ChatGpt).values(priority=0).filter(ChatGpt.model == current_model)
|
||||
set_new_priority_query = update(ChatGpt).values(priority=priority).filter(ChatGpt.model == model_id)
|
||||
|
||||
async with self.db.get_transaction_session() as session:
|
||||
await session.execute(reset_priority_query)
|
||||
await session.execute(set_new_priority_query)
|
||||
|
||||
async def add_chatgpt_model(self, model: str, priority: int) -> dict[str, str | int]:
|
||||
query = (
|
||||
insert(ChatGpt)
|
||||
.values(
|
||||
{ChatGpt.model: model, ChatGpt.priority: priority},
|
||||
)
|
||||
.prefix_with("OR IGNORE")
|
||||
)
|
||||
async with self.db.session() as session:
|
||||
await session.execute(query)
|
||||
await session.commit()
|
||||
return {"model": model, "priority": priority}
|
||||
|
||||
async def delete_chatgpt_model(self, model_id: int) -> None:
|
||||
query = delete(ChatGpt).filter_by(id=model_id)
|
||||
|
||||
async with self.db.session() as session:
|
||||
await session.execute(query)
|
||||
|
||||
async def get_current_chatgpt_model(self) -> str:
|
||||
query = select(ChatGpt.model).order_by(desc(ChatGpt.priority)).limit(1)
|
||||
|
||||
async with self.db.session() as session:
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one()
|
||||
|
||||
async def ask_question(self, question: str, chat_gpt_model: str) -> str:
|
||||
try:
|
||||
response = await self.request_to_chatgpt_microservice(question=question, chat_gpt_model=chat_gpt_model)
|
||||
status = response.status_code
|
||||
for message in INVALID_GPT_REQUEST_MESSAGES:
|
||||
if message in response.text:
|
||||
message = f"{message}: {chat_gpt_model}"
|
||||
logger.info(message, question=question, chat_gpt_model=chat_gpt_model)
|
||||
return message
|
||||
if status != httpx.codes.OK:
|
||||
logger.info(f"got response status: {status} from chat api", response.text)
|
||||
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
|
||||
return response.text
|
||||
except Exception as error:
|
||||
logger.error("error get data from chat api", error=error)
|
||||
return "Вообще всё сломалось :("
|
||||
|
||||
async def request_to_chatgpt_microservice(self, question: str, chat_gpt_model: str) -> Response:
|
||||
data = self._build_request_data(question=question, chat_gpt_model=chat_gpt_model)
|
||||
|
||||
transport = AsyncHTTPTransport(retries=3)
|
||||
async with AsyncClient(base_url=self.settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
|
||||
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
|
||||
|
||||
@staticmethod
|
||||
def _build_request_data(*, question: str, chat_gpt_model: str) -> dict[str, Any]:
|
||||
return {
|
||||
"conversation_id": str(uuid4()),
|
||||
"action": "_ask",
|
||||
"model": chat_gpt_model,
|
||||
"jailbreak": "default",
|
||||
"meta": {
|
||||
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
|
||||
"content": {
|
||||
"conversation": [],
|
||||
"internet_access": False,
|
||||
"content_type": "text",
|
||||
"parts": [{"content": question, "role": "user"}],
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1,12 +1,10 @@
|
||||
import os
|
||||
import random
|
||||
import subprocess # noqa
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Sequence
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient, AsyncHTTPTransport, Response
|
||||
from httpx import Response
|
||||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
from speech_recognition import (
|
||||
@@ -15,33 +13,30 @@ from speech_recognition import (
|
||||
UnknownValueError as SpeechRecognizerError,
|
||||
)
|
||||
|
||||
from constants import (
|
||||
AUDIO_SEGMENT_DURATION,
|
||||
CHAT_GPT_BASE_URI,
|
||||
INVALID_GPT_REQUEST_MESSAGES,
|
||||
)
|
||||
from constants import AUDIO_SEGMENT_DURATION
|
||||
from core.bot.models.chat_gpt import ChatGpt
|
||||
from core.bot.repository import ChatGPTRepository
|
||||
from infra.database.db_adapter import Database
|
||||
from settings.config import settings
|
||||
|
||||
|
||||
class SpeechToTextService:
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self) -> None:
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
self.filename = filename
|
||||
self.recognizer = Recognizer()
|
||||
self.recognizer.energy_threshold = 50
|
||||
self.text_parts: dict[int, str] = {}
|
||||
self.text_recognised = False
|
||||
|
||||
def get_text_from_audio(self) -> None:
|
||||
self.executor.submit(self.worker)
|
||||
def get_text_from_audio(self, filename: str) -> None:
|
||||
self.executor.submit(self.worker, filename=filename)
|
||||
|
||||
def worker(self) -> Any:
|
||||
self._convert_file_to_wav()
|
||||
self._convert_audio_to_text()
|
||||
def worker(self, filename: str) -> Any:
|
||||
self._convert_file_to_wav(filename)
|
||||
self._convert_audio_to_text(filename)
|
||||
|
||||
def _convert_audio_to_text(self) -> None:
|
||||
wav_filename = f"{self.filename}.wav"
|
||||
def _convert_audio_to_text(self, filename: str) -> None:
|
||||
wav_filename = f"{filename}.wav"
|
||||
|
||||
speech = AudioSegment.from_wav(wav_filename)
|
||||
speech_duration = len(speech)
|
||||
@@ -63,18 +58,19 @@ class SpeechToTextService:
|
||||
# clean temp voice message main files
|
||||
try:
|
||||
os.remove(wav_filename)
|
||||
os.remove(self.filename)
|
||||
os.remove(filename)
|
||||
except FileNotFoundError as error:
|
||||
logger.error("error temps files not deleted", error=error, filenames=[self.filename, self.filename])
|
||||
logger.error("error temps files not deleted", error=error, filenames=[filename, wav_filename])
|
||||
|
||||
def _convert_file_to_wav(self) -> None:
|
||||
new_filename = self.filename + ".wav"
|
||||
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", self.filename, "-vn", new_filename]
|
||||
@staticmethod
|
||||
def _convert_file_to_wav(filename: str) -> None:
|
||||
new_filename = filename + ".wav"
|
||||
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", filename, "-vn", new_filename]
|
||||
try:
|
||||
subprocess.run(args=cmd) # noqa: S603
|
||||
logger.info("file has been converted to wav", filename=new_filename)
|
||||
except Exception as error:
|
||||
logger.error("cant convert voice", error=error, filename=self.filename)
|
||||
logger.error("cant convert voice", error=error, filename=filename)
|
||||
|
||||
def _recognize_by_google(self, filename: str, sound_segment: AudioSegment) -> str:
|
||||
tmp_filename = f"{filename}_tmp_part"
|
||||
@@ -91,48 +87,36 @@ class SpeechToTextService:
|
||||
raise error
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatGptService:
|
||||
def __init__(self, chat_gpt_model: str) -> None:
|
||||
self.chat_gpt_model = chat_gpt_model
|
||||
repository: ChatGPTRepository
|
||||
|
||||
async def get_chatgpt_models(self) -> Sequence[ChatGpt]:
|
||||
return await self.repository.get_chatgpt_models()
|
||||
|
||||
async def request_to_chatgpt(self, question: str | None) -> str:
|
||||
question = question or "Привет!"
|
||||
chat_gpt_request = self.build_request_data(question)
|
||||
try:
|
||||
response = await self.do_request(chat_gpt_request)
|
||||
status = response.status_code
|
||||
for message in INVALID_GPT_REQUEST_MESSAGES:
|
||||
if message in response.text:
|
||||
message = f"{message}: {settings.GPT_MODEL}"
|
||||
logger.info(message, data=chat_gpt_request)
|
||||
return message
|
||||
if status != httpx.codes.OK:
|
||||
logger.info(f"got response status: {status} from chat api", data=chat_gpt_request)
|
||||
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
|
||||
return response.text
|
||||
except Exception as error:
|
||||
logger.error("error get data from chat api", error=error)
|
||||
return "Вообще всё сломалось :("
|
||||
chat_gpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.ask_question(question=question, chat_gpt_model=chat_gpt_model)
|
||||
|
||||
@staticmethod
|
||||
async def do_request(data: dict[str, Any]) -> Response:
|
||||
transport = AsyncHTTPTransport(retries=3)
|
||||
async with AsyncClient(base_url=settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
|
||||
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
|
||||
async def request_to_chatgpt_microservice(self, question: str) -> Response:
|
||||
chat_gpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.request_to_chatgpt_microservice(question=question, chat_gpt_model=chat_gpt_model)
|
||||
|
||||
def build_request_data(self, question: str) -> dict[str, Any]:
|
||||
return {
|
||||
"conversation_id": str(uuid4()),
|
||||
"action": "_ask",
|
||||
"model": self.chat_gpt_model,
|
||||
"jailbreak": "default",
|
||||
"meta": {
|
||||
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
|
||||
"content": {
|
||||
"conversation": [],
|
||||
"internet_access": False,
|
||||
"content_type": "text",
|
||||
"parts": [{"content": question, "role": "user"}],
|
||||
},
|
||||
},
|
||||
}
|
||||
async def get_current_chatgpt_model(self) -> str:
|
||||
return await self.repository.get_current_chatgpt_model()
|
||||
|
||||
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
|
||||
return await self.repository.change_chatgpt_model_priority(model_id=model_id, priority=priority)
|
||||
|
||||
async def add_chatgpt_model(self, gpt_model: str, priority: int) -> dict[str, str | int]:
|
||||
return await self.repository.add_chatgpt_model(model=gpt_model, priority=priority)
|
||||
|
||||
async def delete_chatgpt_model(self, model_id: int) -> None:
|
||||
return await self.repository.delete_chatgpt_model(model_id=model_id)
|
||||
|
||||
@classmethod
|
||||
def build(cls) -> "ChatGptService":
|
||||
db = Database(settings=settings)
|
||||
repository = ChatGPTRepository(settings=settings, db=db)
|
||||
return ChatGptService(repository=repository)
|
||||
|
||||
73
bot_microservice/core/lifetime.py
Normal file
73
bot_microservice/core/lifetime.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from asyncio import current_task
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_scoped_session,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from settings.config import AppSettings
|
||||
|
||||
|
||||
def startup(app: FastAPI, settings: AppSettings) -> Callable[[], Awaitable[None]]:
|
||||
"""
|
||||
Actions to run on application startup.
|
||||
|
||||
This function use fastAPI app to store data,
|
||||
such as db_engine.
|
||||
|
||||
:param app: the fastAPI application.
|
||||
:param settings: app settings
|
||||
:return: function that actually performs actions.
|
||||
|
||||
"""
|
||||
|
||||
async def _startup() -> None:
|
||||
_setup_db(app, settings)
|
||||
|
||||
return _startup
|
||||
|
||||
|
||||
def shutdown(app: FastAPI) -> Callable[[], Awaitable[None]]:
|
||||
"""
|
||||
Actions to run on application's shutdown.
|
||||
|
||||
:param app: fastAPI application.
|
||||
:return: function that actually performs actions.
|
||||
|
||||
"""
|
||||
|
||||
async def _shutdown() -> None:
|
||||
await app.state.db_engine.dispose()
|
||||
|
||||
return _shutdown
|
||||
|
||||
|
||||
def _setup_db(app: FastAPI, settings: AppSettings) -> None:
|
||||
"""
|
||||
Create connection to the database.
|
||||
|
||||
This function creates SQLAlchemy engine instance,
|
||||
session_factory for creating sessions
|
||||
and stores them in the application's state property.
|
||||
|
||||
:param app: fastAPI application.
|
||||
"""
|
||||
engine = create_async_engine(
|
||||
str(settings.db_url),
|
||||
echo=settings.DB_ECHO,
|
||||
execution_options={"isolation_level": "AUTOCOMMIT"},
|
||||
)
|
||||
session_factory = async_scoped_session(
|
||||
async_sessionmaker(
|
||||
engine,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession,
|
||||
),
|
||||
scopefunc=current_task,
|
||||
)
|
||||
app.state.db_engine = engine
|
||||
app.state.db_session_factory = session_factory
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache, wraps
|
||||
from inspect import cleandoc
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -22,3 +23,9 @@ def timed_cache(**timedelta_kwargs: Any) -> Any:
|
||||
return _wrapped
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def clean_doc(cls: Any) -> str | None:
|
||||
if cls.__doc__ is None:
|
||||
return None
|
||||
return cleandoc(cls.__doc__)
|
||||
|
||||
0
bot_microservice/infra/database/__init__.py
Normal file
0
bot_microservice/infra/database/__init__.py
Normal file
34
bot_microservice/infra/database/base.py
Normal file
34
bot_microservice/infra/database/base.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from sqlalchemy import Table, inspect
|
||||
from sqlalchemy.orm import as_declarative, declared_attr
|
||||
|
||||
from infra.database.meta import meta
|
||||
|
||||
|
||||
@as_declarative(metadata=meta)
|
||||
class Base:
|
||||
"""
|
||||
Base for all models.
|
||||
|
||||
It has some type definitions to
|
||||
enhance autocompletion.
|
||||
"""
|
||||
|
||||
# Generate __tablename__ automatically
|
||||
@declared_attr
|
||||
def __tablename__(self) -> str:
|
||||
return self.__name__.lower()
|
||||
|
||||
__table__: Table
|
||||
|
||||
@classmethod
|
||||
def get_real_column_name(cls, attr_name: str) -> str:
|
||||
return getattr(inspect(cls).c, attr_name).name # type: ignore
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
try:
|
||||
return f"{self.__class__.__name__}(id={self.id})" # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
return super().__repr__()
|
||||
102
bot_microservice/infra/database/db_adapter.py
Normal file
102
bot_microservice/infra/database/db_adapter.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
import pkgutil
|
||||
from asyncio import current_task
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_scoped_session,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from settings.config import AppSettings
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, settings: AppSettings) -> None:
|
||||
self.db_connect_url = settings.db_url
|
||||
self.echo_logs = settings.DB_ECHO
|
||||
self.db_file = settings.DB_FILE
|
||||
self._engine: AsyncEngine = create_async_engine(
|
||||
str(settings.db_url),
|
||||
echo=settings.DB_ECHO,
|
||||
execution_options={"isolation_level": "AUTOCOMMIT"},
|
||||
)
|
||||
self._async_session_factory = async_scoped_session(
|
||||
async_sessionmaker(
|
||||
autoflush=False,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
bind=self._engine,
|
||||
),
|
||||
scopefunc=current_task,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
session: AsyncSession = self._async_session_factory()
|
||||
|
||||
async with session:
|
||||
try:
|
||||
yield session
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_transaction_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
async with self._async_session_factory() as session, session.begin():
|
||||
try:
|
||||
yield session
|
||||
except Exception as error:
|
||||
await session.rollback()
|
||||
raise error
|
||||
|
||||
async def create_database(self) -> None:
|
||||
"""
|
||||
Create a test database.
|
||||
|
||||
:param engine: Async engine for database creation
|
||||
:param db_path: path to sqlite file
|
||||
|
||||
"""
|
||||
if not self.db_file.exists():
|
||||
from infra.database.meta import meta
|
||||
|
||||
load_all_models()
|
||||
try:
|
||||
async with self._engine.begin() as connection:
|
||||
await connection.run_sync(meta.create_all)
|
||||
|
||||
logger.info("all migrations are applied")
|
||||
except Exception as err:
|
||||
logger.error("Cant run migrations", err=err)
|
||||
|
||||
async def drop_database(self) -> None:
|
||||
"""
|
||||
Drop current database.
|
||||
|
||||
:param path: Delete sqlite database file
|
||||
|
||||
"""
|
||||
if self.db_file.exists():
|
||||
os.remove(self.db_file)
|
||||
|
||||
|
||||
def load_all_models() -> None:
|
||||
"""Load all models from this folder."""
|
||||
package_dir = Path(__file__).resolve().parent.parent
|
||||
package_dir = package_dir.joinpath("core")
|
||||
modules = pkgutil.walk_packages(path=[str(package_dir)], prefix="core.")
|
||||
models_packages = [module for module in modules if module.ispkg and "models" in module.name]
|
||||
for module in models_packages:
|
||||
model_pkgs = pkgutil.walk_packages(
|
||||
path=[os.path.join(str(module.module_finder.path), "models")], prefix=f"{module.name}." # type: ignore
|
||||
)
|
||||
for model_pkg in model_pkgs:
|
||||
__import__(model_pkg.name)
|
||||
20
bot_microservice/infra/database/deps.py
Normal file
20
bot_microservice/infra/database/deps.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.requests import Request
|
||||
|
||||
|
||||
async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Create and get database session.
|
||||
|
||||
:param request: current request.
|
||||
:yield: database session.
|
||||
"""
|
||||
session: AsyncSession = request.app.state.db_session_factory()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.commit()
|
||||
await session.close()
|
||||
3
bot_microservice/infra/database/meta.py
Normal file
3
bot_microservice/infra/database/meta.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from sqlalchemy import MetaData
|
||||
|
||||
meta = MetaData()
|
||||
1
bot_microservice/infra/database/migrations/__init__.py
Normal file
1
bot_microservice/infra/database/migrations/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Alembic migraions."""
|
||||
81
bot_microservice/infra/database/migrations/env.py
Normal file
81
bot_microservice/infra/database/migrations/env.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.ext.asyncio.engine import create_async_engine
|
||||
from sqlalchemy.future import Connection
|
||||
|
||||
from infra.database.db_adapter import load_all_models
|
||||
from infra.database.meta import meta
|
||||
from settings.config import settings
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# for 'autogenerate' support from myapp import mymodel
|
||||
load_all_models()
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
target_metadata = meta
|
||||
|
||||
|
||||
async def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
context.configure(
|
||||
url=str(settings.db_url),
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
"""
|
||||
Run actual sync migrations.
|
||||
|
||||
:param connection: connection to the database.
|
||||
|
||||
"""
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online() -> None:
|
||||
"""
|
||||
Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = create_async_engine(str(settings.db_url))
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
asyncio.run(run_migrations_offline())
|
||||
else:
|
||||
asyncio.run(run_migrations_online())
|
||||
24
bot_microservice/infra/database/migrations/script.py.mako
Normal file
24
bot_microservice/infra/database/migrations/script.py.mako
Normal file
@@ -0,0 +1,24 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,34 @@
|
||||
"""initial commit
|
||||
|
||||
Revision ID: eb78565abec7
|
||||
Revises:
|
||||
Create Date: 2023-10-05 18:28:30.915361
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "eb78565abec7"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"chatgpt",
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("model", sa.VARCHAR(length=256), nullable=False),
|
||||
sa.Column("priority", sa.SMALLINT(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("model"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("chatgpt")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,49 @@
|
||||
"""create chat gpt models
|
||||
|
||||
Revision ID: c2e443941930
|
||||
Revises: eb78565abec7
|
||||
Create Date: 2025-10-05 20:44:05.414977
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from constants import ChatGptModelsEnum
|
||||
from core.bot.models.chat_gpt import ChatGpt
|
||||
from settings.config import settings
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c2e443941930"
|
||||
down_revision = "eb78565abec7"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | None = None
|
||||
|
||||
engine = create_engine(str(settings.db_url), echo=settings.DB_ECHO)
|
||||
session_factory = sessionmaker(engine)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with session_factory() as session:
|
||||
query = select(ChatGpt)
|
||||
results = session.execute(query)
|
||||
models = results.scalars().all()
|
||||
|
||||
if models:
|
||||
return None
|
||||
models = []
|
||||
for model in ChatGptModelsEnum:
|
||||
priority = 0 if model != "gpt-3.5-turbo-stream-FreeGpt" else 1
|
||||
fields = {"model": model, "priority": priority}
|
||||
models.append(ChatGpt(**fields))
|
||||
session.add_all(models)
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with session_factory() as session:
|
||||
session.execute(f"""TRUNCATE TABLE {ChatGpt.__tablename__}""")
|
||||
session.commit()
|
||||
|
||||
|
||||
engine.dispose()
|
||||
@@ -17,6 +17,56 @@ else:
|
||||
Record = dict[str, Any]
|
||||
|
||||
|
||||
class Formatter:
|
||||
@staticmethod
|
||||
def json_formatter(record: Record) -> str:
|
||||
# Обрезаем `\n` в конце логов, т.к. в json формате переносы не нужны
|
||||
return Formatter.scrap_sensitive_info(record.get("message", "").strip())
|
||||
|
||||
@staticmethod
|
||||
def sentry_formatter(record: Record) -> str:
|
||||
if message := record.get("message", ""):
|
||||
record["message"] = Formatter.scrap_sensitive_info(message)
|
||||
return "{name}:{function} {message}"
|
||||
|
||||
@staticmethod
|
||||
def text_formatter(record: Record) -> str:
|
||||
# WARNING !!!
|
||||
# Функция должна возвращать строку, которая содержит только шаблоны для форматирования.
|
||||
# Если в строку прокидывать значения из record (или еще откуда-либо),
|
||||
# то loguru может принять их за f-строки и попытается обработать, что приведет к ошибке.
|
||||
# Например, если нужно достать какое-то значение из поля extra, вместо того чтобы прокидывать его в строку
|
||||
# формата, нужно прокидывать подстроку вида {extra[тут_ключ]}
|
||||
|
||||
if message := record.get("message", ""):
|
||||
record["message"] = Formatter.scrap_sensitive_info(message)
|
||||
|
||||
# Стандартный формат loguru. Задается через env LOGURU_FORMAT
|
||||
format_ = (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<magenta>{name}</magenta>:<magenta>{function}</magenta>:<magenta>{line}</magenta> - "
|
||||
"<level>{message}</level>"
|
||||
)
|
||||
|
||||
# Добавляем мета параметры по типу user_id, art_id, которые передаются через logger.bind(...)
|
||||
extra = record["extra"]
|
||||
if extra:
|
||||
formatted = ", ".join(f"{key}" + "={extra[" + str(key) + "]}" for key, value in extra.items())
|
||||
format_ += f" - <cyan>{formatted}</cyan>"
|
||||
|
||||
format_ += "\n"
|
||||
|
||||
if record["exception"] is not None:
|
||||
format_ += "{exception}\n"
|
||||
|
||||
return format_
|
||||
|
||||
@staticmethod
|
||||
def scrap_sensitive_info(message: str) -> str:
|
||||
return message.replace(settings.TELEGRAM_API_TOKEN, "TELEGRAM_API_TOKEN".center(24, "*"))
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# Get corresponding Loguru level if it exists
|
||||
@@ -31,13 +81,9 @@ class InterceptHandler(logging.Handler):
|
||||
frame = cast(FrameType, frame.f_back)
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(level, self._scrap_sensitive_info(record))
|
||||
|
||||
@staticmethod
|
||||
def _scrap_sensitive_info(record: logging.LogRecord) -> str:
|
||||
message = record.getMessage()
|
||||
message.replace(settings.TELEGRAM_API_TOKEN, "TELEGRAM_API_TOKEN".center(24, "*"))
|
||||
return message
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, Formatter.scrap_sensitive_info(record.getMessage())
|
||||
)
|
||||
|
||||
|
||||
def configure_logging(
|
||||
@@ -45,7 +91,7 @@ def configure_logging(
|
||||
) -> None:
|
||||
intercept_handler = InterceptHandler()
|
||||
|
||||
formatter = _json_formatter if enable_json_logs else _text_formatter
|
||||
formatter = Formatter.json_formatter if enable_json_logs else Formatter.text_formatter
|
||||
|
||||
base_config_handlers = [intercept_handler]
|
||||
|
||||
@@ -64,10 +110,10 @@ def configure_logging(
|
||||
base_config_handlers.append(graylog_handler)
|
||||
loguru_handlers.append({**base_loguru_handler, "sink": graylog_handler})
|
||||
if log_to_file:
|
||||
file_path = os.path.join(DIR_LOGS, log_to_file)
|
||||
file_path = DIR_LOGS / log_to_file
|
||||
if not os.path.exists(log_to_file):
|
||||
with open(file_path, 'w') as f:
|
||||
f.write('')
|
||||
with open(file_path, "w") as f:
|
||||
f.write("")
|
||||
loguru_handlers.append({**base_loguru_handler, "sink": file_path})
|
||||
|
||||
logging.basicConfig(handlers=base_config_handlers, level=level.name)
|
||||
@@ -78,42 +124,4 @@ def configure_logging(
|
||||
# https://forum.sentry.io/t/changing-issue-title-when-logging-with-traceback/446
|
||||
if enable_sentry_logs:
|
||||
handler = EventHandler(level=logging.WARNING)
|
||||
logger.add(handler, diagnose=True, level=logging.WARNING, format=_sentry_formatter)
|
||||
|
||||
|
||||
def _json_formatter(record: Record) -> str:
|
||||
# Обрезаем `\n` в конце логов, т.к. в json формате переносы не нужны
|
||||
return record.get("message", "").strip()
|
||||
|
||||
|
||||
def _sentry_formatter(record: Record) -> str:
|
||||
return "{name}:{function} {message}"
|
||||
|
||||
|
||||
def _text_formatter(record: Record) -> str:
|
||||
# WARNING !!!
|
||||
# Функция должна возвращать строку, которая содержит только шаблоны для форматирования.
|
||||
# Если в строку прокидывать значения из record (или еще откуда-либо),
|
||||
# то loguru может принять их за f-строки и попытается обработать, что приведет к ошибке.
|
||||
# Например, если нужно достать какое-то значение из поля extra, вместо того чтобы прокидывать его в строку формата,
|
||||
# нужно прокидывать подстроку вида {extra[тут_ключ]}
|
||||
|
||||
# Стандартный формат loguru. Задается через env LOGURU_FORMAT
|
||||
format_ = (
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
# Добавляем мета параметры по типу user_id, art_id, которые передаются через logger.bind(...)
|
||||
extra = record["extra"]
|
||||
if extra:
|
||||
formatted = ", ".join(f"{key}" + "={extra[" + str(key) + "]}" for key, value in extra.items())
|
||||
format_ += f" - <cyan>{formatted}</cyan>"
|
||||
|
||||
format_ += "\n"
|
||||
|
||||
if record["exception"] is not None:
|
||||
format_ += "{exception}\n"
|
||||
|
||||
return format_
|
||||
logger.add(handler, diagnose=True, level=logging.WARNING, format=Formatter.sentry_formatter)
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi.responses import UJSONResponse
|
||||
from constants import LogLevelEnum
|
||||
from core.bot.app import BotApplication, BotQueue
|
||||
from core.bot.handlers import bot_event_handlers
|
||||
from core.lifetime import shutdown, startup
|
||||
from infra.logging_conf import configure_logging
|
||||
from routers import api_router
|
||||
from settings.config import AppSettings, get_settings
|
||||
@@ -26,8 +27,12 @@ class Application:
|
||||
)
|
||||
self.app.state.settings = settings
|
||||
self.app.state.queue = BotQueue(bot_app=bot_app)
|
||||
self.app.state.bot_app = bot_app
|
||||
self.bot_app = bot_app
|
||||
|
||||
self.app.on_event("startup")(startup(self.app, settings))
|
||||
self.app.on_event("shutdown")(shutdown(self.app))
|
||||
|
||||
self.app.include_router(api_router)
|
||||
self.configure_hooks()
|
||||
configure_logging(
|
||||
@@ -51,18 +56,18 @@ class Application:
|
||||
|
||||
def configure_hooks(self) -> None:
|
||||
if self.bot_app.start_with_webhook:
|
||||
self.app.add_event_handler("startup", self._on_start_up)
|
||||
self.app.add_event_handler("startup", self._bot_start_up)
|
||||
else:
|
||||
self.app.add_event_handler("startup", self.bot_app.polling)
|
||||
|
||||
self.app.add_event_handler("shutdown", self._on_shutdown)
|
||||
self.app.add_event_handler("shutdown", self._bot_shutdown)
|
||||
|
||||
async def _on_start_up(self) -> None:
|
||||
async def _bot_start_up(self) -> None:
|
||||
await self.bot_app.set_webhook()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(self.app.state.queue.get_updates_from_queue())
|
||||
|
||||
async def _on_shutdown(self) -> None:
|
||||
async def _bot_shutdown(self) -> None:
|
||||
await asyncio.gather(self.bot_app.delete_webhook(), self.bot_app.shutdown())
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
from yarl import URL
|
||||
|
||||
from constants import API_PREFIX
|
||||
|
||||
@@ -52,6 +53,9 @@ class AppSettings(SentrySettings, BaseSettings):
|
||||
DOMAIN: str = "https://localhost"
|
||||
URL_PREFIX: str = ""
|
||||
|
||||
DB_FILE: Path = SHARED_DIR / "chat_gpt.db"
|
||||
DB_ECHO: bool = False
|
||||
|
||||
# ==== gpt settings ====
|
||||
GPT_MODEL: str = "gpt-3.5-turbo-stream-DeepAi"
|
||||
GPT_BASE_HOST: str = "http://chat_service:8858"
|
||||
@@ -91,6 +95,18 @@ class AppSettings(SentrySettings, BaseSettings):
|
||||
def bot_webhook_url(self) -> str:
|
||||
return "/".join([self.api_prefix, self.token_part])
|
||||
|
||||
@cached_property
|
||||
def db_url(self) -> URL:
|
||||
"""
|
||||
Assemble database URL from settings.
|
||||
|
||||
:return: database URL.
|
||||
"""
|
||||
return URL.build(
|
||||
scheme="sqlite+aiosqlite",
|
||||
path=f"///{self.DB_FILE}",
|
||||
)
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user