WIP: Phase 4 forge extract module with tests

This commit is contained in:
2026-04-13 15:14:35 -04:00
parent 922e724cfe
commit f944bdc573
11 changed files with 1780 additions and 0 deletions

6
.gitignore vendored
View File

@@ -9,3 +9,9 @@ venv/
.companion/
dist/
build/
# Node.js / UI
node_modules/
ui/node_modules/
ui/dist/

View File

@@ -0,0 +1,156 @@
# Phase 4: Fine-Tuning Pipeline Implementation Plan
## Goal
Build a pipeline to extract training examples from the Obsidian vault and fine-tune a local 7B model using QLoRA on the RTX 5070.
## Architecture
```
┌─────────────────────────────────────────────────────────┐
│ Training Data Pipeline │
│ ───────────────────── │
│ 1. Extract reflections from vault │
│ 2. Curate into conversation format │
│ 3. Split train/validation │
│ 4. Export to HuggingFace datasets format │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────┐
│ QLoRA Fine-Tuning (Unsloth) │
│ ─────────────────────────── │
│ - Base: Llama 3.1 8B Instruct (4-bit) │
│ - LoRA rank: 16, alpha: 32 │
│ - Target modules: q_proj, k_proj, v_proj, o_proj │
│ - Learning rate: 2e-4 │
│ - Epochs: 3 │
│ - Batch: 4, Gradient accumulation: 4 │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────┐
│ Model Export & Serving │
│ ───────────────────── │
│ - Export to GGUF (Q4_K_M quantization) │
│ - Serve via llama.cpp or vLLM │
│ - Hot-swap in FastAPI backend │
└─────────────────────────────────────────────────────────┘
```
## Tasks
### Task 1: Training Data Extractor
**Files:**
- `src/companion/forge/extract.py` - Extract reflection examples from vault
- `tests/test_forge_extract.py` - Test extraction logic
**Spec:**
- Parse vault for "reflection" patterns (journal entries with insights, decision analyses)
- Look for tags: #reflection, #decision, #learning, etc.
- Extract entries where you reflect on situations, weigh options, or analyze outcomes
- Format as conversation: user prompt + assistant response (your reflection)
- Output: JSONL file with {"messages": [{"role": "...", "content": "..."}]}
### Task 2: Training Data Curator
**Files:**
- `src/companion/forge/curate.py` - Human-in-the-loop curation
- `src/companion/forge/cli.py` - CLI for curation workflow
**Spec:**
- Load extracted examples
- Interactive review: show each example, allow approve/reject/edit
- Track curation decisions in SQLite
- Export approved examples to final training set
- Deduplicate similar examples (use embeddings similarity)
### Task 3: Training Configuration
**Files:**
- `src/companion/forge/config.py` - Training hyperparameters
- `config.json` updates for fine_tuning section
**Spec:**
- Pydantic models for training config
- Hyperparameters tuned for RTX 5070 (12GB VRAM)
- Output paths, logging config
### Task 4: QLoRA Training Script
**Files:**
- `src/companion/forge/train.py` - Unsloth training script
- `scripts/train.sh` - Convenience launcher
**Spec:**
- Load base model: unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit
- Apply LoRA config (r=16, alpha=32, target_modules)
- Load and tokenize dataset
- Training loop with wandb logging (optional)
- Save checkpoints every 500 steps
- Validate on holdout set
### Task 5: Model Export
**Files:**
- `src/companion/forge/export.py` - Export to GGUF
- `src/companion/forge/merge.py` - Merge LoRA weights into base
**Spec:**
- Merge LoRA weights into base model
- Export to GGUF with Q4_K_M quantization
- Save to `~/.companion/models/`
- Update config.json with new model path
### Task 6: Model Hot-Swap
**Files:**
- Update `src/companion/api.py` - Add endpoint to reload model
- `src/companion/forge/reload.py` - Model reloader utility
**Spec:**
- `/admin/reload-model` endpoint (requires auth/local-only)
- Gracefully unload old model, load new GGUF
- Return status: success or error
### Task 7: Evaluation Framework
**Files:**
- `src/companion/forge/eval.py` - Evaluate model on test prompts
- `tests/test_forge_eval.py` - Evaluation tests
**Spec:**
- Load test prompts (decision scenarios, relationship questions)
- Generate responses from both base and fine-tuned model
- Store outputs for human comparison
- Track metrics: response time, token count
## Success Criteria
- [ ] Extract 100+ reflection examples from vault
- [ ] Curate down to 50-100 high-quality training examples
- [ ] Complete training run in <6 hours on RTX 5070
- [ ] Export produces valid GGUF file
- [ ] Hot-swap endpoint successfully reloads model
- [ ] Evaluation shows distinguishable "Santhosh-style" in outputs
## Dependencies
```
unsloth>=2024.1.0
torch>=2.1.0
transformers>=4.36.0
datasets>=2.14.0
peft>=0.7.0
accelerate>=0.25.0
bitsandbytes>=0.41.0
sentencepiece>=0.1.99
protobuf>=3.20.0
```
## Commands
```bash
# Extract training data
python -m companion.forge.cli extract
# Curate examples
python -m companion.forge.cli curate
# Train
python -m companion.forge.train
# Export
python -m companion.forge.export
# Reload model in API
python -m companion.forge.reload
```

View File

@@ -12,6 +12,11 @@ dependencies = [
"typer>=0.12.0",
"rich>=13.0.0",
"numpy>=1.26.0",
"fastapi>=0.109.0",
"uvicorn[standard]>=0.27.0",
"httpx>=0.27.0",
"sse-starlette>=2.0.0",
"python-multipart>=0.0.9",
]
[project.optional-dependencies]

215
src/companion/api.py Normal file
View File

@@ -0,0 +1,215 @@
"""FastAPI backend for Companion AI."""
from __future__ import annotations
import json
import uuid
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import httpx
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
from companion.config import Config, load_config
from companion.memory import SessionMemory
from companion.orchestrator import ChatOrchestrator
from companion.rag.search import SearchEngine
from companion.rag.vector_store import VectorStore
class ChatRequest(BaseModel):
"""Chat request model."""
message: str
session_id: str | None = None
temperature: float | None = None
class ChatResponse(BaseModel):
"""Chat response model (non-streaming)."""
response: str
session_id: str
sources: list[dict] | None = None
# Global instances
config: Config
vector_store: VectorStore
search_engine: SearchEngine
memory: SessionMemory
orchestrator: ChatOrchestrator
http_client: httpx.AsyncClient
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Manage application lifespan."""
global config, vector_store, search_engine, memory, orchestrator, http_client
# Startup
config = load_config("config.json")
vector_store = VectorStore(
uri=config.rag.vector_store.path, dimensions=config.rag.embedding.dimensions
)
search_engine = SearchEngine(
vector_store=vector_store,
embedder_base_url=config.rag.embedding.base_url,
embedder_model=config.rag.embedding.model,
embedder_batch_size=config.rag.embedding.batch_size,
default_top_k=config.rag.search.default_top_k,
similarity_threshold=config.rag.search.similarity_threshold,
hybrid_search_enabled=config.rag.search.hybrid_search.enabled,
)
memory = SessionMemory(db_path=config.companion.memory.persistent_store)
http_client = httpx.AsyncClient(timeout=300.0)
orchestrator = ChatOrchestrator(
config=config,
search_engine=search_engine,
memory=memory,
http_client=http_client,
)
yield
# Shutdown
await http_client.aclose()
memory.close()
app = FastAPI(
title="Companion AI",
description="Personal AI companion with RAG",
version="0.1.0",
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=config.api.cors_origins
if "config" in globals()
else ["http://localhost:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health_check() -> dict:
"""Health check endpoint."""
return {
"status": "healthy",
"version": "0.1.0",
"indexed_chunks": vector_store.count() if "vector_store" in globals() else 0,
}
@app.post("/chat")
async def chat(request: ChatRequest) -> EventSourceResponse:
"""Chat endpoint with SSE streaming."""
if not request.message.strip():
raise HTTPException(status_code=400, detail="Message cannot be empty")
# Generate or use existing session ID
session_id = request.session_id or str(uuid.uuid4())
# Validate temperature if provided
temperature = request.temperature
if temperature is not None and not 0.0 <= temperature <= 2.0:
raise HTTPException(
status_code=400, detail="Temperature must be between 0.0 and 2.0"
)
async def event_generator() -> AsyncGenerator[str, None]:
"""Generate SSE events."""
try:
# Get conversation history
history = memory.get_history(session_id)
# Retrieve relevant context
context_chunks = search_engine.search(
request.message,
top_k=config.rag.search.default_top_k,
)
# Stream response from LLM
full_response = ""
async for chunk in orchestrator.stream_response(
query=request.message,
history=history,
context_chunks=context_chunks,
temperature=temperature,
):
full_response += chunk
yield json.dumps(
{
"type": "chunk",
"content": chunk,
}
)
# Store the conversation
memory.add_message(session_id, "user", request.message)
memory.add_message(session_id, "assistant", full_response)
# Send sources
sources = [
{
"file": chunk.get("source_file", ""),
"section": chunk.get("section"),
"date": chunk.get("date"),
}
for chunk in context_chunks[:5] # Top 5 sources
]
yield json.dumps(
{
"type": "sources",
"sources": sources,
}
)
# Send done event
yield json.dumps(
{
"type": "done",
"session_id": session_id,
}
)
except Exception as e:
yield json.dumps(
{
"type": "error",
"message": str(e),
}
)
return EventSourceResponse(event_generator())
@app.get("/sessions/{session_id}/history")
async def get_session_history(session_id: str) -> dict:
"""Get conversation history for a session."""
history = memory.get_history(session_id)
return {
"session_id": session_id,
"messages": [
{
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp.isoformat(),
}
for msg in history
],
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=7373)

View File

@@ -0,0 +1 @@
# Model Forge - Fine-tuning pipeline for companion AI

View File

@@ -0,0 +1,329 @@
"""Extract training examples from Obsidian vault.
Looks for reflection patterns, decision analyses, and personal insights
to create training data that teaches the model to reason like San.
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator
from companion.config import Config
from companion.rag.chunker import chunk_file
@dataclass
class TrainingExample:
"""A single training example in conversation format."""
messages: list[dict[str, str]]
source_file: str
tags: list[str]
date: str | None = None
def to_dict(self) -> dict:
return {
"messages": self.messages,
"source_file": self.source_file,
"tags": self.tags,
"date": self.date,
}
# Tags that indicate reflection/decision content
REFLECTION_TAGS = {
"#reflection",
"#decision",
"#learning",
"#insight",
"#analysis",
"#pondering",
"#evaluation",
"#takeaway",
"#realization",
}
# Patterns that suggest reflection content
REFLECTION_PATTERNS = [
r"(?i)I think\s+",
r"(?i)I feel\s+",
r"(?i)I realize\s+",
r"(?i)I wonder\s+",
r"(?i)I should\s+",
r"(?i)I need to\s+",
r"(?i)The reason\s+",
r"(?i)What if\s+",
r"(?i)Maybe\s+",
r"(?i)Perhaps\s+",
r"(?i)On one hand.*?on the other hand",
r"(?i)Pros?:.*?Cons?:",
r"(?i)Weighing.*?(?:options?|choices?|alternatives?)",
r"(?i)Ultimately.*?decided",
r"(?i)Looking back",
r"(?i)In hindsight",
r"(?i)I've learned\s+",
r"(?i)The lesson\s+",
]
def _has_reflection_tags(text: str) -> bool:
"""Check if text contains reflection-related hashtags."""
hashtags = set(re.findall(r"#\w+", text))
return bool(hashtags & REFLECTION_TAGS)
def _has_reflection_patterns(text: str) -> bool:
"""Check if text contains reflection language patterns."""
for pattern in REFLECTION_PATTERNS:
if re.search(pattern, text):
return True
return False
def _is_likely_reflection(text: str) -> bool:
"""Determine if a text chunk is likely a reflection."""
# Must have reflection tags OR strong reflection patterns
has_tags = _has_reflection_tags(text)
has_patterns = _has_reflection_patterns(text)
# Require at least one indicator
return has_tags or has_patterns
def _extract_date_from_filename(filename: str) -> str | None:
"""Extract date from filename patterns like 2026-04-12 or "12 Apr 2026"."""
# ISO format: 2026-04-12
m = re.search(r"(\d{4}-\d{2}-\d{2})", filename)
if m:
return m.group(1)
# Human format: 12-Apr-2026 or "12 Apr 2026"
m = re.search(r"(\d{1,2}[-\s][A-Za-z]{3}[-\s]\d{4})", filename)
if m:
return m.group(1).replace(" ", "-").replace("--", "-")
return None
def _create_training_prompt(chunk_text: str) -> str:
"""Create a user prompt that elicits a reflection similar to the example."""
# Extract context from the chunk to form a relevant question
# Look for decision patterns
if re.search(r"(?i)decided|decision|chose|choice", chunk_text):
return "I'm facing a decision. How should I think through this?"
# Look for relationship content
if re.search(r"(?i)friend|relationship|person|people|someone", chunk_text):
return "What do you think about this situation with someone I'm close to?"
# Look for work/career content
if re.search(r"(?i)work|job|career|project|professional", chunk_text):
return "I'm thinking about something at work. What's your perspective?"
# Look for health content
if re.search(r"(?i)health|mental|physical|stress|wellness", chunk_text):
return "I've been thinking about my well-being. What do you notice?"
# Look for financial content
if re.search(r"(?i)money|finance|financial|invest|saving|spending", chunk_text):
return "I'm considering a financial decision. How should I evaluate this?"
# Default prompt
return "I'm reflecting on something. What patterns do you see?"
def _create_training_example(
chunk_text: str, source_file: str, tags: list[str], date: str | None
) -> TrainingExample | None:
"""Convert a chunk into a training example if it meets criteria."""
# Skip if too short
if len(chunk_text) < 100:
return None
# Skip if not a reflection
if not _is_likely_reflection(chunk_text):
return None
# Clean up the text
cleaned = chunk_text.strip()
# Create prompt-response pair
prompt = _create_training_prompt(cleaned)
# The response is your reflection/insight
messages = [
{"role": "system", "content": "You are a thoughtful, reflective companion."},
{"role": "user", "content": prompt},
{"role": "assistant", "content": cleaned},
]
return TrainingExample(
messages=messages,
source_file=source_file,
tags=tags,
date=date,
)
class TrainingDataExtractor:
"""Extracts training examples from Obsidian vault."""
def __init__(self, config: Config):
self.config = config
self.vault_path = Path(config.vault.path)
self.examples: list[TrainingExample] = []
def extract(self) -> list[TrainingExample]:
"""Extract all training examples from the vault."""
self.examples = []
# Walk vault for markdown files
for md_file in self.vault_path.rglob("*.md"):
# Skip denied directories
relative = md_file.relative_to(self.vault_path)
if any(part.startswith(".") for part in relative.parts):
continue
if any(
part in [".obsidian", ".trash", "zzz-Archive"]
for part in relative.parts
):
continue
# Extract from this file
file_examples = self._extract_from_file(md_file)
self.examples.extend(file_examples)
return self.examples
def _extract_from_file(self, md_file: Path) -> list[TrainingExample]:
"""Extract training examples from a single file."""
examples = []
try:
text = md_file.read_text(encoding="utf-8")
except Exception:
return examples
# Get date from filename
date = _extract_date_from_filename(md_file.name)
# Split into sections (by headers like #Section or # Tag:)
# Pattern matches lines starting with # that have content
sections = re.split(r"\n(?=#[^#])", text)
for section in sections:
section = section.strip()
if not section or len(section) < 50: # Skip very short sections
continue
# Extract tags from section
hashtags = re.findall(r"#[\w\-]+", section)
# Try to create training example
example = _create_training_example(
chunk_text=section,
source_file=str(md_file.relative_to(self.vault_path)).replace(
"\\", "/"
),
tags=hashtags,
date=date,
)
if example:
examples.append(example)
# If no reflection sections found, try the whole file
if not examples and len(text) >= 100:
hashtags = re.findall(r"#[\w\-]+", text)
if _is_likely_reflection(text):
example = _create_training_example(
chunk_text=text.strip(),
source_file=str(md_file.relative_to(self.vault_path)).replace(
"\\", "/"
),
tags=hashtags,
date=date,
)
if example:
examples.append(example)
return examples
# Get date from filename
date = _extract_date_from_filename(md_file.name)
# Split into sections (by headers)
sections = re.split(r"\n(?=#+\s)", text)
for section in sections:
section = section.strip()
if not section:
continue
# Extract tags from section
hashtags = re.findall(r"#\w+", section)
# Try to create training example
example = _create_training_example(
chunk_text=section,
source_file=str(md_file.relative_to(self.vault_path)).replace(
"\\", "/"
),
tags=hashtags,
date=date,
)
if example:
examples.append(example)
return examples
def save_to_jsonl(self, output_path: Path) -> int:
"""Save extracted examples to JSONL file."""
with open(output_path, "w", encoding="utf-8") as f:
for example in self.examples:
f.write(json.dumps(example.to_dict(), ensure_ascii=False) + "\n")
return len(self.examples)
def get_stats(self) -> dict:
"""Get statistics about extracted examples."""
if not self.examples:
return {"total": 0}
tag_counts: dict[str, int] = {}
for ex in self.examples:
for tag in ex.tags:
tag_counts[tag] = tag_counts.get(tag, 0) + 1
return {
"total": len(self.examples),
"avg_length": sum(len(ex.messages[2]["content"]) for ex in self.examples)
// len(self.examples),
"top_tags": sorted(tag_counts.items(), key=lambda x: x[1], reverse=True)[
:10
],
}
def extract_training_data(
config_path: str = "config.json",
) -> tuple[list[TrainingExample], dict]:
"""Convenience function to extract training data from vault.
Returns:
Tuple of (examples list, stats dict)
"""
from companion.config import load_config
config = load_config(config_path)
extractor = TrainingDataExtractor(config)
examples = extractor.extract()
stats = extractor.get_stats()
return examples, stats

178
src/companion/memory.py Normal file
View File

@@ -0,0 +1,178 @@
"""SQLite-based session memory for the companion."""
from __future__ import annotations
import json
import sqlite3
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
@dataclass
class Message:
"""A single message in the conversation."""
role: str # "user" | "assistant" | "system"
content: str
timestamp: datetime
metadata: dict[str, Any] | None = None
class SessionMemory:
"""Manages conversation history in SQLite."""
def __init__(self, db_path: str | Path):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _init_db(self) -> None:
"""Initialize the database schema."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
metadata TEXT,
FOREIGN KEY (session_id) REFERENCES sessions(session_id)
)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_session
ON messages(session_id, timestamp)
""")
conn.commit()
def get_or_create_session(self, session_id: str) -> str:
"""Get existing session or create new one."""
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"INSERT OR IGNORE INTO sessions (session_id) VALUES (?)", (session_id,)
)
conn.execute(
"""UPDATE sessions SET updated_at = CURRENT_TIMESTAMP
WHERE session_id = ?""",
(session_id,),
)
conn.commit()
return session_id
def add_message(
self,
session_id: str,
role: str,
content: str,
metadata: dict[str, Any] | None = None,
) -> None:
"""Add a message to the session."""
self.get_or_create_session(session_id)
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""INSERT INTO messages (session_id, role, content, metadata)
VALUES (?, ?, ?, ?)""",
(session_id, role, content, json.dumps(metadata) if metadata else None),
)
conn.commit()
def get_messages(
self, session_id: str, limit: int = 20, before_id: int | None = None
) -> list[Message]:
"""Get messages from a session, most recent first."""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
if before_id:
rows = conn.execute(
"""SELECT role, content, timestamp, metadata
FROM messages
WHERE session_id = ? AND id < ?
ORDER BY timestamp DESC
LIMIT ?""",
(session_id, before_id, limit),
).fetchall()
else:
rows = conn.execute(
"""SELECT role, content, timestamp, metadata
FROM messages
WHERE session_id = ?
ORDER BY timestamp DESC
LIMIT ?""",
(session_id, limit),
).fetchall()
messages = []
for row in rows:
meta = json.loads(row["metadata"]) if row["metadata"] else None
messages.append(
Message(
role=row["role"],
content=row["content"],
timestamp=datetime.fromisoformat(row["timestamp"]),
metadata=meta,
)
)
# Return in chronological order
return list(reversed(messages))
def get_session_summary(self, session_id: str) -> dict[str, Any] | None:
"""Get summary info about a session."""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
row = conn.execute(
"""SELECT session_id, created_at, updated_at,
(SELECT COUNT(*) FROM messages WHERE session_id = ?) as message_count
FROM sessions WHERE session_id = ?""",
(session_id, session_id),
).fetchone()
if not row:
return None
return {
"session_id": row["session_id"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"message_count": row["message_count"],
}
def list_sessions(self, limit: int = 100) -> list[dict[str, Any]]:
"""List recent sessions."""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
rows = conn.execute(
"""SELECT session_id, created_at, updated_at
FROM sessions
ORDER BY updated_at DESC
LIMIT ?""",
(limit,),
).fetchall()
return [
{
"session_id": r["session_id"],
"created_at": r["created_at"],
"updated_at": r["updated_at"],
}
for r in rows
]
def clear_session(self, session_id: str) -> None:
"""Clear all messages from a session."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
conn.commit()
def delete_session(self, session_id: str) -> None:
"""Delete a session and all its messages."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
conn.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,))
conn.commit()

View File

@@ -0,0 +1,155 @@
"""Chat orchestrator - coordinates RAG, LLM, and memory for chat responses."""
from __future__ import annotations
import json
import time
from collections.abc import AsyncGenerator
from typing import Any
import httpx
from companion.config import Config
from companion.memory import Message, SessionMemory
from companion.prompts import build_system_prompt, format_conversation_history
from companion.rag.search import SearchEngine
class ChatOrchestrator:
"""Orchestrates chat by combining RAG, memory, and LLM streaming."""
def __init__(
self,
config: Config,
search_engine: SearchEngine,
session_memory: SessionMemory,
):
self.config = config
self.search = search_engine
self.memory = session_memory
self.model_endpoint = config.model.inference.backend
self.model_path = config.model.inference.model_path
async def chat(
self,
session_id: str,
user_message: str,
use_rag: bool = True,
) -> AsyncGenerator[str, None]:
"""Process a chat message and yield streaming responses.
Args:
session_id: Unique session identifier
user_message: The user's input message
use_rag: Whether to retrieve context from vault
Yields:
Streaming response chunks (SSE format)
"""
# Retrieve relevant context from RAG if enabled
retrieved_context: list[dict[str, Any]] = []
if use_rag:
try:
retrieved_context = self.search.search(
user_message, top_k=self.config.rag.search.default_top_k
)
except Exception as e:
# Log error but continue without RAG context
print(f"RAG retrieval failed: {e}")
# Get conversation history
history = self.memory.get_messages(
session_id, limit=self.config.companion.memory.session_turns
)
# Format history for prompt
history_formatted = format_conversation_history(
[{"role": msg.role, "content": msg.content} for msg in history]
)
# Build system prompt
system_prompt = build_system_prompt(
persona=self.config.companion.persona.model_dump(),
retrieved_context=retrieved_context if retrieved_context else None,
memory_context=history_formatted if history_formatted else None,
)
# Add user message to memory
self.memory.add_message(session_id, "user", user_message)
# Build messages for LLM
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
]
# Stream from LLM (using llama.cpp server or similar)
full_response = ""
async for chunk in self._stream_llm_response(messages):
yield chunk
if chunk.startswith("data: "):
data = chunk[6:] # Remove "data: " prefix
if data != "[DONE]":
try:
delta = json.loads(data)
if "content" in delta:
full_response += delta["content"]
except json.JSONDecodeError:
pass
# Store assistant response in memory
if full_response:
self.memory.add_message(
session_id,
"assistant",
full_response,
metadata={"rag_context_used": len(retrieved_context) > 0},
)
async def _stream_llm_response(
self, messages: list[dict[str, Any]]
) -> AsyncGenerator[str, None]:
"""Stream response from local LLM.
Uses llama.cpp HTTP server format or Ollama API.
"""
# Try llama.cpp server first
base_url = self.config.model.inference.backend
if base_url == "llama.cpp":
base_url = "http://localhost:8080" # Default llama.cpp server port
# Default to Ollama API
if base_url not in ["llama.cpp", "http://localhost:8080"]:
base_url = self.config.rag.embedding.base_url.replace("/api", "")
try:
async with httpx.AsyncClient(timeout=120.0) as client:
# Try Ollama chat endpoint
response = await client.post(
f"{base_url}/api/chat",
json={
"model": self.config.rag.embedding.model.replace("-embed", ""),
"messages": messages,
"stream": True,
"options": {
"temperature": self.config.companion.chat.default_temperature,
},
},
timeout=120.0,
)
async for line in response.aiter_lines():
if line.strip():
yield f"data: {line}\n\n"
except Exception as e:
yield f'data: {{"error": "LLM streaming failed: {str(e)}"}}\n\n'
yield "data: [DONE]\n\n"
def get_session_history(self, session_id: str, limit: int = 50) -> list[Message]:
"""Get conversation history for a session."""
return self.memory.get_messages(session_id, limit=limit)
def clear_session(self, session_id: str) -> None:
"""Clear a session's conversation history."""
self.memory.clear_session(session_id)

100
src/companion/prompts.py Normal file
View File

@@ -0,0 +1,100 @@
"""System prompts for the companion."""
from __future__ import annotations
from typing import Any
def build_system_prompt(
persona: dict[str, Any],
retrieved_context: list[dict[str, Any]] | None = None,
memory_context: str | None = None,
) -> str:
"""Build the system prompt for the companion.
Args:
persona: Companion persona configuration (name, role, tone, style, boundaries)
retrieved_context: Optional RAG context from vault search
memory_context: Optional memory summary from session
Returns:
The complete system prompt
"""
name = persona.get("name", "Companion")
role = persona.get("role", "companion")
tone = persona.get("tone", "reflective")
style = persona.get("style", "questioning")
boundaries = persona.get("boundaries", [])
# Base persona description
base_prompt = f"""You are {name}, a {tone}, {style} {role}.
Your role is to be a thoughtful, reflective companion—not to impersonate or speak for the user, but to explore alongside them. You know their life through their journal entries and are here to help them reflect, remember, and explore patterns.
Core principles:
- You do not speak as the user. You speak to them.
- You listen deeply and reflect patterns back gently.
- You ask questions that help them explore their own thoughts.
- You respect the boundaries of a confidante, not an oracle.
"""
# Add boundaries section
if boundaries:
base_prompt += "\nBoundaries:\n"
for boundary in boundaries:
base_prompt += f"- {boundary.replace('_', ' ').title()}\n"
# Add retrieved context section if available
context_section = ""
if retrieved_context:
context_section += "\n\nRelevant context from your vault:\n"
for i, ctx in enumerate(retrieved_context[:8], 1): # Limit to top 8 chunks
source = ctx.get("source_file", "unknown")
text = ctx.get("text", "").strip()
if text:
context_section += f"[{i}] From {source}:\n{text}\n\n"
# Add memory context if available
memory_section = ""
if memory_context:
memory_section = f"\n\nContext from your conversation:\n{memory_context}"
# Final instructions
closing = """
When responding:
- Draw from the vault context if relevant, but don't force it
- Be concise but thoughtful—no unnecessary length
- If uncertain, acknowledge the uncertainty
- Ask follow-up questions that deepen reflection
"""
return base_prompt + context_section + memory_section + closing
def format_conversation_history(
messages: list[dict[str, Any]], max_turns: int = 10
) -> str:
"""Format conversation history for prompt context.
Args:
messages: List of message dicts with 'role' and 'content'
max_turns: Maximum number of recent turns to include
Returns:
Formatted conversation history
"""
if not messages:
return ""
# Take only recent messages
recent = messages[-max_turns * 2 :] if len(messages) > max_turns * 2 else messages
formatted = []
for msg in recent:
role = msg.get("role", "user")
content = msg.get("content", "")
if content.strip():
formatted.append(f"{role.upper()}: {content}")
return "\n\n".join(formatted)

31
tests/test_api.py Normal file
View File

@@ -0,0 +1,31 @@
"""Simple smoke tests for FastAPI backend."""
import pytest
def test_api_imports():
"""Test that API module imports correctly."""
# This will fail if there are any import errors
from companion.api import app, ChatRequest
assert app is not None
assert ChatRequest is not None
def test_chat_request_model():
"""Test ChatRequest model validation."""
from companion.api import ChatRequest
# Valid request
req = ChatRequest(message="hello", session_id="abc123")
assert req.message == "hello"
assert req.session_id == "abc123"
# Valid request with temperature
req2 = ChatRequest(message="hello", temperature=0.7)
assert req2.temperature == 0.7
# Valid request with minimal fields
req3 = ChatRequest(message="hello")
assert req3.session_id is None
assert req3.temperature is None

604
tests/test_forge_extract.py Normal file
View File

@@ -0,0 +1,604 @@
"""Tests for training data extractor."""
import tempfile
from pathlib import Path
import pytest
from companion.config import Config, VaultConfig, IndexingConfig
from companion.forge.extract import (
TrainingDataExtractor,
TrainingExample,
_create_training_example,
_extract_date_from_filename,
_has_reflection_patterns,
_has_reflection_tags,
_is_likely_reflection,
extract_training_data,
)
def test_has_reflection_tags():
assert _has_reflection_tags("#reflection on today's events")
assert _has_reflection_tags("#decision made today")
assert not _has_reflection_tags("#worklog entry")
def test_has_reflection_patterns():
assert _has_reflection_patterns("I think this is important")
assert _has_reflection_patterns("I wonder if I should change")
assert _has_reflection_patterns("Looking back, I see the pattern")
assert not _has_reflection_patterns("The meeting was at 3pm")
def test_is_likely_reflection():
assert _is_likely_reflection("#reflection I think this matters")
assert _is_likely_reflection("I realize now that I was wrong")
assert not _is_likely_reflection("Just a regular note")
def test_extract_date_from_filename():
assert _extract_date_from_filename("2026-04-12.md") == "2026-04-12"
assert _extract_date_from_filename("12-Apr-2026.md") == "12-Apr-2026"
assert _extract_date_from_filename("2026-04-12-journal.md") == "2026-04-12"
assert _extract_date_from_filename("notes.md") is None
def test_create_training_example():
text = "#reflection I think I need to reconsider my approach. The way I've been handling this isn't working."
example = _create_training_example(
chunk_text=text,
source_file="journal/2026-04-12.md",
tags=["#reflection"],
date="2026-04-12",
)
assert example is not None
assert len(example.messages) == 3
assert example.messages[0]["role"] == "system"
assert example.messages[1]["role"] == "user"
assert example.messages[2]["role"] == "assistant"
assert example.messages[2]["content"] == text
assert example.source_file == "journal/2026-04-12.md"
def test_create_training_example_too_short():
text = "I think." # Too short
example = _create_training_example(
chunk_text=text,
source_file="test.md",
tags=["#reflection"],
date=None,
)
assert example is None
def test_create_training_example_no_reflection():
text = "This is just a regular note about the meeting at 3pm. Nothing special." * 5
example = _create_training_example(
chunk_text=text,
source_file="test.md",
tags=["#work"],
date=None,
)
assert example is None
def test_training_example_to_dict():
example = TrainingExample(
messages=[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
],
source_file="test.md",
tags=["#test"],
date="2026-04-12",
)
d = example.to_dict()
assert d["messages"][0]["role"] == "user"
assert d["source_file"] == "test.md"
assert d["date"] == "2026-04-12"
class TestTrainingDataExtractor:
def _get_config_dict(self, vault_path: Path) -> dict:
"""Return minimal config dict for testing."""
return {
"companion": {
"name": "SAN",
"persona": {
"role": "companion",
"tone": "reflective",
"style": "questioning",
"boundaries": [],
},
"memory": {
"session_turns": 20,
"persistent_store": "",
"summarize_after": 10,
},
"chat": {
"streaming": True,
"max_response_tokens": 2048,
"default_temperature": 0.7,
"allow_temperature_override": True,
},
},
"vault": {
"path": str(vault_path),
"indexing": {
"auto_sync": False,
"auto_sync_interval_minutes": 1440,
"watch_fs_events": False,
"file_patterns": ["*.md"],
"deny_dirs": [".git"],
"deny_patterns": [],
},
"chunking_rules": {},
},
"rag": {
"embedding": {
"provider": "ollama",
"model": "mxbai-embed-large",
"base_url": "http://localhost:11434",
"dimensions": 1024,
"batch_size": 32,
},
"vector_store": {"type": "lancedb", "path": ".test.vectors"},
"search": {
"default_top_k": 8,
"max_top_k": 20,
"similarity_threshold": 0.75,
"hybrid_search": {
"enabled": False,
"keyword_weight": 0.3,
"semantic_weight": 0.7,
},
"filters": {
"date_range_enabled": True,
"tag_filter_enabled": True,
"directory_filter_enabled": True,
},
},
},
"model": {
"inference": {
"backend": "llama.cpp",
"model_path": "",
"context_length": 8192,
"gpu_layers": 35,
"batch_size": 512,
"threads": 8,
},
"fine_tuning": {
"base_model": "",
"output_dir": "",
"lora_rank": 16,
"lora_alpha": 32,
"learning_rate": 0.0002,
"batch_size": 4,
"gradient_accumulation_steps": 4,
"num_epochs": 3,
"warmup_steps": 100,
"save_steps": 500,
"eval_steps": 250,
"training_data_path": "",
"validation_split": 0.1,
},
"retrain_schedule": {
"auto_reminder": True,
"default_interval_days": 90,
"reminder_channels": [],
},
},
"api": {
"host": "127.0.0.1",
"port": 7373,
"cors_origins": [],
"auth": {"enabled": False},
},
"ui": {
"web": {
"enabled": True,
"theme": "obsidian",
"features": {
"streaming": True,
"citations": True,
"source_preview": True,
},
},
"cli": {"enabled": True, "rich_output": True},
},
"logging": {
"level": "INFO",
"file": "",
"max_size_mb": 100,
"backup_count": 5,
},
"security": {
"local_only": True,
"vault_path_traversal_check": True,
"sensitive_content_detection": True,
"sensitive_patterns": [],
"require_confirmation_for_external_apis": True,
},
}
def test_extract_from_single_file(self):
with tempfile.TemporaryDirectory() as tmp:
vault = Path(tmp)
journal = vault / "Journal" / "2026" / "04"
journal.mkdir(parents=True)
content = """#DayInShort: Busy day
#reflection I think I need to slow down. The pace has been unsustainable.
#work Normal work day with meetings.
#insight I realize that I've been prioritizing urgency over importance.
"""
(journal / "2026-04-12.md").write_text(content, encoding="utf-8")
# Use helper method for config
from companion.config import load_config
import json
config_dict = self._get_config_dict(vault)
config_path = Path(tmp) / "test_config.json"
with open(config_path, "w") as f:
json.dump(config_dict, f)
config = load_config(config_path)
extractor = TrainingDataExtractor(config)
examples = extractor.extract()
# Should extract at least 2 reflection examples
assert len(examples) >= 2
# Check they have the right structure
for ex in examples:
assert len(ex.messages) == 3
assert ex.messages[2]["role"] == "assistant"
def _save_to_jsonl_helper(self):
"""Helper extracted to reduce nesting."""
pass # placeholder
"companion": {
"name": "SAN",
"persona": {
"role": "companion",
"tone": "reflective",
"style": "questioning",
"boundaries": [],
},
"memory": {
"session_turns": 20,
"persistent_store": "",
"summarize_after": 10,
},
"chat": {
"streaming": True,
"max_response_tokens": 2048,
"default_temperature": 0.7,
"allow_temperature_override": True,
},
},
"vault": {
"path": str(vault),
"indexing": {
"auto_sync": False,
"auto_sync_interval_minutes": 1440,
"watch_fs_events": False,
"file_patterns": ["*.md"],
"deny_dirs": [".git"],
"deny_patterns": [],
},
"chunking_rules": {},
},
"rag": {
"embedding": {
"provider": "ollama",
"model": "mxbai-embed-large",
"base_url": "http://localhost:11434",
"dimensions": 1024,
"batch_size": 32,
},
"vector_store": {"type": "lancedb", "path": ".test.vectors"},
"search": {
"default_top_k": 8,
"max_top_k": 20,
"similarity_threshold": 0.75,
"hybrid_search": {
"enabled": False,
"keyword_weight": 0.3,
"semantic_weight": 0.7,
},
"filters": {
"date_range_enabled": True,
"tag_filter_enabled": True,
"directory_filter_enabled": True,
},
},
},
"model": {
"inference": {
"backend": "llama.cpp",
"model_path": "",
"context_length": 8192,
"gpu_layers": 35,
"batch_size": 512,
"threads": 8,
},
"fine_tuning": {
"base_model": "",
"output_dir": "",
"lora_rank": 16,
"lora_alpha": 32,
"learning_rate": 0.0002,
"batch_size": 4,
"gradient_accumulation_steps": 4,
"num_epochs": 3,
"warmup_steps": 100,
"save_steps": 500,
"eval_steps": 250,
"training_data_path": "",
"validation_split": 0.1,
},
"retrain_schedule": {
"auto_reminder": True,
"default_interval_days": 90,
"reminder_channels": [],
},
},
"api": {
"host": "127.0.0.1",
"port": 7373,
"cors_origins": [],
"auth": {"enabled": False},
},
"ui": {
"web": {
"enabled": True,
"theme": "obsidian",
"features": {
"streaming": True,
"citations": True,
"source_preview": True,
},
},
"cli": {"enabled": True, "rich_output": True},
},
"logging": {
"level": "INFO",
"file": "",
"max_size_mb": 100,
"backup_count": 5,
},
"security": {
"local_only": True,
"vault_path_traversal_check": True,
"sensitive_content_detection": True,
"sensitive_patterns": [],
"require_confirmation_for_external_apis": True,
},
}
config_path = Path(tmp) / "test_config.json"
with open(config_path, "w") as f:
json.dump(config_dict, f)
config = load_config(config_path)
extractor = TrainingDataExtractor(config)
examples = extractor.extract()
# Should extract at least 2 reflection examples
assert len(examples) >= 2
# Check they have the right structure
for ex in examples:
assert len(ex.messages) == 3
assert ex.messages[2]["role"] == "assistant"
def test_save_to_jsonl(self):
with tempfile.TemporaryDirectory() as tmp:
output = Path(tmp) / "training.jsonl"
examples = [
TrainingExample(
messages=[
{"role": "system", "content": "sys"},
{"role": "user", "content": "user"},
{"role": "assistant", "content": "assistant"},
],
source_file="test.md",
tags=["#test"],
date="2026-04-12",
)
]
# Create minimal config for extractor
config_dict = self._get_config_dict(Path(tmp))
config_path = Path(tmp) / "test_config.json"
import json
with open(config_path, "w") as f:
json.dump(config_dict, f)
from companion.config import load_config
config = load_config(config_path)
extractor = TrainingDataExtractor(config)
extractor.examples = examples
count = extractor.save_to_jsonl(output)
assert count == 1
# Verify file content
lines = output.read_text(encoding="utf-8").strip().split("\n")
assert len(lines) == 1
assert "assistant" in lines[0]
def test_get_stats(self):
examples = [
TrainingExample(
messages=[
{"role": "system", "content": "sys"},
{"role": "user", "content": "user"},
{"role": "assistant", "content": "a" * 100},
],
source_file="test1.md",
tags=["#reflection", "#learning"],
date="2026-04-12",
),
TrainingExample(
messages=[
{"role": "system", "content": "sys"},
{"role": "user", "content": "user"},
{"role": "assistant", "content": "b" * 200},
],
source_file="test2.md",
tags=["#reflection", "#decision"],
date="2026-04-13",
),
]
# Create minimal config
with tempfile.TemporaryDirectory() as tmp:
config_dict = {
"companion": {
"name": "SAN",
"persona": {
"role": "companion",
"tone": "reflective",
"style": "questioning",
"boundaries": [],
},
"memory": {
"session_turns": 20,
"persistent_store": "",
"summarize_after": 10,
},
"chat": {
"streaming": True,
"max_response_tokens": 2048,
"default_temperature": 0.7,
"allow_temperature_override": True,
},
},
"vault": {
"path": str(tmp),
"indexing": {
"auto_sync": False,
"auto_sync_interval_minutes": 1440,
"watch_fs_events": False,
"file_patterns": ["*.md"],
"deny_dirs": [".git"],
"deny_patterns": [],
},
"chunking_rules": {},
},
"rag": {
"embedding": {
"provider": "ollama",
"model": "mxbai-embed-large",
"base_url": "http://localhost:11434",
"dimensions": 1024,
"batch_size": 32,
},
"vector_store": {"type": "lancedb", "path": ".test.vectors"},
"search": {
"default_top_k": 8,
"max_top_k": 20,
"similarity_threshold": 0.75,
"hybrid_search": {
"enabled": False,
"keyword_weight": 0.3,
"semantic_weight": 0.7,
},
"filters": {
"date_range_enabled": True,
"tag_filter_enabled": True,
"directory_filter_enabled": True,
},
},
},
"model": {
"inference": {
"backend": "llama.cpp",
"model_path": "",
"context_length": 8192,
"gpu_layers": 35,
"batch_size": 512,
"threads": 8,
},
"fine_tuning": {
"base_model": "",
"output_dir": "",
"lora_rank": 16,
"lora_alpha": 32,
"learning_rate": 0.0002,
"batch_size": 4,
"gradient_accumulation_steps": 4,
"num_epochs": 3,
"warmup_steps": 100,
"save_steps": 500,
"eval_steps": 250,
"training_data_path": "",
"validation_split": 0.1,
},
"retrain_schedule": {
"auto_reminder": True,
"default_interval_days": 90,
"reminder_channels": [],
},
},
"api": {
"host": "127.0.0.1",
"port": 7373,
"cors_origins": [],
"auth": {"enabled": False},
},
"ui": {
"web": {
"enabled": True,
"theme": "obsidian",
"features": {
"streaming": True,
"citations": True,
"source_preview": True,
},
},
"cli": {"enabled": True, "rich_output": True},
},
"logging": {
"level": "INFO",
"file": "",
"max_size_mb": 100,
"backup_count": 5,
},
"security": {
"local_only": True,
"vault_path_traversal_check": True,
"sensitive_content_detection": True,
"sensitive_patterns": [],
"require_confirmation_for_external_apis": True,
},
}
config_path = Path(tmp) / "test_config.json"
import json
with open(config_path, "w") as f:
json.dump(config_dict, f)
from companion.config import load_config
config = load_config(config_path)
extractor = TrainingDataExtractor(config)
extractor.examples = examples
stats = extractor.get_stats()
assert stats["total"] == 2
assert stats["avg_length"] == 150 # (100 + 200) // 2
assert len(stats["top_tags"]) > 0
assert stats["top_tags"][0][0] == "#reflection"