diff --git a/bot_microservice/api/bot/constants.py b/bot_microservice/api/bot/constants.py new file mode 100644 index 0000000..7d05c76 --- /dev/null +++ b/bot_microservice/api/bot/constants.py @@ -0,0 +1 @@ +BOT_ACCESS_API_HEADER = "BOT-API-KEY" diff --git a/bot_microservice/api/bot/controllers.py b/bot_microservice/api/bot/controllers.py index 6d7a4b8..dab2c8e 100644 --- a/bot_microservice/api/bot/controllers.py +++ b/bot_microservice/api/bot/controllers.py @@ -3,13 +3,19 @@ from starlette import status from starlette.responses import JSONResponse, Response from telegram import Update -from api.bot.deps import get_bot_queue, get_chatgpt_service, get_update_from_request +from api.bot.deps import ( + get_access_to_bot_api_or_403, + get_bot_queue, + get_chatgpt_service, + get_update_from_request, +) from api.bot.serializers import ( ChatGptModelSerializer, ChatGptModelsPrioritySerializer, GETChatGptModelsSerializer, LightChatGptModel, ) +from api.exceptions import PermissionMissingResponse from core.bot.app import BotQueue from core.bot.services import ChatGptService from settings.config import settings @@ -53,6 +59,10 @@ async def models_list( @router.put( "/chatgpt/models/{model_id}/priority", name="bot:change_model_priority", + dependencies=[Depends(get_access_to_bot_api_or_403)], + responses={ + status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse}, + }, response_class=Response, status_code=status.HTTP_202_ACCEPTED, summary="change gpt model priority", @@ -69,6 +79,10 @@ async def change_model_priority( @router.put( "/chatgpt/models/priority/reset", name="bot:reset_models_priority", + dependencies=[Depends(get_access_to_bot_api_or_403)], + responses={ + status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse}, + }, response_class=Response, status_code=status.HTTP_202_ACCEPTED, summary="reset all model priority to default", @@ -83,6 +97,11 @@ async def reset_models_priority( @router.post( "/chatgpt/models", name="bot:add_new_model", + dependencies=[Depends(get_access_to_bot_api_or_403)], + responses={ + status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse}, + status.HTTP_201_CREATED: {"model": ChatGptModelSerializer}, + }, response_model=ChatGptModelSerializer, status_code=status.HTTP_201_CREATED, summary="add new model", @@ -100,6 +119,10 @@ async def add_new_model( @router.delete( "/chatgpt/models/{model_id}", name="bot:delete_gpt_model", + dependencies=[Depends(get_access_to_bot_api_or_403)], + responses={ + status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse}, + }, response_class=Response, status_code=status.HTTP_204_NO_CONTENT, summary="delete gpt model", diff --git a/bot_microservice/api/bot/deps.py b/bot_microservice/api/bot/deps.py index 0d72e62..b3ab313 100644 --- a/bot_microservice/api/bot/deps.py +++ b/bot_microservice/api/bot/deps.py @@ -1,15 +1,17 @@ -from fastapi import Depends +from fastapi import Depends, Header, HTTPException +from starlette import status from starlette.requests import Request from telegram import Update from api.auth.deps import get_user_service +from api.bot.constants import BOT_ACCESS_API_HEADER from api.deps import get_database from core.auth.services import UserService from core.bot.app import BotApplication, BotQueue from core.bot.repository import ChatGPTRepository from core.bot.services import ChatGptService from infra.database.db_adapter import Database -from settings.config import AppSettings, get_settings +from settings.config import AppSettings, get_settings, settings def get_bot_app(request: Request) -> BotApplication: @@ -40,3 +42,13 @@ def get_chatgpt_service( user_service: UserService = Depends(get_user_service), ) -> ChatGptService: return ChatGptService(repository=chatgpt_repository, user_service=user_service) + + +async def get_access_to_bot_api_or_403( + bot_api_key: str | None = Header(None, alias=BOT_ACCESS_API_HEADER, description="Ключ доступа до API бота"), + user_service: UserService = Depends(get_user_service), +) -> None: + access_token = await user_service.get_user_access_token_by_username(settings.SUPERUSER) + + if not access_token or access_token != bot_api_key: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate api header") diff --git a/bot_microservice/api/exceptions.py b/bot_microservice/api/exceptions.py index 114ef03..20daa54 100644 --- a/bot_microservice/api/exceptions.py +++ b/bot_microservice/api/exceptions.py @@ -1,11 +1,39 @@ +from typing import Any + from fastapi.responses import ORJSONResponse +from starlette import status from starlette.requests import Request from api.base_schemas import BaseError, BaseResponse class BaseAPIException(Exception): - pass + _content_type: str = "application/json" + model: type[BaseResponse] = BaseResponse + status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR + title: str | None = None + type: str | None = None + detail: str | None = None + instance: str | None = None + headers: dict[str, str] | None = None + + def __init__(self, **ctx: Any) -> None: + self.__dict__ = ctx + + @classmethod + def example(cls) -> dict[str, Any] | None: + if isinstance(cls.model.Config.schema_extra, dict): # type: ignore[attr-defined] + return cls.model.Config.schema_extra.get("example") # type: ignore[attr-defined] + return None + + @classmethod + def response(cls) -> dict[str, Any]: + return { + "model": cls.model, + "content": { + cls._content_type: cls.model.Config.schema_extra, # type: ignore[attr-defined] + }, + } class InternalServerError(BaseError): @@ -28,6 +56,31 @@ class InternalServerErrorResponse(BaseResponse): } +class PermissionMissing(BaseError): + pass + + +class PermissionMissingResponse(BaseResponse): + error: PermissionMissing + + class Config: + json_schema_extra = { + "example": { + "status": 403, + "error": { + "type": "PermissionMissing", + "title": "Permission required for this endpoint is missing", + }, + }, + } + + +class PermissionMissingError(BaseAPIException): + model = PermissionMissingResponse + status_code = status.HTTP_403_FORBIDDEN + title: str = "Permission required for this endpoint is missing" + + async def internal_server_error_handler(_request: Request, _exception: Exception) -> ORJSONResponse: error = InternalServerError(title="Something went wrong!", type="InternalServerError") response = InternalServerErrorResponse(status=500, error=error).model_dump(exclude_unset=True) diff --git a/bot_microservice/core/auth/models/users.py b/bot_microservice/core/auth/models/users.py index daf79c9..7e27f55 100644 --- a/bot_microservice/core/auth/models/users.py +++ b/bot_microservice/core/auth/models/users.py @@ -1,3 +1,4 @@ +import uuid from datetime import datetime from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String @@ -26,7 +27,7 @@ class User(Base): "UserQuestionCount", primaryjoin="UserQuestionCount.user_id == User.id", backref="user", - lazy="selectin", + lazy="noload", uselist=False, cascade="delete", ) @@ -68,11 +69,25 @@ class AccessToken(Base): __tablename__ = "access_token" # type: ignore[assignment] user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False) - token: Mapped[str] = mapped_column(String(length=42), primary_key=True) + token: Mapped[str] = mapped_column(String(length=42), primary_key=True, default=lambda: str(uuid.uuid4())) created_at: Mapped[datetime] = mapped_column( TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now ) + user: Mapped["User"] = relationship( + "User", + backref="access_token", + lazy="selectin", + uselist=False, + cascade="expunge", + ) + + @property + def username(self) -> str: + if self.user: + return self.user.username + return "" + class UserQuestionCount(Base): __tablename__ = "user_question_count" # type: ignore[assignment] diff --git a/bot_microservice/core/auth/repository.py b/bot_microservice/core/auth/repository.py index db479ae..967d5a0 100644 --- a/bot_microservice/core/auth/repository.py +++ b/bot_microservice/core/auth/repository.py @@ -5,7 +5,7 @@ from sqlalchemy.dialects.sqlite import insert from sqlalchemy.orm import load_only from core.auth.dto import UserIsBannedDTO -from core.auth.models.users import User, UserQuestionCount +from core.auth.models.users import AccessToken, User, UserQuestionCount from infra.database.db_adapter import Database @@ -76,3 +76,10 @@ class UserRepository: async with self.db.session() as session: await session.execute(query) + + async def get_user_access_token(self, username: str | None) -> str | None: + query = select(AccessToken.token).join(AccessToken.user).where(User.username == username) + + async with self.db.session() as session: + result = await session.execute(query) + return result.scalar() diff --git a/bot_microservice/core/auth/services.py b/bot_microservice/core/auth/services.py index 929bbbc..38212f1 100644 --- a/bot_microservice/core/auth/services.py +++ b/bot_microservice/core/auth/services.py @@ -62,6 +62,9 @@ class UserService: async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO: return await self.repository.check_user_is_banned(user_id) + async def get_user_access_token_by_username(self, username: str | None) -> str | None: + return await self.repository.get_user_access_token(username) + def check_user_is_banned(func: Any) -> Any: @wraps(func) diff --git a/bot_microservice/infra/admin.py b/bot_microservice/infra/admin.py index 8174342..3859f08 100644 --- a/bot_microservice/infra/admin.py +++ b/bot_microservice/infra/admin.py @@ -1,8 +1,11 @@ from typing import TYPE_CHECKING from sqladmin import Admin, ModelView +from sqlalchemy import Select, desc, select +from sqlalchemy.orm import contains_eager, load_only +from starlette.requests import Request -from core.auth.models.users import User +from core.auth.models.users import AccessToken, User, UserQuestionCount from core.bot.models.chatgpt import ChatGptModels from core.utils import build_uri from settings.config import settings @@ -36,10 +39,34 @@ class UserAdmin(ModelView, model=User): "question_count", User.created_at, ] - column_sortable_list = [User.created_at] + column_default_sort = ("created_at", True) form_widget_args = {"created_at": {"readonly": True}} + def list_query(self, request: Request) -> Select[tuple[User]]: + return ( + select(User) + .options( + load_only( + User.id, + User.username, + User.first_name, + User.last_name, + User.is_active, + User.created_at, + ) + ) + .outerjoin(User.user_question_count) + .options(contains_eager(User.user_question_count).options(load_only(UserQuestionCount.question_count))) + ).order_by(desc(UserQuestionCount.question_count)) + + +class AccessTokenAdmin(ModelView, model=AccessToken): + name = "API access token" + name_plural = "API access tokens" + column_list = [AccessToken.user_id, "username", AccessToken.token, AccessToken.created_at] + form_widget_args = {"created_at": {"readonly": True}} + def create_admin(application: "Application") -> Admin: admin = Admin( @@ -51,4 +78,5 @@ def create_admin(application: "Application") -> Admin: ) admin.add_view(ChatGptAdmin) admin.add_view(UserAdmin) + admin.add_view(AccessTokenAdmin) return admin diff --git a/bot_microservice/infra/database/migrations/versions/0002_create_auth_tables.py b/bot_microservice/infra/database/migrations/versions/0002_create_auth_tables.py index 4507585..934c981 100644 --- a/bot_microservice/infra/database/migrations/versions/0002_create_auth_tables.py +++ b/bot_microservice/infra/database/migrations/versions/0002_create_auth_tables.py @@ -10,9 +10,8 @@ from datetime import datetime import sqlalchemy as sa from alembic import op from sqlalchemy import TIMESTAMP -from sqlalchemy.dialects.sqlite import insert -from core.auth.models.users import User +from core.auth.models.users import AccessToken, User from core.auth.utils import create_password_hash from infra.database.deps import get_sync_session from settings.config import settings @@ -58,8 +57,14 @@ def upgrade() -> None: return with get_sync_session() as session: hashed_password = create_password_hash(password.get_secret_value()) - query = insert(User).values({"username": username, "hashed_password": hashed_password}) - session.execute(query) + user = User(username=username, hashed_password=hashed_password) + session.add(user) + session.flush() + session.refresh(user) + + access_token = AccessToken(user_id=user.id) + session.add(access_token) + session.commit() diff --git a/bot_microservice/settings/.env.runtests b/bot_microservice/settings/.env.runtests index 9830aa1..e43d998 100644 --- a/bot_microservice/settings/.env.runtests +++ b/bot_microservice/settings/.env.runtests @@ -11,6 +11,8 @@ TELEGRAM_API_TOKEN="123456789:AABBCCDDEEFFaabbccddeeff-1234567890" # set to true to start with webhook. Else bot will start on polling method START_WITH_WEBHOOK="false" +SUPERUSER="Superuser" + # ==== domain settings ==== DOMAIN="http://localhost" URL_PREFIX= diff --git a/bot_microservice/tests/integration/bot/test_bot_api.py b/bot_microservice/tests/integration/bot/test_bot_api.py index b6e6541..4516f99 100644 --- a/bot_microservice/tests/integration/bot/test_bot_api.py +++ b/bot_microservice/tests/integration/bot/test_bot_api.py @@ -6,7 +6,9 @@ from sqlalchemy import desc from sqlalchemy.orm import Session from core.bot.models.chatgpt import ChatGptModels +from settings.config import AppSettings from tests.integration.factories.bot import ChatGptModelFactory +from tests.integration.factories.user import AccessTokenFactory, UserFactory pytestmark = [ pytest.mark.asyncio, @@ -51,11 +53,18 @@ async def test_change_chatgpt_model_priority( dbsession: Session, rest_client: AsyncClient, faker: Faker, + test_settings: AppSettings, ) -> None: model1 = ChatGptModelFactory(priority=0) model2 = ChatGptModelFactory(priority=1) priority = faker.random_int(min=2, max=7) - response = await rest_client.put(url=f"/api/chatgpt/models/{model2.id}/priority", json={"priority": priority}) + user = UserFactory(username=test_settings.SUPERUSER) + access_token = AccessTokenFactory(user_id=user.id) + response = await rest_client.put( + url=f"/api/chatgpt/models/{model2.id}/priority", + json={"priority": priority}, + headers={"BOT-API-KEY": access_token.token}, + ) assert response.status_code == 202 upd_model1, upd_model2 = dbsession.query(ChatGptModels).order_by(ChatGptModels.priority).all() @@ -69,11 +78,18 @@ async def test_change_chatgpt_model_priority( async def test_reset_chatgpt_models_priority( dbsession: Session, rest_client: AsyncClient, + test_settings: AppSettings, ) -> None: ChatGptModelFactory.create_batch(size=4) ChatGptModelFactory(priority=42) - response = await rest_client.put(url="/api/chatgpt/models/priority/reset") + user = UserFactory(username=test_settings.SUPERUSER) + access_token = AccessTokenFactory(user_id=user.id) + + response = await rest_client.put( + url="/api/chatgpt/models/priority/reset", + headers={"BOT-API-KEY": access_token.token}, + ) assert response.status_code == 202 models = dbsession.query(ChatGptModels).all() @@ -89,10 +105,14 @@ async def test_create_new_chatgpt_model( dbsession: Session, rest_client: AsyncClient, faker: Faker, + test_settings: AppSettings, ) -> None: ChatGptModelFactory.create_batch(size=2) ChatGptModelFactory(priority=42) + user = UserFactory(username=test_settings.SUPERUSER) + access_token = AccessTokenFactory(user_id=user.id) + model_name = "new-gpt-model" model_priority = faker.random_int(min=1, max=5) @@ -105,6 +125,7 @@ async def test_create_new_chatgpt_model( "model": model_name, "priority": model_priority, }, + headers={"BOT-API-KEY": access_token.token}, ) assert response.status_code == 201 @@ -125,9 +146,12 @@ async def test_add_existing_chatgpt_model( dbsession: Session, rest_client: AsyncClient, faker: Faker, + test_settings: AppSettings, ) -> None: ChatGptModelFactory.create_batch(size=2) model = ChatGptModelFactory(priority=42) + user = UserFactory(username=test_settings.SUPERUSER) + access_token = AccessTokenFactory(user_id=user.id) model_name = model.model model_priority = faker.random_int(min=1, max=5) @@ -141,6 +165,7 @@ async def test_add_existing_chatgpt_model( "model": model_name, "priority": model_priority, }, + headers={"BOT-API-KEY": access_token.token}, ) assert response.status_code == 201 @@ -151,14 +176,21 @@ async def test_add_existing_chatgpt_model( async def test_delete_chatgpt_model( dbsession: Session, rest_client: AsyncClient, + test_settings: AppSettings, ) -> None: ChatGptModelFactory.create_batch(size=2) model = ChatGptModelFactory(priority=42) + user = UserFactory(username=test_settings.SUPERUSER) + access_token = AccessTokenFactory(user_id=user.id) + models = dbsession.query(ChatGptModels).all() assert len(models) == 3 - response = await rest_client.delete(url=f"/api/chatgpt/models/{model.id}") + response = await rest_client.delete( + url=f"/api/chatgpt/models/{model.id}", + headers={"BOT-API-KEY": access_token.token}, + ) assert response.status_code == 204 models = dbsession.query(ChatGptModels).all() diff --git a/bot_microservice/tests/integration/factories/user.py b/bot_microservice/tests/integration/factories/user.py index 97eab2b..3b122c7 100644 --- a/bot_microservice/tests/integration/factories/user.py +++ b/bot_microservice/tests/integration/factories/user.py @@ -1,13 +1,15 @@ +import uuid + import factory -from core.auth.models.users import User +from core.auth.models.users import AccessToken, User from tests.integration.factories.utils import BaseModelFactory class UserFactory(BaseModelFactory): id = factory.Sequence(lambda n: n + 1) email = factory.Faker("email") - username = factory.Faker("user_name", locale="en_EN") + username = factory.Faker("user_name", locale="en") first_name = factory.Faker("word") last_name = factory.Faker("word") ban_reason = factory.Faker("text", max_nb_chars=100) @@ -18,3 +20,12 @@ class UserFactory(BaseModelFactory): class Meta: model = User + + +class AccessTokenFactory(BaseModelFactory): + user_id = factory.Sequence(lambda n: n + 1) + token = factory.LazyAttribute(lambda o: str(uuid.uuid4())) + created_at = factory.Faker("past_datetime") + + class Meta: + model = AccessToken