605 lines
22 KiB
Python
605 lines
22 KiB
Python
"""Tests for training data extractor."""
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from companion.config import Config, VaultConfig, IndexingConfig
|
|
from companion.forge.extract import (
|
|
TrainingDataExtractor,
|
|
TrainingExample,
|
|
_create_training_example,
|
|
_extract_date_from_filename,
|
|
_has_reflection_patterns,
|
|
_has_reflection_tags,
|
|
_is_likely_reflection,
|
|
extract_training_data,
|
|
)
|
|
|
|
|
|
def test_has_reflection_tags():
|
|
assert _has_reflection_tags("#reflection on today's events")
|
|
assert _has_reflection_tags("#decision made today")
|
|
assert not _has_reflection_tags("#worklog entry")
|
|
|
|
|
|
def test_has_reflection_patterns():
|
|
assert _has_reflection_patterns("I think this is important")
|
|
assert _has_reflection_patterns("I wonder if I should change")
|
|
assert _has_reflection_patterns("Looking back, I see the pattern")
|
|
assert not _has_reflection_patterns("The meeting was at 3pm")
|
|
|
|
|
|
def test_is_likely_reflection():
|
|
assert _is_likely_reflection("#reflection I think this matters")
|
|
assert _is_likely_reflection("I realize now that I was wrong")
|
|
assert not _is_likely_reflection("Just a regular note")
|
|
|
|
|
|
def test_extract_date_from_filename():
|
|
assert _extract_date_from_filename("2026-04-12.md") == "2026-04-12"
|
|
assert _extract_date_from_filename("12-Apr-2026.md") == "12-Apr-2026"
|
|
assert _extract_date_from_filename("2026-04-12-journal.md") == "2026-04-12"
|
|
assert _extract_date_from_filename("notes.md") is None
|
|
|
|
|
|
def test_create_training_example():
|
|
text = "#reflection I think I need to reconsider my approach. The way I've been handling this isn't working."
|
|
example = _create_training_example(
|
|
chunk_text=text,
|
|
source_file="journal/2026-04-12.md",
|
|
tags=["#reflection"],
|
|
date="2026-04-12",
|
|
)
|
|
|
|
assert example is not None
|
|
assert len(example.messages) == 3
|
|
assert example.messages[0]["role"] == "system"
|
|
assert example.messages[1]["role"] == "user"
|
|
assert example.messages[2]["role"] == "assistant"
|
|
assert example.messages[2]["content"] == text
|
|
assert example.source_file == "journal/2026-04-12.md"
|
|
|
|
|
|
def test_create_training_example_too_short():
|
|
text = "I think." # Too short
|
|
example = _create_training_example(
|
|
chunk_text=text,
|
|
source_file="test.md",
|
|
tags=["#reflection"],
|
|
date=None,
|
|
)
|
|
assert example is None
|
|
|
|
|
|
def test_create_training_example_no_reflection():
|
|
text = "This is just a regular note about the meeting at 3pm. Nothing special." * 5
|
|
example = _create_training_example(
|
|
chunk_text=text,
|
|
source_file="test.md",
|
|
tags=["#work"],
|
|
date=None,
|
|
)
|
|
assert example is None
|
|
|
|
|
|
def test_training_example_to_dict():
|
|
example = TrainingExample(
|
|
messages=[
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "assistant", "content": "Hi"},
|
|
],
|
|
source_file="test.md",
|
|
tags=["#test"],
|
|
date="2026-04-12",
|
|
)
|
|
d = example.to_dict()
|
|
assert d["messages"][0]["role"] == "user"
|
|
assert d["source_file"] == "test.md"
|
|
assert d["date"] == "2026-04-12"
|
|
|
|
|
|
class TestTrainingDataExtractor:
|
|
def _get_config_dict(self, vault_path: Path) -> dict:
|
|
"""Return minimal config dict for testing."""
|
|
return {
|
|
"companion": {
|
|
"name": "SAN",
|
|
"persona": {
|
|
"role": "companion",
|
|
"tone": "reflective",
|
|
"style": "questioning",
|
|
"boundaries": [],
|
|
},
|
|
"memory": {
|
|
"session_turns": 20,
|
|
"persistent_store": "",
|
|
"summarize_after": 10,
|
|
},
|
|
"chat": {
|
|
"streaming": True,
|
|
"max_response_tokens": 2048,
|
|
"default_temperature": 0.7,
|
|
"allow_temperature_override": True,
|
|
},
|
|
},
|
|
"vault": {
|
|
"path": str(vault_path),
|
|
"indexing": {
|
|
"auto_sync": False,
|
|
"auto_sync_interval_minutes": 1440,
|
|
"watch_fs_events": False,
|
|
"file_patterns": ["*.md"],
|
|
"deny_dirs": [".git"],
|
|
"deny_patterns": [],
|
|
},
|
|
"chunking_rules": {},
|
|
},
|
|
"rag": {
|
|
"embedding": {
|
|
"provider": "ollama",
|
|
"model": "mxbai-embed-large",
|
|
"base_url": "http://localhost:11434",
|
|
"dimensions": 1024,
|
|
"batch_size": 32,
|
|
},
|
|
"vector_store": {"type": "lancedb", "path": ".test.vectors"},
|
|
"search": {
|
|
"default_top_k": 8,
|
|
"max_top_k": 20,
|
|
"similarity_threshold": 0.75,
|
|
"hybrid_search": {
|
|
"enabled": False,
|
|
"keyword_weight": 0.3,
|
|
"semantic_weight": 0.7,
|
|
},
|
|
"filters": {
|
|
"date_range_enabled": True,
|
|
"tag_filter_enabled": True,
|
|
"directory_filter_enabled": True,
|
|
},
|
|
},
|
|
},
|
|
"model": {
|
|
"inference": {
|
|
"backend": "llama.cpp",
|
|
"model_path": "",
|
|
"context_length": 8192,
|
|
"gpu_layers": 35,
|
|
"batch_size": 512,
|
|
"threads": 8,
|
|
},
|
|
"fine_tuning": {
|
|
"base_model": "",
|
|
"output_dir": "",
|
|
"lora_rank": 16,
|
|
"lora_alpha": 32,
|
|
"learning_rate": 0.0002,
|
|
"batch_size": 4,
|
|
"gradient_accumulation_steps": 4,
|
|
"num_epochs": 3,
|
|
"warmup_steps": 100,
|
|
"save_steps": 500,
|
|
"eval_steps": 250,
|
|
"training_data_path": "",
|
|
"validation_split": 0.1,
|
|
},
|
|
"retrain_schedule": {
|
|
"auto_reminder": True,
|
|
"default_interval_days": 90,
|
|
"reminder_channels": [],
|
|
},
|
|
},
|
|
"api": {
|
|
"host": "127.0.0.1",
|
|
"port": 7373,
|
|
"cors_origins": [],
|
|
"auth": {"enabled": False},
|
|
},
|
|
"ui": {
|
|
"web": {
|
|
"enabled": True,
|
|
"theme": "obsidian",
|
|
"features": {
|
|
"streaming": True,
|
|
"citations": True,
|
|
"source_preview": True,
|
|
},
|
|
},
|
|
"cli": {"enabled": True, "rich_output": True},
|
|
},
|
|
"logging": {
|
|
"level": "INFO",
|
|
"file": "",
|
|
"max_size_mb": 100,
|
|
"backup_count": 5,
|
|
},
|
|
"security": {
|
|
"local_only": True,
|
|
"vault_path_traversal_check": True,
|
|
"sensitive_content_detection": True,
|
|
"sensitive_patterns": [],
|
|
"require_confirmation_for_external_apis": True,
|
|
},
|
|
}
|
|
|
|
def test_extract_from_single_file(self):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
vault = Path(tmp)
|
|
journal = vault / "Journal" / "2026" / "04"
|
|
journal.mkdir(parents=True)
|
|
|
|
content = """#DayInShort: Busy day
|
|
|
|
#reflection I think I need to slow down. The pace has been unsustainable.
|
|
|
|
#work Normal work day with meetings.
|
|
|
|
#insight I realize that I've been prioritizing urgency over importance.
|
|
"""
|
|
(journal / "2026-04-12.md").write_text(content, encoding="utf-8")
|
|
|
|
# Use helper method for config
|
|
from companion.config import load_config
|
|
import json
|
|
|
|
config_dict = self._get_config_dict(vault)
|
|
config_path = Path(tmp) / "test_config.json"
|
|
with open(config_path, "w") as f:
|
|
json.dump(config_dict, f)
|
|
|
|
config = load_config(config_path)
|
|
extractor = TrainingDataExtractor(config)
|
|
examples = extractor.extract()
|
|
|
|
# Should extract at least 2 reflection examples
|
|
assert len(examples) >= 2
|
|
|
|
# Check they have the right structure
|
|
for ex in examples:
|
|
assert len(ex.messages) == 3
|
|
assert ex.messages[2]["role"] == "assistant"
|
|
|
|
def _save_to_jsonl_helper(self):
|
|
"""Helper extracted to reduce nesting."""
|
|
pass # placeholder
|
|
"companion": {
|
|
"name": "SAN",
|
|
"persona": {
|
|
"role": "companion",
|
|
"tone": "reflective",
|
|
"style": "questioning",
|
|
"boundaries": [],
|
|
},
|
|
"memory": {
|
|
"session_turns": 20,
|
|
"persistent_store": "",
|
|
"summarize_after": 10,
|
|
},
|
|
"chat": {
|
|
"streaming": True,
|
|
"max_response_tokens": 2048,
|
|
"default_temperature": 0.7,
|
|
"allow_temperature_override": True,
|
|
},
|
|
},
|
|
"vault": {
|
|
"path": str(vault),
|
|
"indexing": {
|
|
"auto_sync": False,
|
|
"auto_sync_interval_minutes": 1440,
|
|
"watch_fs_events": False,
|
|
"file_patterns": ["*.md"],
|
|
"deny_dirs": [".git"],
|
|
"deny_patterns": [],
|
|
},
|
|
"chunking_rules": {},
|
|
},
|
|
"rag": {
|
|
"embedding": {
|
|
"provider": "ollama",
|
|
"model": "mxbai-embed-large",
|
|
"base_url": "http://localhost:11434",
|
|
"dimensions": 1024,
|
|
"batch_size": 32,
|
|
},
|
|
"vector_store": {"type": "lancedb", "path": ".test.vectors"},
|
|
"search": {
|
|
"default_top_k": 8,
|
|
"max_top_k": 20,
|
|
"similarity_threshold": 0.75,
|
|
"hybrid_search": {
|
|
"enabled": False,
|
|
"keyword_weight": 0.3,
|
|
"semantic_weight": 0.7,
|
|
},
|
|
"filters": {
|
|
"date_range_enabled": True,
|
|
"tag_filter_enabled": True,
|
|
"directory_filter_enabled": True,
|
|
},
|
|
},
|
|
},
|
|
"model": {
|
|
"inference": {
|
|
"backend": "llama.cpp",
|
|
"model_path": "",
|
|
"context_length": 8192,
|
|
"gpu_layers": 35,
|
|
"batch_size": 512,
|
|
"threads": 8,
|
|
},
|
|
"fine_tuning": {
|
|
"base_model": "",
|
|
"output_dir": "",
|
|
"lora_rank": 16,
|
|
"lora_alpha": 32,
|
|
"learning_rate": 0.0002,
|
|
"batch_size": 4,
|
|
"gradient_accumulation_steps": 4,
|
|
"num_epochs": 3,
|
|
"warmup_steps": 100,
|
|
"save_steps": 500,
|
|
"eval_steps": 250,
|
|
"training_data_path": "",
|
|
"validation_split": 0.1,
|
|
},
|
|
"retrain_schedule": {
|
|
"auto_reminder": True,
|
|
"default_interval_days": 90,
|
|
"reminder_channels": [],
|
|
},
|
|
},
|
|
"api": {
|
|
"host": "127.0.0.1",
|
|
"port": 7373,
|
|
"cors_origins": [],
|
|
"auth": {"enabled": False},
|
|
},
|
|
"ui": {
|
|
"web": {
|
|
"enabled": True,
|
|
"theme": "obsidian",
|
|
"features": {
|
|
"streaming": True,
|
|
"citations": True,
|
|
"source_preview": True,
|
|
},
|
|
},
|
|
"cli": {"enabled": True, "rich_output": True},
|
|
},
|
|
"logging": {
|
|
"level": "INFO",
|
|
"file": "",
|
|
"max_size_mb": 100,
|
|
"backup_count": 5,
|
|
},
|
|
"security": {
|
|
"local_only": True,
|
|
"vault_path_traversal_check": True,
|
|
"sensitive_content_detection": True,
|
|
"sensitive_patterns": [],
|
|
"require_confirmation_for_external_apis": True,
|
|
},
|
|
}
|
|
|
|
config_path = Path(tmp) / "test_config.json"
|
|
with open(config_path, "w") as f:
|
|
json.dump(config_dict, f)
|
|
|
|
config = load_config(config_path)
|
|
extractor = TrainingDataExtractor(config)
|
|
examples = extractor.extract()
|
|
|
|
# Should extract at least 2 reflection examples
|
|
assert len(examples) >= 2
|
|
|
|
# Check they have the right structure
|
|
for ex in examples:
|
|
assert len(ex.messages) == 3
|
|
assert ex.messages[2]["role"] == "assistant"
|
|
|
|
def test_save_to_jsonl(self):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
output = Path(tmp) / "training.jsonl"
|
|
|
|
examples = [
|
|
TrainingExample(
|
|
messages=[
|
|
{"role": "system", "content": "sys"},
|
|
{"role": "user", "content": "user"},
|
|
{"role": "assistant", "content": "assistant"},
|
|
],
|
|
source_file="test.md",
|
|
tags=["#test"],
|
|
date="2026-04-12",
|
|
)
|
|
]
|
|
|
|
# Create minimal config for extractor
|
|
config_dict = self._get_config_dict(Path(tmp))
|
|
config_path = Path(tmp) / "test_config.json"
|
|
import json
|
|
|
|
with open(config_path, "w") as f:
|
|
json.dump(config_dict, f)
|
|
|
|
from companion.config import load_config
|
|
|
|
config = load_config(config_path)
|
|
extractor = TrainingDataExtractor(config)
|
|
extractor.examples = examples
|
|
|
|
count = extractor.save_to_jsonl(output)
|
|
assert count == 1
|
|
|
|
# Verify file content
|
|
lines = output.read_text(encoding="utf-8").strip().split("\n")
|
|
assert len(lines) == 1
|
|
assert "assistant" in lines[0]
|
|
|
|
def test_get_stats(self):
|
|
examples = [
|
|
TrainingExample(
|
|
messages=[
|
|
{"role": "system", "content": "sys"},
|
|
{"role": "user", "content": "user"},
|
|
{"role": "assistant", "content": "a" * 100},
|
|
],
|
|
source_file="test1.md",
|
|
tags=["#reflection", "#learning"],
|
|
date="2026-04-12",
|
|
),
|
|
TrainingExample(
|
|
messages=[
|
|
{"role": "system", "content": "sys"},
|
|
{"role": "user", "content": "user"},
|
|
{"role": "assistant", "content": "b" * 200},
|
|
],
|
|
source_file="test2.md",
|
|
tags=["#reflection", "#decision"],
|
|
date="2026-04-13",
|
|
),
|
|
]
|
|
|
|
# Create minimal config
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
config_dict = {
|
|
"companion": {
|
|
"name": "SAN",
|
|
"persona": {
|
|
"role": "companion",
|
|
"tone": "reflective",
|
|
"style": "questioning",
|
|
"boundaries": [],
|
|
},
|
|
"memory": {
|
|
"session_turns": 20,
|
|
"persistent_store": "",
|
|
"summarize_after": 10,
|
|
},
|
|
"chat": {
|
|
"streaming": True,
|
|
"max_response_tokens": 2048,
|
|
"default_temperature": 0.7,
|
|
"allow_temperature_override": True,
|
|
},
|
|
},
|
|
"vault": {
|
|
"path": str(tmp),
|
|
"indexing": {
|
|
"auto_sync": False,
|
|
"auto_sync_interval_minutes": 1440,
|
|
"watch_fs_events": False,
|
|
"file_patterns": ["*.md"],
|
|
"deny_dirs": [".git"],
|
|
"deny_patterns": [],
|
|
},
|
|
"chunking_rules": {},
|
|
},
|
|
"rag": {
|
|
"embedding": {
|
|
"provider": "ollama",
|
|
"model": "mxbai-embed-large",
|
|
"base_url": "http://localhost:11434",
|
|
"dimensions": 1024,
|
|
"batch_size": 32,
|
|
},
|
|
"vector_store": {"type": "lancedb", "path": ".test.vectors"},
|
|
"search": {
|
|
"default_top_k": 8,
|
|
"max_top_k": 20,
|
|
"similarity_threshold": 0.75,
|
|
"hybrid_search": {
|
|
"enabled": False,
|
|
"keyword_weight": 0.3,
|
|
"semantic_weight": 0.7,
|
|
},
|
|
"filters": {
|
|
"date_range_enabled": True,
|
|
"tag_filter_enabled": True,
|
|
"directory_filter_enabled": True,
|
|
},
|
|
},
|
|
},
|
|
"model": {
|
|
"inference": {
|
|
"backend": "llama.cpp",
|
|
"model_path": "",
|
|
"context_length": 8192,
|
|
"gpu_layers": 35,
|
|
"batch_size": 512,
|
|
"threads": 8,
|
|
},
|
|
"fine_tuning": {
|
|
"base_model": "",
|
|
"output_dir": "",
|
|
"lora_rank": 16,
|
|
"lora_alpha": 32,
|
|
"learning_rate": 0.0002,
|
|
"batch_size": 4,
|
|
"gradient_accumulation_steps": 4,
|
|
"num_epochs": 3,
|
|
"warmup_steps": 100,
|
|
"save_steps": 500,
|
|
"eval_steps": 250,
|
|
"training_data_path": "",
|
|
"validation_split": 0.1,
|
|
},
|
|
"retrain_schedule": {
|
|
"auto_reminder": True,
|
|
"default_interval_days": 90,
|
|
"reminder_channels": [],
|
|
},
|
|
},
|
|
"api": {
|
|
"host": "127.0.0.1",
|
|
"port": 7373,
|
|
"cors_origins": [],
|
|
"auth": {"enabled": False},
|
|
},
|
|
"ui": {
|
|
"web": {
|
|
"enabled": True,
|
|
"theme": "obsidian",
|
|
"features": {
|
|
"streaming": True,
|
|
"citations": True,
|
|
"source_preview": True,
|
|
},
|
|
},
|
|
"cli": {"enabled": True, "rich_output": True},
|
|
},
|
|
"logging": {
|
|
"level": "INFO",
|
|
"file": "",
|
|
"max_size_mb": 100,
|
|
"backup_count": 5,
|
|
},
|
|
"security": {
|
|
"local_only": True,
|
|
"vault_path_traversal_check": True,
|
|
"sensitive_content_detection": True,
|
|
"sensitive_patterns": [],
|
|
"require_confirmation_for_external_apis": True,
|
|
},
|
|
}
|
|
config_path = Path(tmp) / "test_config.json"
|
|
import json
|
|
|
|
with open(config_path, "w") as f:
|
|
json.dump(config_dict, f)
|
|
|
|
from companion.config import load_config
|
|
|
|
config = load_config(config_path)
|
|
extractor = TrainingDataExtractor(config)
|
|
extractor.examples = examples
|
|
|
|
stats = extractor.get_stats()
|
|
assert stats["total"] == 2
|
|
assert stats["avg_length"] == 150 # (100 + 200) // 2
|
|
assert len(stats["top_tags"]) > 0
|
|
assert stats["top_tags"][0][0] == "#reflection"
|