feat: add search engine interface with embedding and filtering
This commit is contained in:
47
src/companion/rag/search.py
Normal file
47
src/companion/rag/search.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from companion.rag.embedder import OllamaEmbedder
|
||||||
|
from companion.rag.vector_store import VectorStore
|
||||||
|
|
||||||
|
|
||||||
|
class SearchEngine:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_store: VectorStore,
|
||||||
|
embedder_base_url: str,
|
||||||
|
embedder_model: str,
|
||||||
|
embedder_batch_size: int,
|
||||||
|
default_top_k: int,
|
||||||
|
similarity_threshold: float,
|
||||||
|
hybrid_search_enabled: bool,
|
||||||
|
keyword_weight: float = 0.3,
|
||||||
|
semantic_weight: float = 0.7,
|
||||||
|
):
|
||||||
|
self.vector_store = vector_store
|
||||||
|
self.embedder = OllamaEmbedder(
|
||||||
|
base_url=embedder_base_url,
|
||||||
|
model=embedder_model,
|
||||||
|
batch_size=embedder_batch_size,
|
||||||
|
)
|
||||||
|
self.default_top_k = default_top_k
|
||||||
|
self.similarity_threshold = similarity_threshold
|
||||||
|
self.hybrid_search_enabled = hybrid_search_enabled
|
||||||
|
self.keyword_weight = keyword_weight
|
||||||
|
self.semantic_weight = semantic_weight
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int | None = None,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
k = top_k or self.default_top_k
|
||||||
|
query_embedding = self.embedder.embed([query])[0]
|
||||||
|
results = self.vector_store.search(query_embedding, top_k=k, filters=filters)
|
||||||
|
if self.similarity_threshold > 0 and results:
|
||||||
|
results = [
|
||||||
|
r
|
||||||
|
for r in results
|
||||||
|
if r.get("_distance", float("inf")) <= self.similarity_threshold
|
||||||
|
]
|
||||||
|
return results
|
||||||
34
tests/test_search.py
Normal file
34
tests/test_search.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from companion.rag.search import SearchEngine
|
||||||
|
from companion.rag.vector_store import VectorStore
|
||||||
|
|
||||||
|
|
||||||
|
@patch("companion.rag.search.OllamaEmbedder")
|
||||||
|
def test_search_returns_results(mock_embedder_cls):
|
||||||
|
mock_embedder = MagicMock()
|
||||||
|
mock_embedder.embed.return_value = [[1.0, 0.0, 0.0, 0.0]]
|
||||||
|
mock_embedder_cls.return_value = mock_embedder
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
store = VectorStore(uri=tmp, dimensions=4)
|
||||||
|
store.upsert(
|
||||||
|
ids=["a"],
|
||||||
|
texts=["hello world"],
|
||||||
|
embeddings=[[1.0, 0.0, 0.0, 0.0]],
|
||||||
|
metadatas=[{"source_file": "a.md", "source_directory": "docs"}],
|
||||||
|
)
|
||||||
|
engine = SearchEngine(
|
||||||
|
vector_store=store,
|
||||||
|
embedder_base_url="http://localhost:11434",
|
||||||
|
embedder_model="dummy",
|
||||||
|
embedder_batch_size=32,
|
||||||
|
default_top_k=5,
|
||||||
|
similarity_threshold=0.0,
|
||||||
|
hybrid_search_enabled=False,
|
||||||
|
)
|
||||||
|
results = engine.search("hello")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["source_file"] == "a.md"
|
||||||
Reference in New Issue
Block a user