mirror of
https://github.com/Balshgit/gpt_chat_bot.git
synced 2025-09-11 22:30:41 +03:00
close dangerous api methods under api auth (#78)
* close dangerous api methods under api auth * rename access_token method
This commit is contained in:
parent
8266342214
commit
de55d873f9
1
bot_microservice/api/bot/constants.py
Normal file
1
bot_microservice/api/bot/constants.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
BOT_ACCESS_API_HEADER = "BOT-API-KEY"
|
@ -3,13 +3,19 @@ from starlette import status
|
|||||||
from starlette.responses import JSONResponse, Response
|
from starlette.responses import JSONResponse, Response
|
||||||
from telegram import Update
|
from telegram import Update
|
||||||
|
|
||||||
from api.bot.deps import get_bot_queue, get_chatgpt_service, get_update_from_request
|
from api.bot.deps import (
|
||||||
|
get_access_to_bot_api_or_403,
|
||||||
|
get_bot_queue,
|
||||||
|
get_chatgpt_service,
|
||||||
|
get_update_from_request,
|
||||||
|
)
|
||||||
from api.bot.serializers import (
|
from api.bot.serializers import (
|
||||||
ChatGptModelSerializer,
|
ChatGptModelSerializer,
|
||||||
ChatGptModelsPrioritySerializer,
|
ChatGptModelsPrioritySerializer,
|
||||||
GETChatGptModelsSerializer,
|
GETChatGptModelsSerializer,
|
||||||
LightChatGptModel,
|
LightChatGptModel,
|
||||||
)
|
)
|
||||||
|
from api.exceptions import PermissionMissingResponse
|
||||||
from core.bot.app import BotQueue
|
from core.bot.app import BotQueue
|
||||||
from core.bot.services import ChatGptService
|
from core.bot.services import ChatGptService
|
||||||
from settings.config import settings
|
from settings.config import settings
|
||||||
@ -53,6 +59,10 @@ async def models_list(
|
|||||||
@router.put(
|
@router.put(
|
||||||
"/chatgpt/models/{model_id}/priority",
|
"/chatgpt/models/{model_id}/priority",
|
||||||
name="bot:change_model_priority",
|
name="bot:change_model_priority",
|
||||||
|
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||||
|
responses={
|
||||||
|
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||||
|
},
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
status_code=status.HTTP_202_ACCEPTED,
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
summary="change gpt model priority",
|
summary="change gpt model priority",
|
||||||
@ -69,6 +79,10 @@ async def change_model_priority(
|
|||||||
@router.put(
|
@router.put(
|
||||||
"/chatgpt/models/priority/reset",
|
"/chatgpt/models/priority/reset",
|
||||||
name="bot:reset_models_priority",
|
name="bot:reset_models_priority",
|
||||||
|
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||||
|
responses={
|
||||||
|
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||||
|
},
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
status_code=status.HTTP_202_ACCEPTED,
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
summary="reset all model priority to default",
|
summary="reset all model priority to default",
|
||||||
@ -83,6 +97,11 @@ async def reset_models_priority(
|
|||||||
@router.post(
|
@router.post(
|
||||||
"/chatgpt/models",
|
"/chatgpt/models",
|
||||||
name="bot:add_new_model",
|
name="bot:add_new_model",
|
||||||
|
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||||
|
responses={
|
||||||
|
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||||
|
status.HTTP_201_CREATED: {"model": ChatGptModelSerializer},
|
||||||
|
},
|
||||||
response_model=ChatGptModelSerializer,
|
response_model=ChatGptModelSerializer,
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
summary="add new model",
|
summary="add new model",
|
||||||
@ -100,6 +119,10 @@ async def add_new_model(
|
|||||||
@router.delete(
|
@router.delete(
|
||||||
"/chatgpt/models/{model_id}",
|
"/chatgpt/models/{model_id}",
|
||||||
name="bot:delete_gpt_model",
|
name="bot:delete_gpt_model",
|
||||||
|
dependencies=[Depends(get_access_to_bot_api_or_403)],
|
||||||
|
responses={
|
||||||
|
status.HTTP_403_FORBIDDEN: {"model": PermissionMissingResponse},
|
||||||
|
},
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
status_code=status.HTTP_204_NO_CONTENT,
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
summary="delete gpt model",
|
summary="delete gpt model",
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
from fastapi import Depends
|
from fastapi import Depends, Header, HTTPException
|
||||||
|
from starlette import status
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from telegram import Update
|
from telegram import Update
|
||||||
|
|
||||||
from api.auth.deps import get_user_service
|
from api.auth.deps import get_user_service
|
||||||
|
from api.bot.constants import BOT_ACCESS_API_HEADER
|
||||||
from api.deps import get_database
|
from api.deps import get_database
|
||||||
from core.auth.services import UserService
|
from core.auth.services import UserService
|
||||||
from core.bot.app import BotApplication, BotQueue
|
from core.bot.app import BotApplication, BotQueue
|
||||||
from core.bot.repository import ChatGPTRepository
|
from core.bot.repository import ChatGPTRepository
|
||||||
from core.bot.services import ChatGptService
|
from core.bot.services import ChatGptService
|
||||||
from infra.database.db_adapter import Database
|
from infra.database.db_adapter import Database
|
||||||
from settings.config import AppSettings, get_settings
|
from settings.config import AppSettings, get_settings, settings
|
||||||
|
|
||||||
|
|
||||||
def get_bot_app(request: Request) -> BotApplication:
|
def get_bot_app(request: Request) -> BotApplication:
|
||||||
@ -40,3 +42,13 @@ def get_chatgpt_service(
|
|||||||
user_service: UserService = Depends(get_user_service),
|
user_service: UserService = Depends(get_user_service),
|
||||||
) -> ChatGptService:
|
) -> ChatGptService:
|
||||||
return ChatGptService(repository=chatgpt_repository, user_service=user_service)
|
return ChatGptService(repository=chatgpt_repository, user_service=user_service)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_access_to_bot_api_or_403(
|
||||||
|
bot_api_key: str | None = Header(None, alias=BOT_ACCESS_API_HEADER, description="Ключ доступа до API бота"),
|
||||||
|
user_service: UserService = Depends(get_user_service),
|
||||||
|
) -> None:
|
||||||
|
access_token = await user_service.get_user_access_token_by_username(settings.SUPERUSER)
|
||||||
|
|
||||||
|
if not access_token or access_token != bot_api_key:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate api header")
|
||||||
|
@ -1,11 +1,39 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
|
from starlette import status
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from api.base_schemas import BaseError, BaseResponse
|
from api.base_schemas import BaseError, BaseResponse
|
||||||
|
|
||||||
|
|
||||||
class BaseAPIException(Exception):
|
class BaseAPIException(Exception):
|
||||||
pass
|
_content_type: str = "application/json"
|
||||||
|
model: type[BaseResponse] = BaseResponse
|
||||||
|
status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
title: str | None = None
|
||||||
|
type: str | None = None
|
||||||
|
detail: str | None = None
|
||||||
|
instance: str | None = None
|
||||||
|
headers: dict[str, str] | None = None
|
||||||
|
|
||||||
|
def __init__(self, **ctx: Any) -> None:
|
||||||
|
self.__dict__ = ctx
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def example(cls) -> dict[str, Any] | None:
|
||||||
|
if isinstance(cls.model.Config.schema_extra, dict): # type: ignore[attr-defined]
|
||||||
|
return cls.model.Config.schema_extra.get("example") # type: ignore[attr-defined]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def response(cls) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"model": cls.model,
|
||||||
|
"content": {
|
||||||
|
cls._content_type: cls.model.Config.schema_extra, # type: ignore[attr-defined]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class InternalServerError(BaseError):
|
class InternalServerError(BaseError):
|
||||||
@ -28,6 +56,31 @@ class InternalServerErrorResponse(BaseResponse):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionMissing(BaseError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionMissingResponse(BaseResponse):
|
||||||
|
error: PermissionMissing
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"status": 403,
|
||||||
|
"error": {
|
||||||
|
"type": "PermissionMissing",
|
||||||
|
"title": "Permission required for this endpoint is missing",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionMissingError(BaseAPIException):
|
||||||
|
model = PermissionMissingResponse
|
||||||
|
status_code = status.HTTP_403_FORBIDDEN
|
||||||
|
title: str = "Permission required for this endpoint is missing"
|
||||||
|
|
||||||
|
|
||||||
async def internal_server_error_handler(_request: Request, _exception: Exception) -> ORJSONResponse:
|
async def internal_server_error_handler(_request: Request, _exception: Exception) -> ORJSONResponse:
|
||||||
error = InternalServerError(title="Something went wrong!", type="InternalServerError")
|
error = InternalServerError(title="Something went wrong!", type="InternalServerError")
|
||||||
response = InternalServerErrorResponse(status=500, error=error).model_dump(exclude_unset=True)
|
response = InternalServerErrorResponse(status=500, error=error).model_dump(exclude_unset=True)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String
|
from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String
|
||||||
@ -26,7 +27,7 @@ class User(Base):
|
|||||||
"UserQuestionCount",
|
"UserQuestionCount",
|
||||||
primaryjoin="UserQuestionCount.user_id == User.id",
|
primaryjoin="UserQuestionCount.user_id == User.id",
|
||||||
backref="user",
|
backref="user",
|
||||||
lazy="selectin",
|
lazy="noload",
|
||||||
uselist=False,
|
uselist=False,
|
||||||
cascade="delete",
|
cascade="delete",
|
||||||
)
|
)
|
||||||
@ -68,11 +69,25 @@ class AccessToken(Base):
|
|||||||
__tablename__ = "access_token" # type: ignore[assignment]
|
__tablename__ = "access_token" # type: ignore[assignment]
|
||||||
|
|
||||||
user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
|
user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
|
||||||
token: Mapped[str] = mapped_column(String(length=42), primary_key=True)
|
token: Mapped[str] = mapped_column(String(length=42), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
|
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user: Mapped["User"] = relationship(
|
||||||
|
"User",
|
||||||
|
backref="access_token",
|
||||||
|
lazy="selectin",
|
||||||
|
uselist=False,
|
||||||
|
cascade="expunge",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def username(self) -> str:
|
||||||
|
if self.user:
|
||||||
|
return self.user.username
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class UserQuestionCount(Base):
|
class UserQuestionCount(Base):
|
||||||
__tablename__ = "user_question_count" # type: ignore[assignment]
|
__tablename__ = "user_question_count" # type: ignore[assignment]
|
||||||
|
@ -5,7 +5,7 @@ from sqlalchemy.dialects.sqlite import insert
|
|||||||
from sqlalchemy.orm import load_only
|
from sqlalchemy.orm import load_only
|
||||||
|
|
||||||
from core.auth.dto import UserIsBannedDTO
|
from core.auth.dto import UserIsBannedDTO
|
||||||
from core.auth.models.users import User, UserQuestionCount
|
from core.auth.models.users import AccessToken, User, UserQuestionCount
|
||||||
from infra.database.db_adapter import Database
|
from infra.database.db_adapter import Database
|
||||||
|
|
||||||
|
|
||||||
@ -76,3 +76,10 @@ class UserRepository:
|
|||||||
|
|
||||||
async with self.db.session() as session:
|
async with self.db.session() as session:
|
||||||
await session.execute(query)
|
await session.execute(query)
|
||||||
|
|
||||||
|
async def get_user_access_token(self, username: str | None) -> str | None:
|
||||||
|
query = select(AccessToken.token).join(AccessToken.user).where(User.username == username)
|
||||||
|
|
||||||
|
async with self.db.session() as session:
|
||||||
|
result = await session.execute(query)
|
||||||
|
return result.scalar()
|
||||||
|
@ -62,6 +62,9 @@ class UserService:
|
|||||||
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
|
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
|
||||||
return await self.repository.check_user_is_banned(user_id)
|
return await self.repository.check_user_is_banned(user_id)
|
||||||
|
|
||||||
|
async def get_user_access_token_by_username(self, username: str | None) -> str | None:
|
||||||
|
return await self.repository.get_user_access_token(username)
|
||||||
|
|
||||||
|
|
||||||
def check_user_is_banned(func: Any) -> Any:
|
def check_user_is_banned(func: Any) -> Any:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqladmin import Admin, ModelView
|
from sqladmin import Admin, ModelView
|
||||||
|
from sqlalchemy import Select, desc, select
|
||||||
|
from sqlalchemy.orm import contains_eager, load_only
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
from core.auth.models.users import User
|
from core.auth.models.users import AccessToken, User, UserQuestionCount
|
||||||
from core.bot.models.chatgpt import ChatGptModels
|
from core.bot.models.chatgpt import ChatGptModels
|
||||||
from core.utils import build_uri
|
from core.utils import build_uri
|
||||||
from settings.config import settings
|
from settings.config import settings
|
||||||
@ -36,10 +39,34 @@ class UserAdmin(ModelView, model=User):
|
|||||||
"question_count",
|
"question_count",
|
||||||
User.created_at,
|
User.created_at,
|
||||||
]
|
]
|
||||||
column_sortable_list = [User.created_at]
|
|
||||||
column_default_sort = ("created_at", True)
|
column_default_sort = ("created_at", True)
|
||||||
form_widget_args = {"created_at": {"readonly": True}}
|
form_widget_args = {"created_at": {"readonly": True}}
|
||||||
|
|
||||||
|
def list_query(self, request: Request) -> Select[tuple[User]]:
|
||||||
|
return (
|
||||||
|
select(User)
|
||||||
|
.options(
|
||||||
|
load_only(
|
||||||
|
User.id,
|
||||||
|
User.username,
|
||||||
|
User.first_name,
|
||||||
|
User.last_name,
|
||||||
|
User.is_active,
|
||||||
|
User.created_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.outerjoin(User.user_question_count)
|
||||||
|
.options(contains_eager(User.user_question_count).options(load_only(UserQuestionCount.question_count)))
|
||||||
|
).order_by(desc(UserQuestionCount.question_count))
|
||||||
|
|
||||||
|
|
||||||
|
class AccessTokenAdmin(ModelView, model=AccessToken):
|
||||||
|
name = "API access token"
|
||||||
|
name_plural = "API access tokens"
|
||||||
|
column_list = [AccessToken.user_id, "username", AccessToken.token, AccessToken.created_at]
|
||||||
|
form_widget_args = {"created_at": {"readonly": True}}
|
||||||
|
|
||||||
|
|
||||||
def create_admin(application: "Application") -> Admin:
|
def create_admin(application: "Application") -> Admin:
|
||||||
admin = Admin(
|
admin = Admin(
|
||||||
@ -51,4 +78,5 @@ def create_admin(application: "Application") -> Admin:
|
|||||||
)
|
)
|
||||||
admin.add_view(ChatGptAdmin)
|
admin.add_view(ChatGptAdmin)
|
||||||
admin.add_view(UserAdmin)
|
admin.add_view(UserAdmin)
|
||||||
|
admin.add_view(AccessTokenAdmin)
|
||||||
return admin
|
return admin
|
||||||
|
@ -10,9 +10,8 @@ from datetime import datetime
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from alembic import op
|
from alembic import op
|
||||||
from sqlalchemy import TIMESTAMP
|
from sqlalchemy import TIMESTAMP
|
||||||
from sqlalchemy.dialects.sqlite import insert
|
|
||||||
|
|
||||||
from core.auth.models.users import User
|
from core.auth.models.users import AccessToken, User
|
||||||
from core.auth.utils import create_password_hash
|
from core.auth.utils import create_password_hash
|
||||||
from infra.database.deps import get_sync_session
|
from infra.database.deps import get_sync_session
|
||||||
from settings.config import settings
|
from settings.config import settings
|
||||||
@ -58,8 +57,14 @@ def upgrade() -> None:
|
|||||||
return
|
return
|
||||||
with get_sync_session() as session:
|
with get_sync_session() as session:
|
||||||
hashed_password = create_password_hash(password.get_secret_value())
|
hashed_password = create_password_hash(password.get_secret_value())
|
||||||
query = insert(User).values({"username": username, "hashed_password": hashed_password})
|
user = User(username=username, hashed_password=hashed_password)
|
||||||
session.execute(query)
|
session.add(user)
|
||||||
|
session.flush()
|
||||||
|
session.refresh(user)
|
||||||
|
|
||||||
|
access_token = AccessToken(user_id=user.id)
|
||||||
|
session.add(access_token)
|
||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,6 +11,8 @@ TELEGRAM_API_TOKEN="123456789:AABBCCDDEEFFaabbccddeeff-1234567890"
|
|||||||
# set to true to start with webhook. Else bot will start on polling method
|
# set to true to start with webhook. Else bot will start on polling method
|
||||||
START_WITH_WEBHOOK="false"
|
START_WITH_WEBHOOK="false"
|
||||||
|
|
||||||
|
SUPERUSER="Superuser"
|
||||||
|
|
||||||
# ==== domain settings ====
|
# ==== domain settings ====
|
||||||
DOMAIN="http://localhost"
|
DOMAIN="http://localhost"
|
||||||
URL_PREFIX=
|
URL_PREFIX=
|
||||||
|
@ -6,7 +6,9 @@ from sqlalchemy import desc
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.bot.models.chatgpt import ChatGptModels
|
from core.bot.models.chatgpt import ChatGptModels
|
||||||
|
from settings.config import AppSettings
|
||||||
from tests.integration.factories.bot import ChatGptModelFactory
|
from tests.integration.factories.bot import ChatGptModelFactory
|
||||||
|
from tests.integration.factories.user import AccessTokenFactory, UserFactory
|
||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.asyncio,
|
pytest.mark.asyncio,
|
||||||
@ -51,11 +53,18 @@ async def test_change_chatgpt_model_priority(
|
|||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
rest_client: AsyncClient,
|
rest_client: AsyncClient,
|
||||||
faker: Faker,
|
faker: Faker,
|
||||||
|
test_settings: AppSettings,
|
||||||
) -> None:
|
) -> None:
|
||||||
model1 = ChatGptModelFactory(priority=0)
|
model1 = ChatGptModelFactory(priority=0)
|
||||||
model2 = ChatGptModelFactory(priority=1)
|
model2 = ChatGptModelFactory(priority=1)
|
||||||
priority = faker.random_int(min=2, max=7)
|
priority = faker.random_int(min=2, max=7)
|
||||||
response = await rest_client.put(url=f"/api/chatgpt/models/{model2.id}/priority", json={"priority": priority})
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
access_token = AccessTokenFactory(user_id=user.id)
|
||||||
|
response = await rest_client.put(
|
||||||
|
url=f"/api/chatgpt/models/{model2.id}/priority",
|
||||||
|
json={"priority": priority},
|
||||||
|
headers={"BOT-API-KEY": access_token.token},
|
||||||
|
)
|
||||||
assert response.status_code == 202
|
assert response.status_code == 202
|
||||||
|
|
||||||
upd_model1, upd_model2 = dbsession.query(ChatGptModels).order_by(ChatGptModels.priority).all()
|
upd_model1, upd_model2 = dbsession.query(ChatGptModels).order_by(ChatGptModels.priority).all()
|
||||||
@ -69,11 +78,18 @@ async def test_change_chatgpt_model_priority(
|
|||||||
async def test_reset_chatgpt_models_priority(
|
async def test_reset_chatgpt_models_priority(
|
||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
rest_client: AsyncClient,
|
rest_client: AsyncClient,
|
||||||
|
test_settings: AppSettings,
|
||||||
) -> None:
|
) -> None:
|
||||||
ChatGptModelFactory.create_batch(size=4)
|
ChatGptModelFactory.create_batch(size=4)
|
||||||
ChatGptModelFactory(priority=42)
|
ChatGptModelFactory(priority=42)
|
||||||
|
|
||||||
response = await rest_client.put(url="/api/chatgpt/models/priority/reset")
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
access_token = AccessTokenFactory(user_id=user.id)
|
||||||
|
|
||||||
|
response = await rest_client.put(
|
||||||
|
url="/api/chatgpt/models/priority/reset",
|
||||||
|
headers={"BOT-API-KEY": access_token.token},
|
||||||
|
)
|
||||||
assert response.status_code == 202
|
assert response.status_code == 202
|
||||||
|
|
||||||
models = dbsession.query(ChatGptModels).all()
|
models = dbsession.query(ChatGptModels).all()
|
||||||
@ -89,10 +105,14 @@ async def test_create_new_chatgpt_model(
|
|||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
rest_client: AsyncClient,
|
rest_client: AsyncClient,
|
||||||
faker: Faker,
|
faker: Faker,
|
||||||
|
test_settings: AppSettings,
|
||||||
) -> None:
|
) -> None:
|
||||||
ChatGptModelFactory.create_batch(size=2)
|
ChatGptModelFactory.create_batch(size=2)
|
||||||
ChatGptModelFactory(priority=42)
|
ChatGptModelFactory(priority=42)
|
||||||
|
|
||||||
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
access_token = AccessTokenFactory(user_id=user.id)
|
||||||
|
|
||||||
model_name = "new-gpt-model"
|
model_name = "new-gpt-model"
|
||||||
model_priority = faker.random_int(min=1, max=5)
|
model_priority = faker.random_int(min=1, max=5)
|
||||||
|
|
||||||
@ -105,6 +125,7 @@ async def test_create_new_chatgpt_model(
|
|||||||
"model": model_name,
|
"model": model_name,
|
||||||
"priority": model_priority,
|
"priority": model_priority,
|
||||||
},
|
},
|
||||||
|
headers={"BOT-API-KEY": access_token.token},
|
||||||
)
|
)
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
|
|
||||||
@ -125,9 +146,12 @@ async def test_add_existing_chatgpt_model(
|
|||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
rest_client: AsyncClient,
|
rest_client: AsyncClient,
|
||||||
faker: Faker,
|
faker: Faker,
|
||||||
|
test_settings: AppSettings,
|
||||||
) -> None:
|
) -> None:
|
||||||
ChatGptModelFactory.create_batch(size=2)
|
ChatGptModelFactory.create_batch(size=2)
|
||||||
model = ChatGptModelFactory(priority=42)
|
model = ChatGptModelFactory(priority=42)
|
||||||
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
access_token = AccessTokenFactory(user_id=user.id)
|
||||||
|
|
||||||
model_name = model.model
|
model_name = model.model
|
||||||
model_priority = faker.random_int(min=1, max=5)
|
model_priority = faker.random_int(min=1, max=5)
|
||||||
@ -141,6 +165,7 @@ async def test_add_existing_chatgpt_model(
|
|||||||
"model": model_name,
|
"model": model_name,
|
||||||
"priority": model_priority,
|
"priority": model_priority,
|
||||||
},
|
},
|
||||||
|
headers={"BOT-API-KEY": access_token.token},
|
||||||
)
|
)
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
|
|
||||||
@ -151,14 +176,21 @@ async def test_add_existing_chatgpt_model(
|
|||||||
async def test_delete_chatgpt_model(
|
async def test_delete_chatgpt_model(
|
||||||
dbsession: Session,
|
dbsession: Session,
|
||||||
rest_client: AsyncClient,
|
rest_client: AsyncClient,
|
||||||
|
test_settings: AppSettings,
|
||||||
) -> None:
|
) -> None:
|
||||||
ChatGptModelFactory.create_batch(size=2)
|
ChatGptModelFactory.create_batch(size=2)
|
||||||
model = ChatGptModelFactory(priority=42)
|
model = ChatGptModelFactory(priority=42)
|
||||||
|
|
||||||
|
user = UserFactory(username=test_settings.SUPERUSER)
|
||||||
|
access_token = AccessTokenFactory(user_id=user.id)
|
||||||
|
|
||||||
models = dbsession.query(ChatGptModels).all()
|
models = dbsession.query(ChatGptModels).all()
|
||||||
assert len(models) == 3
|
assert len(models) == 3
|
||||||
|
|
||||||
response = await rest_client.delete(url=f"/api/chatgpt/models/{model.id}")
|
response = await rest_client.delete(
|
||||||
|
url=f"/api/chatgpt/models/{model.id}",
|
||||||
|
headers={"BOT-API-KEY": access_token.token},
|
||||||
|
)
|
||||||
assert response.status_code == 204
|
assert response.status_code == 204
|
||||||
|
|
||||||
models = dbsession.query(ChatGptModels).all()
|
models = dbsession.query(ChatGptModels).all()
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
import factory
|
import factory
|
||||||
|
|
||||||
from core.auth.models.users import User
|
from core.auth.models.users import AccessToken, User
|
||||||
from tests.integration.factories.utils import BaseModelFactory
|
from tests.integration.factories.utils import BaseModelFactory
|
||||||
|
|
||||||
|
|
||||||
class UserFactory(BaseModelFactory):
|
class UserFactory(BaseModelFactory):
|
||||||
id = factory.Sequence(lambda n: n + 1)
|
id = factory.Sequence(lambda n: n + 1)
|
||||||
email = factory.Faker("email")
|
email = factory.Faker("email")
|
||||||
username = factory.Faker("user_name", locale="en_EN")
|
username = factory.Faker("user_name", locale="en")
|
||||||
first_name = factory.Faker("word")
|
first_name = factory.Faker("word")
|
||||||
last_name = factory.Faker("word")
|
last_name = factory.Faker("word")
|
||||||
ban_reason = factory.Faker("text", max_nb_chars=100)
|
ban_reason = factory.Faker("text", max_nb_chars=100)
|
||||||
@ -18,3 +20,12 @@ class UserFactory(BaseModelFactory):
|
|||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = User
|
model = User
|
||||||
|
|
||||||
|
|
||||||
|
class AccessTokenFactory(BaseModelFactory):
|
||||||
|
user_id = factory.Sequence(lambda n: n + 1)
|
||||||
|
token = factory.LazyAttribute(lambda o: str(uuid.uuid4()))
|
||||||
|
created_at = factory.Faker("past_datetime")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = AccessToken
|
||||||
|
Loading…
x
Reference in New Issue
Block a user