152 lines
4.9 KiB
Python
152 lines
4.9 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from companion.rag.embedder import OllamaEmbedder
|
|
from companion.rag.vector_store import VectorStore
|
|
|
|
|
|
@dataclass
|
|
class SearchResult:
|
|
"""Structured search result with citation information."""
|
|
|
|
id: str
|
|
text: str
|
|
source_file: str
|
|
source_directory: str
|
|
section: str | None
|
|
date: str | None
|
|
tags: list[str]
|
|
chunk_index: int
|
|
total_chunks: int
|
|
distance: float
|
|
|
|
@property
|
|
def citation(self) -> str:
|
|
"""Generate a citation string for this result."""
|
|
parts = [self.source_file]
|
|
if self.section:
|
|
parts.append(f"#{self.section}")
|
|
if self.date:
|
|
parts.append(f"({self.date})")
|
|
return " - ".join(parts)
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary for API serialization."""
|
|
return {
|
|
"id": self.id,
|
|
"text": self.text,
|
|
"source_file": self.source_file,
|
|
"source_directory": self.source_directory,
|
|
"section": self.section,
|
|
"date": self.date,
|
|
"tags": self.tags,
|
|
"chunk_index": self.chunk_index,
|
|
"total_chunks": self.total_chunks,
|
|
"distance": self.distance,
|
|
"citation": self.citation,
|
|
}
|
|
|
|
|
|
class SearchEngine:
|
|
"""Search engine for semantic search using vector embeddings.
|
|
|
|
Args:
|
|
vector_store: Vector store instance for searching
|
|
embedder_base_url: Base URL for the embedding service (Ollama)
|
|
embedder_model: Model name for embeddings
|
|
embedder_batch_size: Batch size for embedding requests
|
|
default_top_k: Default number of results to return
|
|
similarity_threshold: Maximum distance threshold for filtering results
|
|
hybrid_search_enabled: Reserved for future hybrid search implementation
|
|
keyword_weight: Reserved for future hybrid search (keyword component weight)
|
|
semantic_weight: Reserved for future hybrid search (semantic component weight)
|
|
"""
|
|
|
|
_DISTANCE_FIELD = "_distance"
|
|
|
|
def __init__(
|
|
self,
|
|
vector_store: VectorStore,
|
|
embedder_base_url: str,
|
|
embedder_model: str,
|
|
embedder_batch_size: int,
|
|
default_top_k: int,
|
|
similarity_threshold: float,
|
|
hybrid_search_enabled: bool,
|
|
keyword_weight: float = 0.3,
|
|
semantic_weight: float = 0.7,
|
|
):
|
|
self.vector_store = vector_store
|
|
self.embedder = OllamaEmbedder(
|
|
base_url=embedder_base_url,
|
|
model=embedder_model,
|
|
batch_size=embedder_batch_size,
|
|
)
|
|
self.default_top_k = default_top_k
|
|
self.similarity_threshold = similarity_threshold
|
|
self.hybrid_search_enabled = hybrid_search_enabled
|
|
self.keyword_weight = keyword_weight
|
|
self.semantic_weight = semantic_weight
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
top_k: int | None = None,
|
|
filters: dict[str, Any] | None = None,
|
|
) -> list[SearchResult]:
|
|
"""Search for relevant documents using semantic similarity.
|
|
|
|
Args:
|
|
query: The search query string
|
|
top_k: Number of results to return (defaults to default_top_k)
|
|
filters: Optional metadata filters to apply
|
|
|
|
Returns:
|
|
List of SearchResult objects with similarity scores
|
|
|
|
Raises:
|
|
RuntimeError: If embedding generation fails
|
|
"""
|
|
k = top_k or self.default_top_k
|
|
|
|
try:
|
|
embeddings = self.embedder.embed([query])
|
|
if not embeddings:
|
|
raise RuntimeError(
|
|
"Failed to generate embedding for query: embedder returned empty result"
|
|
)
|
|
query_embedding = embeddings[0]
|
|
except RuntimeError as e:
|
|
raise RuntimeError(f"Failed to generate embedding for query: {e}") from e
|
|
|
|
raw_results = self.vector_store.search(
|
|
query_embedding, top_k=k, filters=filters
|
|
)
|
|
|
|
if self.similarity_threshold > 0 and raw_results:
|
|
raw_results = [
|
|
r
|
|
for r in raw_results
|
|
if r.get(self._DISTANCE_FIELD, float("inf"))
|
|
<= self.similarity_threshold
|
|
]
|
|
|
|
# Convert raw results to SearchResult objects
|
|
results: list[SearchResult] = []
|
|
for r in raw_results:
|
|
result = SearchResult(
|
|
id=r.get("id", ""),
|
|
text=r.get("text", ""),
|
|
source_file=r.get("source_file", ""),
|
|
source_directory=r.get("source_directory", ""),
|
|
section=r.get("section"),
|
|
date=r.get("date"),
|
|
tags=r.get("tags") or [],
|
|
chunk_index=r.get("chunk_index", 0),
|
|
total_chunks=r.get("total_chunks", 1),
|
|
distance=r.get(self._DISTANCE_FIELD, 1.0),
|
|
)
|
|
results.append(result)
|
|
|
|
return results
|