From 6e7f347bd30d8e7b1911e497463b8b344c897534 Mon Sep 17 00:00:00 2001 From: Santhosh Janardhanan Date: Mon, 13 Apr 2026 14:11:48 -0400 Subject: [PATCH] feat: add Ollama embedder with batching and retries --- src/companion/rag/embedder.py | 44 +++++++++++++++++++++++++++++++++++ tests/test_embedder.py | 21 +++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 src/companion/rag/embedder.py create mode 100644 tests/test_embedder.py diff --git a/src/companion/rag/embedder.py b/src/companion/rag/embedder.py new file mode 100644 index 0000000..d156176 --- /dev/null +++ b/src/companion/rag/embedder.py @@ -0,0 +1,44 @@ +import time +from typing import List + +import httpx + + +class OllamaEmbedder: + def __init__(self, base_url: str, model: str, batch_size: int): + self.base_url = base_url.rstrip("/") + self.model = model + self.batch_size = batch_size + + def embed( + self, texts: List[str], retries: int = 3, backoff: float = 1.0 + ) -> List[List[float]]: + all_embeddings: List[List[float]] = [] + url = f"{self.base_url}/api/embed" + + for i in range(0, len(texts), self.batch_size): + batch = texts[i : i + self.batch_size] + last_exception: Exception | None = None + + for attempt in range(retries): + try: + with httpx.Client(timeout=300.0) as client: + response = client.post( + url, + json={"model": self.model, "input": batch}, + ) + response.raise_for_status() + data = response.json() + embeddings = data["embeddings"] + all_embeddings.extend(embeddings) + break + except Exception as exc: + last_exception = exc + if attempt < retries - 1: + time.sleep(backoff * (2**attempt)) + else: + raise RuntimeError( + f"Failed to embed batch after {retries} retries" + ) from last_exception + + return all_embeddings diff --git a/tests/test_embedder.py b/tests/test_embedder.py new file mode 100644 index 0000000..9c61e53 --- /dev/null +++ b/tests/test_embedder.py @@ -0,0 +1,21 @@ +import pytest +import respx +from httpx import Response + +from companion.rag.embedder import OllamaEmbedder + + +@respx.mock +def test_embed_batch(): + route = respx.post("http://localhost:11434/api/embed").mock( + return_value=Response(200, json={"embeddings": [[0.1] * 1024, [0.2] * 1024]}) + ) + embedder = OllamaEmbedder( + base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2 + ) + result = embedder.embed(["hello world", "goodbye world"]) + assert len(result) == 2 + assert len(result[0]) == 1024 + assert result[0][0] == 0.1 + assert result[1][0] == 0.2 + assert route.called