add testing database and chatgpt factories (#28)

* add testing database and chatgpt factories

* include lint job to develop stage

* reformat audioconverter save files to tmp directory

* add api tests

* update README.md
This commit is contained in:
Dmitry Afanasyev 2023-10-08 04:43:24 +03:00 committed by GitHub
parent 23031b0777
commit beb32fb0b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 434 additions and 255 deletions

View File

@ -3,7 +3,7 @@ name: lint
on:
push:
branches-ignore:
- develop
- stage
tags-ignore:
- "*"
pull_request:

View File

@ -3,7 +3,7 @@ name: test
on:
push:
branches-ignore:
- develop
- stage
tags-ignore:
- "*"
pull_request:

View File

@ -134,10 +134,11 @@ alembic --config ./alembic.ini downgrade 389018a3e0f0
- [x] add Database and models
- [x] add alembic migrations
- [] add models priority and their rotation
- [x] add models priority
- [] and models rotation
- [x] add update model priority endpoint
- [] add more tests for gpt model selection
- [x] add more tests for gpt model selection
- [] add authorisation for api
- [] reformat conftest.py file
- [x] reformat conftest.py file
- [x] Add sentry
- [x] Add graylog integration and availability to log to file

View File

@ -33,7 +33,7 @@ async def process_bot_updates(
@router.get(
"/models",
"/chatgpt/models",
name="bot:models_list",
response_class=JSONResponse,
response_model=list[ChatGptModelSerializer],
@ -50,8 +50,8 @@ async def models_list(
)
@router.post(
"/models/{model_id}/priority",
@router.put(
"/chatgpt/models/{model_id}/priority",
name="bot:change_model_priority",
response_class=Response,
status_code=status.HTTP_202_ACCEPTED,
@ -66,8 +66,22 @@ async def change_model_priority(
await chatgpt_service.change_chatgpt_model_priority(model_id=model_id, priority=gpt_model.priority)
@router.put(
"/chatgpt/models/priority/reset",
name="bot:reset_models_priority",
response_class=Response,
status_code=status.HTTP_202_ACCEPTED,
summary="reset all model priority to default",
)
async def reset_models_priority(
chatgpt_service: ChatGptService = Depends(get_chatgpt_service),
) -> None:
"""Сбросить приоритеты у всех моделей на дефолтное значение - 0"""
await chatgpt_service.reset_all_chatgpt_models_priority()
@router.post(
"/models",
"/chatgpt/models",
name="bot:add_new_model",
response_model=ChatGptModelSerializer,
status_code=status.HTTP_201_CREATED,
@ -84,7 +98,7 @@ async def add_new_model(
@router.delete(
"/models/{model_id}",
"/chatgpt/models/{model_id}",
name="bot:delete_gpt_model",
response_class=Response,
status_code=status.HTTP_204_NO_CONTENT,

View File

@ -4,7 +4,7 @@ from telegram import Update
from core.bot.app import BotApplication, BotQueue
from core.bot.repository import ChatGPTRepository
from core.bot.services import ChatGptService, SpeechToTextService
from core.bot.services import ChatGptService
from infra.database.db_adapter import Database
from settings.config import AppSettings
@ -36,10 +36,6 @@ def get_chat_gpt_repository(
return ChatGPTRepository(settings=settings, db=db)
def get_speech_to_text_service() -> SpeechToTextService:
return SpeechToTextService()
def new_bot_queue(bot_app: BotApplication = Depends(get_bot_app)) -> BotQueue:
return BotQueue(bot_app=bot_app)

View File

@ -32,8 +32,10 @@ async def about_me(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
async def about_bot(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if not update.effective_message:
return None
chat_gpt_service = ChatGptService.build()
model = await chat_gpt_service.get_current_chatgpt_model()
await update.effective_message.reply_text(
f"Бот использует бесплатную модель {settings.GPT_MODEL} для ответов на вопросы. "
f"Бот использует бесплатную модель {model} для ответов на вопросы. "
f"\nПринимает запросы на разных языках.\n\nБот так же умеет переводить русские голосовые сообщения в текст. "
f"Просто пришлите голосовуху и получите поток сознания в виде текста, но без знаков препинания",
parse_mode="Markdown",
@ -87,9 +89,9 @@ async def voice_recognize(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
logger.info("file has been saved", filename=tmpfile.name)
speech_to_text_service = SpeechToTextService()
speech_to_text_service = SpeechToTextService(filename=tmpfile.name)
speech_to_text_service.get_text_from_audio(filename=tmpfile.name)
speech_to_text_service.get_text_from_audio()
part = 0
while speech_to_text_service.text_parts or not speech_to_text_service.text_recognised:

View File

@ -28,14 +28,15 @@ class ChatGPTRepository:
return result.scalars().all()
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
current_model = await self.get_current_chatgpt_model()
reset_priority_query = update(ChatGpt).values(priority=0).filter(ChatGpt.model == current_model)
set_new_priority_query = update(ChatGpt).values(priority=priority).filter(ChatGpt.model == model_id)
query = update(ChatGpt).values(priority=priority).filter(ChatGpt.id == model_id)
async with self.db.get_transaction_session() as session:
await session.execute(reset_priority_query)
await session.execute(set_new_priority_query)
await session.execute(query)
async def reset_all_chatgpt_models_priority(self) -> None:
query = update(ChatGpt).values(priority=0)
async with self.db.session() as session:
await session.execute(query)
async def add_chatgpt_model(self, model: str, priority: int) -> dict[str, str | int]:
query = (

View File

@ -1,5 +1,6 @@
import os
import subprocess # noqa
import tempfile
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Sequence
@ -21,22 +22,23 @@ from settings.config import settings
class SpeechToTextService:
def __init__(self) -> None:
def __init__(self, filename: str) -> None:
self.filename = filename
self.executor = ThreadPoolExecutor()
self.recognizer = Recognizer()
self.recognizer.energy_threshold = 50
self.text_parts: dict[int, str] = {}
self.text_recognised = False
def get_text_from_audio(self, filename: str) -> None:
self.executor.submit(self.worker, filename=filename)
def get_text_from_audio(self) -> None:
self.executor.submit(self._worker)
def worker(self, filename: str) -> Any:
self._convert_file_to_wav(filename)
self._convert_audio_to_text(filename)
def _worker(self) -> Any:
self._convert_file_to_wav()
self._convert_audio_to_text()
def _convert_audio_to_text(self, filename: str) -> None:
wav_filename = f"{filename}.wav"
def _convert_audio_to_text(self) -> None:
wav_filename = f"{self.filename}.wav"
speech = AudioSegment.from_wav(wav_filename)
speech_duration = len(speech)
@ -51,40 +53,38 @@ class SpeechToTextService:
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : i * AUDIO_SEGMENT_DURATION + ending]
else:
sound_segment = speech[i * AUDIO_SEGMENT_DURATION - 250 : (i + 1) * AUDIO_SEGMENT_DURATION]
self.text_parts[i] = self._recognize_by_google(wav_filename, sound_segment)
self.text_parts[i] = self._recognize_by_google(sound_segment)
self.text_recognised = True
# clean temp voice message main files
try:
os.remove(wav_filename)
os.remove(filename)
os.remove(self.filename)
except FileNotFoundError as error:
logger.error("error temps files not deleted", error=error, filenames=[filename, wav_filename])
logger.error("error temps files not deleted", error=error, filenames=[self.filename, wav_filename])
@staticmethod
def _convert_file_to_wav(filename: str) -> None:
new_filename = filename + ".wav"
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", filename, "-vn", new_filename]
def _convert_file_to_wav(self) -> None:
new_filename = self.filename + ".wav"
cmd = ["ffmpeg", "-loglevel", "quiet", "-i", self.filename, "-vn", new_filename]
try:
subprocess.run(args=cmd) # noqa: S603
logger.info("file has been converted to wav", filename=new_filename)
except Exception as error:
logger.error("cant convert voice", error=error, filename=filename)
logger.error("cant convert voice", error=error, filename=self.filename)
def _recognize_by_google(self, filename: str, sound_segment: AudioSegment) -> str:
tmp_filename = f"{filename}_tmp_part"
sound_segment.export(tmp_filename, format="wav")
with AudioFile(tmp_filename) as source:
audio_text = self.recognizer.listen(source)
try:
text = self.recognizer.recognize_google(audio_text, language="ru-RU")
os.remove(tmp_filename)
return text
except SpeechRecognizerError as error:
os.remove(tmp_filename)
logger.error("error recognizing text with google", error=error)
raise error
def _recognize_by_google(self, sound_segment: AudioSegment) -> str:
with tempfile.NamedTemporaryFile(delete=True) as tmpfile:
tmpfile.write(sound_segment.raw_data)
sound_segment.export(tmpfile, format="wav")
with AudioFile(tmpfile) as source:
audio_text = self.recognizer.listen(source)
try:
text = self.recognizer.recognize_google(audio_text, language="ru-RU")
return text
except SpeechRecognizerError as error:
logger.error("error recognizing text with google", error=error)
raise error
@dataclass
@ -109,6 +109,9 @@ class ChatGptService:
async def change_chatgpt_model_priority(self, model_id: int, priority: int) -> None:
return await self.repository.change_chatgpt_model_priority(model_id=model_id, priority=priority)
async def reset_all_chatgpt_models_priority(self) -> None:
return await self.repository.reset_all_chatgpt_models_priority()
async def add_chatgpt_model(self, gpt_model: str, priority: int) -> dict[str, str | int]:
return await self.repository.add_chatgpt_model(model=gpt_model, priority=priority)

View File

@ -57,7 +57,7 @@ def _setup_db(app: FastAPI, settings: AppSettings) -> None:
:param app: fastAPI application.
"""
engine = create_async_engine(
str(settings.db_url),
str(settings.async_db_url),
echo=settings.DB_ECHO,
execution_options={"isolation_level": "AUTOCOMMIT"},
)

View File

@ -6,6 +6,7 @@ from pathlib import Path
from typing import AsyncGenerator
from loguru import logger
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
@ -13,17 +14,17 @@ from sqlalchemy.ext.asyncio import (
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from settings.config import AppSettings
class Database:
def __init__(self, settings: AppSettings) -> None:
self.db_connect_url = settings.db_url
self.echo_logs = settings.DB_ECHO
self.db_file = settings.DB_FILE
self._engine: AsyncEngine = create_async_engine(
str(settings.db_url),
self.db_file = settings.db_file
self._async_engine: AsyncEngine = create_async_engine(
str(settings.async_db_url),
echo=settings.DB_ECHO,
execution_options={"isolation_level": "AUTOCOMMIT"},
)
@ -32,10 +33,23 @@ class Database:
autoflush=False,
class_=AsyncSession,
expire_on_commit=False,
bind=self._engine,
bind=self._async_engine,
),
scopefunc=current_task,
)
self._sync_engine = create_engine(str(settings.sync_db_url), echo=settings.DB_ECHO)
self._sync_session_factory = scoped_session(sessionmaker(self._sync_engine))
def get_sync_db_session(self) -> Session:
session: Session = self._sync_session_factory()
try:
return session
except Exception as err:
session.rollback()
raise err
finally:
session.commit()
session.close()
@asynccontextmanager
async def session(self) -> AsyncGenerator[AsyncSession, None]:
@ -70,7 +84,7 @@ class Database:
load_all_models()
try:
async with self._engine.begin() as connection:
async with self._async_engine.begin() as connection:
await connection.run_sync(meta.create_all)
logger.info("all migrations are applied")

View File

@ -38,7 +38,7 @@ async def run_migrations_offline() -> None:
"""
context.configure(
url=str(settings.db_url),
url=str(settings.async_db_url),
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
@ -69,7 +69,7 @@ async def run_migrations_online() -> None:
and associate a connection with the context.
"""
connectable = create_async_engine(str(settings.db_url))
connectable = create_async_engine(str(settings.async_db_url))
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)

View File

@ -19,7 +19,7 @@ down_revision = "eb78565abec7"
branch_labels: str | None = None
depends_on: str | None = None
engine = create_engine(str(settings.db_url), echo=settings.DB_ECHO)
engine = create_engine(str(settings.async_db_url), echo=settings.DB_ECHO)
session_factory = sessionmaker(engine)

View File

@ -4,6 +4,8 @@ STAGE="runtests"
APP_HOST="0.0.0.0"
APP_PORT="8000"
DB_NAME="test_chatgpt.db"
# ==== telegram settings ====
TELEGRAM_API_TOKEN="123456789:AABBCCDDEEFFaabbccddeeff-1234567890"
# set to true to start with webhook. Else bot will start on polling method

View File

@ -4,6 +4,8 @@ STAGE="runtests"
APP_HOST="0.0.0.0"
APP_PORT="8000"
DB_NAME="test_chatgpt.db"
# ==== telegram settings ====
TELEGRAM_API_TOKEN="123456789:AABBCCDDEEFFaabbccddeeff-1234567890"
# set to true to start with webhook. Else bot will start on polling method

View File

@ -32,30 +32,7 @@ URL_PREFIX="/gpt"
# ==== gpt settings ====
GPT_BASE_HOST="http://chat_service:8858"
GPT_MODEL="gpt-3.5-turbo-stream-DeepAi"
# ==== other settings ====
USER="web"
TZ="Europe/Moscow"
# "gpt-3.5-turbo-stream-openai"
# "gpt-3.5-turbo-Aichat"
# "gpt-4-ChatgptAi"
# "gpt-3.5-turbo-weWordle"
# "gpt-3.5-turbo-acytoo"
# "gpt-3.5-turbo-stream-DeepAi"
# "gpt-3.5-turbo-stream-H2o"
# "gpt-3.5-turbo-stream-yqcloud"
# "gpt-OpenAssistant-stream-HuggingChat"
# "gpt-4-turbo-stream-you"
# "gpt-3.5-turbo-AItianhu"
# "gpt-3-stream-binjie"
# "gpt-3.5-turbo-stream-CodeLinkAva"
# "gpt-4-stream-ChatBase"
# "gpt-3.5-turbo-stream-aivvm"
# "gpt-3.5-turbo-16k-stream-Ylokh"
# "gpt-3.5-turbo-stream-Vitalentum"
# "gpt-3.5-turbo-stream-GptGo"
# "gpt-3.5-turbo-stream-AItianhuSpace"
# "gpt-3.5-turbo-stream-Aibn"
# "gpt-3.5-turbo-ChatgptDuo"

View File

@ -53,7 +53,7 @@ class AppSettings(SentrySettings, BaseSettings):
DOMAIN: str = "https://localhost"
URL_PREFIX: str = ""
DB_FILE: Path = SHARED_DIR / "chat_gpt.db"
DB_NAME: str = "chatgpt.db"
DB_ECHO: bool = False
# ==== gpt settings ====
@ -96,15 +96,21 @@ class AppSettings(SentrySettings, BaseSettings):
return "/".join([self.api_prefix, self.token_part])
@cached_property
def db_url(self) -> URL:
"""
Assemble database URL from settings.
def db_file(self) -> Path:
return SHARED_DIR / self.DB_NAME
:return: database URL.
"""
@cached_property
def async_db_url(self) -> URL:
return URL.build(
scheme="sqlite+aiosqlite",
path=f"///{self.DB_FILE}",
path=f"///{self.db_file}",
)
@cached_property
def sync_db_url(self) -> URL:
return URL.build(
scheme="sqlite",
path=f"///{self.db_file}",
)
class Config:

View File

@ -1,10 +1,10 @@
from typing import Any, Callable, Optional
from typing import Any
import pytest
from httpx import AsyncClient, Response
from telegram._utils.defaultvalue import DEFAULT_NONE
from telegram._utils.types import ODVInput
from telegram.error import BadRequest, RetryAfter, TimedOut
from telegram.error import RetryAfter, TimedOut
from telegram.request import HTTPXRequest, RequestData
@ -17,7 +17,7 @@ class NonchalantHttpxRequest(HTTPXRequest):
self,
url: str,
method: str,
request_data: Optional[RequestData] = None,
request_data: RequestData | None = None,
read_timeout: ODVInput[float] = DEFAULT_NONE,
write_timeout: ODVInput[float] = DEFAULT_NONE,
connect_timeout: ODVInput[float] = DEFAULT_NONE,
@ -39,29 +39,6 @@ class NonchalantHttpxRequest(HTTPXRequest):
pytest.xfail(f"Ignoring TimedOut error: {e}")
async def expect_bad_request(func: Callable[..., Any], message: str, reason: str) -> Callable[..., Any]:
"""
Wrapper for testing bot functions expected to result in an :class:`telegram.error.BadRequest`.
Makes it XFAIL, if the specified error message is present.
Args:
func: The awaitable to be executed.
message: The expected message of the bad request error. If another message is present,
the error will be reraised.
reason: Explanation for the XFAIL.
Returns:
On success, returns the return value of :attr:`func`
"""
try:
return await func()
except BadRequest as e:
if message in str(e):
pytest.xfail(f"{reason}. {e}")
else:
raise e
async def send_webhook_message(
ip: str,
port: int,

View File

@ -0,0 +1,168 @@
import pytest
from assertpy import assert_that
from faker import Faker
from httpx import AsyncClient
from sqlalchemy import desc
from sqlalchemy.orm import Session
from core.bot.models.chat_gpt import ChatGpt
from tests.integration.factories.bot import ChatGptModelFactory
pytestmark = [
pytest.mark.asyncio,
pytest.mark.enable_socket,
]
async def test_get_chatgpt_models(
dbsession: Session,
rest_client: AsyncClient,
) -> None:
model1 = ChatGptModelFactory(priority=0)
model2 = ChatGptModelFactory(priority=42)
model3 = ChatGptModelFactory(priority=1)
response = await rest_client.get(url="/api/chatgpt/models")
assert response.status_code == 200
data = response.json()["data"]
assert_that(data).is_equal_to(
[
{
"id": model2.id,
"model": model2.model,
"priority": model2.priority,
},
{
"id": model3.id,
"model": model3.model,
"priority": model3.priority,
},
{
"id": model1.id,
"model": model1.model,
"priority": model1.priority,
},
]
)
async def test_change_chagpt_model_priority(
dbsession: Session,
rest_client: AsyncClient,
faker: Faker,
) -> None:
model1 = ChatGptModelFactory(priority=0)
model2 = ChatGptModelFactory(priority=1)
priority = faker.random_int(min=2, max=7)
response = await rest_client.put(url=f"/api/chatgpt/models/{model2.id}/priority", json={"priority": priority})
assert response.status_code == 202
upd_model1, upd_model2 = dbsession.query(ChatGpt).order_by(ChatGpt.priority).all()
assert model1.model == upd_model1.model
assert model2.model == upd_model2.model
updated_from_db_model = dbsession.get(ChatGpt, model2.id)
assert updated_from_db_model.priority == priority # type: ignore[union-attr]
async def test_reset_chatgpt_models_priority(
dbsession: Session,
rest_client: AsyncClient,
) -> None:
ChatGptModelFactory.create_batch(size=4)
ChatGptModelFactory(priority=42)
response = await rest_client.put(url="/api/chatgpt/models/priority/reset")
assert response.status_code == 202
models = dbsession.query(ChatGpt).all()
assert len(models) == 5
models = dbsession.query(ChatGpt).all()
for model in models:
assert model.priority == 0
async def test_create_new_chatgpt_model(
dbsession: Session,
rest_client: AsyncClient,
faker: Faker,
) -> None:
ChatGptModelFactory.create_batch(size=2)
ChatGptModelFactory(priority=42)
model_name = "new-gpt-model"
model_priority = faker.random_int(min=1, max=5)
models = dbsession.query(ChatGpt).all()
assert len(models) == 3
response = await rest_client.post(
url="/api/chatgpt/models",
json={
"model": model_name,
"priority": model_priority,
},
)
assert response.status_code == 201
models = dbsession.query(ChatGpt).all()
assert len(models) == 4
latest_model = dbsession.query(ChatGpt).order_by(desc(ChatGpt.id)).limit(1).one()
assert latest_model.model == model_name
assert latest_model.priority == model_priority
assert response.json() == {
"model": model_name,
"priority": model_priority,
}
async def test_add_existing_chatgpt_model(
dbsession: Session,
rest_client: AsyncClient,
faker: Faker,
) -> None:
ChatGptModelFactory.create_batch(size=2)
model = ChatGptModelFactory(priority=42)
model_name = model.model
model_priority = faker.random_int(min=1, max=5)
models = dbsession.query(ChatGpt).all()
assert len(models) == 3
response = await rest_client.post(
url="/api/chatgpt/models",
json={
"model": model_name,
"priority": model_priority,
},
)
assert response.status_code == 201
models = dbsession.query(ChatGpt).all()
assert len(models) == 3
async def test_delete_chatgpt_model(
dbsession: Session,
rest_client: AsyncClient,
) -> None:
ChatGptModelFactory.create_batch(size=2)
model = ChatGptModelFactory(priority=42)
models = dbsession.query(ChatGpt).all()
assert len(models) == 3
response = await rest_client.delete(url=f"/api/chatgpt/models/{model.id}")
assert response.status_code == 204
models = dbsession.query(ChatGpt).all()
assert len(models) == 2
assert model not in models

View File

@ -8,18 +8,20 @@ import telegram
from assertpy import assert_that
from faker import Faker
from httpx import AsyncClient, Response
from sqlalchemy.orm import Session
from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update
from constants import BotStagesEnum
from core.bot.app import BotApplication, BotQueue
from main import Application
from settings.config import AppSettings, settings
from settings.config import AppSettings
from tests.integration.bot.networking import MockedRequest
from tests.integration.factories.bot import (
BotCallBackQueryFactory,
BotMessageFactory,
BotUpdateFactory,
CallBackFactory,
ChatGptModelFactory,
)
from tests.integration.utils import mocked_ask_question_api
@ -62,6 +64,22 @@ async def test_bot_queue(
assert bot_queue.queue.empty()
async def test_no_update_message(
main_application: Application,
test_settings: AppSettings,
) -> None:
with mock.patch.object(
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
) as mocked_send_message:
bot_update = BotUpdateFactory(message=None)
await main_application.bot_app.application.process_update(
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
)
assert mocked_send_message.called is False
async def test_help_command(
main_application: Application,
test_settings: AppSettings,
@ -150,9 +168,12 @@ async def test_about_me_callback_action(
async def test_about_bot_callback_action(
dbsession: Session,
main_application: Application,
test_settings: AppSettings,
) -> None:
ChatGptModelFactory(priority=0)
model_with_highest_priority = ChatGptModelFactory(priority=1)
with mock.patch.object(telegram._message.Message, "reply_text") as mocked_reply_text:
bot_update = BotCallBackQueryFactory(
message=BotMessageFactory.create_instance(text="Список основных команд:"),
@ -164,7 +185,7 @@ async def test_about_bot_callback_action(
)
assert mocked_reply_text.call_args.args == (
f"Бот использует бесплатную модель {settings.GPT_MODEL} для ответов на вопросы. "
f"Бот использует бесплатную модель {model_with_highest_priority.model} для ответов на вопросы. "
f"\nПринимает запросы на разных языках.\n\nБот так же умеет переводить русские голосовые сообщения "
f"в текст. Просто пришлите голосовуху и получите поток сознания в виде текста, но без знаков препинания",
)
@ -189,9 +210,11 @@ async def test_website_callback_action(
async def test_ask_question_action(
dbsession: Session,
main_application: Application,
test_settings: AppSettings,
) -> None:
ChatGptModelFactory.create_batch(size=3)
with mock.patch.object(
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
) as mocked_send_message, mocked_ask_question_api(
@ -214,9 +237,11 @@ async def test_ask_question_action(
async def test_ask_question_action_not_success(
dbsession: Session,
main_application: Application,
test_settings: AppSettings,
) -> None:
ChatGptModelFactory.create_batch(size=3)
with mock.patch.object(
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
) as mocked_send_message, mocked_ask_question_api(
@ -238,9 +263,11 @@ async def test_ask_question_action_not_success(
async def test_ask_question_action_critical_error(
dbsession: Session,
main_application: Application,
test_settings: AppSettings,
) -> None:
ChatGptModelFactory.create_batch(size=3)
with mock.patch.object(
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
) as mocked_send_message, mocked_ask_question_api(
@ -260,19 +287,3 @@ async def test_ask_question_action_critical_error(
},
include=["text", "chat_id"],
)
async def test_no_update_message(
main_application: Application,
test_settings: AppSettings,
) -> None:
with mock.patch.object(
telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)
) as mocked_send_message:
bot_update = BotUpdateFactory(message=None)
await main_application.bot_app.application.process_update(
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
)
assert mocked_send_message.called is False

View File

@ -1,21 +1,20 @@
"""This module contains subclasses of classes from the python-telegram-bot library that
modify behavior of the respective parent classes in order to make them easier to use in the
pytest framework. A common change is to allow monkeypatching of the class members by not
enforcing slots in the subclasses."""
import asyncio
from asyncio import AbstractEventLoop
from datetime import tzinfo
from typing import Any, AsyncGenerator
from typing import Any, AsyncGenerator, Generator
import pytest
import pytest_asyncio
from httpx import AsyncClient
from pytest_asyncio.plugin import SubRequest
from sqlalchemy import Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker
from telegram import Bot, User
from telegram.ext import Application, ApplicationBuilder, Defaults, ExtBot
from telegram.ext import Application, ApplicationBuilder, ExtBot
from core.bot.app import BotApplication
from core.bot.handlers import bot_event_handlers
from infra.database.db_adapter import Database
from infra.database.meta import meta
from main import Application as AppApplication
from settings.config import AppSettings, get_settings
from tests.integration.bot.networking import NonchalantHttpxRequest
@ -27,6 +26,55 @@ def test_settings() -> AppSettings:
return get_settings()
@pytest.fixture(scope="session")
def engine(test_settings: AppSettings) -> Generator[Engine, None, None]:
"""
Create engine and databases.
:yield: new engine.
"""
engine: Engine = create_engine(
str(test_settings.sync_db_url),
echo=test_settings.DB_ECHO,
isolation_level="AUTOCOMMIT",
)
try:
yield engine
finally:
engine.dispose()
@pytest_asyncio.fixture(scope="function")
def dbsession(engine: Engine) -> Generator[Session, None, None]:
"""
Get session to database.
Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
after the test completes.
:param engine: current engine.
:yields: async session.
"""
connection = engine.connect()
trans = connection.begin()
session_maker = sessionmaker(
connection,
expire_on_commit=False,
)
session = session_maker()
try:
meta.create_all(engine)
yield session
finally:
meta.drop_all(engine)
session.close()
trans.rollback()
connection.close()
class PytestExtBot(ExtBot): # type: ignore
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
@ -101,7 +149,7 @@ def _get_bot_user(token: str) -> User:
# Redefine the event_loop fixture to have a session scope. Otherwise `bot` fixture can't be
# session. See https://github.com/pytest-dev/pytest-asyncio/issues/68 for more details.
@pytest.fixture(scope="session")
@pytest.fixture(scope="session", autouse=True)
def event_loop(request: SubRequest) -> AbstractEventLoop:
"""
Пересоздаем луп для изоляции тестов. В основном нужно для запуска юнит тестов
@ -137,92 +185,6 @@ async def bot(bot_info: dict[str, Any], bot_application: Any) -> AsyncGenerator[
yield _bot
@pytest.fixture()
def one_time_bot(bot_info: dict[str, Any], bot_application: Any) -> PytestExtBot:
"""A function scoped bot since the session bot would shutdown when `async with app` finishes"""
bot = make_bot(bot_info)
bot.application = bot_application
return bot
@pytest_asyncio.fixture(scope="session")
async def cdc_bot(bot_info: dict[str, Any], bot_application: Any) -> AsyncGenerator[PytestExtBot, None]:
"""Makes an ExtBot instance with the given bot_info that uses arbitrary callback_data"""
async with make_bot(bot_info, arbitrary_callback_data=True) as _bot:
_bot.application = bot_application
yield _bot
@pytest_asyncio.fixture(scope="session")
async def raw_bot(bot_info: dict[str, Any], bot_application: Any) -> AsyncGenerator[PytestBot, None]:
"""Makes an regular Bot instance with the given bot_info"""
async with PytestBot(
bot_info["token"],
private_key=None,
request=NonchalantHttpxRequest(8),
get_updates_request=NonchalantHttpxRequest(1),
) as _bot:
_bot.application = bot_application
yield _bot
# Here we store the default bots so that we don't have to create them again and again.
# They are initialized but not shutdown on pytest_sessionfinish because it is causing
# problems with the event loop (Event loop is closed).
_default_bots: dict[Defaults, PytestExtBot] = {}
@pytest_asyncio.fixture(scope="session")
async def default_bot(request: SubRequest, bot_info: dict[str, Any]) -> PytestExtBot:
param = request.param if hasattr(request, "param") else {}
defaults = Defaults(**param)
# If the bot is already created, return it. Else make a new one.
default_bot = _default_bots.get(defaults)
if default_bot is None:
default_bot = make_bot(bot_info, defaults=defaults)
await default_bot.initialize()
_default_bots[defaults] = default_bot # Defaults object is hashable
return default_bot
@pytest_asyncio.fixture(scope="session")
async def tz_bot(timezone: tzinfo, bot_info: dict[str, Any]) -> PytestExtBot:
defaults = Defaults(tzinfo=timezone)
try: # If the bot is already created, return it. Saves time since get_me is not called again.
return _default_bots[defaults]
except KeyError:
default_bot = make_bot(bot_info, defaults=defaults)
await default_bot.initialize()
_default_bots[defaults] = default_bot
return default_bot
@pytest.fixture(scope="session")
def chat_id(bot_info: dict[str, Any]) -> int:
return bot_info["chat_id"]
@pytest.fixture(scope="session")
def super_group_id(bot_info: dict[str, Any]) -> int:
return bot_info["super_group_id"]
@pytest.fixture(scope="session")
def forum_group_id(bot_info: dict[str, Any]) -> int:
return int(bot_info["forum_group_id"])
@pytest.fixture(scope="session")
def channel_id(bot_info: dict[str, Any]) -> int:
return bot_info["channel_id"]
@pytest.fixture(scope="session")
def provider_token(bot_info: dict[str, Any]) -> str:
return bot_info["payment_provider_token"]
@pytest_asyncio.fixture(scope="session")
async def main_application(
bot_application: PytestApplication, test_settings: AppSettings
@ -235,7 +197,10 @@ async def main_application(
bot_app.application.bot = make_bot(BotInfoFactory())
bot_app.application.bot._bot_user = BotUserFactory()
fast_api_app = AppApplication(settings=test_settings, bot_app=bot_app)
database = Database(test_settings)
await database.create_database()
yield fast_api_app
await database.drop_database()
@pytest_asyncio.fixture()

View File

@ -1,17 +1,26 @@
import string
import time
from typing import Any
from typing import Any, NamedTuple
import factory
import factory.fuzzy
from faker import Faker
from constants import BotStagesEnum
from tests.integration.factories.models import Chat, User
from core.bot.models.chat_gpt import ChatGpt
from tests.integration.factories.utils import BaseModelFactory
faker = Faker("ru_RU")
class User(NamedTuple):
id: int
is_bot: bool
first_name: str | None
last_name: str | None
username: str | None
language_code: str
class BotUserFactory(factory.Factory):
id = factory.Sequence(lambda n: 1000 + n)
is_bot = False
@ -24,6 +33,14 @@ class BotUserFactory(factory.Factory):
model = User
class Chat(NamedTuple):
id: int
first_name: str | None
last_name: str | None
username: str
type: str
class BotChatFactory(factory.Factory):
id = factory.Sequence(lambda n: 1 + n)
first_name = factory.Faker("first_name")
@ -35,6 +52,15 @@ class BotChatFactory(factory.Factory):
model = Chat
class ChatGptModelFactory(BaseModelFactory):
id = factory.Sequence(lambda n: n + 1)
model = factory.Faker("word")
priority = factory.Faker("random_int", min=0, max=42)
class Meta:
model = ChatGpt
class BotInfoFactory(factory.DictFactory):
token = factory.Faker(
"bothify", text="#########:??????????????????????????-#????????#?", letters=string.ascii_letters
@ -67,6 +93,7 @@ class BotMessageFactory(factory.DictFactory):
date = time.time()
text = factory.Faker("text")
entities = factory.LazyFunction(lambda: [BotEntitleFactory()])
voice = None
@classmethod
def create_instance(cls, **kwargs: Any) -> dict[str, Any]:
@ -94,3 +121,13 @@ class CallBackFactory(factory.DictFactory):
class BotCallBackQueryFactory(factory.DictFactory):
update_id = factory.Faker("random_int", min=10**8, max=10**9 - 1)
callback_query = factory.LazyFunction(lambda: BotMessageFactory.create_instance())
class BotVoiceFactory(factory.DictFactory):
duration = factory.Faker("random_int", min=1, max=700)
file_id = factory.Faker(
"lexify", text="????????????????????????????????????????????????????????????????????????", locale="en_US"
)
file_size = factory.Faker("random_int")
file_unique_id = factory.Faker("lexify", text="???????????????", locale="en_US")
mime_type = "audio/ogg"

View File

@ -1,18 +0,0 @@
from typing import NamedTuple
class User(NamedTuple):
id: int
is_bot: bool
first_name: str | None
last_name: str | None
username: str | None
language_code: str
class Chat(NamedTuple):
id: int
first_name: str | None
last_name: str | None
username: str
type: str

View File

@ -0,0 +1,13 @@
import factory
from infra.database.db_adapter import Database
from settings.config import settings
database = Database(settings)
class BaseModelFactory(factory.alchemy.SQLAlchemyModelFactory):
class Meta:
abstract = True
sqlalchemy_session_persistence = "commit"
sqlalchemy_session = database.get_sync_db_session()

View File

@ -2,9 +2,12 @@ import httpx
import pytest
from faker import Faker
from httpx import AsyncClient, Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from api.exceptions import BaseAPIException
from settings.config import AppSettings
from tests.integration.factories.bot import ChatGptModelFactory
from tests.integration.utils import mocked_ask_question_api
pytestmark = [
@ -22,9 +25,11 @@ async def test_bot_updates(rest_client: AsyncClient) -> None:
async def test_bot_healthcheck_is_ok(
dbsession: Session,
rest_client: AsyncClient,
test_settings: AppSettings,
) -> None:
ChatGptModelFactory.create_batch(size=3)
with mocked_ask_question_api(
host=test_settings.GPT_BASE_HOST,
return_value=Response(status_code=httpx.codes.OK, text="Привет! Как я могу помочь вам сегодня?"),
@ -35,8 +40,9 @@ async def test_bot_healthcheck_is_ok(
@pytest.mark.parametrize("text", ["Invalid request model", "return unexpected http status code"])
async def test_bot_healthcheck_invalid_request_model(
rest_client: AsyncClient, test_settings: AppSettings, text: str
dbsession: AsyncSession, rest_client: AsyncClient, test_settings: AppSettings, text: str
) -> None:
ChatGptModelFactory.create_batch(size=3)
with mocked_ask_question_api(
host=test_settings.GPT_BASE_HOST,
return_value=Response(status_code=httpx.codes.OK, text=text),
@ -46,9 +52,11 @@ async def test_bot_healthcheck_invalid_request_model(
async def test_bot_healthcheck_not_ok(
dbsession: Session,
rest_client: AsyncClient,
test_settings: AppSettings,
) -> None:
ChatGptModelFactory.create_batch(size=3)
with mocked_ask_question_api(
host=test_settings.GPT_BASE_HOST,
side_effect=BaseAPIException(),

View File

@ -24,7 +24,7 @@ services:
env_file:
- bot_microservice/settings/.env
volumes:
- ./bot_microservice/settings/.env:/app/settings/.env:ro
- ./bot_microservice/settings:/app/settings:ro
- /etc/localtime:/etc/localtime:ro
networks:
chat-gpt-network: