diff --git a/src/companion/rag/embedder.py b/src/companion/rag/embedder.py index d156176..00c3954 100644 --- a/src/companion/rag/embedder.py +++ b/src/companion/rag/embedder.py @@ -5,10 +5,18 @@ import httpx class OllamaEmbedder: - def __init__(self, base_url: str, model: str, batch_size: int): + def __init__( + self, + base_url: str, + model: str, + batch_size: int, + timeout: float = 300.0, + ): self.base_url = base_url.rstrip("/") self.model = model self.batch_size = batch_size + self.timeout = timeout + self._client = httpx.Client(timeout=timeout) def embed( self, texts: List[str], retries: int = 3, backoff: float = 1.0 @@ -22,23 +30,41 @@ class OllamaEmbedder: for attempt in range(retries): try: - with httpx.Client(timeout=300.0) as client: - response = client.post( - url, - json={"model": self.model, "input": batch}, + response = self._client.post( + url, + json={"model": self.model, "input": batch}, + ) + response.raise_for_status() + data = response.json() + embeddings = data["embeddings"] + + # Validate response count matches batch count + if len(embeddings) != len(batch): + raise ValueError( + f"Ollama returned {len(embeddings)} embeddings for {len(batch)} texts" ) - response.raise_for_status() - data = response.json() - embeddings = data["embeddings"] - all_embeddings.extend(embeddings) - break - except Exception as exc: + + all_embeddings.extend(embeddings) + break + except (httpx.HTTPError, httpx.RequestError) as exc: last_exception = exc if attempt < retries - 1: time.sleep(backoff * (2**attempt)) else: raise RuntimeError( - f"Failed to embed batch after {retries} retries" + f"Failed to embed batch after {retries} retries. " + f"Last error: {last_exception}" ) from last_exception return all_embeddings + + def close(self) -> None: + """Close the HTTP client.""" + self._client.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False diff --git a/tests/test_embedder.py b/tests/test_embedder.py index 9c61e53..af5e2b8 100644 --- a/tests/test_embedder.py +++ b/tests/test_embedder.py @@ -19,3 +19,50 @@ def test_embed_batch(): assert result[0][0] == 0.1 assert result[1][0] == 0.2 assert route.called + + +@respx.mock +def test_embed_multi_batch(): + """Test that texts are split into multiple batches and results are concatenated.""" + route = respx.post("http://localhost:11434/api/embed").mock( + side_effect=[ + Response(200, json={"embeddings": [[0.1] * 1024]}), + Response(200, json={"embeddings": [[0.2] * 1024]}), + ] + ) + embedder = OllamaEmbedder( + base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=1 + ) + result = embedder.embed(["hello", "world"]) + assert len(result) == 2 + assert result[0][0] == 0.1 + assert result[1][0] == 0.2 + assert route.call_count == 2 + + +@respx.mock +def test_embed_retry_exhaustion(): + """Test that RuntimeError is raised after all retries fail.""" + route = respx.post("http://localhost:11434/api/embed").mock( + return_value=Response(500, text="Internal Server Error") + ) + embedder = OllamaEmbedder( + base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2 + ) + with pytest.raises(RuntimeError, match="Failed to embed batch after 3 retries"): + embedder.embed(["hello world"], retries=3, backoff=0.01) + assert route.call_count == 3 + + +@respx.mock +def test_embed_count_mismatch(): + """Test that ValueError is raised when Ollama returns fewer embeddings than texts.""" + route = respx.post("http://localhost:11434/api/embed").mock( + return_value=Response(200, json={"embeddings": [[0.1] * 1024]}) + ) + embedder = OllamaEmbedder( + base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2 + ) + with pytest.raises(ValueError, match="returned 1 embeddings for 2 texts"): + embedder.embed(["hello world", "goodbye world"]) + assert route.called