"""Chat orchestrator - coordinates RAG, LLM, and memory for chat responses.""" from __future__ import annotations import json import time from collections.abc import AsyncGenerator from typing import Any import httpx from companion.config import Config from companion.memory import Message, SessionMemory from companion.prompts import build_system_prompt, format_conversation_history from companion.rag.search import SearchEngine class ChatOrchestrator: """Orchestrates chat by combining RAG, memory, and LLM streaming.""" def __init__( self, config: Config, search_engine: SearchEngine, session_memory: SessionMemory, ): self.config = config self.search = search_engine self.memory = session_memory self.model_endpoint = config.model.inference.backend self.model_path = config.model.inference.model_path async def chat( self, session_id: str, user_message: str, use_rag: bool = True, ) -> AsyncGenerator[str, None]: """Process a chat message and yield streaming responses. Args: session_id: Unique session identifier user_message: The user's input message use_rag: Whether to retrieve context from vault Yields: Streaming response chunks (SSE format) """ # Retrieve relevant context from RAG if enabled retrieved_context: list[dict[str, Any]] = [] if use_rag: try: retrieved_context = self.search.search( user_message, top_k=self.config.rag.search.default_top_k ) except Exception as e: # Log error but continue without RAG context print(f"RAG retrieval failed: {e}") # Get conversation history history = self.memory.get_messages( session_id, limit=self.config.companion.memory.session_turns ) # Format history for prompt history_formatted = format_conversation_history( [{"role": msg.role, "content": msg.content} for msg in history] ) # Build system prompt system_prompt = build_system_prompt( persona=self.config.companion.persona.model_dump(), retrieved_context=retrieved_context if retrieved_context else None, memory_context=history_formatted if history_formatted else None, ) # Add user message to memory self.memory.add_message(session_id, "user", user_message) # Build messages for LLM messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}, ] # Stream from LLM (using llama.cpp server or similar) full_response = "" async for chunk in self._stream_llm_response(messages): yield chunk if chunk.startswith("data: "): data = chunk[6:] # Remove "data: " prefix if data != "[DONE]": try: delta = json.loads(data) if "content" in delta: full_response += delta["content"] except json.JSONDecodeError: pass # Store assistant response in memory if full_response: self.memory.add_message( session_id, "assistant", full_response, metadata={"rag_context_used": len(retrieved_context) > 0}, ) async def _stream_llm_response( self, messages: list[dict[str, Any]] ) -> AsyncGenerator[str, None]: """Stream response from local LLM. Uses llama.cpp HTTP server format or Ollama API. """ # Try llama.cpp server first base_url = self.config.model.inference.backend if base_url == "llama.cpp": base_url = "http://localhost:8080" # Default llama.cpp server port # Default to Ollama API if base_url not in ["llama.cpp", "http://localhost:8080"]: base_url = self.config.rag.embedding.base_url.replace("/api", "") try: async with httpx.AsyncClient(timeout=120.0) as client: # Try Ollama chat endpoint response = await client.post( f"{base_url}/api/chat", json={ "model": self.config.rag.embedding.model.replace("-embed", ""), "messages": messages, "stream": True, "options": { "temperature": self.config.companion.chat.default_temperature, }, }, timeout=120.0, ) async for line in response.aiter_lines(): if line.strip(): yield f"data: {line}\n\n" except Exception as e: yield f'data: {{"error": "LLM streaming failed: {str(e)}"}}\n\n' yield "data: [DONE]\n\n" def get_session_history(self, session_id: str, limit: int = 50) -> list[Message]: """Get conversation history for a session.""" return self.memory.get_messages(session_id, limit=limit) def clear_session(self, session_id: str) -> None: """Clear a session's conversation history.""" self.memory.clear_session(session_id)