WIP: Phase 4 forge extract module with tests
This commit is contained in:
155
src/companion/orchestrator.py
Normal file
155
src/companion/orchestrator.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Chat orchestrator - coordinates RAG, LLM, and memory for chat responses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from companion.config import Config
|
||||
from companion.memory import Message, SessionMemory
|
||||
from companion.prompts import build_system_prompt, format_conversation_history
|
||||
from companion.rag.search import SearchEngine
|
||||
|
||||
|
||||
class ChatOrchestrator:
|
||||
"""Orchestrates chat by combining RAG, memory, and LLM streaming."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
search_engine: SearchEngine,
|
||||
session_memory: SessionMemory,
|
||||
):
|
||||
self.config = config
|
||||
self.search = search_engine
|
||||
self.memory = session_memory
|
||||
self.model_endpoint = config.model.inference.backend
|
||||
self.model_path = config.model.inference.model_path
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
session_id: str,
|
||||
user_message: str,
|
||||
use_rag: bool = True,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Process a chat message and yield streaming responses.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
user_message: The user's input message
|
||||
use_rag: Whether to retrieve context from vault
|
||||
|
||||
Yields:
|
||||
Streaming response chunks (SSE format)
|
||||
"""
|
||||
# Retrieve relevant context from RAG if enabled
|
||||
retrieved_context: list[dict[str, Any]] = []
|
||||
if use_rag:
|
||||
try:
|
||||
retrieved_context = self.search.search(
|
||||
user_message, top_k=self.config.rag.search.default_top_k
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but continue without RAG context
|
||||
print(f"RAG retrieval failed: {e}")
|
||||
|
||||
# Get conversation history
|
||||
history = self.memory.get_messages(
|
||||
session_id, limit=self.config.companion.memory.session_turns
|
||||
)
|
||||
|
||||
# Format history for prompt
|
||||
history_formatted = format_conversation_history(
|
||||
[{"role": msg.role, "content": msg.content} for msg in history]
|
||||
)
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = build_system_prompt(
|
||||
persona=self.config.companion.persona.model_dump(),
|
||||
retrieved_context=retrieved_context if retrieved_context else None,
|
||||
memory_context=history_formatted if history_formatted else None,
|
||||
)
|
||||
|
||||
# Add user message to memory
|
||||
self.memory.add_message(session_id, "user", user_message)
|
||||
|
||||
# Build messages for LLM
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
|
||||
# Stream from LLM (using llama.cpp server or similar)
|
||||
full_response = ""
|
||||
async for chunk in self._stream_llm_response(messages):
|
||||
yield chunk
|
||||
if chunk.startswith("data: "):
|
||||
data = chunk[6:] # Remove "data: " prefix
|
||||
if data != "[DONE]":
|
||||
try:
|
||||
delta = json.loads(data)
|
||||
if "content" in delta:
|
||||
full_response += delta["content"]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Store assistant response in memory
|
||||
if full_response:
|
||||
self.memory.add_message(
|
||||
session_id,
|
||||
"assistant",
|
||||
full_response,
|
||||
metadata={"rag_context_used": len(retrieved_context) > 0},
|
||||
)
|
||||
|
||||
async def _stream_llm_response(
|
||||
self, messages: list[dict[str, Any]]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream response from local LLM.
|
||||
|
||||
Uses llama.cpp HTTP server format or Ollama API.
|
||||
"""
|
||||
# Try llama.cpp server first
|
||||
base_url = self.config.model.inference.backend
|
||||
if base_url == "llama.cpp":
|
||||
base_url = "http://localhost:8080" # Default llama.cpp server port
|
||||
|
||||
# Default to Ollama API
|
||||
if base_url not in ["llama.cpp", "http://localhost:8080"]:
|
||||
base_url = self.config.rag.embedding.base_url.replace("/api", "")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
# Try Ollama chat endpoint
|
||||
response = await client.post(
|
||||
f"{base_url}/api/chat",
|
||||
json={
|
||||
"model": self.config.rag.embedding.model.replace("-embed", ""),
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"options": {
|
||||
"temperature": self.config.companion.chat.default_temperature,
|
||||
},
|
||||
},
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.strip():
|
||||
yield f"data: {line}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
yield f'data: {{"error": "LLM streaming failed: {str(e)}"}}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def get_session_history(self, session_id: str, limit: int = 50) -> list[Message]:
|
||||
"""Get conversation history for a session."""
|
||||
return self.memory.get_messages(session_id, limit=limit)
|
||||
|
||||
def clear_session(self, session_id: str) -> None:
|
||||
"""Clear a session's conversation history."""
|
||||
self.memory.clear_session(session_id)
|
||||
Reference in New Issue
Block a user