fix: add error handling and docstrings to search engine
This commit is contained in:
@@ -5,6 +5,22 @@ from companion.rag.vector_store import VectorStore
|
||||
|
||||
|
||||
class SearchEngine:
|
||||
"""Search engine for semantic search using vector embeddings.
|
||||
|
||||
Args:
|
||||
vector_store: Vector store instance for searching
|
||||
embedder_base_url: Base URL for the embedding service (Ollama)
|
||||
embedder_model: Model name for embeddings
|
||||
embedder_batch_size: Batch size for embedding requests
|
||||
default_top_k: Default number of results to return
|
||||
similarity_threshold: Maximum distance threshold for filtering results
|
||||
hybrid_search_enabled: Reserved for future hybrid search implementation
|
||||
keyword_weight: Reserved for future hybrid search (keyword component weight)
|
||||
semantic_weight: Reserved for future hybrid search (semantic component weight)
|
||||
"""
|
||||
|
||||
_DISTANCE_FIELD = "_distance"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: VectorStore,
|
||||
@@ -35,13 +51,34 @@ class SearchEngine:
|
||||
top_k: int | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search for relevant documents using semantic similarity.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
top_k: Number of results to return (defaults to default_top_k)
|
||||
filters: Optional metadata filters to apply
|
||||
|
||||
Returns:
|
||||
List of matching documents with similarity scores
|
||||
|
||||
Raises:
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
k = top_k or self.default_top_k
|
||||
query_embedding = self.embedder.embed([query])[0]
|
||||
|
||||
try:
|
||||
query_embedding = self.embedder.embed([query])[0]
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f"Failed to generate embedding for query: {e}") from e
|
||||
|
||||
results = self.vector_store.search(query_embedding, top_k=k, filters=filters)
|
||||
|
||||
if self.similarity_threshold > 0 and results:
|
||||
results = [
|
||||
r
|
||||
for r in results
|
||||
if r.get("_distance", float("inf")) <= self.similarity_threshold
|
||||
if r.get(self._DISTANCE_FIELD, float("inf"))
|
||||
<= self.similarity_threshold
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user