Files
obsidian-rag/python/obsidian_rag/embedder.py

142 lines
5.1 KiB
Python

"""Ollama API client for embedding generation."""
from __future__ import annotations
import time
import urllib.parse
from typing import TYPE_CHECKING
import httpx
if TYPE_CHECKING:
from obsidian_rag.config import ObsidianRagConfig
DEFAULT_TIMEOUT = 120.0 # seconds
MAX_CHUNK_CHARS = 8000 # safe default for most Ollama models
class EmbeddingError(Exception):
"""Raised when embedding generation fails."""
class OllamaUnavailableError(EmbeddingError):
"""Raised when Ollama is unreachable."""
class SecurityError(Exception):
"""Raised when security validation fails."""
class OllamaEmbedder:
"""Client for Ollama /api/embed endpoint (mxbai-embed-large, 1024-dim)."""
def __init__(self, config: "ObsidianRagConfig"):
self.base_url = config.embedding.base_url.rstrip("/")
self.model = config.embedding.model
self.dimensions = config.embedding.dimensions
self.batch_size = config.embedding.batch_size
self.local_only = config.security.local_only
self._client = httpx.Client(timeout=DEFAULT_TIMEOUT)
self._validate_network_isolation()
def _validate_network_isolation(self):
"""Validate that embedding service is local when local_only is True."""
if not self.local_only:
return
parsed = urllib.parse.urlparse(self.base_url)
if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]:
raise SecurityError(
f"Remote embedding service not allowed when local_only=True: {self.base_url}"
)
def is_available(self) -> bool:
"""Check if Ollama is reachable and has the model."""
try:
resp = self._client.get(f"{self.base_url}/api/tags", timeout=5.0)
if resp.status_code != 200:
return False
models = resp.json().get("models", [])
return any(self.model in m.get("name", "") for m in models)
except Exception:
return False
def embed_chunks(self, texts: list[str]) -> list[list[float]]:
"""Generate embeddings for a batch of texts. Returns list of vectors."""
if not texts:
return []
all_vectors: list[list[float]] = []
for i in range(0, len(texts), self.batch_size):
batch = texts[i : i + self.batch_size]
vectors = self._embed_batch(batch)
all_vectors.extend(vectors)
return all_vectors
def embed_single(self, text: str) -> list[float]:
"""Generate embedding for a single text."""
[vec] = self._embed_batch([text])
return vec
def _embed_batch(self, batch: list[str]) -> list[list[float]]:
"""Internal batch call. Raises EmbeddingError on failure."""
# Ollama /api/embeddings takes {"model": "...", "prompt": "..."} for single
# For batch, call /api/embeddings multiple times sequentially
if len(batch) == 1:
endpoint = f"{self.base_url}/api/embeddings"
prompt = batch[0][:MAX_CHUNK_CHARS]
payload = {"model": self.model, "prompt": prompt}
else:
# For batch, use /api/embeddings with "input" (multiple calls)
results = []
for text in batch:
truncated = text[:MAX_CHUNK_CHARS]
try:
resp = self._client.post(
f"{self.base_url}/api/embeddings",
json={"model": self.model, "prompt": truncated},
timeout=DEFAULT_TIMEOUT,
)
except httpx.ConnectError as e:
raise OllamaUnavailableError(
f"Cannot connect to Ollama at {self.base_url}"
) from e
except httpx.TimeoutException as e:
raise EmbeddingError(
f"Embedding request timed out after {DEFAULT_TIMEOUT}s"
) from e
if resp.status_code != 200:
raise EmbeddingError(
f"Ollama returned {resp.status_code}: {resp.text}"
)
data = resp.json()
embedding = data.get("embedding", [])
if not embedding:
embedding = data.get("embeddings", [[]])[0]
results.append(embedding)
return results
try:
resp = self._client.post(endpoint, json=payload, timeout=DEFAULT_TIMEOUT)
except httpx.ConnectError as e:
raise OllamaUnavailableError(
f"Cannot connect to Ollama at {self.base_url}"
) from e
except httpx.TimeoutException as e:
raise EmbeddingError(
f"Embedding request timed out after {DEFAULT_TIMEOUT}s"
) from e
if resp.status_code != 200:
raise EmbeddingError(f"Ollama returned {resp.status_code}: {resp.text}")
data = resp.json()
embedding = data.get("embedding", [])
if not embedding:
embedding = data.get("embeddings", [[]])[0]
return [embedding]
def close(self):
self._client.close()