add database and migration logic (#27)

* update chat_microservice

* reformat logger_conf

* add database

* add service and repository logic

* fix constants gpt base url

* add models endpoints
This commit is contained in:
Dmitry Afanasyev
2023-10-07 00:04:12 +03:00
committed by GitHub
parent c401e1006c
commit 23031b0777
37 changed files with 1785 additions and 487 deletions

View File

@@ -0,0 +1,52 @@
[alembic]
script_location = infra/database/migrations
file_template = %%(year)d-%%(month).2d-%%(day).2d-%%(hour).2d-%%(minute).2d_%%(rev)s
prepend_sys_path = .
output_encoding = utf-8
[post_write_hooks]
hooks = black,autoflake,isort
black.type = console_scripts
black.entrypoint = black
autoflake.type = console_scripts
autoflake.entrypoint = autoflake
isort.type = console_scripts
isort.entrypoint = isort
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -1,7 +1,17 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Body, Depends, Path
from starlette import status
from starlette.responses import Response
from starlette.responses import JSONResponse, Response
from telegram import Update
from api.bot.serializers import (
ChatGptModelSerializer,
ChatGptModelsPrioritySerializer,
GETChatGptModelsSerializer,
LightChatGptModel,
)
from api.deps import get_bot_queue, get_chatgpt_service, get_update_from_request
from core.bot.app import BotQueue
from core.bot.services import ChatGptService
from settings.config import settings
router = APIRouter()
@@ -15,5 +25,74 @@ router = APIRouter()
summary="process bot updates",
include_in_schema=False,
)
async def process_bot_updates(request: Request) -> None:
await request.app.state.queue.put_updates_on_queue(request)
async def process_bot_updates(
tg_update: Update = Depends(get_update_from_request),
queue: BotQueue = Depends(get_bot_queue),
) -> None:
await queue.put_updates_on_queue(tg_update)
@router.get(
"/models",
name="bot:models_list",
response_class=JSONResponse,
response_model=list[ChatGptModelSerializer],
status_code=status.HTTP_200_OK,
summary="list of models",
)
async def models_list(
chatgpt_service: ChatGptService = Depends(get_chatgpt_service),
) -> JSONResponse:
"""Получить список всех моделей"""
models = await chatgpt_service.get_chatgpt_models()
return JSONResponse(
content=GETChatGptModelsSerializer(data=models).model_dump(), status_code=status.HTTP_200_OK # type: ignore
)
@router.post(
"/models/{model_id}/priority",
name="bot:change_model_priority",
response_class=Response,
status_code=status.HTTP_202_ACCEPTED,
summary="change gpt model priority",
)
async def change_model_priority(
model_id: int = Path(..., gt=0, description="Id модели для обновления приореитета"),
chatgpt_service: ChatGptService = Depends(get_chatgpt_service),
gpt_model: ChatGptModelsPrioritySerializer = Body(...),
) -> None:
"""Изменить приоритет модели в выдаче"""
await chatgpt_service.change_chatgpt_model_priority(model_id=model_id, priority=gpt_model.priority)
@router.post(
"/models",
name="bot:add_new_model",
response_model=ChatGptModelSerializer,
status_code=status.HTTP_201_CREATED,
summary="add new model",
)
async def add_new_model(
chatgpt_service: ChatGptService = Depends(get_chatgpt_service),
gpt_model: LightChatGptModel = Body(...),
) -> JSONResponse:
"""Добавить новую модель"""
model = await chatgpt_service.add_chatgpt_model(gpt_model=gpt_model.model, priority=gpt_model.priority)
return JSONResponse(content=model, status_code=status.HTTP_201_CREATED)
@router.delete(
"/models/{model_id}",
name="bot:delete_gpt_model",
response_class=Response,
status_code=status.HTTP_204_NO_CONTENT,
summary="delete gpt model",
)
async def delete_model(
model_id: int = Path(..., gt=0, description="Id модели для удаления"),
chatgpt_service: ChatGptService = Depends(get_chatgpt_service),
) -> None:
"""Удалить gpt модель"""
await chatgpt_service.delete_chatgpt_model(model_id=model_id)

View File

@@ -1,13 +0,0 @@
from fastapi import Depends
from starlette.requests import Request
from core.bot.services import ChatGptService
from settings.config import AppSettings
def get_settings(request: Request) -> AppSettings:
return request.app.state.settings
def get_chat_gpt_service(settings: AppSettings = Depends(get_settings)) -> ChatGptService:
return ChatGptService(settings.GPT_MODEL)

View File

@@ -0,0 +1,24 @@
from pydantic import BaseModel, ConfigDict, Field
class LightChatGptModel(BaseModel):
model: str = Field(..., title="Chat Gpt model")
priority: int = Field(default=0, ge=0, title="Приоритет модели")
class ChatGptModelsPrioritySerializer(BaseModel):
priority: int = Field(default=0, ge=0, title="Приоритет модели")
class ChatGptModelSerializer(BaseModel):
id: int = Field(..., gt=0, title="Id модели")
model: str = Field(..., title="Chat Gpt model")
priority: int = Field(..., ge=0, title="Приоритет модели")
model_config = ConfigDict(from_attributes=True)
class GETChatGptModelsSerializer(BaseModel):
data: list[ChatGptModelSerializer] = Field(..., title="Список всех моделей")
model_config = ConfigDict(from_attributes=True)

View File

@@ -0,0 +1,50 @@
from fastapi import Depends
from starlette.requests import Request
from telegram import Update
from core.bot.app import BotApplication, BotQueue
from core.bot.repository import ChatGPTRepository
from core.bot.services import ChatGptService, SpeechToTextService
from infra.database.db_adapter import Database
from settings.config import AppSettings
def get_settings(request: Request) -> AppSettings:
return request.app.state.settings
def get_bot_app(request: Request) -> BotApplication:
return request.app.state.bot_app
def get_bot_queue(request: Request) -> BotQueue:
return request.app.state.queue
async def get_update_from_request(request: Request, bot_app: BotApplication = Depends(get_bot_app)) -> Update | None:
data = await request.json()
return Update.de_json(data, bot_app.bot)
def get_database(settings: AppSettings = Depends(get_settings)) -> Database:
return Database(settings=settings)
def get_chat_gpt_repository(
db: Database = Depends(get_database), settings: AppSettings = Depends(get_settings)
) -> ChatGPTRepository:
return ChatGPTRepository(settings=settings, db=db)
def get_speech_to_text_service() -> SpeechToTextService:
return SpeechToTextService()
def new_bot_queue(bot_app: BotApplication = Depends(get_bot_app)) -> BotQueue:
return BotQueue(bot_app=bot_app)
def get_chatgpt_service(
chat_gpt_repository: ChatGPTRepository = Depends(get_chat_gpt_repository),
) -> ChatGptService:
return ChatGptService(repository=chat_gpt_repository)

View File

@@ -3,7 +3,7 @@ from fastapi.responses import ORJSONResponse
from starlette import status
from starlette.responses import Response
from api.bot.deps import get_chat_gpt_service
from api.deps import get_chatgpt_service
from api.exceptions import BaseAPIException
from constants import INVALID_GPT_REQUEST_MESSAGES
from core.bot.services import ChatGptService
@@ -33,12 +33,11 @@ async def healthcheck() -> ORJSONResponse:
)
async def gpt_healthcheck(
response: Response,
chatgpt_service: ChatGptService = Depends(get_chat_gpt_service),
chatgpt_service: ChatGptService = Depends(get_chatgpt_service),
) -> Response:
data = chatgpt_service.build_request_data("Привет!")
response.status_code = status.HTTP_200_OK
try:
chatgpt_response = await chatgpt_service.do_request(data)
chatgpt_response = await chatgpt_service.request_to_chatgpt_microservice(question="Привет!")
if chatgpt_response.status_code != status.HTTP_200_OK:
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
for message in INVALID_GPT_REQUEST_MESSAGES:

View File

@@ -1,4 +1,4 @@
from enum import StrEnum
from enum import StrEnum, unique
AUDIO_SEGMENT_DURATION = 120 * 1000
@@ -26,3 +26,35 @@ class LogLevelEnum(StrEnum):
INFO = "info"
DEBUG = "debug"
NOTSET = ""
@unique
class ChatGptModelsEnum(StrEnum):
gpt_3_5_turbo_stream_openai = "gpt-3.5-turbo-stream-openai"
gpt_3_5_turbo_Aichat = "gpt-3.5-turbo-Aichat"
gpt_4_ChatgptAi = "gpt-4-ChatgptAi"
gpt_3_5_turbo_weWordle = "gpt-3.5-turbo-weWordle"
gpt_3_5_turbo_acytoo = "gpt-3.5-turbo-acytoo"
gpt_3_5_turbo_stream_DeepAi = "gpt-3.5-turbo-stream-DeepAi"
gpt_3_5_turbo_stream_H2o = "gpt-3.5-turbo-stream-H2o"
gpt_3_5_turbo_stream_yqcloud = "gpt-3.5-turbo-stream-yqcloud"
gpt_OpenAssistant_stream_HuggingChat = "gpt-OpenAssistant-stream-HuggingChat"
gpt_4_turbo_stream_you = "gpt-4-turbo-stream-you"
gpt_3_5_turbo_AItianhu = "gpt-3.5-turbo-AItianhu"
gpt_3_stream_binjie = "gpt-3-stream-binjie"
gpt_3_5_turbo_stream_CodeLinkAva = "gpt-3.5-turbo-stream-CodeLinkAva"
gpt_4_stream_ChatBase = "gpt-4-stream-ChatBase"
gpt_3_5_turbo_stream_aivvm = "gpt-3.5-turbo-stream-aivvm"
gpt_3_5_turbo_16k_stream_Ylokh = "gpt-3.5-turbo-16k-stream-Ylokh"
gpt_3_5_turbo_stream_Vitalentum = "gpt-3.5-turbo-stream-Vitalentum"
gpt_3_5_turbo_stream_GptGo = "gpt-3.5-turbo-stream-GptGo"
gpt_3_5_turbo_stream_AItianhuSpace = "gpt-3.5-turbo-stream-AItianhuSpace"
gpt_3_5_turbo_stream_Aibn = "gpt-3.5-turbo-stream-Aibn"
gpt_3_5_turbo_ChatgptDuo = "gpt-3.5-turbo-ChatgptDuo"
gpt_3_5_turbo_stream_FreeGpt = "gpt-3.5-turbo-stream-FreeGpt"
gpt_3_5_turbo_stream_ChatForAi = "gpt-3.5-turbo-stream-ChatForAi"
gpt_3_5_turbo_stream_Cromicle = "gpt-3.5-turbo-stream-Cromicle"
@classmethod
def values(cls) -> set[str]:
return set(map(str, set(ChatGptModelsEnum)))

View File

@@ -6,7 +6,7 @@ from functools import cached_property
from http import HTTPStatus
from typing import Any
from fastapi import Request, Response
from fastapi import Response
from loguru import logger
from telegram import Bot, Update
from telegram.ext import Application
@@ -68,14 +68,11 @@ class BotQueue:
bot_app: BotApplication
queue: Queue = asyncio.Queue() # type: ignore[type-arg]
async def put_updates_on_queue(self, request: Request) -> Response:
async def put_updates_on_queue(self, tg_update: Update) -> Response:
"""
Listen /{URL_PREFIX}/{API_PREFIX}/{TELEGRAM_WEB_TOKEN} path and proxy post request to bot
"""
data = await request.json()
tg_update = Update.de_json(data=data, bot=self.bot_app.application.bot)
self.queue.put_nowait(tg_update)
return Response(status_code=HTTPStatus.ACCEPTED)
async def get_updates_from_queue(self) -> None:

View File

@@ -67,7 +67,7 @@ async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No
await update.message.reply_text("Пожалуйста подождите, ответ в среднем занимает 10-15 секунд")
chat_gpt_service = ChatGptService(chat_gpt_model=settings.GPT_MODEL)
chat_gpt_service = ChatGptService.build()
logger.warning("question asked", user=update.message.from_user, question=update.message.text)
answer = await chat_gpt_service.request_to_chatgpt(question=update.message.text)
await update.message.reply_text(answer)
@@ -87,9 +87,9 @@ async def voice_recognize(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
logger.info("file has been saved", filename=tmpfile.name)
speech_to_text_service = SpeechToTextService(filename=tmpfile.name)
speech_to_text_service = SpeechToTextService()
speech_to_text_service.get_text_from_audio()
speech_to_text_service.get_text_from_audio(filename=tmpfile.name)
part = 0
while speech_to_text_service.text_parts or not speech_to_text_service.text_recognised:

View File

@@ -0,0 +1,12 @@
from sqlalchemy import INTEGER, SMALLINT, VARCHAR
from sqlalchemy.orm import Mapped, mapped_column
from infra.database.base import Base
__slots__ = ("ChatGpt",)
class ChatGpt(Base):
id: Mapped[int] = mapped_column("id", INTEGER(), primary_key=True, autoincrement=True)
model: Mapped[str] = mapped_column("model", VARCHAR(length=256), nullable=False, unique=True)
priority: Mapped[int] = mapped_column("priority", SMALLINT(), default=0)

View File

@@ -0,0 +1,106 @@
import random
from dataclasses import dataclass
from typing import Any, Sequence
from uuid import uuid4
import httpx
from httpx import AsyncClient, AsyncHTTPTransport, Response
from loguru import logger
from sqlalchemy import delete, desc, select, update
from sqlalchemy.dialects.sqlite import insert
from constants import CHAT_GPT_BASE_URI, INVALID_GPT_REQUEST_MESSAGES
from core.bot.models.chat_gpt import ChatGpt
from infra.database.db_adapter import Database
from settings.config import AppSettings
@dataclass
class ChatGPTRepository:
settings: AppSettings
db: Database
async def get_chatgpt_models(self) -> Sequence[ChatGpt]:
query = select(ChatGpt).order_by(desc(ChatGpt.priority))
async with self.db.session() as session:
result = await session.execute(query)
return result.scalars().all()
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
current_model = await self.get_current_chatgpt_model()
reset_priority_query = update(ChatGpt).values(priority=0).filter(ChatGpt.model == current_model)
set_new_priority_query = update(ChatGpt).values(priority=priority).filter(ChatGpt.model == model_id)
async with self.db.get_transaction_session() as session:
await session.execute(reset_priority_query)
await session.execute(set_new_priority_query)
async def add_chatgpt_model(self, model: str, priority: int) -> dict[str, str | int]:
query = (
insert(ChatGpt)
.values(
{ChatGpt.model: model, ChatGpt.priority: priority},
)
.prefix_with("OR IGNORE")
)
async with self.db.session() as session:
await session.execute(query)
await session.commit()
return {"model": model, "priority": priority}
async def delete_chatgpt_model(self, model_id: int) -> None:
query = delete(ChatGpt).filter_by(id=model_id)
async with self.db.session() as session:
await session.execute(query)
async def get_current_chatgpt_model(self) -> str:
query = select(ChatGpt.model).order_by(desc(ChatGpt.priority)).limit(1)
async with self.db.session() as session:
result = await session.execute(query)
return result.scalar_one()
async def ask_question(self, question: str, chat_gpt_model: str) -> str:
try:
response = await self.request_to_chatgpt_microservice(question=question, chat_gpt_model=chat_gpt_model)
status = response.status_code
for message in INVALID_GPT_REQUEST_MESSAGES:
if message in response.text:
message = f"{message}: {chat_gpt_model}"
logger.info(message, question=question, chat_gpt_model=chat_gpt_model)
return message
if status != httpx.codes.OK:
logger.info(f"got response status: {status} from chat api", response.text)
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
return response.text
except Exception as error:
logger.error("error get data from chat api", error=error)
return "Вообще всё сломалось :("
async def request_to_chatgpt_microservice(self, question: str, chat_gpt_model: str) -> Response:
data = self._build_request_data(question=question, chat_gpt_model=chat_gpt_model)
transport = AsyncHTTPTransport(retries=3)
async with AsyncClient(base_url=self.settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
@staticmethod
def _build_request_data(*, question: str, chat_gpt_model: str) -> dict[str, Any]:
return {
"conversation_id": str(uuid4()),
"action": "_ask",
"model": chat_gpt_model,
"jailbreak": "default",
"meta": {
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
"content": {
"conversation": [],
"internet_access": False,
"content_type": "text",
"parts": [{"content": question, "role": "user"}],
},
},
}

View File

@@ -1,12 +1,10 @@
import os
import random
import subprocess # noqa
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Any
from uuid import uuid4
from dataclasses import dataclass
from typing import Any, Sequence
import httpx
from httpx import AsyncClient, AsyncHTTPTransport, Response
from httpx import Response
from loguru import logger
from pydub import AudioSegment
from speech_recognition import (
@@ -15,33 +13,30 @@ from speech_recognition import (
UnknownValueError as SpeechRecognizerError,
)
from constants import (
AUDIO_SEGMENT_DURATION,
CHAT_GPT_BASE_URI,
INVALID_GPT_REQUEST_MESSAGES,
)
from constants import AUDIO_SEGMENT_DURATION
from core.bot.models.chat_gpt import ChatGpt
from core.bot.repository import ChatGPTRepository
from infra.database.db_adapter import Database
from settings.config import settings
class SpeechToTextService:
def __init__(self, filename: str) -> None:
def __init__(self) -> None:
self.executor = ThreadPoolExecutor()
self.filename = filename
self.recognizer = Recognizer()
self.recognizer.energy_threshold = 50
self.text_parts: dict[int, str] = {}
self.text_recognised = False
def get_text_from_audio(self) -> None:
self.executor.submit(self.worker)
def get_text_from_audio(self, filename: str) -> None:
self.executor.submit(self.worker, filename=filename)
def worker(self) -> Any:
self._convert_file_to_wav()
self._convert_audio_to_text()
def worker(self, filename: str) -> Any:
self._convert_file_to_wav(filename)
self._convert_audio_to_text(filename)
def _convert_audio_to_text(self) -> None:
wav_filename = f"{self.filename}.wav"
def _convert_audio_to_text(self, filename: str) -> None:
wav_filename = f"{filename}.wav"
speech = AudioSegment.from_wav(wav_filename)
speech_duration = len(speech)
@@ -63,18 +58,19 @@ class SpeechToTextService:
# clean temp voice message main files
try:
os.remove(wav_filename)
os.remove(self.filename)
os.remove(filename)
except FileNotFoundError as error:
logger.error("error temps files not deleted", error=error, filenames=[self.filename, self.filename])
logger.error("error temps files not deleted", error=error, filenames=[filename, wav_filename])
def _convert_file_to_wav(self) -> None:
new_filename = self.filename + ".wav"
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", self.filename, "-vn", new_filename]
@staticmethod
def _convert_file_to_wav(filename: str) -> None:
new_filename = filename + ".wav"
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", filename, "-vn", new_filename]
try:
subprocess.run(args=cmd) # noqa: S603
logger.info("file has been converted to wav", filename=new_filename)
except Exception as error:
logger.error("cant convert voice", error=error, filename=self.filename)
logger.error("cant convert voice", error=error, filename=filename)
def _recognize_by_google(self, filename: str, sound_segment: AudioSegment) -> str:
tmp_filename = f"{filename}_tmp_part"
@@ -91,48 +87,36 @@ class SpeechToTextService:
raise error
@dataclass
class ChatGptService:
def __init__(self, chat_gpt_model: str) -> None:
self.chat_gpt_model = chat_gpt_model
repository: ChatGPTRepository
async def get_chatgpt_models(self) -> Sequence[ChatGpt]:
return await self.repository.get_chatgpt_models()
async def request_to_chatgpt(self, question: str | None) -> str:
question = question or "Привет!"
chat_gpt_request = self.build_request_data(question)
try:
response = await self.do_request(chat_gpt_request)
status = response.status_code
for message in INVALID_GPT_REQUEST_MESSAGES:
if message in response.text:
message = f"{message}: {settings.GPT_MODEL}"
logger.info(message, data=chat_gpt_request)
return message
if status != httpx.codes.OK:
logger.info(f"got response status: {status} from chat api", data=chat_gpt_request)
return "Что-то пошло не так, попробуйте еще раз или обратитесь к администратору"
return response.text
except Exception as error:
logger.error("error get data from chat api", error=error)
return "Вообще всё сломалось :("
chat_gpt_model = await self.get_current_chatgpt_model()
return await self.repository.ask_question(question=question, chat_gpt_model=chat_gpt_model)
@staticmethod
async def do_request(data: dict[str, Any]) -> Response:
transport = AsyncHTTPTransport(retries=3)
async with AsyncClient(base_url=settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
async def request_to_chatgpt_microservice(self, question: str) -> Response:
chat_gpt_model = await self.get_current_chatgpt_model()
return await self.repository.request_to_chatgpt_microservice(question=question, chat_gpt_model=chat_gpt_model)
def build_request_data(self, question: str) -> dict[str, Any]:
return {
"conversation_id": str(uuid4()),
"action": "_ask",
"model": self.chat_gpt_model,
"jailbreak": "default",
"meta": {
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
"content": {
"conversation": [],
"internet_access": False,
"content_type": "text",
"parts": [{"content": question, "role": "user"}],
},
},
}
async def get_current_chatgpt_model(self) -> str:
return await self.repository.get_current_chatgpt_model()
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
return await self.repository.change_chatgpt_model_priority(model_id=model_id, priority=priority)
async def add_chatgpt_model(self, gpt_model: str, priority: int) -> dict[str, str | int]:
return await self.repository.add_chatgpt_model(model=gpt_model, priority=priority)
async def delete_chatgpt_model(self, model_id: int) -> None:
return await self.repository.delete_chatgpt_model(model_id=model_id)
@classmethod
def build(cls) -> "ChatGptService":
db = Database(settings=settings)
repository = ChatGPTRepository(settings=settings, db=db)
return ChatGptService(repository=repository)

View File

@@ -0,0 +1,73 @@
from asyncio import current_task
from typing import Awaitable, Callable
from fastapi import FastAPI
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_scoped_session,
async_sessionmaker,
create_async_engine,
)
from settings.config import AppSettings
def startup(app: FastAPI, settings: AppSettings) -> Callable[[], Awaitable[None]]:
"""
Actions to run on application startup.
This function use fastAPI app to store data,
such as db_engine.
:param app: the fastAPI application.
:param settings: app settings
:return: function that actually performs actions.
"""
async def _startup() -> None:
_setup_db(app, settings)
return _startup
def shutdown(app: FastAPI) -> Callable[[], Awaitable[None]]:
"""
Actions to run on application's shutdown.
:param app: fastAPI application.
:return: function that actually performs actions.
"""
async def _shutdown() -> None:
await app.state.db_engine.dispose()
return _shutdown
def _setup_db(app: FastAPI, settings: AppSettings) -> None:
"""
Create connection to the database.
This function creates SQLAlchemy engine instance,
session_factory for creating sessions
and stores them in the application's state property.
:param app: fastAPI application.
"""
engine = create_async_engine(
str(settings.db_url),
echo=settings.DB_ECHO,
execution_options={"isolation_level": "AUTOCOMMIT"},
)
session_factory = async_scoped_session(
async_sessionmaker(
engine,
expire_on_commit=False,
class_=AsyncSession,
),
scopefunc=current_task,
)
app.state.db_engine = engine
app.state.db_session_factory = session_factory

View File

@@ -1,5 +1,6 @@
from datetime import datetime, timedelta
from functools import lru_cache, wraps
from inspect import cleandoc
from typing import Any
@@ -22,3 +23,9 @@ def timed_cache(**timedelta_kwargs: Any) -> Any:
return _wrapped
return _wrapper
def clean_doc(cls: Any) -> str | None:
if cls.__doc__ is None:
return None
return cleandoc(cls.__doc__)

View File

@@ -0,0 +1,34 @@
from sqlalchemy import Table, inspect
from sqlalchemy.orm import as_declarative, declared_attr
from infra.database.meta import meta
@as_declarative(metadata=meta)
class Base:
"""
Base for all models.
It has some type definitions to
enhance autocompletion.
"""
# Generate __tablename__ automatically
@declared_attr
def __tablename__(self) -> str:
return self.__name__.lower()
__table__: Table
@classmethod
def get_real_column_name(cls, attr_name: str) -> str:
return getattr(inspect(cls).c, attr_name).name # type: ignore
def __str__(self) -> str:
return self.__repr__()
def __repr__(self) -> str:
try:
return f"{self.__class__.__name__}(id={self.id})" # type: ignore[attr-defined]
except AttributeError:
return super().__repr__()

View File

@@ -0,0 +1,102 @@
import os
import pkgutil
from asyncio import current_task
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator
from loguru import logger
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_scoped_session,
async_sessionmaker,
create_async_engine,
)
from settings.config import AppSettings
class Database:
def __init__(self, settings: AppSettings) -> None:
self.db_connect_url = settings.db_url
self.echo_logs = settings.DB_ECHO
self.db_file = settings.DB_FILE
self._engine: AsyncEngine = create_async_engine(
str(settings.db_url),
echo=settings.DB_ECHO,
execution_options={"isolation_level": "AUTOCOMMIT"},
)
self._async_session_factory = async_scoped_session(
async_sessionmaker(
autoflush=False,
class_=AsyncSession,
expire_on_commit=False,
bind=self._engine,
),
scopefunc=current_task,
)
@asynccontextmanager
async def session(self) -> AsyncGenerator[AsyncSession, None]:
session: AsyncSession = self._async_session_factory()
async with session:
try:
yield session
except Exception:
await session.rollback()
raise
@asynccontextmanager
async def get_transaction_session(self) -> AsyncGenerator[AsyncSession, None]:
async with self._async_session_factory() as session, session.begin():
try:
yield session
except Exception as error:
await session.rollback()
raise error
async def create_database(self) -> None:
"""
Create a test database.
:param engine: Async engine for database creation
:param db_path: path to sqlite file
"""
if not self.db_file.exists():
from infra.database.meta import meta
load_all_models()
try:
async with self._engine.begin() as connection:
await connection.run_sync(meta.create_all)
logger.info("all migrations are applied")
except Exception as err:
logger.error("Cant run migrations", err=err)
async def drop_database(self) -> None:
"""
Drop current database.
:param path: Delete sqlite database file
"""
if self.db_file.exists():
os.remove(self.db_file)
def load_all_models() -> None:
"""Load all models from this folder."""
package_dir = Path(__file__).resolve().parent.parent
package_dir = package_dir.joinpath("core")
modules = pkgutil.walk_packages(path=[str(package_dir)], prefix="core.")
models_packages = [module for module in modules if module.ispkg and "models" in module.name]
for module in models_packages:
model_pkgs = pkgutil.walk_packages(
path=[os.path.join(str(module.module_finder.path), "models")], prefix=f"{module.name}." # type: ignore
)
for model_pkg in model_pkgs:
__import__(model_pkg.name)

View File

@@ -0,0 +1,20 @@
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.requests import Request
async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession, None]:
"""
Create and get database session.
:param request: current request.
:yield: database session.
"""
session: AsyncSession = request.app.state.db_session_factory()
try:
yield session
finally:
await session.commit()
await session.close()

View File

@@ -0,0 +1,3 @@
from sqlalchemy import MetaData
meta = MetaData()

View File

@@ -0,0 +1 @@
"""Alembic migraions."""

View File

@@ -0,0 +1,81 @@
import asyncio
from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio.engine import create_async_engine
from sqlalchemy.future import Connection
from infra.database.db_adapter import load_all_models
from infra.database.meta import meta
from settings.config import settings
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# for 'autogenerate' support from myapp import mymodel
load_all_models()
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
target_metadata = meta
async def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
context.configure(
url=str(settings.db_url),
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
"""
Run actual sync migrations.
:param connection: connection to the database.
"""
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
"""
Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = create_async_engine(str(settings.db_url))
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
if context.is_offline_mode():
asyncio.run(run_migrations_offline())
else:
asyncio.run(run_migrations_online())

View File

@@ -0,0 +1,24 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,34 @@
"""initial commit
Revision ID: eb78565abec7
Revises:
Create Date: 2023-10-05 18:28:30.915361
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "eb78565abec7"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"chatgpt",
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("model", sa.VARCHAR(length=256), nullable=False),
sa.Column("priority", sa.SMALLINT(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("model"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("chatgpt")
# ### end Alembic commands ###

View File

@@ -0,0 +1,49 @@
"""create chat gpt models
Revision ID: c2e443941930
Revises: eb78565abec7
Create Date: 2025-10-05 20:44:05.414977
"""
from sqlalchemy import create_engine, select
from sqlalchemy.orm import sessionmaker
from constants import ChatGptModelsEnum
from core.bot.models.chat_gpt import ChatGpt
from settings.config import settings
# revision identifiers, used by Alembic.
revision = "c2e443941930"
down_revision = "eb78565abec7"
branch_labels: str | None = None
depends_on: str | None = None
engine = create_engine(str(settings.db_url), echo=settings.DB_ECHO)
session_factory = sessionmaker(engine)
def upgrade() -> None:
with session_factory() as session:
query = select(ChatGpt)
results = session.execute(query)
models = results.scalars().all()
if models:
return None
models = []
for model in ChatGptModelsEnum:
priority = 0 if model != "gpt-3.5-turbo-stream-FreeGpt" else 1
fields = {"model": model, "priority": priority}
models.append(ChatGpt(**fields))
session.add_all(models)
session.commit()
def downgrade() -> None:
with session_factory() as session:
session.execute(f"""TRUNCATE TABLE {ChatGpt.__tablename__}""")
session.commit()
engine.dispose()

View File

@@ -17,6 +17,56 @@ else:
Record = dict[str, Any]
class Formatter:
@staticmethod
def json_formatter(record: Record) -> str:
# Обрезаем `\n` в конце логов, т.к. в json формате переносы не нужны
return Formatter.scrap_sensitive_info(record.get("message", "").strip())
@staticmethod
def sentry_formatter(record: Record) -> str:
if message := record.get("message", ""):
record["message"] = Formatter.scrap_sensitive_info(message)
return "{name}:{function} {message}"
@staticmethod
def text_formatter(record: Record) -> str:
# WARNING !!!
# Функция должна возвращать строку, которая содержит только шаблоны для форматирования.
# Если в строку прокидывать значения из record (или еще откуда-либо),
# то loguru может принять их за f-строки и попытается обработать, что приведет к ошибке.
# Например, если нужно достать какое-то значение из поля extra, вместо того чтобы прокидывать его в строку
# формата, нужно прокидывать подстроку вида {extra[тут_ключ]}
if message := record.get("message", ""):
record["message"] = Formatter.scrap_sensitive_info(message)
# Стандартный формат loguru. Задается через env LOGURU_FORMAT
format_ = (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<magenta>{name}</magenta>:<magenta>{function}</magenta>:<magenta>{line}</magenta> - "
"<level>{message}</level>"
)
# Добавляем мета параметры по типу user_id, art_id, которые передаются через logger.bind(...)
extra = record["extra"]
if extra:
formatted = ", ".join(f"{key}" + "={extra[" + str(key) + "]}" for key, value in extra.items())
format_ += f" - <cyan>{formatted}</cyan>"
format_ += "\n"
if record["exception"] is not None:
format_ += "{exception}\n"
return format_
@staticmethod
def scrap_sensitive_info(message: str) -> str:
return message.replace(settings.TELEGRAM_API_TOKEN, "TELEGRAM_API_TOKEN".center(24, "*"))
class InterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists
@@ -31,13 +81,9 @@ class InterceptHandler(logging.Handler):
frame = cast(FrameType, frame.f_back)
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(level, self._scrap_sensitive_info(record))
@staticmethod
def _scrap_sensitive_info(record: logging.LogRecord) -> str:
message = record.getMessage()
message.replace(settings.TELEGRAM_API_TOKEN, "TELEGRAM_API_TOKEN".center(24, "*"))
return message
logger.opt(depth=depth, exception=record.exc_info).log(
level, Formatter.scrap_sensitive_info(record.getMessage())
)
def configure_logging(
@@ -45,7 +91,7 @@ def configure_logging(
) -> None:
intercept_handler = InterceptHandler()
formatter = _json_formatter if enable_json_logs else _text_formatter
formatter = Formatter.json_formatter if enable_json_logs else Formatter.text_formatter
base_config_handlers = [intercept_handler]
@@ -64,10 +110,10 @@ def configure_logging(
base_config_handlers.append(graylog_handler)
loguru_handlers.append({**base_loguru_handler, "sink": graylog_handler})
if log_to_file:
file_path = os.path.join(DIR_LOGS, log_to_file)
file_path = DIR_LOGS / log_to_file
if not os.path.exists(log_to_file):
with open(file_path, 'w') as f:
f.write('')
with open(file_path, "w") as f:
f.write("")
loguru_handlers.append({**base_loguru_handler, "sink": file_path})
logging.basicConfig(handlers=base_config_handlers, level=level.name)
@@ -78,42 +124,4 @@ def configure_logging(
# https://forum.sentry.io/t/changing-issue-title-when-logging-with-traceback/446
if enable_sentry_logs:
handler = EventHandler(level=logging.WARNING)
logger.add(handler, diagnose=True, level=logging.WARNING, format=_sentry_formatter)
def _json_formatter(record: Record) -> str:
# Обрезаем `\n` в конце логов, т.к. в json формате переносы не нужны
return record.get("message", "").strip()
def _sentry_formatter(record: Record) -> str:
return "{name}:{function} {message}"
def _text_formatter(record: Record) -> str:
# WARNING !!!
# Функция должна возвращать строку, которая содержит только шаблоны для форматирования.
# Если в строку прокидывать значения из record (или еще откуда-либо),
# то loguru может принять их за f-строки и попытается обработать, что приведет к ошибке.
# Например, если нужно достать какое-то значение из поля extra, вместо того чтобы прокидывать его в строку формата,
# нужно прокидывать подстроку вида {extra[тут_ключ]}
# Стандартный формат loguru. Задается через env LOGURU_FORMAT
format_ = (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
)
# Добавляем мета параметры по типу user_id, art_id, которые передаются через logger.bind(...)
extra = record["extra"]
if extra:
formatted = ", ".join(f"{key}" + "={extra[" + str(key) + "]}" for key, value in extra.items())
format_ += f" - <cyan>{formatted}</cyan>"
format_ += "\n"
if record["exception"] is not None:
format_ += "{exception}\n"
return format_
logger.add(handler, diagnose=True, level=logging.WARNING, format=Formatter.sentry_formatter)

View File

@@ -8,6 +8,7 @@ from fastapi.responses import UJSONResponse
from constants import LogLevelEnum
from core.bot.app import BotApplication, BotQueue
from core.bot.handlers import bot_event_handlers
from core.lifetime import shutdown, startup
from infra.logging_conf import configure_logging
from routers import api_router
from settings.config import AppSettings, get_settings
@@ -26,8 +27,12 @@ class Application:
)
self.app.state.settings = settings
self.app.state.queue = BotQueue(bot_app=bot_app)
self.app.state.bot_app = bot_app
self.bot_app = bot_app
self.app.on_event("startup")(startup(self.app, settings))
self.app.on_event("shutdown")(shutdown(self.app))
self.app.include_router(api_router)
self.configure_hooks()
configure_logging(
@@ -51,18 +56,18 @@ class Application:
def configure_hooks(self) -> None:
if self.bot_app.start_with_webhook:
self.app.add_event_handler("startup", self._on_start_up)
self.app.add_event_handler("startup", self._bot_start_up)
else:
self.app.add_event_handler("startup", self.bot_app.polling)
self.app.add_event_handler("shutdown", self._on_shutdown)
self.app.add_event_handler("shutdown", self._bot_shutdown)
async def _on_start_up(self) -> None:
async def _bot_start_up(self) -> None:
await self.bot_app.set_webhook()
loop = asyncio.get_event_loop()
loop.create_task(self.app.state.queue.get_updates_from_queue())
async def _on_shutdown(self) -> None:
async def _bot_shutdown(self) -> None:
await asyncio.gather(self.bot_app.delete_webhook(), self.bot_app.shutdown())

View File

@@ -6,6 +6,7 @@ from typing import Any
from dotenv import load_dotenv
from pydantic import model_validator
from pydantic_settings import BaseSettings
from yarl import URL
from constants import API_PREFIX
@@ -52,6 +53,9 @@ class AppSettings(SentrySettings, BaseSettings):
DOMAIN: str = "https://localhost"
URL_PREFIX: str = ""
DB_FILE: Path = SHARED_DIR / "chat_gpt.db"
DB_ECHO: bool = False
# ==== gpt settings ====
GPT_MODEL: str = "gpt-3.5-turbo-stream-DeepAi"
GPT_BASE_HOST: str = "http://chat_service:8858"
@@ -91,6 +95,18 @@ class AppSettings(SentrySettings, BaseSettings):
def bot_webhook_url(self) -> str:
return "/".join([self.api_prefix, self.token_part])
@cached_property
def db_url(self) -> URL:
"""
Assemble database URL from settings.
:return: database URL.
"""
return URL.build(
scheme="sqlite+aiosqlite",
path=f"///{self.DB_FILE}",
)
class Config:
case_sensitive = True