diff --git a/.env b/.env index ef79d5e..a447112 100644 --- a/.env +++ b/.env @@ -2,14 +2,14 @@ PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 # Postgres -POSTGRES_HOST=db +POSTGRES_HOST=localhost POSTGRES_PORT=5432 POSTGRES_DB=devdb POSTGRES_USER=devdb POSTGRES_PASSWORD=secret # Redis -REDIS_HOST=inmemory +REDIS_HOST=localhost REDIS_PORT=6379 REDIS_DB=2 diff --git a/app/services/llm.py b/app/services/llm.py index eb8903c..db29c28 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -10,36 +10,42 @@ class StreamLLMService: async def stream_chat(self, prompt: str) -> AsyncGenerator[bytes, None]: """Stream chat completion responses from LLM.""" - # Send initial user message - yield orjson.dumps({"role": "user", "content": prompt}) + b"\n" + # Send user message first + user_msg = { + "role": "user", + "content": prompt, + } + yield orjson.dumps(user_msg) + b"\n" + # Open client as context manager and stream responses async with httpx.AsyncClient(base_url=self.base_url) as client: - request_data = { - "model": self.model, - "messages": [{"role": "user", "content": prompt}], - "stream": True, - } - async with client.stream( - "POST", "/chat/completions", json=request_data, timeout=60.0 + "POST", + "/chat/completions", + json={ + "model": self.model, + "messages": [{"role": "user", "content": prompt}], + "stream": True, + }, + timeout=60.0, ) as response: async for line in response.aiter_lines(): - if not (line.startswith("data: ") and line != "data: [DONE]"): - continue - try: - data = orjson.loads(line[6:]) # Skip "data: " prefix - if ( - content := data.get("choices", [{}])[0] - .get("delta", {}) - .get("content", "") - ): - yield ( - orjson.dumps({"role": "model", "content": content}) - + b"\n" + if line.startswith("data: ") and line != "data: [DONE]": + try: + json_line = line[6:] # Remove "data: " prefix + data = orjson.loads(json_line) + content = ( + data.get("choices", [{}])[0] + .get("delta", {}) + .get("content", "") ) - except Exception: - pass + if content: + model_msg = {"role": "model", "content": content} + yield orjson.dumps(model_msg) + b"\n" + except Exception: + pass +# FastAPI dependency def get_llm_service(base_url: Optional[str] = None) -> StreamLLMService: - return StreamLLMService(base_url=base_url) + return StreamLLMService(base_url=base_url or "http://localhost:11434/v1") diff --git a/compose.yml b/compose.yml index 892008b..9640db2 100644 --- a/compose.yml +++ b/compose.yml @@ -1,6 +1,7 @@ services: app: container_name: fsap_app + network_mode: host build: . env_file: - .env @@ -22,6 +23,7 @@ services: db: container_name: fsap_db + network_mode: host build: context: ./db dockerfile: Dockerfile @@ -46,6 +48,7 @@ services: inmemory: image: redis:latest + network_mode: host container_name: fsap_inmemory ports: - "6379:6379" diff --git a/tests/chat.py b/tests/chat.py index 617ee14..25bfa2f 100644 --- a/tests/chat.py +++ b/tests/chat.py @@ -1,33 +1,53 @@ -import anyio +from typing import Optional, AsyncGenerator + import httpx import orjson -API_URL = "http://localhost:8000/chat/" +class StreamLLMService: + def __init__(self, base_url: str = "http://localhost:11434/v1"): + self.base_url = base_url + self.model = "llama3.2" -async def chat_with_endpoint(): - async with httpx.AsyncClient() as client: - while True: - prompt = input("\nYou: ") - if prompt.lower() == "exit": - break - - print("\nModel: ", end="", flush=True) - try: - async with client.stream( - "POST", API_URL, data={"prompt": prompt}, timeout=60 - ) as response: - async for chunk in response.aiter_lines(): - if not chunk: - continue + async def stream_chat(self, prompt: str) -> AsyncGenerator[bytes, None]: + """Stream chat completion responses from LLM.""" + # Send user message first + user_msg = { + "role": "user", + "content": prompt, + } + yield orjson.dumps(user_msg) + b"\n" + # Open client as context manager and stream responses + async with httpx.AsyncClient(base_url=self.base_url) as client: + async with client.stream( + "POST", + "/chat/completions", + json={ + "model": self.model, + "messages": [{"role": "user", "content": prompt}], + "stream": True, + }, + timeout=60.0, + ) as response: + async for line in response.aiter_lines(): + print(line) + if line.startswith("data: ") and line != "data: [DONE]": try: - print(orjson.loads(chunk)["content"], end="", flush=True) - except Exception as e: - print(f"\nError parsing chunk: {e}") - except httpx.RequestError as e: - print(f"\nConnection error: {e}") + json_line = line[6:] # Remove "data: " prefix + data = orjson.loads(json_line) + content = ( + data.get("choices", [{}])[0] + .get("delta", {}) + .get("content", "") + ) + if content: + model_msg = {"role": "model", "content": content} + yield orjson.dumps(model_msg) + b"\n" + except Exception: + pass -if __name__ == "__main__": - anyio.run(chat_with_endpoint) +# FastAPI dependency +def get_llm_service(base_url: Optional[str] = None) -> StreamLLMService: + return StreamLLMService(base_url=base_url or "http://localhost:11434/v1")