add ban user action (#77)

* add ban user action

* fix tests

* send message through update.effective_message
This commit is contained in:
Dmitry Afanasyev 2024-01-07 16:08:23 +03:00 committed by GitHub
parent 1e79c981c2
commit 8266342214
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 89 additions and 12 deletions

View File

@ -26,6 +26,7 @@ class BotCommands(StrEnum):
help = "help" help = "help"
bug_report = "bug_report" bug_report = "bug_report"
website = "website" website = "website"
developer = "developer"
class BotEntryPoints(StrEnum): class BotEntryPoints(StrEnum):

View File

@ -28,6 +28,7 @@ class User(Base):
backref="user", backref="user",
lazy="selectin", lazy="selectin",
uselist=False, uselist=False,
cascade="delete",
) )
@property @property

View File

@ -1,6 +1,13 @@
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps
from typing import Any
from loguru import logger
from telegram import Update
from telegram.ext import ContextTypes
from constants import BotCommands
from core.auth.dto import UserIsBannedDTO from core.auth.dto import UserIsBannedDTO
from core.auth.models.users import User from core.auth.models.users import User
from core.auth.repository import UserRepository from core.auth.repository import UserRepository
@ -54,3 +61,31 @@ class UserService:
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO: async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
return await self.repository.check_user_is_banned(user_id) return await self.repository.check_user_is_banned(user_id)
def check_user_is_banned(func: Any) -> Any:
@wraps(func)
async def wrapper(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if not update.effective_message:
logger.error('no effective message', update=update, context=context)
return
if not update.effective_user:
logger.error('no effective user', update=update, context=context)
await update.effective_message.reply_text(
"Бот не смог определить пользователя. :(\nОб ошибке уже сообщено."
)
return
user_service = UserService.build() # noqa: NEW100
user_status = await user_service.check_user_is_banned(update.effective_user.id)
if user_status.is_banned:
await update.effective_message.reply_text(
text=f"You have banned for reason: *{user_status.ban_reason}*."
f"\nPlease contact the /{BotCommands.developer}",
parse_mode="Markdown",
)
else:
await func(update, context)
return wrapper

View File

@ -7,6 +7,7 @@ from telegram import InlineKeyboardMarkup, Update
from telegram.ext import ContextTypes from telegram.ext import ContextTypes
from constants import BotCommands, BotEntryPoints from constants import BotCommands, BotEntryPoints
from core.auth.services import check_user_is_banned
from core.bot.app import get_bot from core.bot.app import get_bot
from core.bot.keyboards import main_keyboard from core.bot.keyboards import main_keyboard
from core.bot.services import ChatGptService, SpeechToTextService from core.bot.services import ChatGptService, SpeechToTextService
@ -42,6 +43,7 @@ async def about_bot(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
) )
@check_user_is_banned
async def website(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def website(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if not update.effective_message: if not update.effective_message:
return return
@ -49,6 +51,7 @@ async def website(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
await update.effective_message.reply_text(f"Веб версия: {website}") await update.effective_message.reply_text(f"Веб версия: {website}")
@check_user_is_banned
async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Send a message when the command /help is issued.""" """Send a message when the command /help is issued."""
@ -63,6 +66,7 @@ async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No
) )
@check_user_is_banned
async def bug_report(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def bug_report(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Send a message when the command /bug-report is issued.""" """Send a message when the command /bug-report is issued."""
@ -88,15 +92,11 @@ async def github(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
) )
@check_user_is_banned
async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
if not update.message: if not update.message:
return return
if not update.effective_user:
logger.error('no effective user', update=update, context=context)
await update.message.reply_text("Бот не смог определить пользователя. :(\nОб ошибке уже сообщено.")
return
await update.message.reply_text( await update.message.reply_text(
f"Ответ в среднем занимает 10-15 секунд.\n" f"Ответ в среднем занимает 10-15 секунд.\n"
f"- Список команд: /{BotCommands.help}\n" f"- Список команд: /{BotCommands.help}\n"
@ -108,10 +108,10 @@ async def ask_question(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No
answer, user = await asyncio.gather( answer, user = await asyncio.gather(
chatgpt_service.request_to_chatgpt(question=update.message.text), chatgpt_service.request_to_chatgpt(question=update.message.text),
chatgpt_service.get_or_create_bot_user( chatgpt_service.get_or_create_bot_user(
user_id=update.effective_user.id, user_id=update.effective_user.id, # type: ignore[union-attr]
username=update.effective_user.username, username=update.effective_user.username, # type: ignore[union-attr]
first_name=update.effective_user.first_name, first_name=update.effective_user.first_name, # type: ignore[union-attr]
last_name=update.effective_user.last_name, last_name=update.effective_user.last_name, # type: ignore[union-attr]
), ),
) )
await asyncio.gather(update.message.reply_text(answer), chatgpt_service.update_bot_user_message_count(user.id)) await asyncio.gather(update.message.reply_text(answer), chatgpt_service.update_bot_user_message_count(user.id))

View File

@ -36,6 +36,7 @@ bot_event_handlers = BotEventHandlers()
bot_event_handlers.add_handler(CommandHandler(BotCommands.help, help_command)) bot_event_handlers.add_handler(CommandHandler(BotCommands.help, help_command))
bot_event_handlers.add_handler(CommandHandler(BotCommands.website, website)) bot_event_handlers.add_handler(CommandHandler(BotCommands.website, website))
bot_event_handlers.add_handler(CommandHandler(BotCommands.bug_report, bug_report)) bot_event_handlers.add_handler(CommandHandler(BotCommands.bug_report, bug_report))
bot_event_handlers.add_handler(CommandHandler(BotCommands.developer, about_me))
bot_event_handlers.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, ask_question)) bot_event_handlers.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, ask_question))
bot_event_handlers.add_handler(MessageHandler(filters.VOICE | filters.AUDIO, voice_recognize)) bot_event_handlers.add_handler(MessageHandler(filters.VOICE | filters.AUDIO, voice_recognize))

View File

@ -12,6 +12,8 @@ if TYPE_CHECKING:
class ChatGptAdmin(ModelView, model=ChatGptModels): class ChatGptAdmin(ModelView, model=ChatGptModels):
name = "ChatGPT model"
name_plural = "ChatGPT models"
column_list = [ChatGptModels.id, ChatGptModels.model, ChatGptModels.priority] column_list = [ChatGptModels.id, ChatGptModels.model, ChatGptModels.priority]
column_sortable_list = [ChatGptModels.priority] column_sortable_list = [ChatGptModels.priority]
column_default_sort = ("priority", True) column_default_sort = ("priority", True)
@ -22,6 +24,8 @@ class ChatGptAdmin(ModelView, model=ChatGptModels):
class UserAdmin(ModelView, model=User): class UserAdmin(ModelView, model=User):
name = "User"
name_plural = "Users"
column_list = [ column_list = [
User.id, User.id,
User.username, User.username,

View File

@ -57,7 +57,7 @@ def engine(test_settings: AppSettings) -> Generator[Engine, None, None]:
engine.dispose() engine.dispose()
@pytest.fixture() @pytest.fixture(autouse=True)
def dbsession(engine: Engine) -> Generator[Session, None, None]: def dbsession(engine: Engine) -> Generator[Session, None, None]:
""" """
Get session to database. Get session to database.
@ -69,7 +69,6 @@ def dbsession(engine: Engine) -> Generator[Session, None, None]:
:yields: async session. :yields: async session.
""" """
connection = engine.connect() connection = engine.connect()
trans = connection.begin()
session_maker = sessionmaker( session_maker = sessionmaker(
connection, connection,
@ -83,7 +82,6 @@ def dbsession(engine: Engine) -> Generator[Session, None, None]:
finally: finally:
meta.drop_all(engine) meta.drop_all(engine)
session.close() session.close()
trans.rollback()
connection.close() connection.close()

View File

@ -261,6 +261,23 @@ async def test_bug_report_action(
) )
async def test_get_developer_action(
main_application: Application,
test_settings: AppSettings,
) -> None:
with (
mock.patch.object(telegram._message.Message, "reply_text") as mocked_reply_text,
mock.patch.object(telegram._bot.Bot, "send_message", return_value=lambda *args, **kwargs: (args, kwargs)),
):
bot_update = BotUpdateFactory(message=BotMessageFactory.create_instance(text="/developer"))
await main_application.bot_app.application.process_update(
update=Update.de_json(data=bot_update, bot=main_application.bot_app.bot)
)
assert mocked_reply_text.call_args.args == ("Автор бота: *Дмитрий Афанасьев*\n\nTg nickname: *Balshtg*",)
async def test_ask_question_action( async def test_ask_question_action(
dbsession: Session, dbsession: Session,
main_application: Application, main_application: Application,

View File

@ -0,0 +1,20 @@
import factory
from core.auth.models.users import User
from tests.integration.factories.utils import BaseModelFactory
class UserFactory(BaseModelFactory):
id = factory.Sequence(lambda n: n + 1)
email = factory.Faker("email")
username = factory.Faker("user_name", locale="en_EN")
first_name = factory.Faker("word")
last_name = factory.Faker("word")
ban_reason = factory.Faker("text", max_nb_chars=100)
hashed_password = factory.Faker("word")
is_active = True
is_superuser = False
created_at = factory.Faker("past_datetime")
class Meta:
model = User