129 lines
4.7 KiB
Python
129 lines
4.7 KiB
Python
"""Ollama API client for embedding generation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
import urllib.parse
|
|
from typing import TYPE_CHECKING
|
|
|
|
import httpx
|
|
|
|
if TYPE_CHECKING:
|
|
from obsidian_rag.config import ObsidianRagConfig
|
|
|
|
DEFAULT_TIMEOUT = 120.0 # seconds
|
|
|
|
|
|
class EmbeddingError(Exception):
|
|
"""Raised when embedding generation fails."""
|
|
|
|
|
|
class OllamaUnavailableError(EmbeddingError):
|
|
"""Raised when Ollama is unreachable."""
|
|
|
|
|
|
class SecurityError(Exception):
|
|
"""Raised when security validation fails."""
|
|
|
|
|
|
class OllamaEmbedder:
|
|
"""Client for Ollama /api/embed endpoint (mxbai-embed-large, 1024-dim)."""
|
|
|
|
def __init__(self, config: "ObsidianRagConfig"):
|
|
self.base_url = config.embedding.base_url.rstrip("/")
|
|
self.model = config.embedding.model
|
|
self.dimensions = config.embedding.dimensions
|
|
self.batch_size = config.embedding.batch_size
|
|
self.local_only = config.security.local_only
|
|
self._client = httpx.Client(timeout=DEFAULT_TIMEOUT)
|
|
self._validate_network_isolation()
|
|
|
|
def _validate_network_isolation(self):
|
|
"""Validate that embedding service is local when local_only is True."""
|
|
if not self.local_only:
|
|
return
|
|
|
|
parsed = urllib.parse.urlparse(self.base_url)
|
|
if parsed.hostname not in ['localhost', '127.0.0.1', '::1']:
|
|
raise SecurityError(
|
|
f"Remote embedding service not allowed when local_only=True: {self.base_url}"
|
|
)
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if Ollama is reachable and has the model."""
|
|
try:
|
|
resp = self._client.get(f"{self.base_url}/api/tags", timeout=5.0)
|
|
if resp.status_code != 200:
|
|
return False
|
|
models = resp.json().get("models", [])
|
|
return any(self.model in m.get("name", "") for m in models)
|
|
except Exception:
|
|
return False
|
|
|
|
def embed_chunks(self, texts: list[str]) -> list[list[float]]:
|
|
"""Generate embeddings for a batch of texts. Returns list of vectors."""
|
|
if not texts:
|
|
return []
|
|
|
|
all_vectors: list[list[float]] = []
|
|
for i in range(0, len(texts), self.batch_size):
|
|
batch = texts[i : i + self.batch_size]
|
|
vectors = self._embed_batch(batch)
|
|
all_vectors.extend(vectors)
|
|
|
|
return all_vectors
|
|
|
|
def embed_single(self, text: str) -> list[float]:
|
|
"""Generate embedding for a single text."""
|
|
[vec] = self._embed_batch([text])
|
|
return vec
|
|
|
|
def _embed_batch(self, batch: list[str]) -> list[list[float]]:
|
|
"""Internal batch call. Raises EmbeddingError on failure."""
|
|
# Ollama /api/embeddings takes {"model": "...", "prompt": "..."} for single
|
|
# For batch, call /api/embeddings multiple times sequentially
|
|
if len(batch) == 1:
|
|
endpoint = f"{self.base_url}/api/embeddings"
|
|
payload = {"model": self.model, "prompt": batch[0]}
|
|
else:
|
|
# For batch, use /api/embeddings with "input" (multiple calls)
|
|
results = []
|
|
for text in batch:
|
|
try:
|
|
resp = self._client.post(
|
|
f"{self.base_url}/api/embeddings",
|
|
json={"model": self.model, "prompt": text},
|
|
timeout=DEFAULT_TIMEOUT,
|
|
)
|
|
except httpx.ConnectError as e:
|
|
raise OllamaUnavailableError(f"Cannot connect to Ollama at {self.base_url}") from e
|
|
except httpx.TimeoutException as e:
|
|
raise EmbeddingError(f"Embedding request timed out after {DEFAULT_TIMEOUT}s") from e
|
|
if resp.status_code != 200:
|
|
raise EmbeddingError(f"Ollama returned {resp.status_code}: {resp.text}")
|
|
data = resp.json()
|
|
embedding = data.get("embedding", [])
|
|
if not embedding:
|
|
embedding = data.get("embeddings", [[]])[0]
|
|
results.append(embedding)
|
|
return results
|
|
|
|
try:
|
|
resp = self._client.post(endpoint, json=payload, timeout=DEFAULT_TIMEOUT)
|
|
except httpx.ConnectError as e:
|
|
raise OllamaUnavailableError(f"Cannot connect to Ollama at {self.base_url}") from e
|
|
except httpx.TimeoutException as e:
|
|
raise EmbeddingError(f"Embedding request timed out after {DEFAULT_TIMEOUT}s") from e
|
|
|
|
if resp.status_code != 200:
|
|
raise EmbeddingError(f"Ollama returned {resp.status_code}: {resp.text}")
|
|
|
|
data = resp.json()
|
|
embedding = data.get("embedding", [])
|
|
if not embedding:
|
|
embedding = data.get("embeddings", [[]])[0]
|
|
return [embedding]
|
|
|
|
def close(self):
|
|
self._client.close()
|