69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
import pytest
|
|
import respx
|
|
from httpx import Response
|
|
|
|
from companion.rag.embedder import OllamaEmbedder
|
|
|
|
|
|
@respx.mock
|
|
def test_embed_batch():
|
|
route = respx.post("http://localhost:11434/api/embed").mock(
|
|
return_value=Response(200, json={"embeddings": [[0.1] * 1024, [0.2] * 1024]})
|
|
)
|
|
embedder = OllamaEmbedder(
|
|
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2
|
|
)
|
|
result = embedder.embed(["hello world", "goodbye world"])
|
|
assert len(result) == 2
|
|
assert len(result[0]) == 1024
|
|
assert result[0][0] == 0.1
|
|
assert result[1][0] == 0.2
|
|
assert route.called
|
|
|
|
|
|
@respx.mock
|
|
def test_embed_multi_batch():
|
|
"""Test that texts are split into multiple batches and results are concatenated."""
|
|
route = respx.post("http://localhost:11434/api/embed").mock(
|
|
side_effect=[
|
|
Response(200, json={"embeddings": [[0.1] * 1024]}),
|
|
Response(200, json={"embeddings": [[0.2] * 1024]}),
|
|
]
|
|
)
|
|
embedder = OllamaEmbedder(
|
|
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=1
|
|
)
|
|
result = embedder.embed(["hello", "world"])
|
|
assert len(result) == 2
|
|
assert result[0][0] == 0.1
|
|
assert result[1][0] == 0.2
|
|
assert route.call_count == 2
|
|
|
|
|
|
@respx.mock
|
|
def test_embed_retry_exhaustion():
|
|
"""Test that RuntimeError is raised after all retries fail."""
|
|
route = respx.post("http://localhost:11434/api/embed").mock(
|
|
return_value=Response(500, text="Internal Server Error")
|
|
)
|
|
embedder = OllamaEmbedder(
|
|
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2
|
|
)
|
|
with pytest.raises(RuntimeError, match="Failed to embed batch after 3 retries"):
|
|
embedder.embed(["hello world"], retries=3, backoff=0.01)
|
|
assert route.call_count == 3
|
|
|
|
|
|
@respx.mock
|
|
def test_embed_count_mismatch():
|
|
"""Test that ValueError is raised when Ollama returns fewer embeddings than texts."""
|
|
route = respx.post("http://localhost:11434/api/embed").mock(
|
|
return_value=Response(200, json={"embeddings": [[0.1] * 1024]})
|
|
)
|
|
embedder = OllamaEmbedder(
|
|
base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2
|
|
)
|
|
with pytest.raises(ValueError, match="returned 1 embeddings for 2 texts"):
|
|
embedder.embed(["hello world", "goodbye world"])
|
|
assert route.called
|