"""Ollama API client for embedding generation.""" from __future__ import annotations import time from typing import TYPE_CHECKING import httpx if TYPE_CHECKING: from obsidian_rag.config import ObsidianRagConfig DEFAULT_TIMEOUT = 120.0 # seconds class EmbeddingError(Exception): """Raised when embedding generation fails.""" class OllamaUnavailableError(EmbeddingError): """Raised when Ollama is unreachable.""" class OllamaEmbedder: """Client for Ollama /api/embed endpoint (mxbai-embed-large, 1024-dim).""" def __init__(self, config: "ObsidianRagConfig"): self.base_url = config.embedding.base_url.rstrip("/") self.model = config.embedding.model self.dimensions = config.embedding.dimensions self.batch_size = config.embedding.batch_size self._client = httpx.Client(timeout=DEFAULT_TIMEOUT) def is_available(self) -> bool: """Check if Ollama is reachable and has the model.""" try: resp = self._client.get(f"{self.base_url}/api/tags", timeout=5.0) if resp.status_code != 200: return False models = resp.json().get("models", []) return any(self.model in m.get("name", "") for m in models) except Exception: return False def embed_chunks(self, texts: list[str]) -> list[list[float]]: """Generate embeddings for a batch of texts. Returns list of vectors.""" if not texts: return [] all_vectors: list[list[float]] = [] for i in range(0, len(texts), self.batch_size): batch = texts[i : i + self.batch_size] vectors = self._embed_batch(batch) all_vectors.extend(vectors) return all_vectors def embed_single(self, text: str) -> list[float]: """Generate embedding for a single text.""" [vec] = self._embed_batch([text]) return vec def _embed_batch(self, batch: list[str]) -> list[list[float]]: """Internal batch call. Raises EmbeddingError on failure.""" # Ollama /api/embeddings takes {"model": "...", "prompt": "..."} for single # For batch, call /api/embeddings multiple times sequentially if len(batch) == 1: endpoint = f"{self.base_url}/api/embeddings" payload = {"model": self.model, "prompt": batch[0]} else: # For batch, use /api/embeddings with "input" (multiple calls) results = [] for text in batch: try: resp = self._client.post( f"{self.base_url}/api/embeddings", json={"model": self.model, "prompt": text}, timeout=DEFAULT_TIMEOUT, ) except httpx.ConnectError as e: raise OllamaUnavailableError(f"Cannot connect to Ollama at {self.base_url}") from e except httpx.TimeoutException as e: raise EmbeddingError(f"Embedding request timed out after {DEFAULT_TIMEOUT}s") from e if resp.status_code != 200: raise EmbeddingError(f"Ollama returned {resp.status_code}: {resp.text}") data = resp.json() embedding = data.get("embedding", []) if not embedding: embedding = data.get("embeddings", [[]])[0] results.append(embedding) return results try: resp = self._client.post(endpoint, json=payload, timeout=DEFAULT_TIMEOUT) except httpx.ConnectError as e: raise OllamaUnavailableError(f"Cannot connect to Ollama at {self.base_url}") from e except httpx.TimeoutException as e: raise EmbeddingError(f"Embedding request timed out after {DEFAULT_TIMEOUT}s") from e if resp.status_code != 200: raise EmbeddingError(f"Ollama returned {resp.status_code}: {resp.text}") data = resp.json() embedding = data.get("embedding", []) if not embedding: embedding = data.get("embeddings", [[]])[0] return [embedding] def close(self): self._client.close()