refactor chat testing client

This commit is contained in:
grillazz 2025-05-03 08:47:08 +02:00
parent 6fd874c2cc
commit 61ba8cc12c

View File

@ -2,29 +2,30 @@ import anyio
import httpx import httpx
import orjson import orjson
API_URL = "http://localhost:8000/chat/"
async def chat_with_endpoint(): async def chat_with_endpoint():
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
while True: while True:
# Get user input
prompt = input("\nYou: ") prompt = input("\nYou: ")
if prompt.lower() == "exit": if prompt.lower() == "exit":
break break
# Send request to the API
print("\nModel: ", end="", flush=True) print("\nModel: ", end="", flush=True)
async with client.stream( try:
"POST", async with client.stream("POST", API_URL, data={"prompt": prompt}, timeout=60) as response:
"http://localhost:8000/chat/", async for chunk in response.aiter_lines():
data={"prompt": prompt}, if not chunk:
timeout=60 continue
) as response:
async for chunk in response.aiter_lines():
if chunk:
try: try:
data = orjson.loads(chunk) print(orjson.loads(chunk)["content"], end="", flush=True)
print(data["content"], end="", flush=True)
except Exception as e: except Exception as e:
print(f"\nError parsing chunk: {e}") print(f"\nError parsing chunk: {e}")
except httpx.RequestError as e:
print(f"\nConnection error: {e}")
if __name__ == "__main__": if __name__ == "__main__":
anyio.run(chat_with_endpoint) anyio.run(chat_with_endpoint)