187 lines
5.9 KiB
Python
187 lines
5.9 KiB
Python
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from companion.config import (
|
|
Config,
|
|
VaultConfig,
|
|
IndexingConfig,
|
|
RagConfig,
|
|
EmbeddingConfig,
|
|
VectorStoreConfig,
|
|
SearchConfig,
|
|
HybridSearchConfig,
|
|
FiltersConfig,
|
|
CompanionConfig,
|
|
PersonaConfig,
|
|
MemoryConfig,
|
|
ChatConfig,
|
|
ModelConfig,
|
|
InferenceConfig,
|
|
FineTuningConfig,
|
|
RetrainScheduleConfig,
|
|
ApiConfig,
|
|
AuthConfig,
|
|
UiConfig,
|
|
WebConfig,
|
|
WebFeaturesConfig,
|
|
CliConfig,
|
|
LoggingConfig,
|
|
SecurityConfig,
|
|
)
|
|
from companion.rag.indexer import Indexer
|
|
from companion.rag.search import SearchEngine
|
|
from companion.rag.vector_store import VectorStore
|
|
|
|
|
|
def _make_config(vault_path: Path, vector_store_path: Path) -> Config:
|
|
return Config(
|
|
companion=CompanionConfig(
|
|
name="SAN",
|
|
persona=PersonaConfig(
|
|
role="companion", tone="reflective", style="questioning", boundaries=[]
|
|
),
|
|
memory=MemoryConfig(
|
|
session_turns=20, persistent_store="", summarize_after=10
|
|
),
|
|
chat=ChatConfig(
|
|
streaming=True,
|
|
max_response_tokens=2048,
|
|
default_temperature=0.7,
|
|
allow_temperature_override=True,
|
|
),
|
|
),
|
|
vault=VaultConfig(
|
|
path=str(vault_path),
|
|
indexing=IndexingConfig(
|
|
auto_sync=False,
|
|
auto_sync_interval_minutes=1440,
|
|
watch_fs_events=False,
|
|
file_patterns=["*.md"],
|
|
deny_dirs=[".git"],
|
|
deny_patterns=[".*"],
|
|
),
|
|
chunking_rules={},
|
|
),
|
|
rag=RagConfig(
|
|
embedding=EmbeddingConfig(
|
|
provider="ollama",
|
|
model="dummy",
|
|
base_url="http://localhost:11434",
|
|
dimensions=4,
|
|
batch_size=2,
|
|
),
|
|
vector_store=VectorStoreConfig(type="lancedb", path=str(vector_store_path)),
|
|
search=SearchConfig(
|
|
default_top_k=8,
|
|
max_top_k=20,
|
|
similarity_threshold=0.0,
|
|
hybrid_search=HybridSearchConfig(
|
|
enabled=False, keyword_weight=0.3, semantic_weight=0.7
|
|
),
|
|
filters=FiltersConfig(
|
|
date_range_enabled=True,
|
|
tag_filter_enabled=True,
|
|
directory_filter_enabled=True,
|
|
),
|
|
),
|
|
),
|
|
model=ModelConfig(
|
|
inference=InferenceConfig(
|
|
backend="llama.cpp",
|
|
model_path="",
|
|
context_length=8192,
|
|
gpu_layers=35,
|
|
batch_size=512,
|
|
threads=8,
|
|
),
|
|
fine_tuning=FineTuningConfig(
|
|
base_model="",
|
|
output_dir="",
|
|
lora_rank=16,
|
|
lora_alpha=32,
|
|
learning_rate=0.0002,
|
|
batch_size=4,
|
|
gradient_accumulation_steps=4,
|
|
num_epochs=3,
|
|
warmup_steps=100,
|
|
save_steps=500,
|
|
eval_steps=250,
|
|
training_data_path="",
|
|
validation_split=0.1,
|
|
),
|
|
retrain_schedule=RetrainScheduleConfig(
|
|
auto_reminder=True, default_interval_days=90, reminder_channels=[]
|
|
),
|
|
),
|
|
api=ApiConfig(
|
|
host="127.0.0.1", port=7373, cors_origins=[], auth=AuthConfig(enabled=False)
|
|
),
|
|
ui=UiConfig(
|
|
web=WebConfig(
|
|
enabled=True,
|
|
theme="obsidian",
|
|
features=WebFeaturesConfig(
|
|
streaming=True, citations=True, source_preview=True
|
|
),
|
|
),
|
|
cli=CliConfig(enabled=True, rich_output=True),
|
|
),
|
|
logging=LoggingConfig(level="INFO", file="", max_size_mb=100, backup_count=5),
|
|
security=SecurityConfig(
|
|
local_only=True,
|
|
vault_path_traversal_check=True,
|
|
sensitive_content_detection=True,
|
|
sensitive_patterns=[],
|
|
require_confirmation_for_external_apis=True,
|
|
),
|
|
)
|
|
|
|
|
|
@patch("companion.rag.search.OllamaEmbedder")
|
|
@patch("companion.rag.indexer.OllamaEmbedder")
|
|
def test_index_and_search_flow(mock_indexer_embedder, mock_search_embedder):
|
|
"""Verify end-to-end indexing and semantic search with mocked embeddings."""
|
|
mock_embed = MagicMock()
|
|
|
|
def mock_embed_side_effect(texts):
|
|
return [
|
|
[1.0, 0.0, 0.0, 0.0] if i == 0 else [0.0, 1.0, 0.0, 0.0]
|
|
for i in range(len(texts))
|
|
]
|
|
|
|
mock_embed.embed.side_effect = mock_embed_side_effect
|
|
mock_indexer_embedder.return_value = mock_embed
|
|
mock_search_embedder.return_value = mock_embed
|
|
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
vault = Path(tmp) / "vault"
|
|
vault.mkdir()
|
|
(vault / "note1.md").write_text("hello world", encoding="utf-8")
|
|
(vault / "note2.md").write_text("goodbye world", encoding="utf-8")
|
|
vs_path = Path(tmp) / "vectors"
|
|
config = _make_config(vault, vs_path)
|
|
store = VectorStore(uri=vs_path, dimensions=4)
|
|
indexer = Indexer(config, store)
|
|
indexer.full_index()
|
|
assert store.count() == 2
|
|
|
|
engine = SearchEngine(
|
|
vector_store=store,
|
|
embedder_base_url="http://localhost:11434",
|
|
embedder_model="dummy",
|
|
embedder_batch_size=2,
|
|
default_top_k=5,
|
|
similarity_threshold=0.0,
|
|
hybrid_search_enabled=False,
|
|
)
|
|
results = engine.search("hello")
|
|
assert len(results) >= 1
|
|
files = {r["source_file"] for r in results}
|
|
assert "note1.md" in files
|
|
|
|
results = engine.search("goodbye")
|
|
assert len(results) >= 1
|
|
files = {r["source_file"] for r in results}
|
|
assert "note2.md" in files
|