Files
kv-ai/tests/test_integration.py

187 lines
5.9 KiB
Python

import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
from companion.config import (
Config,
VaultConfig,
IndexingConfig,
RagConfig,
EmbeddingConfig,
VectorStoreConfig,
SearchConfig,
HybridSearchConfig,
FiltersConfig,
CompanionConfig,
PersonaConfig,
MemoryConfig,
ChatConfig,
ModelConfig,
InferenceConfig,
FineTuningConfig,
RetrainScheduleConfig,
ApiConfig,
AuthConfig,
UiConfig,
WebConfig,
WebFeaturesConfig,
CliConfig,
LoggingConfig,
SecurityConfig,
)
from companion.rag.indexer import Indexer
from companion.rag.search import SearchEngine
from companion.rag.vector_store import VectorStore
def _make_config(vault_path: Path, vector_store_path: Path) -> Config:
return Config(
companion=CompanionConfig(
name="SAN",
persona=PersonaConfig(
role="companion", tone="reflective", style="questioning", boundaries=[]
),
memory=MemoryConfig(
session_turns=20, persistent_store="", summarize_after=10
),
chat=ChatConfig(
streaming=True,
max_response_tokens=2048,
default_temperature=0.7,
allow_temperature_override=True,
),
),
vault=VaultConfig(
path=str(vault_path),
indexing=IndexingConfig(
auto_sync=False,
auto_sync_interval_minutes=1440,
watch_fs_events=False,
file_patterns=["*.md"],
deny_dirs=[".git"],
deny_patterns=[".*"],
),
chunking_rules={},
),
rag=RagConfig(
embedding=EmbeddingConfig(
provider="ollama",
model="dummy",
base_url="http://localhost:11434",
dimensions=4,
batch_size=2,
),
vector_store=VectorStoreConfig(type="lancedb", path=str(vector_store_path)),
search=SearchConfig(
default_top_k=8,
max_top_k=20,
similarity_threshold=0.0,
hybrid_search=HybridSearchConfig(
enabled=False, keyword_weight=0.3, semantic_weight=0.7
),
filters=FiltersConfig(
date_range_enabled=True,
tag_filter_enabled=True,
directory_filter_enabled=True,
),
),
),
model=ModelConfig(
inference=InferenceConfig(
backend="llama.cpp",
model_path="",
context_length=8192,
gpu_layers=35,
batch_size=512,
threads=8,
),
fine_tuning=FineTuningConfig(
base_model="",
output_dir="",
lora_rank=16,
lora_alpha=32,
learning_rate=0.0002,
batch_size=4,
gradient_accumulation_steps=4,
num_epochs=3,
warmup_steps=100,
save_steps=500,
eval_steps=250,
training_data_path="",
validation_split=0.1,
),
retrain_schedule=RetrainScheduleConfig(
auto_reminder=True, default_interval_days=90, reminder_channels=[]
),
),
api=ApiConfig(
host="127.0.0.1", port=7373, cors_origins=[], auth=AuthConfig(enabled=False)
),
ui=UiConfig(
web=WebConfig(
enabled=True,
theme="obsidian",
features=WebFeaturesConfig(
streaming=True, citations=True, source_preview=True
),
),
cli=CliConfig(enabled=True, rich_output=True),
),
logging=LoggingConfig(level="INFO", file="", max_size_mb=100, backup_count=5),
security=SecurityConfig(
local_only=True,
vault_path_traversal_check=True,
sensitive_content_detection=True,
sensitive_patterns=[],
require_confirmation_for_external_apis=True,
),
)
@patch("companion.rag.search.OllamaEmbedder")
@patch("companion.rag.indexer.OllamaEmbedder")
def test_index_and_search_flow(mock_indexer_embedder, mock_search_embedder):
"""Verify end-to-end indexing and semantic search with mocked embeddings."""
mock_embed = MagicMock()
def mock_embed_side_effect(texts):
return [
[1.0, 0.0, 0.0, 0.0] if i == 0 else [0.0, 1.0, 0.0, 0.0]
for i in range(len(texts))
]
mock_embed.embed.side_effect = mock_embed_side_effect
mock_indexer_embedder.return_value = mock_embed
mock_search_embedder.return_value = mock_embed
with tempfile.TemporaryDirectory() as tmp:
vault = Path(tmp) / "vault"
vault.mkdir()
(vault / "note1.md").write_text("hello world", encoding="utf-8")
(vault / "note2.md").write_text("goodbye world", encoding="utf-8")
vs_path = Path(tmp) / "vectors"
config = _make_config(vault, vs_path)
store = VectorStore(uri=vs_path, dimensions=4)
indexer = Indexer(config, store)
indexer.full_index()
assert store.count() == 2
engine = SearchEngine(
vector_store=store,
embedder_base_url="http://localhost:11434",
embedder_model="dummy",
embedder_batch_size=2,
default_top_k=5,
similarity_threshold=0.0,
hybrid_search_enabled=False,
)
results = engine.search("hello")
assert len(results) >= 1
files = {r["source_file"] for r in results}
assert "note1.md" in files
results = engine.search("goodbye")
assert len(results) >= 1
files = {r["source_file"] for r in results}
assert "note2.md" in files