fix: add error handling and docstrings to search engine
This commit is contained in:
@@ -5,6 +5,22 @@ from companion.rag.vector_store import VectorStore
|
|||||||
|
|
||||||
|
|
||||||
class SearchEngine:
|
class SearchEngine:
|
||||||
|
"""Search engine for semantic search using vector embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_store: Vector store instance for searching
|
||||||
|
embedder_base_url: Base URL for the embedding service (Ollama)
|
||||||
|
embedder_model: Model name for embeddings
|
||||||
|
embedder_batch_size: Batch size for embedding requests
|
||||||
|
default_top_k: Default number of results to return
|
||||||
|
similarity_threshold: Maximum distance threshold for filtering results
|
||||||
|
hybrid_search_enabled: Reserved for future hybrid search implementation
|
||||||
|
keyword_weight: Reserved for future hybrid search (keyword component weight)
|
||||||
|
semantic_weight: Reserved for future hybrid search (semantic component weight)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DISTANCE_FIELD = "_distance"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vector_store: VectorStore,
|
vector_store: VectorStore,
|
||||||
@@ -35,13 +51,34 @@ class SearchEngine:
|
|||||||
top_k: int | None = None,
|
top_k: int | None = None,
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Search for relevant documents using semantic similarity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query string
|
||||||
|
top_k: Number of results to return (defaults to default_top_k)
|
||||||
|
filters: Optional metadata filters to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching documents with similarity scores
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If embedding generation fails
|
||||||
|
"""
|
||||||
k = top_k or self.default_top_k
|
k = top_k or self.default_top_k
|
||||||
|
|
||||||
|
try:
|
||||||
query_embedding = self.embedder.embed([query])[0]
|
query_embedding = self.embedder.embed([query])[0]
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError(f"Failed to generate embedding for query: {e}") from e
|
||||||
|
|
||||||
results = self.vector_store.search(query_embedding, top_k=k, filters=filters)
|
results = self.vector_store.search(query_embedding, top_k=k, filters=filters)
|
||||||
|
|
||||||
if self.similarity_threshold > 0 and results:
|
if self.similarity_threshold > 0 and results:
|
||||||
results = [
|
results = [
|
||||||
r
|
r
|
||||||
for r in results
|
for r in results
|
||||||
if r.get("_distance", float("inf")) <= self.similarity_threshold
|
if r.get(self._DISTANCE_FIELD, float("inf"))
|
||||||
|
<= self.similarity_threshold
|
||||||
]
|
]
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
Reference in New Issue
Block a user