diff --git a/app/api/ml.py b/app/api/ml.py index e8d2c77..ca600a9 100644 --- a/app/api/ml.py +++ b/app/api/ml.py @@ -11,9 +11,6 @@ logger = AppLogger().get_logger() router = APIRouter() -@router.post('/chat/') -async def chat( - prompt: Annotated[str, Form()], - llm_service = Depends(get_llm_service) -): - return StreamingResponse(llm_service.stream_chat(prompt), media_type="text/plain") \ No newline at end of file +@router.post("/chat/") +async def chat(prompt: Annotated[str, Form()], llm_service=Depends(get_llm_service)): + return StreamingResponse(llm_service.stream_chat(prompt), media_type="text/plain") diff --git a/tests/chat.py b/tests/chat.py index f99c819..617ee14 100644 --- a/tests/chat.py +++ b/tests/chat.py @@ -14,7 +14,9 @@ async def chat_with_endpoint(): print("\nModel: ", end="", flush=True) try: - async with client.stream("POST", API_URL, data={"prompt": prompt}, timeout=60) as response: + async with client.stream( + "POST", API_URL, data={"prompt": prompt}, timeout=60 + ) as response: async for chunk in response.aiter_lines(): if not chunk: continue @@ -28,4 +30,4 @@ async def chat_with_endpoint(): if __name__ == "__main__": - anyio.run(chat_with_endpoint) \ No newline at end of file + anyio.run(chat_with_endpoint)