diff --git a/src/companion/rag/search.py b/src/companion/rag/search.py new file mode 100644 index 0000000..b67742e --- /dev/null +++ b/src/companion/rag/search.py @@ -0,0 +1,47 @@ +from typing import Any + +from companion.rag.embedder import OllamaEmbedder +from companion.rag.vector_store import VectorStore + + +class SearchEngine: + def __init__( + self, + vector_store: VectorStore, + embedder_base_url: str, + embedder_model: str, + embedder_batch_size: int, + default_top_k: int, + similarity_threshold: float, + hybrid_search_enabled: bool, + keyword_weight: float = 0.3, + semantic_weight: float = 0.7, + ): + self.vector_store = vector_store + self.embedder = OllamaEmbedder( + base_url=embedder_base_url, + model=embedder_model, + batch_size=embedder_batch_size, + ) + self.default_top_k = default_top_k + self.similarity_threshold = similarity_threshold + self.hybrid_search_enabled = hybrid_search_enabled + self.keyword_weight = keyword_weight + self.semantic_weight = semantic_weight + + def search( + self, + query: str, + top_k: int | None = None, + filters: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + k = top_k or self.default_top_k + query_embedding = self.embedder.embed([query])[0] + results = self.vector_store.search(query_embedding, top_k=k, filters=filters) + if self.similarity_threshold > 0 and results: + results = [ + r + for r in results + if r.get("_distance", float("inf")) <= self.similarity_threshold + ] + return results diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 0000000..f86bf03 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,34 @@ +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +from companion.rag.search import SearchEngine +from companion.rag.vector_store import VectorStore + + +@patch("companion.rag.search.OllamaEmbedder") +def test_search_returns_results(mock_embedder_cls): + mock_embedder = MagicMock() + mock_embedder.embed.return_value = [[1.0, 0.0, 0.0, 0.0]] + mock_embedder_cls.return_value = mock_embedder + + with tempfile.TemporaryDirectory() as tmp: + store = VectorStore(uri=tmp, dimensions=4) + store.upsert( + ids=["a"], + texts=["hello world"], + embeddings=[[1.0, 0.0, 0.0, 0.0]], + metadatas=[{"source_file": "a.md", "source_directory": "docs"}], + ) + engine = SearchEngine( + vector_store=store, + embedder_base_url="http://localhost:11434", + embedder_model="dummy", + embedder_batch_size=32, + default_top_k=5, + similarity_threshold=0.0, + hybrid_search_enabled=False, + ) + results = engine.search("hello") + assert len(results) == 1 + assert results[0]["source_file"] == "a.md"