feat: add search engine interface with embedding and filtering

This commit is contained in:
2026-04-13 14:30:03 -04:00
parent 827ebfadaa
commit 9a02c5fcbd
2 changed files with 81 additions and 0 deletions

View File

@@ -0,0 +1,47 @@
from typing import Any
from companion.rag.embedder import OllamaEmbedder
from companion.rag.vector_store import VectorStore
class SearchEngine:
def __init__(
self,
vector_store: VectorStore,
embedder_base_url: str,
embedder_model: str,
embedder_batch_size: int,
default_top_k: int,
similarity_threshold: float,
hybrid_search_enabled: bool,
keyword_weight: float = 0.3,
semantic_weight: float = 0.7,
):
self.vector_store = vector_store
self.embedder = OllamaEmbedder(
base_url=embedder_base_url,
model=embedder_model,
batch_size=embedder_batch_size,
)
self.default_top_k = default_top_k
self.similarity_threshold = similarity_threshold
self.hybrid_search_enabled = hybrid_search_enabled
self.keyword_weight = keyword_weight
self.semantic_weight = semantic_weight
def search(
self,
query: str,
top_k: int | None = None,
filters: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
k = top_k or self.default_top_k
query_embedding = self.embedder.embed([query])[0]
results = self.vector_store.search(query_embedding, top_k=k, filters=filters)
if self.similarity_threshold > 0 and results:
results = [
r
for r in results
if r.get("_distance", float("inf")) <= self.similarity_threshold
]
return results

34
tests/test_search.py Normal file
View File

@@ -0,0 +1,34 @@
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
from companion.rag.search import SearchEngine
from companion.rag.vector_store import VectorStore
@patch("companion.rag.search.OllamaEmbedder")
def test_search_returns_results(mock_embedder_cls):
mock_embedder = MagicMock()
mock_embedder.embed.return_value = [[1.0, 0.0, 0.0, 0.0]]
mock_embedder_cls.return_value = mock_embedder
with tempfile.TemporaryDirectory() as tmp:
store = VectorStore(uri=tmp, dimensions=4)
store.upsert(
ids=["a"],
texts=["hello world"],
embeddings=[[1.0, 0.0, 0.0, 0.0]],
metadatas=[{"source_file": "a.md", "source_directory": "docs"}],
)
engine = SearchEngine(
vector_store=store,
embedder_base_url="http://localhost:11434",
embedder_model="dummy",
embedder_batch_size=32,
default_top_k=5,
similarity_threshold=0.0,
hybrid_search_enabled=False,
)
results = engine.search("hello")
assert len(results) == 1
assert results[0]["source_file"] == "a.md"