feat: add Ollama embedder with batching and retries
This commit is contained in:
44
src/companion/rag/embedder.py
Normal file
44
src/companion/rag/embedder.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbedder:
|
||||||
|
def __init__(self, base_url: str, model: str, batch_size: int):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.model = model
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self, texts: List[str], retries: int = 3, backoff: float = 1.0
|
||||||
|
) -> List[List[float]]:
|
||||||
|
all_embeddings: List[List[float]] = []
|
||||||
|
url = f"{self.base_url}/api/embed"
|
||||||
|
|
||||||
|
for i in range(0, len(texts), self.batch_size):
|
||||||
|
batch = texts[i : i + self.batch_size]
|
||||||
|
last_exception: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(retries):
|
||||||
|
try:
|
||||||
|
with httpx.Client(timeout=300.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
url,
|
||||||
|
json={"model": self.model, "input": batch},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
embeddings = data["embeddings"]
|
||||||
|
all_embeddings.extend(embeddings)
|
||||||
|
break
|
||||||
|
except Exception as exc:
|
||||||
|
last_exception = exc
|
||||||
|
if attempt < retries - 1:
|
||||||
|
time.sleep(backoff * (2**attempt))
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to embed batch after {retries} retries"
|
||||||
|
) from last_exception
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
21
tests/test_embedder.py
Normal file
21
tests/test_embedder.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
from httpx import Response
|
||||||
|
|
||||||
|
from companion.rag.embedder import OllamaEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_embed_batch():
|
||||||
|
route = respx.post("http://localhost:11434/api/embed").mock(
|
||||||
|
return_value=Response(200, json={"embeddings": [[0.1] * 1024, [0.2] * 1024]})
|
||||||
|
)
|
||||||
|
embedder = OllamaEmbedder(
|
||||||
|
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2
|
||||||
|
)
|
||||||
|
result = embedder.embed(["hello world", "goodbye world"])
|
||||||
|
assert len(result) == 2
|
||||||
|
assert len(result[0]) == 1024
|
||||||
|
assert result[0][0] == 0.1
|
||||||
|
assert result[1][0] == 0.2
|
||||||
|
assert route.called
|
||||||
Reference in New Issue
Block a user