Files
kv-ai/src/companion/orchestrator.py

156 lines
5.4 KiB
Python

"""Chat orchestrator - coordinates RAG, LLM, and memory for chat responses."""
from __future__ import annotations
import json
import time
from collections.abc import AsyncGenerator
from typing import Any
import httpx
from companion.config import Config
from companion.memory import Message, SessionMemory
from companion.prompts import build_system_prompt, format_conversation_history
from companion.rag.search import SearchEngine
class ChatOrchestrator:
"""Orchestrates chat by combining RAG, memory, and LLM streaming."""
def __init__(
self,
config: Config,
search_engine: SearchEngine,
session_memory: SessionMemory,
):
self.config = config
self.search = search_engine
self.memory = session_memory
self.model_endpoint = config.model.inference.backend
self.model_path = config.model.inference.model_path
async def chat(
self,
session_id: str,
user_message: str,
use_rag: bool = True,
) -> AsyncGenerator[str, None]:
"""Process a chat message and yield streaming responses.
Args:
session_id: Unique session identifier
user_message: The user's input message
use_rag: Whether to retrieve context from vault
Yields:
Streaming response chunks (SSE format)
"""
# Retrieve relevant context from RAG if enabled
retrieved_context: list[dict[str, Any]] = []
if use_rag:
try:
retrieved_context = self.search.search(
user_message, top_k=self.config.rag.search.default_top_k
)
except Exception as e:
# Log error but continue without RAG context
print(f"RAG retrieval failed: {e}")
# Get conversation history
history = self.memory.get_messages(
session_id, limit=self.config.companion.memory.session_turns
)
# Format history for prompt
history_formatted = format_conversation_history(
[{"role": msg.role, "content": msg.content} for msg in history]
)
# Build system prompt
system_prompt = build_system_prompt(
persona=self.config.companion.persona.model_dump(),
retrieved_context=retrieved_context if retrieved_context else None,
memory_context=history_formatted if history_formatted else None,
)
# Add user message to memory
self.memory.add_message(session_id, "user", user_message)
# Build messages for LLM
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
]
# Stream from LLM (using llama.cpp server or similar)
full_response = ""
async for chunk in self._stream_llm_response(messages):
yield chunk
if chunk.startswith("data: "):
data = chunk[6:] # Remove "data: " prefix
if data != "[DONE]":
try:
delta = json.loads(data)
if "content" in delta:
full_response += delta["content"]
except json.JSONDecodeError:
pass
# Store assistant response in memory
if full_response:
self.memory.add_message(
session_id,
"assistant",
full_response,
metadata={"rag_context_used": len(retrieved_context) > 0},
)
async def _stream_llm_response(
self, messages: list[dict[str, Any]]
) -> AsyncGenerator[str, None]:
"""Stream response from local LLM.
Uses llama.cpp HTTP server format or Ollama API.
"""
# Try llama.cpp server first
base_url = self.config.model.inference.backend
if base_url == "llama.cpp":
base_url = "http://localhost:8080" # Default llama.cpp server port
# Default to Ollama API
if base_url not in ["llama.cpp", "http://localhost:8080"]:
base_url = self.config.rag.embedding.base_url.replace("/api", "")
try:
async with httpx.AsyncClient(timeout=120.0) as client:
# Try Ollama chat endpoint
response = await client.post(
f"{base_url}/api/chat",
json={
"model": self.config.rag.embedding.model.replace("-embed", ""),
"messages": messages,
"stream": True,
"options": {
"temperature": self.config.companion.chat.default_temperature,
},
},
timeout=120.0,
)
async for line in response.aiter_lines():
if line.strip():
yield f"data: {line}\n\n"
except Exception as e:
yield f'data: {{"error": "LLM streaming failed: {str(e)}"}}\n\n'
yield "data: [DONE]\n\n"
def get_session_history(self, session_id: str, limit: int = 50) -> list[Message]:
"""Get conversation history for a session."""
return self.memory.get_messages(session_id, limit=limit)
def clear_session(self, session_id: str) -> None:
"""Clear a session's conversation history."""
self.memory.clear_session(session_id)