fix: add error handling and docstrings to search engine

This commit is contained in:
2026-04-13 14:32:02 -04:00
parent 9a02c5fcbd
commit b2a42e5fe6

View File

@@ -5,6 +5,22 @@ from companion.rag.vector_store import VectorStore
class SearchEngine: class SearchEngine:
"""Search engine for semantic search using vector embeddings.
Args:
vector_store: Vector store instance for searching
embedder_base_url: Base URL for the embedding service (Ollama)
embedder_model: Model name for embeddings
embedder_batch_size: Batch size for embedding requests
default_top_k: Default number of results to return
similarity_threshold: Maximum distance threshold for filtering results
hybrid_search_enabled: Reserved for future hybrid search implementation
keyword_weight: Reserved for future hybrid search (keyword component weight)
semantic_weight: Reserved for future hybrid search (semantic component weight)
"""
_DISTANCE_FIELD = "_distance"
def __init__( def __init__(
self, self,
vector_store: VectorStore, vector_store: VectorStore,
@@ -35,13 +51,34 @@ class SearchEngine:
top_k: int | None = None, top_k: int | None = None,
filters: dict[str, Any] | None = None, filters: dict[str, Any] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Search for relevant documents using semantic similarity.
Args:
query: The search query string
top_k: Number of results to return (defaults to default_top_k)
filters: Optional metadata filters to apply
Returns:
List of matching documents with similarity scores
Raises:
RuntimeError: If embedding generation fails
"""
k = top_k or self.default_top_k k = top_k or self.default_top_k
try:
query_embedding = self.embedder.embed([query])[0] query_embedding = self.embedder.embed([query])[0]
except RuntimeError as e:
raise RuntimeError(f"Failed to generate embedding for query: {e}") from e
results = self.vector_store.search(query_embedding, top_k=k, filters=filters) results = self.vector_store.search(query_embedding, top_k=k, filters=filters)
if self.similarity_threshold > 0 and results: if self.similarity_threshold > 0 and results:
results = [ results = [
r r
for r in results for r in results
if r.get("_distance", float("inf")) <= self.similarity_threshold if r.get(self._DISTANCE_FIELD, float("inf"))
<= self.similarity_threshold
] ]
return results return results