mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2025-09-10 17:20:41 +03:00
close dangerous api methods under api auth (#78)
* close dangerous api methods under api auth * rename access_token method
This commit is contained in:
parent
8266342214
commit
de55d873f9
1
bot_microservice/api/bot/constants.py
Normal file
1
bot_microservice/api/bot/constants.py
Normal file
@ -0,0 +1 @@
|
||||
BOT_ACCESS_API_HEADER = "BOT-API-KEY"
|
@ -3,13 +3,19 @@ from starlette import status
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from telegram import Update
|
||||
|
||||
from api.bot.deps import get_bot_queue, get_chatgpt_service, get_update_from_request
|
||||
from api.bot.deps import (
|
||||
get_access_to_bot_api_or_403,
|
||||
get_bot_queue,
|
||||
get_chatgpt_service,
|
||||
get_update_from_request,
|
||||
)
|
||||
from api.bot.serializers import (
|
||||
ChatGptModelSerializer,
|
||||
ChatGptModelsPrioritySerializer,
|
||||
GETChatGptModelsSerializer,
|
||||
LightChatGptModel,
|
||||
)
|
||||
from api.exceptions import PermissionMissingResponse
|
||||
from core.bot.app import BotQueue
|
||||
from core.bot.services import ChatGptService
|
||||
from settings.config import settings
|
||||
@ -53,6 +59,10 @@ async def models_list(
|
||||
@router.put(
|
||||
"/chatgpt/models/{model_id}/priority",
|
||||
name="bot:change_model_priority",
|
||||
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||
responses={
|
||||
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||
},
|
||||
response_class=Response,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="change gpt model priority",
|
||||
@ -69,6 +79,10 @@ async def change_model_priority(
|
||||
@router.put(
|
||||
"/chatgpt/models/priority/reset",
|
||||
name="bot:reset_models_priority",
|
||||
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||
responses={
|
||||
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||
},
|
||||
response_class=Response,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="reset all model priority to default",
|
||||
@ -83,6 +97,11 @@ async def reset_models_priority(
|
||||
@router.post(
|
||||
"/chatgpt/models",
|
||||
name="bot:add_new_model",
|
||||
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||
responses={
|
||||
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||
status.HTTP_201_CREATED: {"model": ChatGptModelSerializer},
|
||||
},
|
||||
response_model=ChatGptModelSerializer,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="add new model",
|
||||
@ -100,6 +119,10 @@ async def add_new_model(
|
||||
@router.delete(
|
||||
"/chatgpt/models/{model_id}",
|
||||
name="bot:delete_gpt_model",
|
||||
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||
responses={
|
||||
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||
},
|
||||
response_class=Response,
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="delete gpt model",
|
||||
|
@ -1,15 +1,17 @@
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
from telegram import Update
|
||||
|
||||
from api.auth.deps import get_user_service
|
||||
from api.bot.constants import BOT_ACCESS_API_HEADER
|
||||
from api.deps import get_database
|
||||
from core.auth.services import UserService
|
||||
from core.bot.app import BotApplication, BotQueue
|
||||
from core.bot.repository import ChatGPTRepository
|
||||
from core.bot.services import ChatGptService
|
||||
from infra.database.db_adapter import Database
|
||||
from settings.config import AppSettings, get_settings
|
||||
from settings.config import AppSettings, get_settings, settings
|
||||
|
||||
|
||||
def get_bot_app(request: Request) -> BotApplication:
|
||||
@ -40,3 +42,13 @@ def get_chatgpt_service(
|
||||
user_service: UserService = Depends(get_user_service),
|
||||
) -> ChatGptService:
|
||||
return ChatGptService(repository=chatgpt_repository, user_service=user_service)
|
||||
|
||||
|
||||
async def get_access_to_bot_api_or_403(
|
||||
bot_api_key: str | None = Header(None, alias=BOT_ACCESS_API_HEADER, description="Ключ доступа до API бота"),
|
||||
user_service: UserService = Depends(get_user_service),
|
||||
) -> None:
|
||||
access_token = await user_service.get_user_access_token_by_username(settings.SUPERUSER)
|
||||
|
||||
if not access_token or access_token != bot_api_key:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate api header")
|
||||
|
@ -1,11 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
|
||||
from api.base_schemas import BaseError, BaseResponse
|
||||
|
||||
|
||||
class BaseAPIException(Exception):
|
||||
pass
|
||||
_content_type: str = "application/json"
|
||||
model: type[BaseResponse] = BaseResponse
|
||||
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
title: str | None = None
|
||||
type: str | None = None
|
||||
detail: str | None = None
|
||||
instance: str | None = None
|
||||
headers: dict[str, str] | None = None
|
||||
|
||||
def __init__(self, **ctx: Any) -> None:
|
||||
self.__dict__ = ctx
|
||||
|
||||
@classmethod
|
||||
def example(cls) -> dict[str, Any] | None:
|
||||
if isinstance(cls.model.Config.schema_extra, dict): # type: ignore[attr-defined]
|
||||
return cls.model.Config.schema_extra.get("example") # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def response(cls) -> dict[str, Any]:
|
||||
return {
|
||||
"model": cls.model,
|
||||
"content": {
|
||||
cls._content_type: cls.model.Config.schema_extra, # type: ignore[attr-defined]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class InternalServerError(BaseError):
|
||||
@ -28,6 +56,31 @@ class InternalServerErrorResponse(BaseResponse):
|
||||
}
|
||||
|
||||
|
||||
class PermissionMissing(BaseError):
|
||||
pass
|
||||
|
||||
|
||||
class PermissionMissingResponse(BaseResponse):
|
||||
error: PermissionMissing
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"status": 403,
|
||||
"error": {
|
||||
"type": "PermissionMissing",
|
||||
"title": "Permission required for this endpoint is missing",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PermissionMissingError(BaseAPIException):
|
||||
model = PermissionMissingResponse
|
||||
status_code = status.HTTP_403_FORBIDDEN
|
||||
title: str = "Permission required for this endpoint is missing"
|
||||
|
||||
|
||||
async def internal_server_error_handler(_request: Request, _exception: Exception) -> ORJSONResponse:
|
||||
error = InternalServerError(title="Something went wrong!", type="InternalServerError")
|
||||
response = InternalServerErrorResponse(status=500, error=error).model_dump(exclude_unset=True)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String
|
||||
@ -26,7 +27,7 @@ class User(Base):
|
||||
"UserQuestionCount",
|
||||
primaryjoin="UserQuestionCount.user_id == User.id",
|
||||
backref="user",
|
||||
lazy="selectin",
|
||||
lazy="noload",
|
||||
uselist=False,
|
||||
cascade="delete",
|
||||
)
|
||||
@ -68,11 +69,25 @@ class AccessToken(Base):
|
||||
__tablename__ = "access_token" # type: ignore[assignment]
|
||||
|
||||
user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
|
||||
token: Mapped[str] = mapped_column(String(length=42), primary_key=True)
|
||||
token: Mapped[str] = mapped_column(String(length=42), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
|
||||
)
|
||||
|
||||
user: Mapped["User"] = relationship(
|
||||
"User",
|
||||
backref="access_token",
|
||||
lazy="selectin",
|
||||
uselist=False,
|
||||
cascade="expunge",
|
||||
)
|
||||
|
||||
@property
|
||||
def username(self) -> str:
|
||||
if self.user:
|
||||
return self.user.username
|
||||
return ""
|
||||
|
||||
|
||||
class UserQuestionCount(Base):
|
||||
__tablename__ = "user_question_count" # type: ignore[assignment]
|
||||
|
@ -5,7 +5,7 @@ from sqlalchemy.dialects.sqlite import insert
|
||||
from sqlalchemy.orm import load_only
|
||||
|
||||
from core.auth.dto import UserIsBannedDTO
|
||||
from core.auth.models.users import User, UserQuestionCount
|
||||
from core.auth.models.users import AccessToken, User, UserQuestionCount
|
||||
from infra.database.db_adapter import Database
|
||||
|
||||
|
||||
@ -76,3 +76,10 @@ class UserRepository:
|
||||
|
||||
async with self.db.session() as session:
|
||||
await session.execute(query)
|
||||
|
||||
async def get_user_access_token(self, username: str | None) -> str | None:
|
||||
query = select(AccessToken.token).join(AccessToken.user).where(User.username == username)
|
||||
|
||||
async with self.db.session() as session:
|
||||
result = await session.execute(query)
|
||||
return result.scalar()
|
||||
|
@ -62,6 +62,9 @@ class UserService:
|
||||
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
|
||||
return await self.repository.check_user_is_banned(user_id)
|
||||
|
||||
async def get_user_access_token_by_username(self, username: str | None) -> str | None:
|
||||
return await self.repository.get_user_access_token(username)
|
||||
|
||||
|
||||
def check_user_is_banned(func: Any) -> Any:
|
||||
@wraps(func)
|
||||
|
@ -1,8 +1,11 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqladmin import Admin, ModelView
|
||||
from sqlalchemy import Select, desc, select
|
||||
from sqlalchemy.orm import contains_eager, load_only
|
||||
from starlette.requests import Request
|
||||
|
||||
from core.auth.models.users import User
|
||||
from core.auth.models.users import AccessToken, User, UserQuestionCount
|
||||
from core.bot.models.chatgpt import ChatGptModels
|
||||
from core.utils import build_uri
|
||||
from settings.config import settings
|
||||
@ -36,10 +39,34 @@ class UserAdmin(ModelView, model=User):
|
||||
"question_count",
|
||||
User.created_at,
|
||||
]
|
||||
column_sortable_list = [User.created_at]
|
||||
|
||||
column_default_sort = ("created_at", True)
|
||||
form_widget_args = {"created_at": {"readonly": True}}
|
||||
|
||||
def list_query(self, request: Request) -> Select[tuple[User]]:
|
||||
return (
|
||||
select(User)
|
||||
.options(
|
||||
load_only(
|
||||
User.id,
|
||||
User.username,
|
||||
User.first_name,
|
||||
User.last_name,
|
||||
User.is_active,
|
||||
User.created_at,
|
||||
)
|
||||
)
|
||||
.outerjoin(User.user_question_count)
|
||||
.options(contains_eager(User.user_question_count).options(load_only(UserQuestionCount.question_count)))
|
||||
).order_by(desc(UserQuestionCount.question_count))
|
||||
|
||||
|
||||
class AccessTokenAdmin(ModelView, model=AccessToken):
|
||||
name = "API access token"
|
||||
name_plural = "API access tokens"
|
||||
column_list = [AccessToken.user_id, "username", AccessToken.token, AccessToken.created_at]
|
||||
form_widget_args = {"created_at": {"readonly": True}}
|
||||
|
||||
|
||||
def create_admin(application: "Application") -> Admin:
|
||||
admin = Admin(
|
||||
@ -51,4 +78,5 @@ def create_admin(application: "Application") -> Admin:
|
||||
)
|
||||
admin.add_view(ChatGptAdmin)
|
||||
admin.add_view(UserAdmin)
|
||||
admin.add_view(AccessTokenAdmin)
|
||||
return admin
|
||||
|
@ -10,9 +10,8 @@ from datetime import datetime
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy import TIMESTAMP
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
|
||||
from core.auth.models.users import User
|
||||
from core.auth.models.users import AccessToken, User
|
||||
from core.auth.utils import create_password_hash
|
||||
from infra.database.deps import get_sync_session
|
||||
from settings.config import settings
|
||||
@ -58,8 +57,14 @@ def upgrade() -> None:
|
||||
return
|
||||
with get_sync_session() as session:
|
||||
hashed_password = create_password_hash(password.get_secret_value())
|
||||
query = insert(User).values({"username": username, "hashed_password": hashed_password})
|
||||
session.execute(query)
|
||||
user = User(username=username, hashed_password=hashed_password)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
session.refresh(user)
|
||||
|
||||
access_token = AccessToken(user_id=user.id)
|
||||
session.add(access_token)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
|
@ -11,6 +11,8 @@ TELEGRAM_API_TOKEN="123456789:AABBCCDDEEFFaabbccddeeff-1234567890"
|
||||
# set to true to start with webhook. Else bot will start on polling method
|
||||
START_WITH_WEBHOOK="false"
|
||||
|
||||
SUPERUSER="Superuser"
|
||||
|
||||
# ==== domain settings ====
|
||||
DOMAIN="http://localhost"
|
||||
URL_PREFIX=
|
||||
|
@ -6,7 +6,9 @@ from sqlalchemy import desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.bot.models.chatgpt import ChatGptModels
|
||||
from settings.config import AppSettings
|
||||
from tests.integration.factories.bot import ChatGptModelFactory
|
||||
from tests.integration.factories.user import AccessTokenFactory, UserFactory
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.asyncio,
|
||||
@ -51,11 +53,18 @@ async def test_change_chatgpt_model_priority(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
faker: Faker,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
model1 = ChatGptModelFactory(priority=0)
|
||||
model2 = ChatGptModelFactory(priority=1)
|
||||
priority = faker.random_int(min=2, max=7)
|
||||
response = await rest_client.put(url=f"/api/chatgpt/models/{model2.id}/priority", json={"priority": priority})
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
access_token = AccessTokenFactory(user_id=user.id)
|
||||
response = await rest_client.put(
|
||||
url=f"/api/chatgpt/models/{model2.id}/priority",
|
||||
json={"priority": priority},
|
||||
headers={"BOT-API-KEY": access_token.token},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
upd_model1, upd_model2 = dbsession.query(ChatGptModels).order_by(ChatGptModels.priority).all()
|
||||
@ -69,11 +78,18 @@ async def test_change_chatgpt_model_priority(
|
||||
async def test_reset_chatgpt_models_priority(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=4)
|
||||
ChatGptModelFactory(priority=42)
|
||||
|
||||
response = await rest_client.put(url="/api/chatgpt/models/priority/reset")
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
access_token = AccessTokenFactory(user_id=user.id)
|
||||
|
||||
response = await rest_client.put(
|
||||
url="/api/chatgpt/models/priority/reset",
|
||||
headers={"BOT-API-KEY": access_token.token},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
@ -89,10 +105,14 @@ async def test_create_new_chatgpt_model(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
faker: Faker,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=2)
|
||||
ChatGptModelFactory(priority=42)
|
||||
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
access_token = AccessTokenFactory(user_id=user.id)
|
||||
|
||||
model_name = "new-gpt-model"
|
||||
model_priority = faker.random_int(min=1, max=5)
|
||||
|
||||
@ -105,6 +125,7 @@ async def test_create_new_chatgpt_model(
|
||||
"model": model_name,
|
||||
"priority": model_priority,
|
||||
},
|
||||
headers={"BOT-API-KEY": access_token.token},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
@ -125,9 +146,12 @@ async def test_add_existing_chatgpt_model(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
faker: Faker,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=2)
|
||||
model = ChatGptModelFactory(priority=42)
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
access_token = AccessTokenFactory(user_id=user.id)
|
||||
|
||||
model_name = model.model
|
||||
model_priority = faker.random_int(min=1, max=5)
|
||||
@ -141,6 +165,7 @@ async def test_add_existing_chatgpt_model(
|
||||
"model": model_name,
|
||||
"priority": model_priority,
|
||||
},
|
||||
headers={"BOT-API-KEY": access_token.token},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
@ -151,14 +176,21 @@ async def test_add_existing_chatgpt_model(
|
||||
async def test_delete_chatgpt_model(
|
||||
dbsession: Session,
|
||||
rest_client: AsyncClient,
|
||||
test_settings: AppSettings,
|
||||
) -> None:
|
||||
ChatGptModelFactory.create_batch(size=2)
|
||||
model = ChatGptModelFactory(priority=42)
|
||||
|
||||
user = UserFactory(username=test_settings.SUPERUSER)
|
||||
access_token = AccessTokenFactory(user_id=user.id)
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
assert len(models) == 3
|
||||
|
||||
response = await rest_client.delete(url=f"/api/chatgpt/models/{model.id}")
|
||||
response = await rest_client.delete(
|
||||
url=f"/api/chatgpt/models/{model.id}",
|
||||
headers={"BOT-API-KEY": access_token.token},
|
||||
)
|
||||
assert response.status_code == 204
|
||||
|
||||
models = dbsession.query(ChatGptModels).all()
|
||||
|
@ -1,13 +1,15 @@
|
||||
import uuid
|
||||
|
||||
import factory
|
||||
|
||||
from core.auth.models.users import User
|
||||
from core.auth.models.users import AccessToken, User
|
||||
from tests.integration.factories.utils import BaseModelFactory
|
||||
|
||||
|
||||
class UserFactory(BaseModelFactory):
|
||||
id = factory.Sequence(lambda n: n + 1)
|
||||
email = factory.Faker("email")
|
||||
username = factory.Faker("user_name", locale="en_EN")
|
||||
username = factory.Faker("user_name", locale="en")
|
||||
first_name = factory.Faker("word")
|
||||
last_name = factory.Faker("word")
|
||||
ban_reason = factory.Faker("text", max_nb_chars=100)
|
||||
@ -18,3 +20,12 @@ class UserFactory(BaseModelFactory):
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
|
||||
|
||||
class AccessTokenFactory(BaseModelFactory):
|
||||
user_id = factory.Sequence(lambda n: n + 1)
|
||||
token = factory.LazyAttribute(lambda o: str(uuid.uuid4()))
|
||||
created_at = factory.Faker("past_datetime")
|
||||
|
||||
class Meta:
|
||||
model = AccessToken
|
||||
|
Loading…
x
Reference in New Issue
Block a user