fix: address embedder review feedback

This commit is contained in:
2026-04-13 14:14:00 -04:00
parent 6e7f347bd3
commit 0948d2dcb7
2 changed files with 85 additions and 12 deletions

View File

@@ -5,10 +5,18 @@ import httpx
class OllamaEmbedder: class OllamaEmbedder:
def __init__(self, base_url: str, model: str, batch_size: int): def __init__(
self,
base_url: str,
model: str,
batch_size: int,
timeout: float = 300.0,
):
self.base_url = base_url.rstrip("/") self.base_url = base_url.rstrip("/")
self.model = model self.model = model
self.batch_size = batch_size self.batch_size = batch_size
self.timeout = timeout
self._client = httpx.Client(timeout=timeout)
def embed( def embed(
self, texts: List[str], retries: int = 3, backoff: float = 1.0 self, texts: List[str], retries: int = 3, backoff: float = 1.0
@@ -22,23 +30,41 @@ class OllamaEmbedder:
for attempt in range(retries): for attempt in range(retries):
try: try:
with httpx.Client(timeout=300.0) as client: response = self._client.post(
response = client.post( url,
url, json={"model": self.model, "input": batch},
json={"model": self.model, "input": batch}, )
response.raise_for_status()
data = response.json()
embeddings = data["embeddings"]
# Validate response count matches batch count
if len(embeddings) != len(batch):
raise ValueError(
f"Ollama returned {len(embeddings)} embeddings for {len(batch)} texts"
) )
response.raise_for_status()
data = response.json() all_embeddings.extend(embeddings)
embeddings = data["embeddings"] break
all_embeddings.extend(embeddings) except (httpx.HTTPError, httpx.RequestError) as exc:
break
except Exception as exc:
last_exception = exc last_exception = exc
if attempt < retries - 1: if attempt < retries - 1:
time.sleep(backoff * (2**attempt)) time.sleep(backoff * (2**attempt))
else: else:
raise RuntimeError( raise RuntimeError(
f"Failed to embed batch after {retries} retries" f"Failed to embed batch after {retries} retries. "
f"Last error: {last_exception}"
) from last_exception ) from last_exception
return all_embeddings return all_embeddings
def close(self) -> None:
"""Close the HTTP client."""
self._client.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False

View File

@@ -19,3 +19,50 @@ def test_embed_batch():
assert result[0][0] == 0.1 assert result[0][0] == 0.1
assert result[1][0] == 0.2 assert result[1][0] == 0.2
assert route.called assert route.called
@respx.mock
def test_embed_multi_batch():
"""Test that texts are split into multiple batches and results are concatenated."""
route = respx.post("http://localhost:11434/api/embed").mock(
side_effect=[
Response(200, json={"embeddings": [[0.1] * 1024]}),
Response(200, json={"embeddings": [[0.2] * 1024]}),
]
)
embedder = OllamaEmbedder(
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=1
)
result = embedder.embed(["hello", "world"])
assert len(result) == 2
assert result[0][0] == 0.1
assert result[1][0] == 0.2
assert route.call_count == 2
@respx.mock
def test_embed_retry_exhaustion():
"""Test that RuntimeError is raised after all retries fail."""
route = respx.post("http://localhost:11434/api/embed").mock(
return_value=Response(500, text="Internal Server Error")
)
embedder = OllamaEmbedder(
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2
)
with pytest.raises(RuntimeError, match="Failed to embed batch after 3 retries"):
embedder.embed(["hello world"], retries=3, backoff=0.01)
assert route.call_count == 3
@respx.mock
def test_embed_count_mismatch():
"""Test that ValueError is raised when Ollama returns fewer embeddings than texts."""
route = respx.post("http://localhost:11434/api/embed").mock(
return_value=Response(200, json={"embeddings": [[0.1] * 1024]})
)
embedder = OllamaEmbedder(
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2
)
with pytest.raises(ValueError, match="returned 1 embeddings for 2 texts"):
embedder.embed(["hello world", "goodbye world"])
assert route.called