mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2026-02-03 11:40:39 +03:00
update chat service (#31)
* rename chatgpt service * add zeus tool for new provider * add zeus tool for new provider * update chat service * update README.md
This commit is contained in:
@@ -30,7 +30,7 @@ def get_database(settings: AppSettings = Depends(get_settings)) -> Database:
|
||||
return Database(settings=settings)
|
||||
|
||||
|
||||
def get_chat_gpt_repository(
|
||||
def get_chatgpt_repository(
|
||||
db: Database = Depends(get_database), settings: AppSettings = Depends(get_settings)
|
||||
) -> ChatGPTRepository:
|
||||
return ChatGPTRepository(settings=settings, db=db)
|
||||
@@ -41,6 +41,6 @@ def new_bot_queue(bot_app: BotApplication = Depends(get_bot_app)) -> BotQueue:
|
||||
|
||||
|
||||
def get_chatgpt_service(
|
||||
chat_gpt_repository: ChatGPTRepository = Depends(get_chat_gpt_repository),
|
||||
chatgpt_repository: ChatGPTRepository = Depends(get_chatgpt_repository),
|
||||
) -> ChatGptService:
|
||||
return ChatGptService(repository=chat_gpt_repository)
|
||||
return ChatGptService(repository=chatgpt_repository)
|
||||
|
||||
@@ -3,7 +3,7 @@ from enum import StrEnum, unique
|
||||
AUDIO_SEGMENT_DURATION = 120 * 1000
|
||||
|
||||
API_PREFIX = "/api"
|
||||
CHAT_GPT_BASE_URI = "/backend-api/v2/conversation"
|
||||
CHATGPT_BASE_URI = "/backend-api/v2/conversation"
|
||||
INVALID_GPT_REQUEST_MESSAGES = ("Invalid request model", "return unexpected http status code")
|
||||
|
||||
|
||||
@@ -31,16 +31,12 @@ class LogLevelEnum(StrEnum):
|
||||
@unique
|
||||
class ChatGptModelsEnum(StrEnum):
|
||||
gpt_3_5_turbo_stream_openai = "gpt-3.5-turbo-stream-openai"
|
||||
gpt_3_5_turbo_Aichat = "gpt-3.5-turbo-Aichat"
|
||||
gpt_4_ChatgptAi = "gpt-4-ChatgptAi"
|
||||
gpt_3_5_turbo_weWordle = "gpt-3.5-turbo-weWordle"
|
||||
gpt_3_5_turbo_acytoo = "gpt-3.5-turbo-acytoo"
|
||||
gpt_3_5_turbo_stream_DeepAi = "gpt-3.5-turbo-stream-DeepAi"
|
||||
gpt_3_5_turbo_stream_H2o = "gpt-3.5-turbo-stream-H2o"
|
||||
gpt_3_5_turbo_stream_yqcloud = "gpt-3.5-turbo-stream-yqcloud"
|
||||
gpt_OpenAssistant_stream_HuggingChat = "gpt-OpenAssistant-stream-HuggingChat"
|
||||
gpt_4_turbo_stream_you = "gpt-4-turbo-stream-you"
|
||||
gpt_3_5_turbo_AItianhu = "gpt-3.5-turbo-AItianhu"
|
||||
gpt_3_stream_binjie = "gpt-3-stream-binjie"
|
||||
gpt_3_5_turbo_stream_CodeLinkAva = "gpt-3.5-turbo-stream-CodeLinkAva"
|
||||
gpt_4_stream_ChatBase = "gpt-4-stream-ChatBase"
|
||||
@@ -48,14 +44,15 @@ class ChatGptModelsEnum(StrEnum):
|
||||
gpt_3_5_turbo_16k_stream_Ylokh = "gpt-3.5-turbo-16k-stream-Ylokh"
|
||||
gpt_3_5_turbo_stream_Vitalentum = "gpt-3.5-turbo-stream-Vitalentum"
|
||||
gpt_3_5_turbo_stream_GptGo = "gpt-3.5-turbo-stream-GptGo"
|
||||
gpt_3_5_turbo_stream_AItianhuSpace = "gpt-3.5-turbo-stream-AItianhuSpace"
|
||||
gpt_3_5_turbo_stream_Aibn = "gpt-3.5-turbo-stream-Aibn"
|
||||
gpt_3_5_turbo_ChatgptDuo = "gpt-3.5-turbo-ChatgptDuo"
|
||||
gpt_3_5_turbo_stream_FreeGpt = "gpt-3.5-turbo-stream-FreeGpt"
|
||||
gpt_3_5_turbo_stream_ChatForAi = "gpt-3.5-turbo-stream-ChatForAi"
|
||||
gpt_3_5_turbo_stream_Cromicle = "gpt-3.5-turbo-stream-Cromicle"
|
||||
gpt_4_stream_Chatgpt4Online = "gpt-4-stream-Chatgpt4Online"
|
||||
gpt_3_5_turbo_stream_gptalk = "gpt-3.5-turbo-stream-gptalk"
|
||||
gpt_3_5_turbo_stream_ChatgptDemo = "gpt-3.5-turbo-stream-ChatgptDemo"
|
||||
gpt_3_5_turbo_stream_H2o = "gpt-3.5-turbo-stream-H2o"
|
||||
gpt_3_5_turbo_stream_gptforlove = "gpt-3.5-turbo-stream-gptforlove"
|
||||
|
||||
@classmethod
|
||||
def values(cls) -> set[str]:
|
||||
@@ -64,9 +61,6 @@ class ChatGptModelsEnum(StrEnum):
|
||||
@staticmethod
|
||||
def _deprecated() -> set[str]:
|
||||
return {
|
||||
"gpt-3.5-turbo-Aichat",
|
||||
"gpt-3.5-turbo-stream-ChatForAi",
|
||||
"gpt-3.5-turbo-stream-AItianhuSpace",
|
||||
"gpt-3.5-turbo-AItianhu",
|
||||
"gpt-3.5-turbo-acytoo",
|
||||
"gpt-3.5-turbo-stream-H2o",
|
||||
"gpt-3.5-turbo-stream-gptforlove",
|
||||
}
|
||||
|
||||
@@ -32,8 +32,8 @@ async def about_me(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
async def about_bot(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
if not update.effective_message:
|
||||
return None
|
||||
chat_gpt_service = ChatGptService.build()
|
||||
model = await chat_gpt_service.get_current_chatgpt_model()
|
||||
chatgpt_service = ChatGptService.build()
|
||||
model = await chatgpt_service.get_current_chatgpt_model()
|
||||
await update.effective_message.reply_text(
|
||||
f"Бот использует бесплатную модель {model} для ответов на вопросы. "
|
||||
f"\nПринимает запросы на разных языках.\n\nБот так же умеет переводить русские голосовые сообщения в текст. "
|
||||
@@ -69,9 +69,9 @@ async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No
|
||||
|
||||
await update.message.reply_text("Пожалуйста подождите, ответ в среднем занимает 10-15 секунд")
|
||||
|
||||
chat_gpt_service = ChatGptService.build()
|
||||
chatgpt_service = ChatGptService.build()
|
||||
logger.warning("question asked", user=update.message.from_user, question=update.message.text)
|
||||
answer = await chat_gpt_service.request_to_chatgpt(question=update.message.text)
|
||||
answer = await chatgpt_service.request_to_chatgpt(question=update.message.text)
|
||||
await update.message.reply_text(answer)
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from loguru import logger
|
||||
from sqlalchemy import delete, desc, select, update
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
|
||||
from constants import CHAT_GPT_BASE_URI, INVALID_GPT_REQUEST_MESSAGES
|
||||
from constants import CHATGPT_BASE_URI, INVALID_GPT_REQUEST_MESSAGES
|
||||
from core.bot.models.chat_gpt import ChatGpt
|
||||
from infra.database.db_adapter import Database
|
||||
from settings.config import AppSettings
|
||||
@@ -64,14 +64,14 @@ class ChatGPTRepository:
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one()
|
||||
|
||||
async def ask_question(self, question: str, chat_gpt_model: str) -> str:
|
||||
async def ask_question(self, question: str, chatgpt_model: str) -> str:
|
||||
try:
|
||||
response = await self.request_to_chatgpt_microservice(question=question, chat_gpt_model=chat_gpt_model)
|
||||
response = await self.request_to_chatgpt_microservice(question=question, chatgpt_model=chatgpt_model)
|
||||
status = response.status_code
|
||||
for message in INVALID_GPT_REQUEST_MESSAGES:
|
||||
if message in response.text:
|
||||
message = f"{message}: {chat_gpt_model}"
|
||||
logger.info(message, question=question, chat_gpt_model=chat_gpt_model)
|
||||
message = f"{message}: {chatgpt_model}"
|
||||
logger.info(message, question=question, chatgpt_model=chatgpt_model)
|
||||
return message
|
||||
if status != httpx.codes.OK:
|
||||
logger.info(f"got response status: {status} from chat api", response.text)
|
||||
@@ -81,19 +81,19 @@ class ChatGPTRepository:
|
||||
logger.error("error get data from chat api", error=error)
|
||||
return "Вообще всё сломалось :("
|
||||
|
||||
async def request_to_chatgpt_microservice(self, question: str, chat_gpt_model: str) -> Response:
|
||||
data = self._build_request_data(question=question, chat_gpt_model=chat_gpt_model)
|
||||
async def request_to_chatgpt_microservice(self, question: str, chatgpt_model: str) -> Response:
|
||||
data = self._build_request_data(question=question, chatgpt_model=chatgpt_model)
|
||||
|
||||
transport = AsyncHTTPTransport(retries=3)
|
||||
async with AsyncClient(base_url=self.settings.GPT_BASE_HOST, transport=transport, timeout=50) as client:
|
||||
return await client.post(CHAT_GPT_BASE_URI, json=data, timeout=50)
|
||||
return await client.post(CHATGPT_BASE_URI, json=data, timeout=50)
|
||||
|
||||
@staticmethod
|
||||
def _build_request_data(*, question: str, chat_gpt_model: str) -> dict[str, Any]:
|
||||
def _build_request_data(*, question: str, chatgpt_model: str) -> dict[str, Any]:
|
||||
return {
|
||||
"conversation_id": str(uuid4()),
|
||||
"action": "_ask",
|
||||
"model": chat_gpt_model,
|
||||
"model": chatgpt_model,
|
||||
"jailbreak": "default",
|
||||
"meta": {
|
||||
"id": random.randint(10**18, 10**19 - 1), # noqa: S311
|
||||
|
||||
@@ -96,12 +96,12 @@ class ChatGptService:
|
||||
|
||||
async def request_to_chatgpt(self, question: str | None) -> str:
|
||||
question = question or "Привет!"
|
||||
chat_gpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.ask_question(question=question, chat_gpt_model=chat_gpt_model)
|
||||
chatgpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.ask_question(question=question, chatgpt_model=chatgpt_model)
|
||||
|
||||
async def request_to_chatgpt_microservice(self, question: str) -> Response:
|
||||
chat_gpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.request_to_chatgpt_microservice(question=question, chat_gpt_model=chat_gpt_model)
|
||||
chatgpt_model = await self.get_current_chatgpt_model()
|
||||
return await self.repository.request_to_chatgpt_microservice(question=question, chatgpt_model=chatgpt_model)
|
||||
|
||||
async def get_current_chatgpt_model(self) -> str:
|
||||
return await self.repository.get_current_chatgpt_model()
|
||||
|
||||
@@ -1,18 +1,26 @@
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache, wraps
|
||||
from functools import cache, wraps
|
||||
from inspect import cleandoc
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
def timed_cache(**timedelta_kwargs: Any) -> Any:
|
||||
def _wrapper(func: Any) -> Any:
|
||||
update_delta = timedelta(**timedelta_kwargs)
|
||||
def timed_lru_cache(
|
||||
microseconds: int = 0,
|
||||
milliseconds: int = 0,
|
||||
seconds: int = 0,
|
||||
minutes: int = 0,
|
||||
hours: int = 0,
|
||||
) -> Any:
|
||||
def _wrapper(func: Any) -> Callable[[Any], Any]:
|
||||
update_delta = timedelta(
|
||||
microseconds=microseconds, milliseconds=milliseconds, seconds=seconds, minutes=minutes, hours=hours
|
||||
)
|
||||
next_update = datetime.utcnow() + update_delta
|
||||
# Apply @lru_cache to f with no cache size limit
|
||||
cached_func = lru_cache(None)(func)
|
||||
|
||||
cached_func = cache(func)
|
||||
|
||||
@wraps(func)
|
||||
def _wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
def _wrapped(*args: Any, **kwargs: Any) -> Callable[[Any], Any]:
|
||||
nonlocal next_update
|
||||
now = datetime.utcnow()
|
||||
if now >= next_update:
|
||||
|
||||
@@ -105,10 +105,11 @@ def configure_logging(
|
||||
{**base_loguru_handler, "colorize": True, "sink": sys.stdout},
|
||||
]
|
||||
|
||||
if settings.GRAYLOG_HOST and settings.GRAYLOG_PORT:
|
||||
if settings.ENABLE_GRAYLOG:
|
||||
graylog_handler = graypy.GELFUDPHandler(settings.GRAYLOG_HOST, settings.GRAYLOG_PORT)
|
||||
base_config_handlers.append(graylog_handler)
|
||||
loguru_handlers.append({**base_loguru_handler, "sink": graylog_handler})
|
||||
|
||||
if log_to_file:
|
||||
file_path = DIR_LOGS / log_to_file
|
||||
if not os.path.exists(log_to_file):
|
||||
|
||||
@@ -42,7 +42,7 @@ class Application:
|
||||
log_to_file=settings.LOG_TO_FILE,
|
||||
)
|
||||
|
||||
if settings.SENTRY_DSN is not None:
|
||||
if settings.ENABLE_SENTRY:
|
||||
sentry_sdk.init(
|
||||
dsn=settings.SENTRY_DSN,
|
||||
environment=settings.DEPLOY_ENVIRONMENT,
|
||||
|
||||
@@ -10,6 +10,7 @@ RELOAD="true"
|
||||
DEBUG="true"
|
||||
|
||||
# ==== sentry ====
|
||||
ENABLE_SENTRY="false"
|
||||
SENTRY_DSN=
|
||||
SENTRY_TRACES_SAMPLE_RATE="0.95"
|
||||
DEPLOY_ENVIRONMENT="stage"
|
||||
@@ -17,8 +18,11 @@ DEPLOY_ENVIRONMENT="stage"
|
||||
# ==== logs ====:
|
||||
ENABLE_JSON_LOGS="true"
|
||||
ENABLE_SENTRY_LOGS="false"
|
||||
|
||||
ENABLE_GRAYLOG="false"
|
||||
GRAYLOG_HOST=
|
||||
GRAYLOG_PORT=
|
||||
|
||||
LOG_TO_FILE="example.log"
|
||||
|
||||
# ==== telegram settings ====
|
||||
@@ -31,7 +35,7 @@ DOMAIN="https://mydomain.com"
|
||||
URL_PREFIX="/gpt"
|
||||
|
||||
# ==== gpt settings ====
|
||||
GPT_BASE_HOST="http://chat_service:8858"
|
||||
GPT_BASE_HOST="http://chatgpt_chat_service:8858"
|
||||
|
||||
# ==== other settings ====
|
||||
USER="web"
|
||||
|
||||
@@ -29,12 +29,36 @@ load_dotenv(env_path, override=True)
|
||||
|
||||
|
||||
class SentrySettings(BaseSettings):
|
||||
ENABLE_SENTRY: bool = False
|
||||
SENTRY_DSN: str | None = None
|
||||
DEPLOY_ENVIRONMENT: str | None = None
|
||||
SENTRY_TRACES_SAMPLE_RATE: float = 0.95
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_sentry_enabled(self) -> "SentrySettings":
|
||||
if self.ENABLE_SENTRY and not self.SENTRY_DSN:
|
||||
raise RuntimeError("sentry dsn must be set")
|
||||
return self
|
||||
|
||||
class AppSettings(SentrySettings, BaseSettings):
|
||||
|
||||
class LoggingSettings(BaseSettings):
|
||||
ENABLE_JSON_LOGS: bool = True
|
||||
ENABLE_SENTRY_LOGS: bool = False
|
||||
|
||||
ENABLE_GRAYLOG: bool = False
|
||||
GRAYLOG_HOST: str | None = None
|
||||
GRAYLOG_PORT: int | None = None
|
||||
|
||||
LOG_TO_FILE: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_graylog_enabled(self) -> "LoggingSettings":
|
||||
if self.ENABLE_GRAYLOG and not all([self.GRAYLOG_HOST, self.GRAYLOG_PORT]):
|
||||
raise RuntimeError("graylog host and port must be set")
|
||||
return self
|
||||
|
||||
|
||||
class AppSettings(SentrySettings, LoggingSettings, BaseSettings):
|
||||
"""Application settings."""
|
||||
|
||||
PROJECT_NAME: str = "chat gpt bot"
|
||||
@@ -58,13 +82,7 @@ class AppSettings(SentrySettings, BaseSettings):
|
||||
|
||||
# ==== gpt settings ====
|
||||
GPT_MODEL: str = "gpt-3.5-turbo-stream-DeepAi"
|
||||
GPT_BASE_HOST: str = "http://chat_service:8858"
|
||||
|
||||
ENABLE_JSON_LOGS: bool = True
|
||||
ENABLE_SENTRY_LOGS: bool = False
|
||||
GRAYLOG_HOST: str | None = None
|
||||
GRAYLOG_PORT: int | None = None
|
||||
LOG_TO_FILE: str | None = None
|
||||
GPT_BASE_HOST: str = "http://chathpt_chat_service:8858"
|
||||
|
||||
@model_validator(mode="before") # type: ignore[arg-type]
|
||||
def validate_boolean_fields(self) -> Any:
|
||||
@@ -75,6 +93,8 @@ class AppSettings(SentrySettings, BaseSettings):
|
||||
"START_WITH_WEBHOOK",
|
||||
"RELOAD",
|
||||
"DEBUG",
|
||||
"ENABLE_GRAYLOG",
|
||||
"ENABLE_SENTRY",
|
||||
):
|
||||
setting_value: str | None = values_dict.get(value)
|
||||
if setting_value and setting_value.lower() == "false":
|
||||
|
||||
@@ -47,7 +47,7 @@ async def test_get_chatgpt_models(
|
||||
)
|
||||
|
||||
|
||||
async def test_change_chagpt_model_priority(
|
||||
async def test_change_chatgpt_model_priority(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
faker: Faker,
|
||||
@@ -61,10 +61,9 @@ async def test_change_chagpt_model_priority(
|
||||
upd_model1, upd_model2 = dbsession.query(ChatGpt).order_by(ChatGpt.priority).all()
|
||||
|
||||
assert model1.model == upd_model1.model
|
||||
assert model1.priority == upd_model1.priority
|
||||
assert model2.model == upd_model2.model
|
||||
|
||||
updated_from_db_model = dbsession.get(ChatGpt, model2.id)
|
||||
assert updated_from_db_model.priority == priority # type: ignore[union-attr]
|
||||
assert upd_model2.priority == priority
|
||||
|
||||
|
||||
async def test_reset_chatgpt_models_priority(
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Iterator
|
||||
import respx
|
||||
from httpx import Response
|
||||
|
||||
from constants import CHAT_GPT_BASE_URI
|
||||
from constants import CHATGPT_BASE_URI
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -16,7 +16,7 @@ def mocked_ask_question_api(
|
||||
assert_all_called=True,
|
||||
base_url=host,
|
||||
) as respx_mock:
|
||||
ask_question_route = respx_mock.post(url=CHAT_GPT_BASE_URI, name="ask_question")
|
||||
ask_question_route = respx_mock.post(url=CHATGPT_BASE_URI, name="ask_question")
|
||||
ask_question_route.return_value = return_value
|
||||
ask_question_route.side_effect = side_effect
|
||||
yield respx_mock
|
||||
|
||||
Reference in New Issue
Block a user