feat: add Ollama embedder with batching and retries

This commit is contained in:
2026-04-13 14:11:48 -04:00
parent 95687fad2e
commit 6e7f347bd3
2 changed files with 65 additions and 0 deletions

View File

@@ -0,0 +1,44 @@
import time
from typing import List
import httpx
class OllamaEmbedder:
def __init__(self, base_url: str, model: str, batch_size: int):
self.base_url = base_url.rstrip("/")
self.model = model
self.batch_size = batch_size
def embed(
self, texts: List[str], retries: int = 3, backoff: float = 1.0
) -> List[List[float]]:
all_embeddings: List[List[float]] = []
url = f"{self.base_url}/api/embed"
for i in range(0, len(texts), self.batch_size):
batch = texts[i : i + self.batch_size]
last_exception: Exception | None = None
for attempt in range(retries):
try:
with httpx.Client(timeout=300.0) as client:
response = client.post(
url,
json={"model": self.model, "input": batch},
)
response.raise_for_status()
data = response.json()
embeddings = data["embeddings"]
all_embeddings.extend(embeddings)
break
except Exception as exc:
last_exception = exc
if attempt < retries - 1:
time.sleep(backoff * (2**attempt))
else:
raise RuntimeError(
f"Failed to embed batch after {retries} retries"
) from last_exception
return all_embeddings

21
tests/test_embedder.py Normal file
View File

@@ -0,0 +1,21 @@
import pytest
import respx
from httpx import Response
from companion.rag.embedder import OllamaEmbedder
@respx.mock
def test_embed_batch():
route = respx.post("http://localhost:11434/api/embed").mock(
return_value=Response(200, json={"embeddings": [[0.1] * 1024, [0.2] * 1024]})
)
embedder = OllamaEmbedder(
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2
)
result = embedder.embed(["hello world", "goodbye world"])
assert len(result) == 2
assert len(result[0]) == 1024
assert result[0][0] == 0.1
assert result[1][0] == 0.2
assert route.called