feat: Phase 5 - citations, source highlighting, and UI polish
This commit is contained in:
@@ -25,7 +25,8 @@ class ChatRequest(BaseModel):
|
||||
|
||||
message: str
|
||||
session_id: str | None = None
|
||||
temperature: float | None = None
|
||||
stream: bool = True
|
||||
use_rag: bool = True
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
@@ -118,78 +119,53 @@ async def chat(request: ChatRequest) -> EventSourceResponse:
|
||||
# Generate or use existing session ID
|
||||
session_id = request.session_id or str(uuid.uuid4())
|
||||
|
||||
# Validate temperature if provided
|
||||
temperature = request.temperature
|
||||
if temperature is not None and not 0.0 <= temperature <= 2.0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Temperature must be between 0.0 and 2.0"
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
"""Generate SSE events."""
|
||||
try:
|
||||
# Get conversation history
|
||||
history = memory.get_history(session_id)
|
||||
|
||||
# Retrieve relevant context
|
||||
context_chunks = search_engine.search(
|
||||
request.message,
|
||||
top_k=config.rag.search.default_top_k,
|
||||
)
|
||||
|
||||
# Stream response from LLM
|
||||
full_response = ""
|
||||
async for chunk in orchestrator.stream_response(
|
||||
query=request.message,
|
||||
history=history,
|
||||
context_chunks=context_chunks,
|
||||
temperature=temperature,
|
||||
# Stream response from orchestrator
|
||||
async for chunk in orchestrator.chat(
|
||||
session_id=session_id,
|
||||
user_message=request.message,
|
||||
use_rag=request.use_rag,
|
||||
):
|
||||
full_response += chunk
|
||||
# Parse the SSE data format
|
||||
if chunk.startswith("data: "):
|
||||
data = chunk[6:]
|
||||
if data == "[DONE]":
|
||||
yield json.dumps({"type": "done", "session_id": session_id})
|
||||
else:
|
||||
try:
|
||||
parsed = json.loads(data)
|
||||
if "content" in parsed:
|
||||
yield json.dumps(
|
||||
{"type": "chunk", "content": parsed["content"]}
|
||||
)
|
||||
elif "citations" in parsed:
|
||||
yield json.dumps(
|
||||
{
|
||||
"type": "chunk",
|
||||
"content": chunk,
|
||||
"type": "citations",
|
||||
"citations": parsed["citations"],
|
||||
}
|
||||
)
|
||||
|
||||
# Store the conversation
|
||||
memory.add_message(session_id, "user", request.message)
|
||||
memory.add_message(session_id, "assistant", full_response)
|
||||
|
||||
# Send sources
|
||||
sources = [
|
||||
{
|
||||
"file": chunk.get("source_file", ""),
|
||||
"section": chunk.get("section"),
|
||||
"date": chunk.get("date"),
|
||||
}
|
||||
for chunk in context_chunks[:5] # Top 5 sources
|
||||
]
|
||||
elif "error" in parsed:
|
||||
yield json.dumps(
|
||||
{
|
||||
"type": "sources",
|
||||
"sources": sources,
|
||||
}
|
||||
)
|
||||
|
||||
# Send done event
|
||||
yield json.dumps(
|
||||
{
|
||||
"type": "done",
|
||||
"session_id": session_id,
|
||||
}
|
||||
{"type": "error", "message": parsed["error"]}
|
||||
)
|
||||
else:
|
||||
yield data
|
||||
except json.JSONDecodeError:
|
||||
# Pass through raw data
|
||||
yield data
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
yield json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": str(e),
|
||||
}
|
||||
)
|
||||
yield json.dumps({"type": "error", "message": str(e)})
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
return EventSourceResponse(
|
||||
event_generator(),
|
||||
headers={"X-Session-ID": session_id},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/sessions/{session_id}/history")
|
||||
|
||||
@@ -12,7 +12,7 @@ import httpx
|
||||
from companion.config import Config
|
||||
from companion.memory import Message, SessionMemory
|
||||
from companion.prompts import build_system_prompt, format_conversation_history
|
||||
from companion.rag.search import SearchEngine
|
||||
from companion.rag.search import SearchEngine, SearchResult
|
||||
|
||||
|
||||
class ChatOrchestrator:
|
||||
@@ -47,7 +47,7 @@ class ChatOrchestrator:
|
||||
Streaming response chunks (SSE format)
|
||||
"""
|
||||
# Retrieve relevant context from RAG if enabled
|
||||
retrieved_context: list[dict[str, Any]] = []
|
||||
retrieved_context: list[SearchResult] = []
|
||||
if use_rag:
|
||||
try:
|
||||
retrieved_context = self.search.search(
|
||||
@@ -103,9 +103,21 @@ class ChatOrchestrator:
|
||||
session_id,
|
||||
"assistant",
|
||||
full_response,
|
||||
metadata={"rag_context_used": len(retrieved_context) > 0},
|
||||
metadata={
|
||||
"rag_context_used": len(retrieved_context) > 0,
|
||||
"citations": [ctx.to_dict() for ctx in retrieved_context[:5]]
|
||||
if retrieved_context
|
||||
else [],
|
||||
},
|
||||
)
|
||||
|
||||
# Yield citations after the main response
|
||||
if retrieved_context:
|
||||
citations_data = {
|
||||
"citations": [ctx.to_dict() for ctx in retrieved_context[:5]]
|
||||
}
|
||||
yield f"data: {json.dumps(citations_data)}\n\n"
|
||||
|
||||
async def _stream_llm_response(
|
||||
self, messages: list[dict[str, Any]]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
|
||||
@@ -4,10 +4,12 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from companion.rag.search import SearchResult
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
persona: dict[str, Any],
|
||||
retrieved_context: list[dict[str, Any]] | None = None,
|
||||
retrieved_context: list[SearchResult] | None = None,
|
||||
memory_context: str | None = None,
|
||||
) -> str:
|
||||
"""Build the system prompt for the companion.
|
||||
@@ -49,10 +51,9 @@ Core principles:
|
||||
if retrieved_context:
|
||||
context_section += "\n\nRelevant context from your vault:\n"
|
||||
for i, ctx in enumerate(retrieved_context[:8], 1): # Limit to top 8 chunks
|
||||
source = ctx.get("source_file", "unknown")
|
||||
text = ctx.get("text", "").strip()
|
||||
text = ctx.text.strip()
|
||||
if text:
|
||||
context_section += f"[{i}] From {source}:\n{text}\n\n"
|
||||
context_section += f"[{i}] From {ctx.citation}:\n{text}\n\n"
|
||||
|
||||
# Add memory context if available
|
||||
memory_section = ""
|
||||
|
||||
@@ -1,9 +1,52 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from companion.rag.embedder import OllamaEmbedder
|
||||
from companion.rag.vector_store import VectorStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Structured search result with citation information."""
|
||||
|
||||
id: str
|
||||
text: str
|
||||
source_file: str
|
||||
source_directory: str
|
||||
section: str | None
|
||||
date: str | None
|
||||
tags: list[str]
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
distance: float
|
||||
|
||||
@property
|
||||
def citation(self) -> str:
|
||||
"""Generate a citation string for this result."""
|
||||
parts = [self.source_file]
|
||||
if self.section:
|
||||
parts.append(f"#{self.section}")
|
||||
if self.date:
|
||||
parts.append(f"({self.date})")
|
||||
return " - ".join(parts)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for API serialization."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"text": self.text,
|
||||
"source_file": self.source_file,
|
||||
"source_directory": self.source_directory,
|
||||
"section": self.section,
|
||||
"date": self.date,
|
||||
"tags": self.tags,
|
||||
"chunk_index": self.chunk_index,
|
||||
"total_chunks": self.total_chunks,
|
||||
"distance": self.distance,
|
||||
"citation": self.citation,
|
||||
}
|
||||
|
||||
|
||||
class SearchEngine:
|
||||
"""Search engine for semantic search using vector embeddings.
|
||||
|
||||
@@ -50,7 +93,7 @@ class SearchEngine:
|
||||
query: str,
|
||||
top_k: int | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[SearchResult]:
|
||||
"""Search for relevant documents using semantic similarity.
|
||||
|
||||
Args:
|
||||
@@ -59,7 +102,7 @@ class SearchEngine:
|
||||
filters: Optional metadata filters to apply
|
||||
|
||||
Returns:
|
||||
List of matching documents with similarity scores
|
||||
List of SearchResult objects with similarity scores
|
||||
|
||||
Raises:
|
||||
RuntimeError: If embedding generation fails
|
||||
@@ -76,14 +119,33 @@ class SearchEngine:
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f"Failed to generate embedding for query: {e}") from e
|
||||
|
||||
results = self.vector_store.search(query_embedding, top_k=k, filters=filters)
|
||||
raw_results = self.vector_store.search(
|
||||
query_embedding, top_k=k, filters=filters
|
||||
)
|
||||
|
||||
if self.similarity_threshold > 0 and results:
|
||||
results = [
|
||||
if self.similarity_threshold > 0 and raw_results:
|
||||
raw_results = [
|
||||
r
|
||||
for r in results
|
||||
for r in raw_results
|
||||
if r.get(self._DISTANCE_FIELD, float("inf"))
|
||||
<= self.similarity_threshold
|
||||
]
|
||||
|
||||
# Convert raw results to SearchResult objects
|
||||
results: list[SearchResult] = []
|
||||
for r in raw_results:
|
||||
result = SearchResult(
|
||||
id=r.get("id", ""),
|
||||
text=r.get("text", ""),
|
||||
source_file=r.get("source_file", ""),
|
||||
source_directory=r.get("source_directory", ""),
|
||||
section=r.get("section"),
|
||||
date=r.get("date"),
|
||||
tags=r.get("tags") or [],
|
||||
chunk_index=r.get("chunk_index", 0),
|
||||
total_chunks=r.get("total_chunks", 1),
|
||||
distance=r.get(self._DISTANCE_FIELD, 1.0),
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
@@ -2,7 +2,8 @@ import { useState, useRef, useEffect } from 'react'
|
||||
import './App.css'
|
||||
import MessageList from './components/MessageList'
|
||||
import ChatInput from './components/ChatInput'
|
||||
import { useChatStream } from './hooks/useChatStream'
|
||||
import CitationsPanel from './components/CitationsPanel'
|
||||
import { useChatStream, Citation } from './hooks/useChatStream'
|
||||
|
||||
export interface Message {
|
||||
role: 'user' | 'assistant'
|
||||
@@ -13,15 +14,13 @@ function App() {
|
||||
const [messages, setMessages] = useState<Message[]>([])
|
||||
const [input, setInput] = useState('')
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const [citations, setCitations] = useState<Citation[]>([])
|
||||
const [showCitations, setShowCitations] = useState(false)
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null)
|
||||
const { sendMessage } = useChatStream()
|
||||
|
||||
const scrollToBottom = () => {
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' })
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
scrollToBottom()
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' })
|
||||
}, [messages])
|
||||
|
||||
const handleSend = async () => {
|
||||
@@ -31,10 +30,13 @@ function App() {
|
||||
setInput('')
|
||||
setMessages(prev => [...prev, { role: 'user', content: userMessage }])
|
||||
setIsLoading(true)
|
||||
setCitations([])
|
||||
setShowCitations(false)
|
||||
|
||||
let assistantContent = ''
|
||||
|
||||
await sendMessage(userMessage, (chunk) => {
|
||||
await sendMessage(userMessage, {
|
||||
onChunk: (chunk) => {
|
||||
assistantContent += chunk
|
||||
setMessages(prev => {
|
||||
const newMessages = [...prev]
|
||||
@@ -46,9 +48,19 @@ function App() {
|
||||
}
|
||||
return newMessages
|
||||
})
|
||||
})
|
||||
|
||||
},
|
||||
onCitations: (newCitations) => {
|
||||
setCitations(newCitations)
|
||||
setShowCitations(newCitations.length > 0)
|
||||
},
|
||||
onDone: () => {
|
||||
setIsLoading(false)
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error('Stream error:', error)
|
||||
setIsLoading(false)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||
@@ -76,6 +88,11 @@ function App() {
|
||||
disabled={isLoading}
|
||||
/>
|
||||
</footer>
|
||||
<CitationsPanel
|
||||
citations={citations}
|
||||
isOpen={showCitations}
|
||||
onClose={() => setShowCitations(false)}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
124
ui/src/components/CitationsPanel.css
Normal file
124
ui/src/components/CitationsPanel.css
Normal file
@@ -0,0 +1,124 @@
|
||||
/* CitationsPanel.css */
|
||||
.citations-overlay {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
z-index: 100;
|
||||
border: none;
|
||||
width: 100%;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.citations-panel {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
right: 0;
|
||||
width: 400px;
|
||||
height: 100vh;
|
||||
background: var(--bg-secondary);
|
||||
border-left: 1px solid var(--border);
|
||||
z-index: 101;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
animation: slideIn 0.2s ease-out;
|
||||
}
|
||||
|
||||
@keyframes slideIn {
|
||||
from {
|
||||
transform: translateX(100%);
|
||||
}
|
||||
to {
|
||||
transform: translateX(0);
|
||||
}
|
||||
}
|
||||
|
||||
.citations-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 16px 20px;
|
||||
border-bottom: 1px solid var(--border);
|
||||
}
|
||||
|
||||
.citations-header h3 {
|
||||
margin: 0;
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.close-button {
|
||||
background: none;
|
||||
border: none;
|
||||
color: var(--text-secondary);
|
||||
font-size: 24px;
|
||||
cursor: pointer;
|
||||
padding: 0 4px;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
.close-button:hover {
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.citations-list {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 16px 20px;
|
||||
}
|
||||
|
||||
.citation-item {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
padding: 16px;
|
||||
margin-bottom: 12px;
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: 8px;
|
||||
border: 1px solid var(--border);
|
||||
}
|
||||
|
||||
.citation-number {
|
||||
flex-shrink: 0;
|
||||
font-weight: 600;
|
||||
color: var(--accent-primary);
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.citation-content {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.citation-source {
|
||||
font-size: 12px;
|
||||
color: var(--text-secondary);
|
||||
margin-bottom: 8px;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.citation-text {
|
||||
font-size: 14px;
|
||||
color: var(--text-primary);
|
||||
line-height: 1.5;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
|
||||
.citation-tags {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 6px;
|
||||
margin-top: 12px;
|
||||
}
|
||||
|
||||
.citation-tag {
|
||||
font-size: 11px;
|
||||
padding: 2px 8px;
|
||||
background: var(--bg-secondary);
|
||||
color: var(--text-secondary);
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--border);
|
||||
}
|
||||
47
ui/src/components/CitationsPanel.tsx
Normal file
47
ui/src/components/CitationsPanel.tsx
Normal file
@@ -0,0 +1,47 @@
|
||||
import { Citation } from '../hooks/useChatStream'
|
||||
import './CitationsPanel.css'
|
||||
|
||||
interface CitationsPanelProps {
|
||||
citations: Citation[]
|
||||
isOpen: boolean
|
||||
onClose: () => void
|
||||
}
|
||||
|
||||
export default function CitationsPanel({ citations, isOpen, onClose }: CitationsPanelProps) {
|
||||
if (!isOpen || citations.length === 0) return null
|
||||
|
||||
return (
|
||||
<>
|
||||
<button
|
||||
type="button"
|
||||
className="citations-overlay"
|
||||
onClick={onClose}
|
||||
aria-label="Close citations panel"
|
||||
/>
|
||||
<aside className="citations-panel">
|
||||
<div className="citations-header">
|
||||
<h3>Sources</h3>
|
||||
<button type="button" className="close-button" onClick={onClose}>×</button>
|
||||
</div>
|
||||
<div className="citations-list">
|
||||
{citations.map((citation, index) => (
|
||||
<div key={citation.id} className="citation-item">
|
||||
<div className="citation-number">[{index + 1}]</div>
|
||||
<div className="citation-content">
|
||||
<div className="citation-source">{citation.citation}</div>
|
||||
<div className="citation-text">{citation.text}</div>
|
||||
{citation.tags && citation.tags.length > 0 && (
|
||||
<div className="citation-tags">
|
||||
{citation.tags.map(tag => (
|
||||
<span key={tag} className="citation-tag">{tag}</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</aside>
|
||||
</>
|
||||
)
|
||||
}
|
||||
@@ -2,13 +2,33 @@ import { useState } from 'react'
|
||||
|
||||
const API_BASE = '/api'
|
||||
|
||||
export interface Citation {
|
||||
id: string
|
||||
text: string
|
||||
source_file: string
|
||||
source_directory: string
|
||||
section: string | null
|
||||
date: string | null
|
||||
tags: string[]
|
||||
citation: string
|
||||
}
|
||||
|
||||
export interface StreamCallbacks {
|
||||
onChunk: (chunk: string) => void
|
||||
onCitations?: (citations: Citation[]) => void
|
||||
onError?: (error: string) => void
|
||||
onDone?: () => void
|
||||
}
|
||||
|
||||
export function useChatStream() {
|
||||
const [sessionId, setSessionId] = useState<string | null>(null)
|
||||
|
||||
const sendMessage = async (
|
||||
message: string,
|
||||
onChunk: (chunk: string) => void
|
||||
callbacks: StreamCallbacks
|
||||
): Promise<void> => {
|
||||
const { onChunk, onCitations, onError, onDone } = callbacks
|
||||
|
||||
const response = await fetch(`${API_BASE}/chat`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -18,6 +38,7 @@ export function useChatStream() {
|
||||
message,
|
||||
session_id: sessionId,
|
||||
stream: true,
|
||||
use_rag: true,
|
||||
}),
|
||||
})
|
||||
|
||||
@@ -50,15 +71,33 @@ export function useChatStream() {
|
||||
if (line.startsWith('data: ')) {
|
||||
const data = line.slice(6)
|
||||
if (data === '[DONE]') {
|
||||
onDone?.()
|
||||
return
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(data)
|
||||
switch (parsed.type) {
|
||||
case 'chunk':
|
||||
if (parsed.content) {
|
||||
onChunk(parsed.content)
|
||||
}
|
||||
break
|
||||
case 'citations':
|
||||
if (parsed.citations && onCitations) {
|
||||
onCitations(parsed.citations as Citation[])
|
||||
}
|
||||
break
|
||||
case 'error':
|
||||
if (parsed.message && onError) {
|
||||
onError(parsed.message)
|
||||
}
|
||||
break
|
||||
case 'done':
|
||||
onDone?.()
|
||||
return
|
||||
}
|
||||
} catch {
|
||||
// Ignore parse errors
|
||||
// Ignore parse errors for non-JSON data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user