feat: Phase 5 - citations, source highlighting, and UI polish

This commit is contained in:
2026-04-13 15:47:47 -04:00
parent e77fa69b31
commit 732555cf55
8 changed files with 381 additions and 103 deletions

View File

@@ -25,7 +25,8 @@ class ChatRequest(BaseModel):
message: str
session_id: str | None = None
temperature: float | None = None
stream: bool = True
use_rag: bool = True
class ChatResponse(BaseModel):
@@ -118,78 +119,53 @@ async def chat(request: ChatRequest) -> EventSourceResponse:
# Generate or use existing session ID
session_id = request.session_id or str(uuid.uuid4())
# Validate temperature if provided
temperature = request.temperature
if temperature is not None and not 0.0 <= temperature <= 2.0:
raise HTTPException(
status_code=400, detail="Temperature must be between 0.0 and 2.0"
)
async def event_generator() -> AsyncGenerator[str, None]:
"""Generate SSE events."""
try:
# Get conversation history
history = memory.get_history(session_id)
# Retrieve relevant context
context_chunks = search_engine.search(
request.message,
top_k=config.rag.search.default_top_k,
)
# Stream response from LLM
full_response = ""
async for chunk in orchestrator.stream_response(
query=request.message,
history=history,
context_chunks=context_chunks,
temperature=temperature,
# Stream response from orchestrator
async for chunk in orchestrator.chat(
session_id=session_id,
user_message=request.message,
use_rag=request.use_rag,
):
full_response += chunk
yield json.dumps(
{
"type": "chunk",
"content": chunk,
}
)
# Store the conversation
memory.add_message(session_id, "user", request.message)
memory.add_message(session_id, "assistant", full_response)
# Send sources
sources = [
{
"file": chunk.get("source_file", ""),
"section": chunk.get("section"),
"date": chunk.get("date"),
}
for chunk in context_chunks[:5] # Top 5 sources
]
yield json.dumps(
{
"type": "sources",
"sources": sources,
}
)
# Send done event
yield json.dumps(
{
"type": "done",
"session_id": session_id,
}
)
# Parse the SSE data format
if chunk.startswith("data: "):
data = chunk[6:]
if data == "[DONE]":
yield json.dumps({"type": "done", "session_id": session_id})
else:
try:
parsed = json.loads(data)
if "content" in parsed:
yield json.dumps(
{"type": "chunk", "content": parsed["content"]}
)
elif "citations" in parsed:
yield json.dumps(
{
"type": "citations",
"citations": parsed["citations"],
}
)
elif "error" in parsed:
yield json.dumps(
{"type": "error", "message": parsed["error"]}
)
else:
yield data
except json.JSONDecodeError:
# Pass through raw data
yield data
else:
yield chunk
except Exception as e:
yield json.dumps(
{
"type": "error",
"message": str(e),
}
)
yield json.dumps({"type": "error", "message": str(e)})
return EventSourceResponse(event_generator())
return EventSourceResponse(
event_generator(),
headers={"X-Session-ID": session_id},
)
@app.get("/sessions/{session_id}/history")

View File

@@ -12,7 +12,7 @@ import httpx
from companion.config import Config
from companion.memory import Message, SessionMemory
from companion.prompts import build_system_prompt, format_conversation_history
from companion.rag.search import SearchEngine
from companion.rag.search import SearchEngine, SearchResult
class ChatOrchestrator:
@@ -47,7 +47,7 @@ class ChatOrchestrator:
Streaming response chunks (SSE format)
"""
# Retrieve relevant context from RAG if enabled
retrieved_context: list[dict[str, Any]] = []
retrieved_context: list[SearchResult] = []
if use_rag:
try:
retrieved_context = self.search.search(
@@ -103,9 +103,21 @@ class ChatOrchestrator:
session_id,
"assistant",
full_response,
metadata={"rag_context_used": len(retrieved_context) > 0},
metadata={
"rag_context_used": len(retrieved_context) > 0,
"citations": [ctx.to_dict() for ctx in retrieved_context[:5]]
if retrieved_context
else [],
},
)
# Yield citations after the main response
if retrieved_context:
citations_data = {
"citations": [ctx.to_dict() for ctx in retrieved_context[:5]]
}
yield f"data: {json.dumps(citations_data)}\n\n"
async def _stream_llm_response(
self, messages: list[dict[str, Any]]
) -> AsyncGenerator[str, None]:

View File

@@ -4,10 +4,12 @@ from __future__ import annotations
from typing import Any
from companion.rag.search import SearchResult
def build_system_prompt(
persona: dict[str, Any],
retrieved_context: list[dict[str, Any]] | None = None,
retrieved_context: list[SearchResult] | None = None,
memory_context: str | None = None,
) -> str:
"""Build the system prompt for the companion.
@@ -49,10 +51,9 @@ Core principles:
if retrieved_context:
context_section += "\n\nRelevant context from your vault:\n"
for i, ctx in enumerate(retrieved_context[:8], 1): # Limit to top 8 chunks
source = ctx.get("source_file", "unknown")
text = ctx.get("text", "").strip()
text = ctx.text.strip()
if text:
context_section += f"[{i}] From {source}:\n{text}\n\n"
context_section += f"[{i}] From {ctx.citation}:\n{text}\n\n"
# Add memory context if available
memory_section = ""

View File

@@ -1,9 +1,52 @@
from dataclasses import dataclass
from typing import Any
from companion.rag.embedder import OllamaEmbedder
from companion.rag.vector_store import VectorStore
@dataclass
class SearchResult:
"""Structured search result with citation information."""
id: str
text: str
source_file: str
source_directory: str
section: str | None
date: str | None
tags: list[str]
chunk_index: int
total_chunks: int
distance: float
@property
def citation(self) -> str:
"""Generate a citation string for this result."""
parts = [self.source_file]
if self.section:
parts.append(f"#{self.section}")
if self.date:
parts.append(f"({self.date})")
return " - ".join(parts)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for API serialization."""
return {
"id": self.id,
"text": self.text,
"source_file": self.source_file,
"source_directory": self.source_directory,
"section": self.section,
"date": self.date,
"tags": self.tags,
"chunk_index": self.chunk_index,
"total_chunks": self.total_chunks,
"distance": self.distance,
"citation": self.citation,
}
class SearchEngine:
"""Search engine for semantic search using vector embeddings.
@@ -50,7 +93,7 @@ class SearchEngine:
query: str,
top_k: int | None = None,
filters: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
) -> list[SearchResult]:
"""Search for relevant documents using semantic similarity.
Args:
@@ -59,7 +102,7 @@ class SearchEngine:
filters: Optional metadata filters to apply
Returns:
List of matching documents with similarity scores
List of SearchResult objects with similarity scores
Raises:
RuntimeError: If embedding generation fails
@@ -76,14 +119,33 @@ class SearchEngine:
except RuntimeError as e:
raise RuntimeError(f"Failed to generate embedding for query: {e}") from e
results = self.vector_store.search(query_embedding, top_k=k, filters=filters)
raw_results = self.vector_store.search(
query_embedding, top_k=k, filters=filters
)
if self.similarity_threshold > 0 and results:
results = [
if self.similarity_threshold > 0 and raw_results:
raw_results = [
r
for r in results
for r in raw_results
if r.get(self._DISTANCE_FIELD, float("inf"))
<= self.similarity_threshold
]
# Convert raw results to SearchResult objects
results: list[SearchResult] = []
for r in raw_results:
result = SearchResult(
id=r.get("id", ""),
text=r.get("text", ""),
source_file=r.get("source_file", ""),
source_directory=r.get("source_directory", ""),
section=r.get("section"),
date=r.get("date"),
tags=r.get("tags") or [],
chunk_index=r.get("chunk_index", 0),
total_chunks=r.get("total_chunks", 1),
distance=r.get(self._DISTANCE_FIELD, 1.0),
)
results.append(result)
return results

View File

@@ -2,7 +2,8 @@ import { useState, useRef, useEffect } from 'react'
import './App.css'
import MessageList from './components/MessageList'
import ChatInput from './components/ChatInput'
import { useChatStream } from './hooks/useChatStream'
import CitationsPanel from './components/CitationsPanel'
import { useChatStream, Citation } from './hooks/useChatStream'
export interface Message {
role: 'user' | 'assistant'
@@ -13,15 +14,13 @@ function App() {
const [messages, setMessages] = useState<Message[]>([])
const [input, setInput] = useState('')
const [isLoading, setIsLoading] = useState(false)
const [citations, setCitations] = useState<Citation[]>([])
const [showCitations, setShowCitations] = useState(false)
const messagesEndRef = useRef<HTMLDivElement>(null)
const { sendMessage } = useChatStream()
const scrollToBottom = () => {
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' })
}
useEffect(() => {
scrollToBottom()
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' })
}, [messages])
const handleSend = async () => {
@@ -31,24 +30,37 @@ function App() {
setInput('')
setMessages(prev => [...prev, { role: 'user', content: userMessage }])
setIsLoading(true)
setCitations([])
setShowCitations(false)
let assistantContent = ''
await sendMessage(userMessage, (chunk) => {
assistantContent += chunk
setMessages(prev => {
const newMessages = [...prev]
const lastMsg = newMessages[newMessages.length - 1]
if (lastMsg?.role === 'assistant') {
lastMsg.content = assistantContent
} else {
newMessages.push({ role: 'assistant', content: assistantContent })
}
return newMessages
})
await sendMessage(userMessage, {
onChunk: (chunk) => {
assistantContent += chunk
setMessages(prev => {
const newMessages = [...prev]
const lastMsg = newMessages[newMessages.length - 1]
if (lastMsg?.role === 'assistant') {
lastMsg.content = assistantContent
} else {
newMessages.push({ role: 'assistant', content: assistantContent })
}
return newMessages
})
},
onCitations: (newCitations) => {
setCitations(newCitations)
setShowCitations(newCitations.length > 0)
},
onDone: () => {
setIsLoading(false)
},
onError: (error) => {
console.error('Stream error:', error)
setIsLoading(false)
}
})
setIsLoading(false)
}
const handleKeyDown = (e: React.KeyboardEvent) => {
@@ -76,6 +88,11 @@ function App() {
disabled={isLoading}
/>
</footer>
<CitationsPanel
citations={citations}
isOpen={showCitations}
onClose={() => setShowCitations(false)}
/>
</div>
)
}

View File

@@ -0,0 +1,124 @@
/* CitationsPanel.css */
.citations-overlay {
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: rgba(0, 0, 0, 0.5);
z-index: 100;
border: none;
width: 100%;
cursor: pointer;
}
.citations-panel {
position: fixed;
top: 0;
right: 0;
width: 400px;
height: 100vh;
background: var(--bg-secondary);
border-left: 1px solid var(--border);
z-index: 101;
display: flex;
flex-direction: column;
animation: slideIn 0.2s ease-out;
}
@keyframes slideIn {
from {
transform: translateX(100%);
}
to {
transform: translateX(0);
}
}
.citations-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 16px 20px;
border-bottom: 1px solid var(--border);
}
.citations-header h3 {
margin: 0;
font-size: 16px;
font-weight: 600;
color: var(--text-primary);
}
.close-button {
background: none;
border: none;
color: var(--text-secondary);
font-size: 24px;
cursor: pointer;
padding: 0 4px;
line-height: 1;
}
.close-button:hover {
color: var(--text-primary);
}
.citations-list {
flex: 1;
overflow-y: auto;
padding: 16px 20px;
}
.citation-item {
display: flex;
gap: 12px;
padding: 16px;
margin-bottom: 12px;
background: var(--bg-tertiary);
border-radius: 8px;
border: 1px solid var(--border);
}
.citation-number {
flex-shrink: 0;
font-weight: 600;
color: var(--accent-primary);
font-size: 14px;
}
.citation-content {
flex: 1;
min-width: 0;
}
.citation-source {
font-size: 12px;
color: var(--text-secondary);
margin-bottom: 8px;
word-break: break-all;
}
.citation-text {
font-size: 14px;
color: var(--text-primary);
line-height: 1.5;
white-space: pre-wrap;
word-wrap: break-word;
}
.citation-tags {
display: flex;
flex-wrap: wrap;
gap: 6px;
margin-top: 12px;
}
.citation-tag {
font-size: 11px;
padding: 2px 8px;
background: var(--bg-secondary);
color: var(--text-secondary);
border-radius: 4px;
border: 1px solid var(--border);
}

View File

@@ -0,0 +1,47 @@
import { Citation } from '../hooks/useChatStream'
import './CitationsPanel.css'
interface CitationsPanelProps {
citations: Citation[]
isOpen: boolean
onClose: () => void
}
export default function CitationsPanel({ citations, isOpen, onClose }: CitationsPanelProps) {
if (!isOpen || citations.length === 0) return null
return (
<>
<button
type="button"
className="citations-overlay"
onClick={onClose}
aria-label="Close citations panel"
/>
<aside className="citations-panel">
<div className="citations-header">
<h3>Sources</h3>
<button type="button" className="close-button" onClick={onClose}>×</button>
</div>
<div className="citations-list">
{citations.map((citation, index) => (
<div key={citation.id} className="citation-item">
<div className="citation-number">[{index + 1}]</div>
<div className="citation-content">
<div className="citation-source">{citation.citation}</div>
<div className="citation-text">{citation.text}</div>
{citation.tags && citation.tags.length > 0 && (
<div className="citation-tags">
{citation.tags.map(tag => (
<span key={tag} className="citation-tag">{tag}</span>
))}
</div>
)}
</div>
</div>
))}
</div>
</aside>
</>
)
}

View File

@@ -2,13 +2,33 @@ import { useState } from 'react'
const API_BASE = '/api'
export interface Citation {
id: string
text: string
source_file: string
source_directory: string
section: string | null
date: string | null
tags: string[]
citation: string
}
export interface StreamCallbacks {
onChunk: (chunk: string) => void
onCitations?: (citations: Citation[]) => void
onError?: (error: string) => void
onDone?: () => void
}
export function useChatStream() {
const [sessionId, setSessionId] = useState<string | null>(null)
const sendMessage = async (
message: string,
onChunk: (chunk: string) => void
callbacks: StreamCallbacks
): Promise<void> => {
const { onChunk, onCitations, onError, onDone } = callbacks
const response = await fetch(`${API_BASE}/chat`, {
method: 'POST',
headers: {
@@ -18,6 +38,7 @@ export function useChatStream() {
message,
session_id: sessionId,
stream: true,
use_rag: true,
}),
})
@@ -50,15 +71,33 @@ export function useChatStream() {
if (line.startsWith('data: ')) {
const data = line.slice(6)
if (data === '[DONE]') {
onDone?.()
return
}
try {
const parsed = JSON.parse(data)
if (parsed.content) {
onChunk(parsed.content)
switch (parsed.type) {
case 'chunk':
if (parsed.content) {
onChunk(parsed.content)
}
break
case 'citations':
if (parsed.citations && onCitations) {
onCitations(parsed.citations as Citation[])
}
break
case 'error':
if (parsed.message && onError) {
onError(parsed.message)
}
break
case 'done':
onDone?.()
return
}
} catch {
// Ignore parse errors
// Ignore parse errors for non-JSON data
}
}
}