diff --git a/src/companion/api.py b/src/companion/api.py index 0dff87e..2dc6015 100644 --- a/src/companion/api.py +++ b/src/companion/api.py @@ -25,7 +25,8 @@ class ChatRequest(BaseModel): message: str session_id: str | None = None - temperature: float | None = None + stream: bool = True + use_rag: bool = True class ChatResponse(BaseModel): @@ -118,78 +119,53 @@ async def chat(request: ChatRequest) -> EventSourceResponse: # Generate or use existing session ID session_id = request.session_id or str(uuid.uuid4()) - # Validate temperature if provided - temperature = request.temperature - if temperature is not None and not 0.0 <= temperature <= 2.0: - raise HTTPException( - status_code=400, detail="Temperature must be between 0.0 and 2.0" - ) - async def event_generator() -> AsyncGenerator[str, None]: """Generate SSE events.""" try: - # Get conversation history - history = memory.get_history(session_id) - - # Retrieve relevant context - context_chunks = search_engine.search( - request.message, - top_k=config.rag.search.default_top_k, - ) - - # Stream response from LLM - full_response = "" - async for chunk in orchestrator.stream_response( - query=request.message, - history=history, - context_chunks=context_chunks, - temperature=temperature, + # Stream response from orchestrator + async for chunk in orchestrator.chat( + session_id=session_id, + user_message=request.message, + use_rag=request.use_rag, ): - full_response += chunk - yield json.dumps( - { - "type": "chunk", - "content": chunk, - } - ) - - # Store the conversation - memory.add_message(session_id, "user", request.message) - memory.add_message(session_id, "assistant", full_response) - - # Send sources - sources = [ - { - "file": chunk.get("source_file", ""), - "section": chunk.get("section"), - "date": chunk.get("date"), - } - for chunk in context_chunks[:5] # Top 5 sources - ] - yield json.dumps( - { - "type": "sources", - "sources": sources, - } - ) - - # Send done event - yield json.dumps( - { - "type": "done", - "session_id": session_id, - } - ) + # Parse the SSE data format + if chunk.startswith("data: "): + data = chunk[6:] + if data == "[DONE]": + yield json.dumps({"type": "done", "session_id": session_id}) + else: + try: + parsed = json.loads(data) + if "content" in parsed: + yield json.dumps( + {"type": "chunk", "content": parsed["content"]} + ) + elif "citations" in parsed: + yield json.dumps( + { + "type": "citations", + "citations": parsed["citations"], + } + ) + elif "error" in parsed: + yield json.dumps( + {"type": "error", "message": parsed["error"]} + ) + else: + yield data + except json.JSONDecodeError: + # Pass through raw data + yield data + else: + yield chunk except Exception as e: - yield json.dumps( - { - "type": "error", - "message": str(e), - } - ) + yield json.dumps({"type": "error", "message": str(e)}) - return EventSourceResponse(event_generator()) + return EventSourceResponse( + event_generator(), + headers={"X-Session-ID": session_id}, + ) @app.get("/sessions/{session_id}/history") diff --git a/src/companion/orchestrator.py b/src/companion/orchestrator.py index ab05b5a..1ed8bbf 100644 --- a/src/companion/orchestrator.py +++ b/src/companion/orchestrator.py @@ -12,7 +12,7 @@ import httpx from companion.config import Config from companion.memory import Message, SessionMemory from companion.prompts import build_system_prompt, format_conversation_history -from companion.rag.search import SearchEngine +from companion.rag.search import SearchEngine, SearchResult class ChatOrchestrator: @@ -47,7 +47,7 @@ class ChatOrchestrator: Streaming response chunks (SSE format) """ # Retrieve relevant context from RAG if enabled - retrieved_context: list[dict[str, Any]] = [] + retrieved_context: list[SearchResult] = [] if use_rag: try: retrieved_context = self.search.search( @@ -103,9 +103,21 @@ class ChatOrchestrator: session_id, "assistant", full_response, - metadata={"rag_context_used": len(retrieved_context) > 0}, + metadata={ + "rag_context_used": len(retrieved_context) > 0, + "citations": [ctx.to_dict() for ctx in retrieved_context[:5]] + if retrieved_context + else [], + }, ) + # Yield citations after the main response + if retrieved_context: + citations_data = { + "citations": [ctx.to_dict() for ctx in retrieved_context[:5]] + } + yield f"data: {json.dumps(citations_data)}\n\n" + async def _stream_llm_response( self, messages: list[dict[str, Any]] ) -> AsyncGenerator[str, None]: diff --git a/src/companion/prompts.py b/src/companion/prompts.py index 0afd922..4e84f9d 100644 --- a/src/companion/prompts.py +++ b/src/companion/prompts.py @@ -4,10 +4,12 @@ from __future__ import annotations from typing import Any +from companion.rag.search import SearchResult + def build_system_prompt( persona: dict[str, Any], - retrieved_context: list[dict[str, Any]] | None = None, + retrieved_context: list[SearchResult] | None = None, memory_context: str | None = None, ) -> str: """Build the system prompt for the companion. @@ -49,10 +51,9 @@ Core principles: if retrieved_context: context_section += "\n\nRelevant context from your vault:\n" for i, ctx in enumerate(retrieved_context[:8], 1): # Limit to top 8 chunks - source = ctx.get("source_file", "unknown") - text = ctx.get("text", "").strip() + text = ctx.text.strip() if text: - context_section += f"[{i}] From {source}:\n{text}\n\n" + context_section += f"[{i}] From {ctx.citation}:\n{text}\n\n" # Add memory context if available memory_section = "" diff --git a/src/companion/rag/search.py b/src/companion/rag/search.py index 598ceb1..eece647 100644 --- a/src/companion/rag/search.py +++ b/src/companion/rag/search.py @@ -1,9 +1,52 @@ +from dataclasses import dataclass from typing import Any from companion.rag.embedder import OllamaEmbedder from companion.rag.vector_store import VectorStore +@dataclass +class SearchResult: + """Structured search result with citation information.""" + + id: str + text: str + source_file: str + source_directory: str + section: str | None + date: str | None + tags: list[str] + chunk_index: int + total_chunks: int + distance: float + + @property + def citation(self) -> str: + """Generate a citation string for this result.""" + parts = [self.source_file] + if self.section: + parts.append(f"#{self.section}") + if self.date: + parts.append(f"({self.date})") + return " - ".join(parts) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for API serialization.""" + return { + "id": self.id, + "text": self.text, + "source_file": self.source_file, + "source_directory": self.source_directory, + "section": self.section, + "date": self.date, + "tags": self.tags, + "chunk_index": self.chunk_index, + "total_chunks": self.total_chunks, + "distance": self.distance, + "citation": self.citation, + } + + class SearchEngine: """Search engine for semantic search using vector embeddings. @@ -50,7 +93,7 @@ class SearchEngine: query: str, top_k: int | None = None, filters: dict[str, Any] | None = None, - ) -> list[dict[str, Any]]: + ) -> list[SearchResult]: """Search for relevant documents using semantic similarity. Args: @@ -59,7 +102,7 @@ class SearchEngine: filters: Optional metadata filters to apply Returns: - List of matching documents with similarity scores + List of SearchResult objects with similarity scores Raises: RuntimeError: If embedding generation fails @@ -76,14 +119,33 @@ class SearchEngine: except RuntimeError as e: raise RuntimeError(f"Failed to generate embedding for query: {e}") from e - results = self.vector_store.search(query_embedding, top_k=k, filters=filters) + raw_results = self.vector_store.search( + query_embedding, top_k=k, filters=filters + ) - if self.similarity_threshold > 0 and results: - results = [ + if self.similarity_threshold > 0 and raw_results: + raw_results = [ r - for r in results + for r in raw_results if r.get(self._DISTANCE_FIELD, float("inf")) <= self.similarity_threshold ] + # Convert raw results to SearchResult objects + results: list[SearchResult] = [] + for r in raw_results: + result = SearchResult( + id=r.get("id", ""), + text=r.get("text", ""), + source_file=r.get("source_file", ""), + source_directory=r.get("source_directory", ""), + section=r.get("section"), + date=r.get("date"), + tags=r.get("tags") or [], + chunk_index=r.get("chunk_index", 0), + total_chunks=r.get("total_chunks", 1), + distance=r.get(self._DISTANCE_FIELD, 1.0), + ) + results.append(result) + return results diff --git a/ui/src/App.tsx b/ui/src/App.tsx index 358896e..5804f70 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -2,7 +2,8 @@ import { useState, useRef, useEffect } from 'react' import './App.css' import MessageList from './components/MessageList' import ChatInput from './components/ChatInput' -import { useChatStream } from './hooks/useChatStream' +import CitationsPanel from './components/CitationsPanel' +import { useChatStream, Citation } from './hooks/useChatStream' export interface Message { role: 'user' | 'assistant' @@ -13,15 +14,13 @@ function App() { const [messages, setMessages] = useState([]) const [input, setInput] = useState('') const [isLoading, setIsLoading] = useState(false) + const [citations, setCitations] = useState([]) + const [showCitations, setShowCitations] = useState(false) const messagesEndRef = useRef(null) const { sendMessage } = useChatStream() - const scrollToBottom = () => { - messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }) - } - useEffect(() => { - scrollToBottom() + messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }) }, [messages]) const handleSend = async () => { @@ -31,24 +30,37 @@ function App() { setInput('') setMessages(prev => [...prev, { role: 'user', content: userMessage }]) setIsLoading(true) + setCitations([]) + setShowCitations(false) let assistantContent = '' - await sendMessage(userMessage, (chunk) => { - assistantContent += chunk - setMessages(prev => { - const newMessages = [...prev] - const lastMsg = newMessages[newMessages.length - 1] - if (lastMsg?.role === 'assistant') { - lastMsg.content = assistantContent - } else { - newMessages.push({ role: 'assistant', content: assistantContent }) - } - return newMessages - }) + await sendMessage(userMessage, { + onChunk: (chunk) => { + assistantContent += chunk + setMessages(prev => { + const newMessages = [...prev] + const lastMsg = newMessages[newMessages.length - 1] + if (lastMsg?.role === 'assistant') { + lastMsg.content = assistantContent + } else { + newMessages.push({ role: 'assistant', content: assistantContent }) + } + return newMessages + }) + }, + onCitations: (newCitations) => { + setCitations(newCitations) + setShowCitations(newCitations.length > 0) + }, + onDone: () => { + setIsLoading(false) + }, + onError: (error) => { + console.error('Stream error:', error) + setIsLoading(false) + } }) - - setIsLoading(false) } const handleKeyDown = (e: React.KeyboardEvent) => { @@ -76,6 +88,11 @@ function App() { disabled={isLoading} /> + setShowCitations(false)} + /> ) } diff --git a/ui/src/components/CitationsPanel.css b/ui/src/components/CitationsPanel.css new file mode 100644 index 0000000..84750b9 --- /dev/null +++ b/ui/src/components/CitationsPanel.css @@ -0,0 +1,124 @@ +/* CitationsPanel.css */ +.citations-overlay { + position: fixed; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: rgba(0, 0, 0, 0.5); + z-index: 100; + border: none; + width: 100%; + cursor: pointer; +} + +.citations-panel { + position: fixed; + top: 0; + right: 0; + width: 400px; + height: 100vh; + background: var(--bg-secondary); + border-left: 1px solid var(--border); + z-index: 101; + display: flex; + flex-direction: column; + animation: slideIn 0.2s ease-out; +} + +@keyframes slideIn { + from { + transform: translateX(100%); + } + to { + transform: translateX(0); + } +} + +.citations-header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 16px 20px; + border-bottom: 1px solid var(--border); +} + +.citations-header h3 { + margin: 0; + font-size: 16px; + font-weight: 600; + color: var(--text-primary); +} + +.close-button { + background: none; + border: none; + color: var(--text-secondary); + font-size: 24px; + cursor: pointer; + padding: 0 4px; + line-height: 1; +} + +.close-button:hover { + color: var(--text-primary); +} + +.citations-list { + flex: 1; + overflow-y: auto; + padding: 16px 20px; +} + +.citation-item { + display: flex; + gap: 12px; + padding: 16px; + margin-bottom: 12px; + background: var(--bg-tertiary); + border-radius: 8px; + border: 1px solid var(--border); +} + +.citation-number { + flex-shrink: 0; + font-weight: 600; + color: var(--accent-primary); + font-size: 14px; +} + +.citation-content { + flex: 1; + min-width: 0; +} + +.citation-source { + font-size: 12px; + color: var(--text-secondary); + margin-bottom: 8px; + word-break: break-all; +} + +.citation-text { + font-size: 14px; + color: var(--text-primary); + line-height: 1.5; + white-space: pre-wrap; + word-wrap: break-word; +} + +.citation-tags { + display: flex; + flex-wrap: wrap; + gap: 6px; + margin-top: 12px; +} + +.citation-tag { + font-size: 11px; + padding: 2px 8px; + background: var(--bg-secondary); + color: var(--text-secondary); + border-radius: 4px; + border: 1px solid var(--border); +} diff --git a/ui/src/components/CitationsPanel.tsx b/ui/src/components/CitationsPanel.tsx new file mode 100644 index 0000000..6fd9693 --- /dev/null +++ b/ui/src/components/CitationsPanel.tsx @@ -0,0 +1,47 @@ +import { Citation } from '../hooks/useChatStream' +import './CitationsPanel.css' + +interface CitationsPanelProps { + citations: Citation[] + isOpen: boolean + onClose: () => void +} + +export default function CitationsPanel({ citations, isOpen, onClose }: CitationsPanelProps) { + if (!isOpen || citations.length === 0) return null + + return ( + <> + + +
+ {citations.map((citation, index) => ( +
+
[{index + 1}]
+
+
{citation.citation}
+
{citation.text}
+ {citation.tags && citation.tags.length > 0 && ( +
+ {citation.tags.map(tag => ( + {tag} + ))} +
+ )} +
+
+ ))} +
+ + + ) +} diff --git a/ui/src/hooks/useChatStream.ts b/ui/src/hooks/useChatStream.ts index aea3445..2b23734 100644 --- a/ui/src/hooks/useChatStream.ts +++ b/ui/src/hooks/useChatStream.ts @@ -2,13 +2,33 @@ import { useState } from 'react' const API_BASE = '/api' +export interface Citation { + id: string + text: string + source_file: string + source_directory: string + section: string | null + date: string | null + tags: string[] + citation: string +} + +export interface StreamCallbacks { + onChunk: (chunk: string) => void + onCitations?: (citations: Citation[]) => void + onError?: (error: string) => void + onDone?: () => void +} + export function useChatStream() { const [sessionId, setSessionId] = useState(null) const sendMessage = async ( message: string, - onChunk: (chunk: string) => void + callbacks: StreamCallbacks ): Promise => { + const { onChunk, onCitations, onError, onDone } = callbacks + const response = await fetch(`${API_BASE}/chat`, { method: 'POST', headers: { @@ -18,6 +38,7 @@ export function useChatStream() { message, session_id: sessionId, stream: true, + use_rag: true, }), }) @@ -50,15 +71,33 @@ export function useChatStream() { if (line.startsWith('data: ')) { const data = line.slice(6) if (data === '[DONE]') { + onDone?.() return } try { const parsed = JSON.parse(data) - if (parsed.content) { - onChunk(parsed.content) + switch (parsed.type) { + case 'chunk': + if (parsed.content) { + onChunk(parsed.content) + } + break + case 'citations': + if (parsed.citations && onCitations) { + onCitations(parsed.citations as Citation[]) + } + break + case 'error': + if (parsed.message && onError) { + onError(parsed.message) + } + break + case 'done': + onDone?.() + return } } catch { - // Ignore parse errors + // Ignore parse errors for non-JSON data } } }