WIP: Phase 4 forge extract module with tests
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -9,3 +9,9 @@ venv/
|
|||||||
.companion/
|
.companion/
|
||||||
dist/
|
dist/
|
||||||
build/
|
build/
|
||||||
|
|
||||||
|
# Node.js / UI
|
||||||
|
node_modules/
|
||||||
|
ui/node_modules/
|
||||||
|
ui/dist/
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,156 @@
|
|||||||
|
# Phase 4: Fine-Tuning Pipeline Implementation Plan
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
Build a pipeline to extract training examples from the Obsidian vault and fine-tune a local 7B model using QLoRA on the RTX 5070.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ Training Data Pipeline │
|
||||||
|
│ ───────────────────── │
|
||||||
|
│ 1. Extract reflections from vault │
|
||||||
|
│ 2. Curate into conversation format │
|
||||||
|
│ 3. Split train/validation │
|
||||||
|
│ 4. Export to HuggingFace datasets format │
|
||||||
|
└─────────────────────────────────────────────────────────┘
|
||||||
|
↓
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ QLoRA Fine-Tuning (Unsloth) │
|
||||||
|
│ ─────────────────────────── │
|
||||||
|
│ - Base: Llama 3.1 8B Instruct (4-bit) │
|
||||||
|
│ - LoRA rank: 16, alpha: 32 │
|
||||||
|
│ - Target modules: q_proj, k_proj, v_proj, o_proj │
|
||||||
|
│ - Learning rate: 2e-4 │
|
||||||
|
│ - Epochs: 3 │
|
||||||
|
│ - Batch: 4, Gradient accumulation: 4 │
|
||||||
|
└─────────────────────────────────────────────────────────┘
|
||||||
|
↓
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ Model Export & Serving │
|
||||||
|
│ ───────────────────── │
|
||||||
|
│ - Export to GGUF (Q4_K_M quantization) │
|
||||||
|
│ - Serve via llama.cpp or vLLM │
|
||||||
|
│ - Hot-swap in FastAPI backend │
|
||||||
|
└─────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tasks
|
||||||
|
|
||||||
|
### Task 1: Training Data Extractor
|
||||||
|
**Files:**
|
||||||
|
- `src/companion/forge/extract.py` - Extract reflection examples from vault
|
||||||
|
- `tests/test_forge_extract.py` - Test extraction logic
|
||||||
|
|
||||||
|
**Spec:**
|
||||||
|
- Parse vault for "reflection" patterns (journal entries with insights, decision analyses)
|
||||||
|
- Look for tags: #reflection, #decision, #learning, etc.
|
||||||
|
- Extract entries where you reflect on situations, weigh options, or analyze outcomes
|
||||||
|
- Format as conversation: user prompt + assistant response (your reflection)
|
||||||
|
- Output: JSONL file with {"messages": [{"role": "...", "content": "..."}]}
|
||||||
|
|
||||||
|
### Task 2: Training Data Curator
|
||||||
|
**Files:**
|
||||||
|
- `src/companion/forge/curate.py` - Human-in-the-loop curation
|
||||||
|
- `src/companion/forge/cli.py` - CLI for curation workflow
|
||||||
|
|
||||||
|
**Spec:**
|
||||||
|
- Load extracted examples
|
||||||
|
- Interactive review: show each example, allow approve/reject/edit
|
||||||
|
- Track curation decisions in SQLite
|
||||||
|
- Export approved examples to final training set
|
||||||
|
- Deduplicate similar examples (use embeddings similarity)
|
||||||
|
|
||||||
|
### Task 3: Training Configuration
|
||||||
|
**Files:**
|
||||||
|
- `src/companion/forge/config.py` - Training hyperparameters
|
||||||
|
- `config.json` updates for fine_tuning section
|
||||||
|
|
||||||
|
**Spec:**
|
||||||
|
- Pydantic models for training config
|
||||||
|
- Hyperparameters tuned for RTX 5070 (12GB VRAM)
|
||||||
|
- Output paths, logging config
|
||||||
|
|
||||||
|
### Task 4: QLoRA Training Script
|
||||||
|
**Files:**
|
||||||
|
- `src/companion/forge/train.py` - Unsloth training script
|
||||||
|
- `scripts/train.sh` - Convenience launcher
|
||||||
|
|
||||||
|
**Spec:**
|
||||||
|
- Load base model: unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit
|
||||||
|
- Apply LoRA config (r=16, alpha=32, target_modules)
|
||||||
|
- Load and tokenize dataset
|
||||||
|
- Training loop with wandb logging (optional)
|
||||||
|
- Save checkpoints every 500 steps
|
||||||
|
- Validate on holdout set
|
||||||
|
|
||||||
|
### Task 5: Model Export
|
||||||
|
**Files:**
|
||||||
|
- `src/companion/forge/export.py` - Export to GGUF
|
||||||
|
- `src/companion/forge/merge.py` - Merge LoRA weights into base
|
||||||
|
|
||||||
|
**Spec:**
|
||||||
|
- Merge LoRA weights into base model
|
||||||
|
- Export to GGUF with Q4_K_M quantization
|
||||||
|
- Save to `~/.companion/models/`
|
||||||
|
- Update config.json with new model path
|
||||||
|
|
||||||
|
### Task 6: Model Hot-Swap
|
||||||
|
**Files:**
|
||||||
|
- Update `src/companion/api.py` - Add endpoint to reload model
|
||||||
|
- `src/companion/forge/reload.py` - Model reloader utility
|
||||||
|
|
||||||
|
**Spec:**
|
||||||
|
- `/admin/reload-model` endpoint (requires auth/local-only)
|
||||||
|
- Gracefully unload old model, load new GGUF
|
||||||
|
- Return status: success or error
|
||||||
|
|
||||||
|
### Task 7: Evaluation Framework
|
||||||
|
**Files:**
|
||||||
|
- `src/companion/forge/eval.py` - Evaluate model on test prompts
|
||||||
|
- `tests/test_forge_eval.py` - Evaluation tests
|
||||||
|
|
||||||
|
**Spec:**
|
||||||
|
- Load test prompts (decision scenarios, relationship questions)
|
||||||
|
- Generate responses from both base and fine-tuned model
|
||||||
|
- Store outputs for human comparison
|
||||||
|
- Track metrics: response time, token count
|
||||||
|
|
||||||
|
## Success Criteria
|
||||||
|
- [ ] Extract 100+ reflection examples from vault
|
||||||
|
- [ ] Curate down to 50-100 high-quality training examples
|
||||||
|
- [ ] Complete training run in <6 hours on RTX 5070
|
||||||
|
- [ ] Export produces valid GGUF file
|
||||||
|
- [ ] Hot-swap endpoint successfully reloads model
|
||||||
|
- [ ] Evaluation shows distinguishable "Santhosh-style" in outputs
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
```
|
||||||
|
unsloth>=2024.1.0
|
||||||
|
torch>=2.1.0
|
||||||
|
transformers>=4.36.0
|
||||||
|
datasets>=2.14.0
|
||||||
|
peft>=0.7.0
|
||||||
|
accelerate>=0.25.0
|
||||||
|
bitsandbytes>=0.41.0
|
||||||
|
sentencepiece>=0.1.99
|
||||||
|
protobuf>=3.20.0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
```bash
|
||||||
|
# Extract training data
|
||||||
|
python -m companion.forge.cli extract
|
||||||
|
|
||||||
|
# Curate examples
|
||||||
|
python -m companion.forge.cli curate
|
||||||
|
|
||||||
|
# Train
|
||||||
|
python -m companion.forge.train
|
||||||
|
|
||||||
|
# Export
|
||||||
|
python -m companion.forge.export
|
||||||
|
|
||||||
|
# Reload model in API
|
||||||
|
python -m companion.forge.reload
|
||||||
|
```
|
||||||
@@ -12,6 +12,11 @@ dependencies = [
|
|||||||
"typer>=0.12.0",
|
"typer>=0.12.0",
|
||||||
"rich>=13.0.0",
|
"rich>=13.0.0",
|
||||||
"numpy>=1.26.0",
|
"numpy>=1.26.0",
|
||||||
|
"fastapi>=0.109.0",
|
||||||
|
"uvicorn[standard]>=0.27.0",
|
||||||
|
"httpx>=0.27.0",
|
||||||
|
"sse-starlette>=2.0.0",
|
||||||
|
"python-multipart>=0.0.9",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
215
src/companion/api.py
Normal file
215
src/companion/api.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
"""FastAPI backend for Companion AI."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
|
from companion.config import Config, load_config
|
||||||
|
from companion.memory import SessionMemory
|
||||||
|
from companion.orchestrator import ChatOrchestrator
|
||||||
|
from companion.rag.search import SearchEngine
|
||||||
|
from companion.rag.vector_store import VectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
"""Chat request model."""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
session_id: str | None = None
|
||||||
|
temperature: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
"""Chat response model (non-streaming)."""
|
||||||
|
|
||||||
|
response: str
|
||||||
|
session_id: str
|
||||||
|
sources: list[dict] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Global instances
|
||||||
|
config: Config
|
||||||
|
vector_store: VectorStore
|
||||||
|
search_engine: SearchEngine
|
||||||
|
memory: SessionMemory
|
||||||
|
orchestrator: ChatOrchestrator
|
||||||
|
http_client: httpx.AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
"""Manage application lifespan."""
|
||||||
|
global config, vector_store, search_engine, memory, orchestrator, http_client
|
||||||
|
|
||||||
|
# Startup
|
||||||
|
config = load_config("config.json")
|
||||||
|
vector_store = VectorStore(
|
||||||
|
uri=config.rag.vector_store.path, dimensions=config.rag.embedding.dimensions
|
||||||
|
)
|
||||||
|
search_engine = SearchEngine(
|
||||||
|
vector_store=vector_store,
|
||||||
|
embedder_base_url=config.rag.embedding.base_url,
|
||||||
|
embedder_model=config.rag.embedding.model,
|
||||||
|
embedder_batch_size=config.rag.embedding.batch_size,
|
||||||
|
default_top_k=config.rag.search.default_top_k,
|
||||||
|
similarity_threshold=config.rag.search.similarity_threshold,
|
||||||
|
hybrid_search_enabled=config.rag.search.hybrid_search.enabled,
|
||||||
|
)
|
||||||
|
memory = SessionMemory(db_path=config.companion.memory.persistent_store)
|
||||||
|
http_client = httpx.AsyncClient(timeout=300.0)
|
||||||
|
orchestrator = ChatOrchestrator(
|
||||||
|
config=config,
|
||||||
|
search_engine=search_engine,
|
||||||
|
memory=memory,
|
||||||
|
http_client=http_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
await http_client.aclose()
|
||||||
|
memory.close()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Companion AI",
|
||||||
|
description="Personal AI companion with RAG",
|
||||||
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=config.api.cors_origins
|
||||||
|
if "config" in globals()
|
||||||
|
else ["http://localhost:5173"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check() -> dict:
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"version": "0.1.0",
|
||||||
|
"indexed_chunks": vector_store.count() if "vector_store" in globals() else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/chat")
|
||||||
|
async def chat(request: ChatRequest) -> EventSourceResponse:
|
||||||
|
"""Chat endpoint with SSE streaming."""
|
||||||
|
if not request.message.strip():
|
||||||
|
raise HTTPException(status_code=400, detail="Message cannot be empty")
|
||||||
|
|
||||||
|
# Generate or use existing session ID
|
||||||
|
session_id = request.session_id or str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Validate temperature if provided
|
||||||
|
temperature = request.temperature
|
||||||
|
if temperature is not None and not 0.0 <= temperature <= 2.0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Temperature must be between 0.0 and 2.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
"""Generate SSE events."""
|
||||||
|
try:
|
||||||
|
# Get conversation history
|
||||||
|
history = memory.get_history(session_id)
|
||||||
|
|
||||||
|
# Retrieve relevant context
|
||||||
|
context_chunks = search_engine.search(
|
||||||
|
request.message,
|
||||||
|
top_k=config.rag.search.default_top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream response from LLM
|
||||||
|
full_response = ""
|
||||||
|
async for chunk in orchestrator.stream_response(
|
||||||
|
query=request.message,
|
||||||
|
history=history,
|
||||||
|
context_chunks=context_chunks,
|
||||||
|
temperature=temperature,
|
||||||
|
):
|
||||||
|
full_response += chunk
|
||||||
|
yield json.dumps(
|
||||||
|
{
|
||||||
|
"type": "chunk",
|
||||||
|
"content": chunk,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the conversation
|
||||||
|
memory.add_message(session_id, "user", request.message)
|
||||||
|
memory.add_message(session_id, "assistant", full_response)
|
||||||
|
|
||||||
|
# Send sources
|
||||||
|
sources = [
|
||||||
|
{
|
||||||
|
"file": chunk.get("source_file", ""),
|
||||||
|
"section": chunk.get("section"),
|
||||||
|
"date": chunk.get("date"),
|
||||||
|
}
|
||||||
|
for chunk in context_chunks[:5] # Top 5 sources
|
||||||
|
]
|
||||||
|
yield json.dumps(
|
||||||
|
{
|
||||||
|
"type": "sources",
|
||||||
|
"sources": sources,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send done event
|
||||||
|
yield json.dumps(
|
||||||
|
{
|
||||||
|
"type": "done",
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield json.dumps(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"message": str(e),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return EventSourceResponse(event_generator())
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/sessions/{session_id}/history")
|
||||||
|
async def get_session_history(session_id: str) -> dict:
|
||||||
|
"""Get conversation history for a session."""
|
||||||
|
history = memory.get_history(session_id)
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content,
|
||||||
|
"timestamp": msg.timestamp.isoformat(),
|
||||||
|
}
|
||||||
|
for msg in history
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="127.0.0.1", port=7373)
|
||||||
1
src/companion/forge/__init__.py
Normal file
1
src/companion/forge/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Model Forge - Fine-tuning pipeline for companion AI
|
||||||
329
src/companion/forge/extract.py
Normal file
329
src/companion/forge/extract.py
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
"""Extract training examples from Obsidian vault.
|
||||||
|
|
||||||
|
Looks for reflection patterns, decision analyses, and personal insights
|
||||||
|
to create training data that teaches the model to reason like San.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
from companion.config import Config
|
||||||
|
from companion.rag.chunker import chunk_file
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingExample:
|
||||||
|
"""A single training example in conversation format."""
|
||||||
|
|
||||||
|
messages: list[dict[str, str]]
|
||||||
|
source_file: str
|
||||||
|
tags: list[str]
|
||||||
|
date: str | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"messages": self.messages,
|
||||||
|
"source_file": self.source_file,
|
||||||
|
"tags": self.tags,
|
||||||
|
"date": self.date,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Tags that indicate reflection/decision content
|
||||||
|
REFLECTION_TAGS = {
|
||||||
|
"#reflection",
|
||||||
|
"#decision",
|
||||||
|
"#learning",
|
||||||
|
"#insight",
|
||||||
|
"#analysis",
|
||||||
|
"#pondering",
|
||||||
|
"#evaluation",
|
||||||
|
"#takeaway",
|
||||||
|
"#realization",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patterns that suggest reflection content
|
||||||
|
REFLECTION_PATTERNS = [
|
||||||
|
r"(?i)I think\s+",
|
||||||
|
r"(?i)I feel\s+",
|
||||||
|
r"(?i)I realize\s+",
|
||||||
|
r"(?i)I wonder\s+",
|
||||||
|
r"(?i)I should\s+",
|
||||||
|
r"(?i)I need to\s+",
|
||||||
|
r"(?i)The reason\s+",
|
||||||
|
r"(?i)What if\s+",
|
||||||
|
r"(?i)Maybe\s+",
|
||||||
|
r"(?i)Perhaps\s+",
|
||||||
|
r"(?i)On one hand.*?on the other hand",
|
||||||
|
r"(?i)Pros?:.*?Cons?:",
|
||||||
|
r"(?i)Weighing.*?(?:options?|choices?|alternatives?)",
|
||||||
|
r"(?i)Ultimately.*?decided",
|
||||||
|
r"(?i)Looking back",
|
||||||
|
r"(?i)In hindsight",
|
||||||
|
r"(?i)I've learned\s+",
|
||||||
|
r"(?i)The lesson\s+",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _has_reflection_tags(text: str) -> bool:
|
||||||
|
"""Check if text contains reflection-related hashtags."""
|
||||||
|
hashtags = set(re.findall(r"#\w+", text))
|
||||||
|
return bool(hashtags & REFLECTION_TAGS)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_reflection_patterns(text: str) -> bool:
|
||||||
|
"""Check if text contains reflection language patterns."""
|
||||||
|
for pattern in REFLECTION_PATTERNS:
|
||||||
|
if re.search(pattern, text):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_likely_reflection(text: str) -> bool:
|
||||||
|
"""Determine if a text chunk is likely a reflection."""
|
||||||
|
# Must have reflection tags OR strong reflection patterns
|
||||||
|
has_tags = _has_reflection_tags(text)
|
||||||
|
has_patterns = _has_reflection_patterns(text)
|
||||||
|
|
||||||
|
# Require at least one indicator
|
||||||
|
return has_tags or has_patterns
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_date_from_filename(filename: str) -> str | None:
|
||||||
|
"""Extract date from filename patterns like 2026-04-12 or "12 Apr 2026"."""
|
||||||
|
# ISO format: 2026-04-12
|
||||||
|
m = re.search(r"(\d{4}-\d{2}-\d{2})", filename)
|
||||||
|
if m:
|
||||||
|
return m.group(1)
|
||||||
|
|
||||||
|
# Human format: 12-Apr-2026 or "12 Apr 2026"
|
||||||
|
m = re.search(r"(\d{1,2}[-\s][A-Za-z]{3}[-\s]\d{4})", filename)
|
||||||
|
if m:
|
||||||
|
return m.group(1).replace(" ", "-").replace("--", "-")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _create_training_prompt(chunk_text: str) -> str:
|
||||||
|
"""Create a user prompt that elicits a reflection similar to the example."""
|
||||||
|
# Extract context from the chunk to form a relevant question
|
||||||
|
|
||||||
|
# Look for decision patterns
|
||||||
|
if re.search(r"(?i)decided|decision|chose|choice", chunk_text):
|
||||||
|
return "I'm facing a decision. How should I think through this?"
|
||||||
|
|
||||||
|
# Look for relationship content
|
||||||
|
if re.search(r"(?i)friend|relationship|person|people|someone", chunk_text):
|
||||||
|
return "What do you think about this situation with someone I'm close to?"
|
||||||
|
|
||||||
|
# Look for work/career content
|
||||||
|
if re.search(r"(?i)work|job|career|project|professional", chunk_text):
|
||||||
|
return "I'm thinking about something at work. What's your perspective?"
|
||||||
|
|
||||||
|
# Look for health content
|
||||||
|
if re.search(r"(?i)health|mental|physical|stress|wellness", chunk_text):
|
||||||
|
return "I've been thinking about my well-being. What do you notice?"
|
||||||
|
|
||||||
|
# Look for financial content
|
||||||
|
if re.search(r"(?i)money|finance|financial|invest|saving|spending", chunk_text):
|
||||||
|
return "I'm considering a financial decision. How should I evaluate this?"
|
||||||
|
|
||||||
|
# Default prompt
|
||||||
|
return "I'm reflecting on something. What patterns do you see?"
|
||||||
|
|
||||||
|
|
||||||
|
def _create_training_example(
|
||||||
|
chunk_text: str, source_file: str, tags: list[str], date: str | None
|
||||||
|
) -> TrainingExample | None:
|
||||||
|
"""Convert a chunk into a training example if it meets criteria."""
|
||||||
|
|
||||||
|
# Skip if too short
|
||||||
|
if len(chunk_text) < 100:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Skip if not a reflection
|
||||||
|
if not _is_likely_reflection(chunk_text):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Clean up the text
|
||||||
|
cleaned = chunk_text.strip()
|
||||||
|
|
||||||
|
# Create prompt-response pair
|
||||||
|
prompt = _create_training_prompt(cleaned)
|
||||||
|
|
||||||
|
# The response is your reflection/insight
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a thoughtful, reflective companion."},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
{"role": "assistant", "content": cleaned},
|
||||||
|
]
|
||||||
|
|
||||||
|
return TrainingExample(
|
||||||
|
messages=messages,
|
||||||
|
source_file=source_file,
|
||||||
|
tags=tags,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingDataExtractor:
|
||||||
|
"""Extracts training examples from Obsidian vault."""
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
self.config = config
|
||||||
|
self.vault_path = Path(config.vault.path)
|
||||||
|
self.examples: list[TrainingExample] = []
|
||||||
|
|
||||||
|
def extract(self) -> list[TrainingExample]:
|
||||||
|
"""Extract all training examples from the vault."""
|
||||||
|
self.examples = []
|
||||||
|
|
||||||
|
# Walk vault for markdown files
|
||||||
|
for md_file in self.vault_path.rglob("*.md"):
|
||||||
|
# Skip denied directories
|
||||||
|
relative = md_file.relative_to(self.vault_path)
|
||||||
|
if any(part.startswith(".") for part in relative.parts):
|
||||||
|
continue
|
||||||
|
if any(
|
||||||
|
part in [".obsidian", ".trash", "zzz-Archive"]
|
||||||
|
for part in relative.parts
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract from this file
|
||||||
|
file_examples = self._extract_from_file(md_file)
|
||||||
|
self.examples.extend(file_examples)
|
||||||
|
|
||||||
|
return self.examples
|
||||||
|
|
||||||
|
def _extract_from_file(self, md_file: Path) -> list[TrainingExample]:
|
||||||
|
"""Extract training examples from a single file."""
|
||||||
|
examples = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
text = md_file.read_text(encoding="utf-8")
|
||||||
|
except Exception:
|
||||||
|
return examples
|
||||||
|
|
||||||
|
# Get date from filename
|
||||||
|
date = _extract_date_from_filename(md_file.name)
|
||||||
|
|
||||||
|
# Split into sections (by headers like #Section or # Tag:)
|
||||||
|
# Pattern matches lines starting with # that have content
|
||||||
|
sections = re.split(r"\n(?=#[^#])", text)
|
||||||
|
|
||||||
|
for section in sections:
|
||||||
|
section = section.strip()
|
||||||
|
if not section or len(section) < 50: # Skip very short sections
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract tags from section
|
||||||
|
hashtags = re.findall(r"#[\w\-]+", section)
|
||||||
|
|
||||||
|
# Try to create training example
|
||||||
|
example = _create_training_example(
|
||||||
|
chunk_text=section,
|
||||||
|
source_file=str(md_file.relative_to(self.vault_path)).replace(
|
||||||
|
"\\", "/"
|
||||||
|
),
|
||||||
|
tags=hashtags,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
|
|
||||||
|
if example:
|
||||||
|
examples.append(example)
|
||||||
|
|
||||||
|
# If no reflection sections found, try the whole file
|
||||||
|
if not examples and len(text) >= 100:
|
||||||
|
hashtags = re.findall(r"#[\w\-]+", text)
|
||||||
|
if _is_likely_reflection(text):
|
||||||
|
example = _create_training_example(
|
||||||
|
chunk_text=text.strip(),
|
||||||
|
source_file=str(md_file.relative_to(self.vault_path)).replace(
|
||||||
|
"\\", "/"
|
||||||
|
),
|
||||||
|
tags=hashtags,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
|
if example:
|
||||||
|
examples.append(example)
|
||||||
|
|
||||||
|
return examples
|
||||||
|
|
||||||
|
# Get date from filename
|
||||||
|
date = _extract_date_from_filename(md_file.name)
|
||||||
|
|
||||||
|
# Split into sections (by headers)
|
||||||
|
sections = re.split(r"\n(?=#+\s)", text)
|
||||||
|
|
||||||
|
for section in sections:
|
||||||
|
section = section.strip()
|
||||||
|
if not section:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract tags from section
|
||||||
|
hashtags = re.findall(r"#\w+", section)
|
||||||
|
|
||||||
|
# Try to create training example
|
||||||
|
example = _create_training_example(
|
||||||
|
chunk_text=section,
|
||||||
|
source_file=str(md_file.relative_to(self.vault_path)).replace(
|
||||||
|
"\\", "/"
|
||||||
|
),
|
||||||
|
tags=hashtags,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
|
|
||||||
|
if example:
|
||||||
|
examples.append(example)
|
||||||
|
|
||||||
|
return examples
|
||||||
|
|
||||||
|
def save_to_jsonl(self, output_path: Path) -> int:
|
||||||
|
"""Save extracted examples to JSONL file."""
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
for example in self.examples:
|
||||||
|
f.write(json.dumps(example.to_dict(), ensure_ascii=False) + "\n")
|
||||||
|
return len(self.examples)
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Get statistics about extracted examples."""
|
||||||
|
if not self.examples:
|
||||||
|
return {"total": 0}
|
||||||
|
|
||||||
|
tag_counts: dict[str, int] = {}
|
||||||
|
for ex in self.examples:
|
||||||
|
for tag in ex.tags:
|
||||||
|
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": len(self.examples),
|
||||||
|
"avg_length": sum(len(ex.messages[2]["content"]) for ex in self.examples)
|
||||||
|
// len(self.examples),
|
||||||
|
"top_tags": sorted(tag_counts.items(), key=lambda x: x[1], reverse=True)[
|
||||||
|
:10
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_training_data(
|
||||||
|
config_path: str = "config.json",
|
||||||
|
) -> tuple[list[TrainingExample], dict]:
|
||||||
|
"""Convenience function to extract training data from vault.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (examples list, stats dict)
|
||||||
|
"""
|
||||||
|
from companion.config import load_config
|
||||||
|
|
||||||
|
config = load_config(config_path)
|
||||||
|
extractor = TrainingDataExtractor(config)
|
||||||
|
examples = extractor.extract()
|
||||||
|
stats = extractor.get_stats()
|
||||||
|
|
||||||
|
return examples, stats
|
||||||
178
src/companion/memory.py
Normal file
178
src/companion/memory.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""SQLite-based session memory for the companion."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message:
|
||||||
|
"""A single message in the conversation."""
|
||||||
|
|
||||||
|
role: str # "user" | "assistant" | "system"
|
||||||
|
content: str
|
||||||
|
timestamp: datetime
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SessionMemory:
|
||||||
|
"""Manages conversation history in SQLite."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | Path):
|
||||||
|
self.db_path = Path(db_path)
|
||||||
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
def _init_db(self) -> None:
|
||||||
|
"""Initialize the database schema."""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
session_id TEXT PRIMARY KEY,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
metadata TEXT,
|
||||||
|
FOREIGN KEY (session_id) REFERENCES sessions(session_id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_messages_session
|
||||||
|
ON messages(session_id, timestamp)
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_or_create_session(self, session_id: str) -> str:
|
||||||
|
"""Get existing session or create new one."""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR IGNORE INTO sessions (session_id) VALUES (?)", (session_id,)
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""UPDATE sessions SET updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE session_id = ?""",
|
||||||
|
(session_id,),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
def add_message(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
role: str,
|
||||||
|
content: str,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Add a message to the session."""
|
||||||
|
self.get_or_create_session(session_id)
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO messages (session_id, role, content, metadata)
|
||||||
|
VALUES (?, ?, ?, ?)""",
|
||||||
|
(session_id, role, content, json.dumps(metadata) if metadata else None),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_messages(
|
||||||
|
self, session_id: str, limit: int = 20, before_id: int | None = None
|
||||||
|
) -> list[Message]:
|
||||||
|
"""Get messages from a session, most recent first."""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
if before_id:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""SELECT role, content, timestamp, metadata
|
||||||
|
FROM messages
|
||||||
|
WHERE session_id = ? AND id < ?
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
LIMIT ?""",
|
||||||
|
(session_id, before_id, limit),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""SELECT role, content, timestamp, metadata
|
||||||
|
FROM messages
|
||||||
|
WHERE session_id = ?
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
LIMIT ?""",
|
||||||
|
(session_id, limit),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for row in rows:
|
||||||
|
meta = json.loads(row["metadata"]) if row["metadata"] else None
|
||||||
|
messages.append(
|
||||||
|
Message(
|
||||||
|
role=row["role"],
|
||||||
|
content=row["content"],
|
||||||
|
timestamp=datetime.fromisoformat(row["timestamp"]),
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Return in chronological order
|
||||||
|
return list(reversed(messages))
|
||||||
|
|
||||||
|
def get_session_summary(self, session_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Get summary info about a session."""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
row = conn.execute(
|
||||||
|
"""SELECT session_id, created_at, updated_at,
|
||||||
|
(SELECT COUNT(*) FROM messages WHERE session_id = ?) as message_count
|
||||||
|
FROM sessions WHERE session_id = ?""",
|
||||||
|
(session_id, session_id),
|
||||||
|
).fetchone()
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"session_id": row["session_id"],
|
||||||
|
"created_at": row["created_at"],
|
||||||
|
"updated_at": row["updated_at"],
|
||||||
|
"message_count": row["message_count"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def list_sessions(self, limit: int = 100) -> list[dict[str, Any]]:
|
||||||
|
"""List recent sessions."""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
rows = conn.execute(
|
||||||
|
"""SELECT session_id, created_at, updated_at
|
||||||
|
FROM sessions
|
||||||
|
ORDER BY updated_at DESC
|
||||||
|
LIMIT ?""",
|
||||||
|
(limit,),
|
||||||
|
).fetchall()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"session_id": r["session_id"],
|
||||||
|
"created_at": r["created_at"],
|
||||||
|
"updated_at": r["updated_at"],
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
def clear_session(self, session_id: str) -> None:
|
||||||
|
"""Clear all messages from a session."""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def delete_session(self, session_id: str) -> None:
|
||||||
|
"""Delete a session and all its messages."""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||||
|
conn.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,))
|
||||||
|
conn.commit()
|
||||||
155
src/companion/orchestrator.py
Normal file
155
src/companion/orchestrator.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""Chat orchestrator - coordinates RAG, LLM, and memory for chat responses."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from companion.config import Config
|
||||||
|
from companion.memory import Message, SessionMemory
|
||||||
|
from companion.prompts import build_system_prompt, format_conversation_history
|
||||||
|
from companion.rag.search import SearchEngine
|
||||||
|
|
||||||
|
|
||||||
|
class ChatOrchestrator:
|
||||||
|
"""Orchestrates chat by combining RAG, memory, and LLM streaming."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Config,
|
||||||
|
search_engine: SearchEngine,
|
||||||
|
session_memory: SessionMemory,
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.search = search_engine
|
||||||
|
self.memory = session_memory
|
||||||
|
self.model_endpoint = config.model.inference.backend
|
||||||
|
self.model_path = config.model.inference.model_path
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
user_message: str,
|
||||||
|
use_rag: bool = True,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Process a chat message and yield streaming responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Unique session identifier
|
||||||
|
user_message: The user's input message
|
||||||
|
use_rag: Whether to retrieve context from vault
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Streaming response chunks (SSE format)
|
||||||
|
"""
|
||||||
|
# Retrieve relevant context from RAG if enabled
|
||||||
|
retrieved_context: list[dict[str, Any]] = []
|
||||||
|
if use_rag:
|
||||||
|
try:
|
||||||
|
retrieved_context = self.search.search(
|
||||||
|
user_message, top_k=self.config.rag.search.default_top_k
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but continue without RAG context
|
||||||
|
print(f"RAG retrieval failed: {e}")
|
||||||
|
|
||||||
|
# Get conversation history
|
||||||
|
history = self.memory.get_messages(
|
||||||
|
session_id, limit=self.config.companion.memory.session_turns
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format history for prompt
|
||||||
|
history_formatted = format_conversation_history(
|
||||||
|
[{"role": msg.role, "content": msg.content} for msg in history]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build system prompt
|
||||||
|
system_prompt = build_system_prompt(
|
||||||
|
persona=self.config.companion.persona.model_dump(),
|
||||||
|
retrieved_context=retrieved_context if retrieved_context else None,
|
||||||
|
memory_context=history_formatted if history_formatted else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add user message to memory
|
||||||
|
self.memory.add_message(session_id, "user", user_message)
|
||||||
|
|
||||||
|
# Build messages for LLM
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_message},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Stream from LLM (using llama.cpp server or similar)
|
||||||
|
full_response = ""
|
||||||
|
async for chunk in self._stream_llm_response(messages):
|
||||||
|
yield chunk
|
||||||
|
if chunk.startswith("data: "):
|
||||||
|
data = chunk[6:] # Remove "data: " prefix
|
||||||
|
if data != "[DONE]":
|
||||||
|
try:
|
||||||
|
delta = json.loads(data)
|
||||||
|
if "content" in delta:
|
||||||
|
full_response += delta["content"]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Store assistant response in memory
|
||||||
|
if full_response:
|
||||||
|
self.memory.add_message(
|
||||||
|
session_id,
|
||||||
|
"assistant",
|
||||||
|
full_response,
|
||||||
|
metadata={"rag_context_used": len(retrieved_context) > 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_llm_response(
|
||||||
|
self, messages: list[dict[str, Any]]
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Stream response from local LLM.
|
||||||
|
|
||||||
|
Uses llama.cpp HTTP server format or Ollama API.
|
||||||
|
"""
|
||||||
|
# Try llama.cpp server first
|
||||||
|
base_url = self.config.model.inference.backend
|
||||||
|
if base_url == "llama.cpp":
|
||||||
|
base_url = "http://localhost:8080" # Default llama.cpp server port
|
||||||
|
|
||||||
|
# Default to Ollama API
|
||||||
|
if base_url not in ["llama.cpp", "http://localhost:8080"]:
|
||||||
|
base_url = self.config.rag.embedding.base_url.replace("/api", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
# Try Ollama chat endpoint
|
||||||
|
response = await client.post(
|
||||||
|
f"{base_url}/api/chat",
|
||||||
|
json={
|
||||||
|
"model": self.config.rag.embedding.model.replace("-embed", ""),
|
||||||
|
"messages": messages,
|
||||||
|
"stream": True,
|
||||||
|
"options": {
|
||||||
|
"temperature": self.config.companion.chat.default_temperature,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
timeout=120.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.strip():
|
||||||
|
yield f"data: {line}\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield f'data: {{"error": "LLM streaming failed: {str(e)}"}}\n\n'
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
def get_session_history(self, session_id: str, limit: int = 50) -> list[Message]:
|
||||||
|
"""Get conversation history for a session."""
|
||||||
|
return self.memory.get_messages(session_id, limit=limit)
|
||||||
|
|
||||||
|
def clear_session(self, session_id: str) -> None:
|
||||||
|
"""Clear a session's conversation history."""
|
||||||
|
self.memory.clear_session(session_id)
|
||||||
100
src/companion/prompts.py
Normal file
100
src/companion/prompts.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""System prompts for the companion."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def build_system_prompt(
|
||||||
|
persona: dict[str, Any],
|
||||||
|
retrieved_context: list[dict[str, Any]] | None = None,
|
||||||
|
memory_context: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build the system prompt for the companion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
persona: Companion persona configuration (name, role, tone, style, boundaries)
|
||||||
|
retrieved_context: Optional RAG context from vault search
|
||||||
|
memory_context: Optional memory summary from session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The complete system prompt
|
||||||
|
"""
|
||||||
|
name = persona.get("name", "Companion")
|
||||||
|
role = persona.get("role", "companion")
|
||||||
|
tone = persona.get("tone", "reflective")
|
||||||
|
style = persona.get("style", "questioning")
|
||||||
|
boundaries = persona.get("boundaries", [])
|
||||||
|
|
||||||
|
# Base persona description
|
||||||
|
base_prompt = f"""You are {name}, a {tone}, {style} {role}.
|
||||||
|
|
||||||
|
Your role is to be a thoughtful, reflective companion—not to impersonate or speak for the user, but to explore alongside them. You know their life through their journal entries and are here to help them reflect, remember, and explore patterns.
|
||||||
|
|
||||||
|
Core principles:
|
||||||
|
- You do not speak as the user. You speak to them.
|
||||||
|
- You listen deeply and reflect patterns back gently.
|
||||||
|
- You ask questions that help them explore their own thoughts.
|
||||||
|
- You respect the boundaries of a confidante, not an oracle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Add boundaries section
|
||||||
|
if boundaries:
|
||||||
|
base_prompt += "\nBoundaries:\n"
|
||||||
|
for boundary in boundaries:
|
||||||
|
base_prompt += f"- {boundary.replace('_', ' ').title()}\n"
|
||||||
|
|
||||||
|
# Add retrieved context section if available
|
||||||
|
context_section = ""
|
||||||
|
if retrieved_context:
|
||||||
|
context_section += "\n\nRelevant context from your vault:\n"
|
||||||
|
for i, ctx in enumerate(retrieved_context[:8], 1): # Limit to top 8 chunks
|
||||||
|
source = ctx.get("source_file", "unknown")
|
||||||
|
text = ctx.get("text", "").strip()
|
||||||
|
if text:
|
||||||
|
context_section += f"[{i}] From {source}:\n{text}\n\n"
|
||||||
|
|
||||||
|
# Add memory context if available
|
||||||
|
memory_section = ""
|
||||||
|
if memory_context:
|
||||||
|
memory_section = f"\n\nContext from your conversation:\n{memory_context}"
|
||||||
|
|
||||||
|
# Final instructions
|
||||||
|
closing = """
|
||||||
|
|
||||||
|
When responding:
|
||||||
|
- Draw from the vault context if relevant, but don't force it
|
||||||
|
- Be concise but thoughtful—no unnecessary length
|
||||||
|
- If uncertain, acknowledge the uncertainty
|
||||||
|
- Ask follow-up questions that deepen reflection
|
||||||
|
"""
|
||||||
|
|
||||||
|
return base_prompt + context_section + memory_section + closing
|
||||||
|
|
||||||
|
|
||||||
|
def format_conversation_history(
|
||||||
|
messages: list[dict[str, Any]], max_turns: int = 10
|
||||||
|
) -> str:
|
||||||
|
"""Format conversation history for prompt context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content'
|
||||||
|
max_turns: Maximum number of recent turns to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted conversation history
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Take only recent messages
|
||||||
|
recent = messages[-max_turns * 2 :] if len(messages) > max_turns * 2 else messages
|
||||||
|
|
||||||
|
formatted = []
|
||||||
|
for msg in recent:
|
||||||
|
role = msg.get("role", "user")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if content.strip():
|
||||||
|
formatted.append(f"{role.upper()}: {content}")
|
||||||
|
|
||||||
|
return "\n\n".join(formatted)
|
||||||
31
tests/test_api.py
Normal file
31
tests/test_api.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Simple smoke tests for FastAPI backend."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_imports():
|
||||||
|
"""Test that API module imports correctly."""
|
||||||
|
# This will fail if there are any import errors
|
||||||
|
from companion.api import app, ChatRequest
|
||||||
|
|
||||||
|
assert app is not None
|
||||||
|
assert ChatRequest is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_request_model():
|
||||||
|
"""Test ChatRequest model validation."""
|
||||||
|
from companion.api import ChatRequest
|
||||||
|
|
||||||
|
# Valid request
|
||||||
|
req = ChatRequest(message="hello", session_id="abc123")
|
||||||
|
assert req.message == "hello"
|
||||||
|
assert req.session_id == "abc123"
|
||||||
|
|
||||||
|
# Valid request with temperature
|
||||||
|
req2 = ChatRequest(message="hello", temperature=0.7)
|
||||||
|
assert req2.temperature == 0.7
|
||||||
|
|
||||||
|
# Valid request with minimal fields
|
||||||
|
req3 = ChatRequest(message="hello")
|
||||||
|
assert req3.session_id is None
|
||||||
|
assert req3.temperature is None
|
||||||
604
tests/test_forge_extract.py
Normal file
604
tests/test_forge_extract.py
Normal file
@@ -0,0 +1,604 @@
|
|||||||
|
"""Tests for training data extractor."""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from companion.config import Config, VaultConfig, IndexingConfig
|
||||||
|
from companion.forge.extract import (
|
||||||
|
TrainingDataExtractor,
|
||||||
|
TrainingExample,
|
||||||
|
_create_training_example,
|
||||||
|
_extract_date_from_filename,
|
||||||
|
_has_reflection_patterns,
|
||||||
|
_has_reflection_tags,
|
||||||
|
_is_likely_reflection,
|
||||||
|
extract_training_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_reflection_tags():
|
||||||
|
assert _has_reflection_tags("#reflection on today's events")
|
||||||
|
assert _has_reflection_tags("#decision made today")
|
||||||
|
assert not _has_reflection_tags("#worklog entry")
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_reflection_patterns():
|
||||||
|
assert _has_reflection_patterns("I think this is important")
|
||||||
|
assert _has_reflection_patterns("I wonder if I should change")
|
||||||
|
assert _has_reflection_patterns("Looking back, I see the pattern")
|
||||||
|
assert not _has_reflection_patterns("The meeting was at 3pm")
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_likely_reflection():
|
||||||
|
assert _is_likely_reflection("#reflection I think this matters")
|
||||||
|
assert _is_likely_reflection("I realize now that I was wrong")
|
||||||
|
assert not _is_likely_reflection("Just a regular note")
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_date_from_filename():
|
||||||
|
assert _extract_date_from_filename("2026-04-12.md") == "2026-04-12"
|
||||||
|
assert _extract_date_from_filename("12-Apr-2026.md") == "12-Apr-2026"
|
||||||
|
assert _extract_date_from_filename("2026-04-12-journal.md") == "2026-04-12"
|
||||||
|
assert _extract_date_from_filename("notes.md") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_training_example():
|
||||||
|
text = "#reflection I think I need to reconsider my approach. The way I've been handling this isn't working."
|
||||||
|
example = _create_training_example(
|
||||||
|
chunk_text=text,
|
||||||
|
source_file="journal/2026-04-12.md",
|
||||||
|
tags=["#reflection"],
|
||||||
|
date="2026-04-12",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert example is not None
|
||||||
|
assert len(example.messages) == 3
|
||||||
|
assert example.messages[0]["role"] == "system"
|
||||||
|
assert example.messages[1]["role"] == "user"
|
||||||
|
assert example.messages[2]["role"] == "assistant"
|
||||||
|
assert example.messages[2]["content"] == text
|
||||||
|
assert example.source_file == "journal/2026-04-12.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_training_example_too_short():
|
||||||
|
text = "I think." # Too short
|
||||||
|
example = _create_training_example(
|
||||||
|
chunk_text=text,
|
||||||
|
source_file="test.md",
|
||||||
|
tags=["#reflection"],
|
||||||
|
date=None,
|
||||||
|
)
|
||||||
|
assert example is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_training_example_no_reflection():
|
||||||
|
text = "This is just a regular note about the meeting at 3pm. Nothing special." * 5
|
||||||
|
example = _create_training_example(
|
||||||
|
chunk_text=text,
|
||||||
|
source_file="test.md",
|
||||||
|
tags=["#work"],
|
||||||
|
date=None,
|
||||||
|
)
|
||||||
|
assert example is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_example_to_dict():
|
||||||
|
example = TrainingExample(
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi"},
|
||||||
|
],
|
||||||
|
source_file="test.md",
|
||||||
|
tags=["#test"],
|
||||||
|
date="2026-04-12",
|
||||||
|
)
|
||||||
|
d = example.to_dict()
|
||||||
|
assert d["messages"][0]["role"] == "user"
|
||||||
|
assert d["source_file"] == "test.md"
|
||||||
|
assert d["date"] == "2026-04-12"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainingDataExtractor:
|
||||||
|
def _get_config_dict(self, vault_path: Path) -> dict:
|
||||||
|
"""Return minimal config dict for testing."""
|
||||||
|
return {
|
||||||
|
"companion": {
|
||||||
|
"name": "SAN",
|
||||||
|
"persona": {
|
||||||
|
"role": "companion",
|
||||||
|
"tone": "reflective",
|
||||||
|
"style": "questioning",
|
||||||
|
"boundaries": [],
|
||||||
|
},
|
||||||
|
"memory": {
|
||||||
|
"session_turns": 20,
|
||||||
|
"persistent_store": "",
|
||||||
|
"summarize_after": 10,
|
||||||
|
},
|
||||||
|
"chat": {
|
||||||
|
"streaming": True,
|
||||||
|
"max_response_tokens": 2048,
|
||||||
|
"default_temperature": 0.7,
|
||||||
|
"allow_temperature_override": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"vault": {
|
||||||
|
"path": str(vault_path),
|
||||||
|
"indexing": {
|
||||||
|
"auto_sync": False,
|
||||||
|
"auto_sync_interval_minutes": 1440,
|
||||||
|
"watch_fs_events": False,
|
||||||
|
"file_patterns": ["*.md"],
|
||||||
|
"deny_dirs": [".git"],
|
||||||
|
"deny_patterns": [],
|
||||||
|
},
|
||||||
|
"chunking_rules": {},
|
||||||
|
},
|
||||||
|
"rag": {
|
||||||
|
"embedding": {
|
||||||
|
"provider": "ollama",
|
||||||
|
"model": "mxbai-embed-large",
|
||||||
|
"base_url": "http://localhost:11434",
|
||||||
|
"dimensions": 1024,
|
||||||
|
"batch_size": 32,
|
||||||
|
},
|
||||||
|
"vector_store": {"type": "lancedb", "path": ".test.vectors"},
|
||||||
|
"search": {
|
||||||
|
"default_top_k": 8,
|
||||||
|
"max_top_k": 20,
|
||||||
|
"similarity_threshold": 0.75,
|
||||||
|
"hybrid_search": {
|
||||||
|
"enabled": False,
|
||||||
|
"keyword_weight": 0.3,
|
||||||
|
"semantic_weight": 0.7,
|
||||||
|
},
|
||||||
|
"filters": {
|
||||||
|
"date_range_enabled": True,
|
||||||
|
"tag_filter_enabled": True,
|
||||||
|
"directory_filter_enabled": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"inference": {
|
||||||
|
"backend": "llama.cpp",
|
||||||
|
"model_path": "",
|
||||||
|
"context_length": 8192,
|
||||||
|
"gpu_layers": 35,
|
||||||
|
"batch_size": 512,
|
||||||
|
"threads": 8,
|
||||||
|
},
|
||||||
|
"fine_tuning": {
|
||||||
|
"base_model": "",
|
||||||
|
"output_dir": "",
|
||||||
|
"lora_rank": 16,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"learning_rate": 0.0002,
|
||||||
|
"batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"num_epochs": 3,
|
||||||
|
"warmup_steps": 100,
|
||||||
|
"save_steps": 500,
|
||||||
|
"eval_steps": 250,
|
||||||
|
"training_data_path": "",
|
||||||
|
"validation_split": 0.1,
|
||||||
|
},
|
||||||
|
"retrain_schedule": {
|
||||||
|
"auto_reminder": True,
|
||||||
|
"default_interval_days": 90,
|
||||||
|
"reminder_channels": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"api": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 7373,
|
||||||
|
"cors_origins": [],
|
||||||
|
"auth": {"enabled": False},
|
||||||
|
},
|
||||||
|
"ui": {
|
||||||
|
"web": {
|
||||||
|
"enabled": True,
|
||||||
|
"theme": "obsidian",
|
||||||
|
"features": {
|
||||||
|
"streaming": True,
|
||||||
|
"citations": True,
|
||||||
|
"source_preview": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"cli": {"enabled": True, "rich_output": True},
|
||||||
|
},
|
||||||
|
"logging": {
|
||||||
|
"level": "INFO",
|
||||||
|
"file": "",
|
||||||
|
"max_size_mb": 100,
|
||||||
|
"backup_count": 5,
|
||||||
|
},
|
||||||
|
"security": {
|
||||||
|
"local_only": True,
|
||||||
|
"vault_path_traversal_check": True,
|
||||||
|
"sensitive_content_detection": True,
|
||||||
|
"sensitive_patterns": [],
|
||||||
|
"require_confirmation_for_external_apis": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_extract_from_single_file(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
vault = Path(tmp)
|
||||||
|
journal = vault / "Journal" / "2026" / "04"
|
||||||
|
journal.mkdir(parents=True)
|
||||||
|
|
||||||
|
content = """#DayInShort: Busy day
|
||||||
|
|
||||||
|
#reflection I think I need to slow down. The pace has been unsustainable.
|
||||||
|
|
||||||
|
#work Normal work day with meetings.
|
||||||
|
|
||||||
|
#insight I realize that I've been prioritizing urgency over importance.
|
||||||
|
"""
|
||||||
|
(journal / "2026-04-12.md").write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
# Use helper method for config
|
||||||
|
from companion.config import load_config
|
||||||
|
import json
|
||||||
|
|
||||||
|
config_dict = self._get_config_dict(vault)
|
||||||
|
config_path = Path(tmp) / "test_config.json"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config_dict, f)
|
||||||
|
|
||||||
|
config = load_config(config_path)
|
||||||
|
extractor = TrainingDataExtractor(config)
|
||||||
|
examples = extractor.extract()
|
||||||
|
|
||||||
|
# Should extract at least 2 reflection examples
|
||||||
|
assert len(examples) >= 2
|
||||||
|
|
||||||
|
# Check they have the right structure
|
||||||
|
for ex in examples:
|
||||||
|
assert len(ex.messages) == 3
|
||||||
|
assert ex.messages[2]["role"] == "assistant"
|
||||||
|
|
||||||
|
def _save_to_jsonl_helper(self):
|
||||||
|
"""Helper extracted to reduce nesting."""
|
||||||
|
pass # placeholder
|
||||||
|
"companion": {
|
||||||
|
"name": "SAN",
|
||||||
|
"persona": {
|
||||||
|
"role": "companion",
|
||||||
|
"tone": "reflective",
|
||||||
|
"style": "questioning",
|
||||||
|
"boundaries": [],
|
||||||
|
},
|
||||||
|
"memory": {
|
||||||
|
"session_turns": 20,
|
||||||
|
"persistent_store": "",
|
||||||
|
"summarize_after": 10,
|
||||||
|
},
|
||||||
|
"chat": {
|
||||||
|
"streaming": True,
|
||||||
|
"max_response_tokens": 2048,
|
||||||
|
"default_temperature": 0.7,
|
||||||
|
"allow_temperature_override": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"vault": {
|
||||||
|
"path": str(vault),
|
||||||
|
"indexing": {
|
||||||
|
"auto_sync": False,
|
||||||
|
"auto_sync_interval_minutes": 1440,
|
||||||
|
"watch_fs_events": False,
|
||||||
|
"file_patterns": ["*.md"],
|
||||||
|
"deny_dirs": [".git"],
|
||||||
|
"deny_patterns": [],
|
||||||
|
},
|
||||||
|
"chunking_rules": {},
|
||||||
|
},
|
||||||
|
"rag": {
|
||||||
|
"embedding": {
|
||||||
|
"provider": "ollama",
|
||||||
|
"model": "mxbai-embed-large",
|
||||||
|
"base_url": "http://localhost:11434",
|
||||||
|
"dimensions": 1024,
|
||||||
|
"batch_size": 32,
|
||||||
|
},
|
||||||
|
"vector_store": {"type": "lancedb", "path": ".test.vectors"},
|
||||||
|
"search": {
|
||||||
|
"default_top_k": 8,
|
||||||
|
"max_top_k": 20,
|
||||||
|
"similarity_threshold": 0.75,
|
||||||
|
"hybrid_search": {
|
||||||
|
"enabled": False,
|
||||||
|
"keyword_weight": 0.3,
|
||||||
|
"semantic_weight": 0.7,
|
||||||
|
},
|
||||||
|
"filters": {
|
||||||
|
"date_range_enabled": True,
|
||||||
|
"tag_filter_enabled": True,
|
||||||
|
"directory_filter_enabled": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"inference": {
|
||||||
|
"backend": "llama.cpp",
|
||||||
|
"model_path": "",
|
||||||
|
"context_length": 8192,
|
||||||
|
"gpu_layers": 35,
|
||||||
|
"batch_size": 512,
|
||||||
|
"threads": 8,
|
||||||
|
},
|
||||||
|
"fine_tuning": {
|
||||||
|
"base_model": "",
|
||||||
|
"output_dir": "",
|
||||||
|
"lora_rank": 16,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"learning_rate": 0.0002,
|
||||||
|
"batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"num_epochs": 3,
|
||||||
|
"warmup_steps": 100,
|
||||||
|
"save_steps": 500,
|
||||||
|
"eval_steps": 250,
|
||||||
|
"training_data_path": "",
|
||||||
|
"validation_split": 0.1,
|
||||||
|
},
|
||||||
|
"retrain_schedule": {
|
||||||
|
"auto_reminder": True,
|
||||||
|
"default_interval_days": 90,
|
||||||
|
"reminder_channels": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"api": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 7373,
|
||||||
|
"cors_origins": [],
|
||||||
|
"auth": {"enabled": False},
|
||||||
|
},
|
||||||
|
"ui": {
|
||||||
|
"web": {
|
||||||
|
"enabled": True,
|
||||||
|
"theme": "obsidian",
|
||||||
|
"features": {
|
||||||
|
"streaming": True,
|
||||||
|
"citations": True,
|
||||||
|
"source_preview": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"cli": {"enabled": True, "rich_output": True},
|
||||||
|
},
|
||||||
|
"logging": {
|
||||||
|
"level": "INFO",
|
||||||
|
"file": "",
|
||||||
|
"max_size_mb": 100,
|
||||||
|
"backup_count": 5,
|
||||||
|
},
|
||||||
|
"security": {
|
||||||
|
"local_only": True,
|
||||||
|
"vault_path_traversal_check": True,
|
||||||
|
"sensitive_content_detection": True,
|
||||||
|
"sensitive_patterns": [],
|
||||||
|
"require_confirmation_for_external_apis": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config_path = Path(tmp) / "test_config.json"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config_dict, f)
|
||||||
|
|
||||||
|
config = load_config(config_path)
|
||||||
|
extractor = TrainingDataExtractor(config)
|
||||||
|
examples = extractor.extract()
|
||||||
|
|
||||||
|
# Should extract at least 2 reflection examples
|
||||||
|
assert len(examples) >= 2
|
||||||
|
|
||||||
|
# Check they have the right structure
|
||||||
|
for ex in examples:
|
||||||
|
assert len(ex.messages) == 3
|
||||||
|
assert ex.messages[2]["role"] == "assistant"
|
||||||
|
|
||||||
|
def test_save_to_jsonl(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
output = Path(tmp) / "training.jsonl"
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
TrainingExample(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "user"},
|
||||||
|
{"role": "assistant", "content": "assistant"},
|
||||||
|
],
|
||||||
|
source_file="test.md",
|
||||||
|
tags=["#test"],
|
||||||
|
date="2026-04-12",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create minimal config for extractor
|
||||||
|
config_dict = self._get_config_dict(Path(tmp))
|
||||||
|
config_path = Path(tmp) / "test_config.json"
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config_dict, f)
|
||||||
|
|
||||||
|
from companion.config import load_config
|
||||||
|
|
||||||
|
config = load_config(config_path)
|
||||||
|
extractor = TrainingDataExtractor(config)
|
||||||
|
extractor.examples = examples
|
||||||
|
|
||||||
|
count = extractor.save_to_jsonl(output)
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
# Verify file content
|
||||||
|
lines = output.read_text(encoding="utf-8").strip().split("\n")
|
||||||
|
assert len(lines) == 1
|
||||||
|
assert "assistant" in lines[0]
|
||||||
|
|
||||||
|
def test_get_stats(self):
|
||||||
|
examples = [
|
||||||
|
TrainingExample(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "user"},
|
||||||
|
{"role": "assistant", "content": "a" * 100},
|
||||||
|
],
|
||||||
|
source_file="test1.md",
|
||||||
|
tags=["#reflection", "#learning"],
|
||||||
|
date="2026-04-12",
|
||||||
|
),
|
||||||
|
TrainingExample(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "user"},
|
||||||
|
{"role": "assistant", "content": "b" * 200},
|
||||||
|
],
|
||||||
|
source_file="test2.md",
|
||||||
|
tags=["#reflection", "#decision"],
|
||||||
|
date="2026-04-13",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create minimal config
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
config_dict = {
|
||||||
|
"companion": {
|
||||||
|
"name": "SAN",
|
||||||
|
"persona": {
|
||||||
|
"role": "companion",
|
||||||
|
"tone": "reflective",
|
||||||
|
"style": "questioning",
|
||||||
|
"boundaries": [],
|
||||||
|
},
|
||||||
|
"memory": {
|
||||||
|
"session_turns": 20,
|
||||||
|
"persistent_store": "",
|
||||||
|
"summarize_after": 10,
|
||||||
|
},
|
||||||
|
"chat": {
|
||||||
|
"streaming": True,
|
||||||
|
"max_response_tokens": 2048,
|
||||||
|
"default_temperature": 0.7,
|
||||||
|
"allow_temperature_override": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"vault": {
|
||||||
|
"path": str(tmp),
|
||||||
|
"indexing": {
|
||||||
|
"auto_sync": False,
|
||||||
|
"auto_sync_interval_minutes": 1440,
|
||||||
|
"watch_fs_events": False,
|
||||||
|
"file_patterns": ["*.md"],
|
||||||
|
"deny_dirs": [".git"],
|
||||||
|
"deny_patterns": [],
|
||||||
|
},
|
||||||
|
"chunking_rules": {},
|
||||||
|
},
|
||||||
|
"rag": {
|
||||||
|
"embedding": {
|
||||||
|
"provider": "ollama",
|
||||||
|
"model": "mxbai-embed-large",
|
||||||
|
"base_url": "http://localhost:11434",
|
||||||
|
"dimensions": 1024,
|
||||||
|
"batch_size": 32,
|
||||||
|
},
|
||||||
|
"vector_store": {"type": "lancedb", "path": ".test.vectors"},
|
||||||
|
"search": {
|
||||||
|
"default_top_k": 8,
|
||||||
|
"max_top_k": 20,
|
||||||
|
"similarity_threshold": 0.75,
|
||||||
|
"hybrid_search": {
|
||||||
|
"enabled": False,
|
||||||
|
"keyword_weight": 0.3,
|
||||||
|
"semantic_weight": 0.7,
|
||||||
|
},
|
||||||
|
"filters": {
|
||||||
|
"date_range_enabled": True,
|
||||||
|
"tag_filter_enabled": True,
|
||||||
|
"directory_filter_enabled": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"inference": {
|
||||||
|
"backend": "llama.cpp",
|
||||||
|
"model_path": "",
|
||||||
|
"context_length": 8192,
|
||||||
|
"gpu_layers": 35,
|
||||||
|
"batch_size": 512,
|
||||||
|
"threads": 8,
|
||||||
|
},
|
||||||
|
"fine_tuning": {
|
||||||
|
"base_model": "",
|
||||||
|
"output_dir": "",
|
||||||
|
"lora_rank": 16,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"learning_rate": 0.0002,
|
||||||
|
"batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"num_epochs": 3,
|
||||||
|
"warmup_steps": 100,
|
||||||
|
"save_steps": 500,
|
||||||
|
"eval_steps": 250,
|
||||||
|
"training_data_path": "",
|
||||||
|
"validation_split": 0.1,
|
||||||
|
},
|
||||||
|
"retrain_schedule": {
|
||||||
|
"auto_reminder": True,
|
||||||
|
"default_interval_days": 90,
|
||||||
|
"reminder_channels": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"api": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 7373,
|
||||||
|
"cors_origins": [],
|
||||||
|
"auth": {"enabled": False},
|
||||||
|
},
|
||||||
|
"ui": {
|
||||||
|
"web": {
|
||||||
|
"enabled": True,
|
||||||
|
"theme": "obsidian",
|
||||||
|
"features": {
|
||||||
|
"streaming": True,
|
||||||
|
"citations": True,
|
||||||
|
"source_preview": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"cli": {"enabled": True, "rich_output": True},
|
||||||
|
},
|
||||||
|
"logging": {
|
||||||
|
"level": "INFO",
|
||||||
|
"file": "",
|
||||||
|
"max_size_mb": 100,
|
||||||
|
"backup_count": 5,
|
||||||
|
},
|
||||||
|
"security": {
|
||||||
|
"local_only": True,
|
||||||
|
"vault_path_traversal_check": True,
|
||||||
|
"sensitive_content_detection": True,
|
||||||
|
"sensitive_patterns": [],
|
||||||
|
"require_confirmation_for_external_apis": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
config_path = Path(tmp) / "test_config.json"
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config_dict, f)
|
||||||
|
|
||||||
|
from companion.config import load_config
|
||||||
|
|
||||||
|
config = load_config(config_path)
|
||||||
|
extractor = TrainingDataExtractor(config)
|
||||||
|
extractor.examples = examples
|
||||||
|
|
||||||
|
stats = extractor.get_stats()
|
||||||
|
assert stats["total"] == 2
|
||||||
|
assert stats["avg_length"] == 150 # (100 + 200) // 2
|
||||||
|
assert len(stats["top_tags"]) > 0
|
||||||
|
assert stats["top_tags"][0][0] == "#reflection"
|
||||||
Reference in New Issue
Block a user