fix: address embedder review feedback
This commit is contained in:
@@ -5,10 +5,18 @@ import httpx
|
|||||||
|
|
||||||
|
|
||||||
class OllamaEmbedder:
|
class OllamaEmbedder:
|
||||||
def __init__(self, base_url: str, model: str, batch_size: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
model: str,
|
||||||
|
batch_size: int,
|
||||||
|
timeout: float = 300.0,
|
||||||
|
):
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
self.model = model
|
self.model = model
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
self.timeout = timeout
|
||||||
|
self._client = httpx.Client(timeout=timeout)
|
||||||
|
|
||||||
def embed(
|
def embed(
|
||||||
self, texts: List[str], retries: int = 3, backoff: float = 1.0
|
self, texts: List[str], retries: int = 3, backoff: float = 1.0
|
||||||
@@ -22,23 +30,41 @@ class OllamaEmbedder:
|
|||||||
|
|
||||||
for attempt in range(retries):
|
for attempt in range(retries):
|
||||||
try:
|
try:
|
||||||
with httpx.Client(timeout=300.0) as client:
|
response = self._client.post(
|
||||||
response = client.post(
|
|
||||||
url,
|
url,
|
||||||
json={"model": self.model, "input": batch},
|
json={"model": self.model, "input": batch},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
embeddings = data["embeddings"]
|
embeddings = data["embeddings"]
|
||||||
|
|
||||||
|
# Validate response count matches batch count
|
||||||
|
if len(embeddings) != len(batch):
|
||||||
|
raise ValueError(
|
||||||
|
f"Ollama returned {len(embeddings)} embeddings for {len(batch)} texts"
|
||||||
|
)
|
||||||
|
|
||||||
all_embeddings.extend(embeddings)
|
all_embeddings.extend(embeddings)
|
||||||
break
|
break
|
||||||
except Exception as exc:
|
except (httpx.HTTPError, httpx.RequestError) as exc:
|
||||||
last_exception = exc
|
last_exception = exc
|
||||||
if attempt < retries - 1:
|
if attempt < retries - 1:
|
||||||
time.sleep(backoff * (2**attempt))
|
time.sleep(backoff * (2**attempt))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Failed to embed batch after {retries} retries"
|
f"Failed to embed batch after {retries} retries. "
|
||||||
|
f"Last error: {last_exception}"
|
||||||
) from last_exception
|
) from last_exception
|
||||||
|
|
||||||
return all_embeddings
|
return all_embeddings
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
self._client.close()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.close()
|
||||||
|
return False
|
||||||
|
|||||||
@@ -19,3 +19,50 @@ def test_embed_batch():
|
|||||||
assert result[0][0] == 0.1
|
assert result[0][0] == 0.1
|
||||||
assert result[1][0] == 0.2
|
assert result[1][0] == 0.2
|
||||||
assert route.called
|
assert route.called
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_embed_multi_batch():
|
||||||
|
"""Test that texts are split into multiple batches and results are concatenated."""
|
||||||
|
route = respx.post("http://localhost:11434/api/embed").mock(
|
||||||
|
side_effect=[
|
||||||
|
Response(200, json={"embeddings": [[0.1] * 1024]}),
|
||||||
|
Response(200, json={"embeddings": [[0.2] * 1024]}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
embedder = OllamaEmbedder(
|
||||||
|
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=1
|
||||||
|
)
|
||||||
|
result = embedder.embed(["hello", "world"])
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0][0] == 0.1
|
||||||
|
assert result[1][0] == 0.2
|
||||||
|
assert route.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_embed_retry_exhaustion():
|
||||||
|
"""Test that RuntimeError is raised after all retries fail."""
|
||||||
|
route = respx.post("http://localhost:11434/api/embed").mock(
|
||||||
|
return_value=Response(500, text="Internal Server Error")
|
||||||
|
)
|
||||||
|
embedder = OllamaEmbedder(
|
||||||
|
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2
|
||||||
|
)
|
||||||
|
with pytest.raises(RuntimeError, match="Failed to embed batch after 3 retries"):
|
||||||
|
embedder.embed(["hello world"], retries=3, backoff=0.01)
|
||||||
|
assert route.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_embed_count_mismatch():
|
||||||
|
"""Test that ValueError is raised when Ollama returns fewer embeddings than texts."""
|
||||||
|
route = respx.post("http://localhost:11434/api/embed").mock(
|
||||||
|
return_value=Response(200, json={"embeddings": [[0.1] * 1024]})
|
||||||
|
)
|
||||||
|
embedder = OllamaEmbedder(
|
||||||
|
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="returned 1 embeddings for 2 texts"):
|
||||||
|
embedder.embed(["hello world", "goodbye world"])
|
||||||
|
assert route.called
|
||||||
|
|||||||
Reference in New Issue
Block a user