close dangerous api methods under api auth (#78)

* close dangerous api methods under api auth

* rename access_token method
This commit is contained in:
Dmitry Afanasyev 2024-01-07 20:06:02 +03:00 committed by GitHub
parent 8266342214
commit de55d873f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 210 additions and 18 deletions

View File

@ -0,0 +1 @@
BOT_ACCESS_API_HEADER = "BOT-API-KEY"

View File

@ -3,13 +3,19 @@ from starlette import status
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from telegram import Update from telegram import Update
from api.bot.deps import get_bot_queue, get_chatgpt_service, get_update_from_request from api.bot.deps import (
get_access_to_bot_api_or_403,
get_bot_queue,
get_chatgpt_service,
get_update_from_request,
)
from api.bot.serializers import ( from api.bot.serializers import (
ChatGptModelSerializer, ChatGptModelSerializer,
ChatGptModelsPrioritySerializer, ChatGptModelsPrioritySerializer,
GETChatGptModelsSerializer, GETChatGptModelsSerializer,
LightChatGptModel, LightChatGptModel,
) )
from api.exceptions import PermissionMissingResponse
from core.bot.app import BotQueue from core.bot.app import BotQueue
from core.bot.services import ChatGptService from core.bot.services import ChatGptService
from settings.config import settings from settings.config import settings
@ -53,6 +59,10 @@ async def models_list(
@router.put( @router.put(
"/chatgpt/models/{model_id}/priority", "/chatgpt/models/{model_id}/priority",
name="bot:change_model_priority", name="bot:change_model_priority",
dependencies=[Depends(get_access_to_bot_api_or_403)],
responses={
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
},
response_class=Response, response_class=Response,
status_code=status.HTTP_202_ACCEPTED, status_code=status.HTTP_202_ACCEPTED,
summary="change gpt model priority", summary="change gpt model priority",
@ -69,6 +79,10 @@ async def change_model_priority(
@router.put( @router.put(
"/chatgpt/models/priority/reset", "/chatgpt/models/priority/reset",
name="bot:reset_models_priority", name="bot:reset_models_priority",
dependencies=[Depends(get_access_to_bot_api_or_403)],
responses={
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
},
response_class=Response, response_class=Response,
status_code=status.HTTP_202_ACCEPTED, status_code=status.HTTP_202_ACCEPTED,
summary="reset all model priority to default", summary="reset all model priority to default",
@ -83,6 +97,11 @@ async def reset_models_priority(
@router.post( @router.post(
"/chatgpt/models", "/chatgpt/models",
name="bot:add_new_model", name="bot:add_new_model",
dependencies=[Depends(get_access_to_bot_api_or_403)],
responses={
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
status.HTTP_201_CREATED: {"model": ChatGptModelSerializer},
},
response_model=ChatGptModelSerializer, response_model=ChatGptModelSerializer,
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
summary="add new model", summary="add new model",
@ -100,6 +119,10 @@ async def add_new_model(
@router.delete( @router.delete(
"/chatgpt/models/{model_id}", "/chatgpt/models/{model_id}",
name="bot:delete_gpt_model", name="bot:delete_gpt_model",
dependencies=[Depends(get_access_to_bot_api_or_403)],
responses={
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
},
response_class=Response, response_class=Response,
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
summary="delete gpt model", summary="delete gpt model",

View File

@ -1,15 +1,17 @@
from fastapi import Depends from fastapi import Depends, Header, HTTPException
from starlette import status
from starlette.requests import Request from starlette.requests import Request
from telegram import Update from telegram import Update
from api.auth.deps import get_user_service from api.auth.deps import get_user_service
from api.bot.constants import BOT_ACCESS_API_HEADER
from api.deps import get_database from api.deps import get_database
from core.auth.services import UserService from core.auth.services import UserService
from core.bot.app import BotApplication, BotQueue from core.bot.app import BotApplication, BotQueue
from core.bot.repository import ChatGPTRepository from core.bot.repository import ChatGPTRepository
from core.bot.services import ChatGptService from core.bot.services import ChatGptService
from infra.database.db_adapter import Database from infra.database.db_adapter import Database
from settings.config import AppSettings, get_settings from settings.config import AppSettings, get_settings, settings
def get_bot_app(request: Request) -> BotApplication: def get_bot_app(request: Request) -> BotApplication:
@ -40,3 +42,13 @@ def get_chatgpt_service(
user_service: UserService = Depends(get_user_service), user_service: UserService = Depends(get_user_service),
) -> ChatGptService: ) -> ChatGptService:
return ChatGptService(repository=chatgpt_repository, user_service=user_service) return ChatGptService(repository=chatgpt_repository, user_service=user_service)
async def get_access_to_bot_api_or_403(
bot_api_key: str | None = Header(None, alias=BOT_ACCESS_API_HEADER, description="Ключ доступа до API бота"),
user_service: UserService = Depends(get_user_service),
) -> None:
access_token = await user_service.get_user_access_token_by_username(settings.SUPERUSER)
if not access_token or access_token != bot_api_key:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate api header")

View File

@ -1,11 +1,39 @@
from typing import Any
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from starlette import status
from starlette.requests import Request from starlette.requests import Request
from api.base_schemas import BaseError, BaseResponse from api.base_schemas import BaseError, BaseResponse
class BaseAPIException(Exception): class BaseAPIException(Exception):
pass _content_type: str = "application/json"
model: type[BaseResponse] = BaseResponse
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR
title: str | None = None
type: str | None = None
detail: str | None = None
instance: str | None = None
headers: dict[str, str] | None = None
def __init__(self, **ctx: Any) -> None:
self.__dict__ = ctx
@classmethod
def example(cls) -> dict[str, Any] | None:
if isinstance(cls.model.Config.schema_extra, dict): # type: ignore[attr-defined]
return cls.model.Config.schema_extra.get("example") # type: ignore[attr-defined]
return None
@classmethod
def response(cls) -> dict[str, Any]:
return {
"model": cls.model,
"content": {
cls._content_type: cls.model.Config.schema_extra, # type: ignore[attr-defined]
},
}
class InternalServerError(BaseError): class InternalServerError(BaseError):
@ -28,6 +56,31 @@ class InternalServerErrorResponse(BaseResponse):
} }
class PermissionMissing(BaseError):
pass
class PermissionMissingResponse(BaseResponse):
error: PermissionMissing
class Config:
json_schema_extra = {
"example": {
"status": 403,
"error": {
"type": "PermissionMissing",
"title": "Permission required for this endpoint is missing",
},
},
}
class PermissionMissingError(BaseAPIException):
model = PermissionMissingResponse
status_code = status.HTTP_403_FORBIDDEN
title: str = "Permission required for this endpoint is missing"
async def internal_server_error_handler(_request: Request, _exception: Exception) -> ORJSONResponse: async def internal_server_error_handler(_request: Request, _exception: Exception) -> ORJSONResponse:
error = InternalServerError(title="Something went wrong!", type="InternalServerError") error = InternalServerError(title="Something went wrong!", type="InternalServerError")
response = InternalServerErrorResponse(status=500, error=error).model_dump(exclude_unset=True) response = InternalServerErrorResponse(status=500, error=error).model_dump(exclude_unset=True)

View File

@ -1,3 +1,4 @@
import uuid
from datetime import datetime from datetime import datetime
from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String
@ -26,7 +27,7 @@ class User(Base):
"UserQuestionCount", "UserQuestionCount",
primaryjoin="UserQuestionCount.user_id == User.id", primaryjoin="UserQuestionCount.user_id == User.id",
backref="user", backref="user",
lazy="selectin", lazy="noload",
uselist=False, uselist=False,
cascade="delete", cascade="delete",
) )
@ -68,11 +69,25 @@ class AccessToken(Base):
__tablename__ = "access_token" # type: ignore[assignment] __tablename__ = "access_token" # type: ignore[assignment]
user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False) user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
token: Mapped[str] = mapped_column(String(length=42), primary_key=True) token: Mapped[str] = mapped_column(String(length=42), primary_key=True, default=lambda: str(uuid.uuid4()))
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
) )
user: Mapped["User"] = relationship(
"User",
backref="access_token",
lazy="selectin",
uselist=False,
cascade="expunge",
)
@property
def username(self) -> str:
if self.user:
return self.user.username
return ""
class UserQuestionCount(Base): class UserQuestionCount(Base):
__tablename__ = "user_question_count" # type: ignore[assignment] __tablename__ = "user_question_count" # type: ignore[assignment]

View File

@ -5,7 +5,7 @@ from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.orm import load_only from sqlalchemy.orm import load_only
from core.auth.dto import UserIsBannedDTO from core.auth.dto import UserIsBannedDTO
from core.auth.models.users import User, UserQuestionCount from core.auth.models.users import AccessToken, User, UserQuestionCount
from infra.database.db_adapter import Database from infra.database.db_adapter import Database
@ -76,3 +76,10 @@ class UserRepository:
async with self.db.session() as session: async with self.db.session() as session:
await session.execute(query) await session.execute(query)
async def get_user_access_token(self, username: str | None) -> str | None:
query = select(AccessToken.token).join(AccessToken.user).where(User.username == username)
async with self.db.session() as session:
result = await session.execute(query)
return result.scalar()

View File

@ -62,6 +62,9 @@ class UserService:
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO: async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
return await self.repository.check_user_is_banned(user_id) return await self.repository.check_user_is_banned(user_id)
async def get_user_access_token_by_username(self, username: str | None) -> str | None:
return await self.repository.get_user_access_token(username)
def check_user_is_banned(func: Any) -> Any: def check_user_is_banned(func: Any) -> Any:
@wraps(func) @wraps(func)

View File

@ -1,8 +1,11 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqladmin import Admin, ModelView from sqladmin import Admin, ModelView
from sqlalchemy import Select, desc, select
from sqlalchemy.orm import contains_eager, load_only
from starlette.requests import Request
from core.auth.models.users import User from core.auth.models.users import AccessToken, User, UserQuestionCount
from core.bot.models.chatgpt import ChatGptModels from core.bot.models.chatgpt import ChatGptModels
from core.utils import build_uri from core.utils import build_uri
from settings.config import settings from settings.config import settings
@ -36,10 +39,34 @@ class UserAdmin(ModelView, model=User):
"question_count", "question_count",
User.created_at, User.created_at,
] ]
column_sortable_list = [User.created_at]
column_default_sort = ("created_at", True) column_default_sort = ("created_at", True)
form_widget_args = {"created_at": {"readonly": True}} form_widget_args = {"created_at": {"readonly": True}}
def list_query(self, request: Request) -> Select[tuple[User]]:
return (
select(User)
.options(
load_only(
User.id,
User.username,
User.first_name,
User.last_name,
User.is_active,
User.created_at,
)
)
.outerjoin(User.user_question_count)
.options(contains_eager(User.user_question_count).options(load_only(UserQuestionCount.question_count)))
).order_by(desc(UserQuestionCount.question_count))
class AccessTokenAdmin(ModelView, model=AccessToken):
name = "API access token"
name_plural = "API access tokens"
column_list = [AccessToken.user_id, "username", AccessToken.token, AccessToken.created_at]
form_widget_args = {"created_at": {"readonly": True}}
def create_admin(application: "Application") -> Admin: def create_admin(application: "Application") -> Admin:
admin = Admin( admin = Admin(
@ -51,4 +78,5 @@ def create_admin(application: "Application") -> Admin:
) )
admin.add_view(ChatGptAdmin) admin.add_view(ChatGptAdmin)
admin.add_view(UserAdmin) admin.add_view(UserAdmin)
admin.add_view(AccessTokenAdmin)
return admin return admin

View File

@ -10,9 +10,8 @@ from datetime import datetime
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
from sqlalchemy import TIMESTAMP from sqlalchemy import TIMESTAMP
from sqlalchemy.dialects.sqlite import insert
from core.auth.models.users import User from core.auth.models.users import AccessToken, User
from core.auth.utils import create_password_hash from core.auth.utils import create_password_hash
from infra.database.deps import get_sync_session from infra.database.deps import get_sync_session
from settings.config import settings from settings.config import settings
@ -58,8 +57,14 @@ def upgrade() -> None:
return return
with get_sync_session() as session: with get_sync_session() as session:
hashed_password = create_password_hash(password.get_secret_value()) hashed_password = create_password_hash(password.get_secret_value())
query = insert(User).values({"username": username, "hashed_password": hashed_password}) user = User(username=username, hashed_password=hashed_password)
session.execute(query) session.add(user)
session.flush()
session.refresh(user)
access_token = AccessToken(user_id=user.id)
session.add(access_token)
session.commit() session.commit()

View File

@ -11,6 +11,8 @@ TELEGRAM_API_TOKEN="123456789:AABBCCDDEEFFaabbccddeeff-1234567890"
# set to true to start with webhook. Else bot will start on polling method # set to true to start with webhook. Else bot will start on polling method
START_WITH_WEBHOOK="false" START_WITH_WEBHOOK="false"
SUPERUSER="Superuser"
# ==== domain settings ==== # ==== domain settings ====
DOMAIN="http://localhost" DOMAIN="http://localhost"
URL_PREFIX= URL_PREFIX=

View File

@ -6,7 +6,9 @@ from sqlalchemy import desc
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.bot.models.chatgpt import ChatGptModels from core.bot.models.chatgpt import ChatGptModels
from settings.config import AppSettings
from tests.integration.factories.bot import ChatGptModelFactory from tests.integration.factories.bot import ChatGptModelFactory
from tests.integration.factories.user import AccessTokenFactory, UserFactory
pytestmark = [ pytestmark = [
pytest.mark.asyncio, pytest.mark.asyncio,
@ -51,11 +53,18 @@ async def test_change_chatgpt_model_priority(
dbsession: Session, dbsession: Session,
rest_client: AsyncClient, rest_client: AsyncClient,
faker: Faker, faker: Faker,
test_settings: AppSettings,
) -> None: ) -> None:
model1 = ChatGptModelFactory(priority=0) model1 = ChatGptModelFactory(priority=0)
model2 = ChatGptModelFactory(priority=1) model2 = ChatGptModelFactory(priority=1)
priority = faker.random_int(min=2, max=7) priority = faker.random_int(min=2, max=7)
response = await rest_client.put(url=f"/api/chatgpt/models/{model2.id}/priority", json={"priority": priority}) user = UserFactory(username=test_settings.SUPERUSER)
access_token = AccessTokenFactory(user_id=user.id)
response = await rest_client.put(
url=f"/api/chatgpt/models/{model2.id}/priority",
json={"priority": priority},
headers={"BOT-API-KEY": access_token.token},
)
assert response.status_code == 202 assert response.status_code == 202
upd_model1, upd_model2 = dbsession.query(ChatGptModels).order_by(ChatGptModels.priority).all() upd_model1, upd_model2 = dbsession.query(ChatGptModels).order_by(ChatGptModels.priority).all()
@ -69,11 +78,18 @@ async def test_change_chatgpt_model_priority(
async def test_reset_chatgpt_models_priority( async def test_reset_chatgpt_models_priority(
dbsession: Session, dbsession: Session,
rest_client: AsyncClient, rest_client: AsyncClient,
test_settings: AppSettings,
) -> None: ) -> None:
ChatGptModelFactory.create_batch(size=4) ChatGptModelFactory.create_batch(size=4)
ChatGptModelFactory(priority=42) ChatGptModelFactory(priority=42)
response = await rest_client.put(url="/api/chatgpt/models/priority/reset") user = UserFactory(username=test_settings.SUPERUSER)
access_token = AccessTokenFactory(user_id=user.id)
response = await rest_client.put(
url="/api/chatgpt/models/priority/reset",
headers={"BOT-API-KEY": access_token.token},
)
assert response.status_code == 202 assert response.status_code == 202
models = dbsession.query(ChatGptModels).all() models = dbsession.query(ChatGptModels).all()
@ -89,10 +105,14 @@ async def test_create_new_chatgpt_model(
dbsession: Session, dbsession: Session,
rest_client: AsyncClient, rest_client: AsyncClient,
faker: Faker, faker: Faker,
test_settings: AppSettings,
) -> None: ) -> None:
ChatGptModelFactory.create_batch(size=2) ChatGptModelFactory.create_batch(size=2)
ChatGptModelFactory(priority=42) ChatGptModelFactory(priority=42)
user = UserFactory(username=test_settings.SUPERUSER)
access_token = AccessTokenFactory(user_id=user.id)
model_name = "new-gpt-model" model_name = "new-gpt-model"
model_priority = faker.random_int(min=1, max=5) model_priority = faker.random_int(min=1, max=5)
@ -105,6 +125,7 @@ async def test_create_new_chatgpt_model(
"model": model_name, "model": model_name,
"priority": model_priority, "priority": model_priority,
}, },
headers={"BOT-API-KEY": access_token.token},
) )
assert response.status_code == 201 assert response.status_code == 201
@ -125,9 +146,12 @@ async def test_add_existing_chatgpt_model(
dbsession: Session, dbsession: Session,
rest_client: AsyncClient, rest_client: AsyncClient,
faker: Faker, faker: Faker,
test_settings: AppSettings,
) -> None: ) -> None:
ChatGptModelFactory.create_batch(size=2) ChatGptModelFactory.create_batch(size=2)
model = ChatGptModelFactory(priority=42) model = ChatGptModelFactory(priority=42)
user = UserFactory(username=test_settings.SUPERUSER)
access_token = AccessTokenFactory(user_id=user.id)
model_name = model.model model_name = model.model
model_priority = faker.random_int(min=1, max=5) model_priority = faker.random_int(min=1, max=5)
@ -141,6 +165,7 @@ async def test_add_existing_chatgpt_model(
"model": model_name, "model": model_name,
"priority": model_priority, "priority": model_priority,
}, },
headers={"BOT-API-KEY": access_token.token},
) )
assert response.status_code == 201 assert response.status_code == 201
@ -151,14 +176,21 @@ async def test_add_existing_chatgpt_model(
async def test_delete_chatgpt_model( async def test_delete_chatgpt_model(
dbsession: Session, dbsession: Session,
rest_client: AsyncClient, rest_client: AsyncClient,
test_settings: AppSettings,
) -> None: ) -> None:
ChatGptModelFactory.create_batch(size=2) ChatGptModelFactory.create_batch(size=2)
model = ChatGptModelFactory(priority=42) model = ChatGptModelFactory(priority=42)
user = UserFactory(username=test_settings.SUPERUSER)
access_token = AccessTokenFactory(user_id=user.id)
models = dbsession.query(ChatGptModels).all() models = dbsession.query(ChatGptModels).all()
assert len(models) == 3 assert len(models) == 3
response = await rest_client.delete(url=f"/api/chatgpt/models/{model.id}") response = await rest_client.delete(
url=f"/api/chatgpt/models/{model.id}",
headers={"BOT-API-KEY": access_token.token},
)
assert response.status_code == 204 assert response.status_code == 204
models = dbsession.query(ChatGptModels).all() models = dbsession.query(ChatGptModels).all()

View File

@ -1,13 +1,15 @@
import uuid
import factory import factory
from core.auth.models.users import User from core.auth.models.users import AccessToken, User
from tests.integration.factories.utils import BaseModelFactory from tests.integration.factories.utils import BaseModelFactory
class UserFactory(BaseModelFactory): class UserFactory(BaseModelFactory):
id = factory.Sequence(lambda n: n + 1) id = factory.Sequence(lambda n: n + 1)
email = factory.Faker("email") email = factory.Faker("email")
username = factory.Faker("user_name", locale="en_EN") username = factory.Faker("user_name", locale="en")
first_name = factory.Faker("word") first_name = factory.Faker("word")
last_name = factory.Faker("word") last_name = factory.Faker("word")
ban_reason = factory.Faker("text", max_nb_chars=100) ban_reason = factory.Faker("text", max_nb_chars=100)
@ -18,3 +20,12 @@ class UserFactory(BaseModelFactory):
class Meta: class Meta:
model = User model = User
class AccessTokenFactory(BaseModelFactory):
user_id = factory.Sequence(lambda n: n + 1)
token = factory.LazyAttribute(lambda o: str(uuid.uuid4()))
created_at = factory.Faker("past_datetime")
class Meta:
model = AccessToken