diff --git a/python/obsidian_rag/embedder.py b/python/obsidian_rag/embedder.py index 358bb95..c408fc4 100644 --- a/python/obsidian_rag/embedder.py +++ b/python/obsidian_rag/embedder.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from obsidian_rag.config import ObsidianRagConfig DEFAULT_TIMEOUT = 120.0 # seconds +MAX_CHUNK_CHARS = 8000 # safe default for most Ollama models class EmbeddingError(Exception): @@ -42,9 +43,9 @@ class OllamaEmbedder: """Validate that embedding service is local when local_only is True.""" if not self.local_only: return - + parsed = urllib.parse.urlparse(self.base_url) - if parsed.hostname not in ['localhost', '127.0.0.1', '::1']: + if parsed.hostname not in ["localhost", "127.0.0.1", "::1"]: raise SecurityError( f"Remote embedding service not allowed when local_only=True: {self.base_url}" ) @@ -84,23 +85,31 @@ class OllamaEmbedder: # For batch, call /api/embeddings multiple times sequentially if len(batch) == 1: endpoint = f"{self.base_url}/api/embeddings" - payload = {"model": self.model, "prompt": batch[0]} + prompt = batch[0][:MAX_CHUNK_CHARS] + payload = {"model": self.model, "prompt": prompt} else: # For batch, use /api/embeddings with "input" (multiple calls) results = [] for text in batch: + truncated = text[:MAX_CHUNK_CHARS] try: resp = self._client.post( f"{self.base_url}/api/embeddings", - json={"model": self.model, "prompt": text}, + json={"model": self.model, "prompt": truncated}, timeout=DEFAULT_TIMEOUT, ) except httpx.ConnectError as e: - raise OllamaUnavailableError(f"Cannot connect to Ollama at {self.base_url}") from e + raise OllamaUnavailableError( + f"Cannot connect to Ollama at {self.base_url}" + ) from e except httpx.TimeoutException as e: - raise EmbeddingError(f"Embedding request timed out after {DEFAULT_TIMEOUT}s") from e + raise EmbeddingError( + f"Embedding request timed out after {DEFAULT_TIMEOUT}s" + ) from e if resp.status_code != 200: - raise EmbeddingError(f"Ollama returned {resp.status_code}: {resp.text}") + raise EmbeddingError( + f"Ollama returned {resp.status_code}: {resp.text}" + ) data = resp.json() embedding = data.get("embedding", []) if not embedding: @@ -111,9 +120,13 @@ class OllamaEmbedder: try: resp = self._client.post(endpoint, json=payload, timeout=DEFAULT_TIMEOUT) except httpx.ConnectError as e: - raise OllamaUnavailableError(f"Cannot connect to Ollama at {self.base_url}") from e + raise OllamaUnavailableError( + f"Cannot connect to Ollama at {self.base_url}" + ) from e except httpx.TimeoutException as e: - raise EmbeddingError(f"Embedding request timed out after {DEFAULT_TIMEOUT}s") from e + raise EmbeddingError( + f"Embedding request timed out after {DEFAULT_TIMEOUT}s" + ) from e if resp.status_code != 200: raise EmbeddingError(f"Ollama returned {resp.status_code}: {resp.text}")