import pytest import respx from httpx import Response from companion.rag.embedder import OllamaEmbedder @respx.mock def test_embed_batch(): route = respx.post("http://localhost:11434/api/embed").mock( return_value=Response(200, json={"embeddings": [[0.1] * 1024, [0.2] * 1024]}) ) embedder = OllamaEmbedder( base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2 ) result = embedder.embed(["hello world", "goodbye world"]) assert len(result) == 2 assert len(result[0]) == 1024 assert result[0][0] == 0.1 assert result[1][0] == 0.2 assert route.called @respx.mock def test_embed_multi_batch(): """Test that texts are split into multiple batches and results are concatenated.""" route = respx.post("http://localhost:11434/api/embed").mock( side_effect=[ Response(200, json={"embeddings": [[0.1] * 1024]}), Response(200, json={"embeddings": [[0.2] * 1024]}), ] ) embedder = OllamaEmbedder( base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=1 ) result = embedder.embed(["hello", "world"]) assert len(result) == 2 assert result[0][0] == 0.1 assert result[1][0] == 0.2 assert route.call_count == 2 @respx.mock def test_embed_retry_exhaustion(): """Test that RuntimeError is raised after all retries fail.""" route = respx.post("http://localhost:11434/api/embed").mock( return_value=Response(500, text="Internal Server Error") ) embedder = OllamaEmbedder( base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2 ) with pytest.raises(RuntimeError, match="Failed to embed batch after 3 retries"): embedder.embed(["hello world"], retries=3, backoff=0.01) assert route.call_count == 3 @respx.mock def test_embed_count_mismatch(): """Test that ValueError is raised when Ollama returns fewer embeddings than texts.""" route = respx.post("http://localhost:11434/api/embed").mock( return_value=Response(200, json={"embeddings": [[0.1] * 1024]}) ) embedder = OllamaEmbedder( base_url="http://localhost:11434", model="mxbai-embed-large", batch_size=2 ) with pytest.raises(ValueError, match="returned 1 embeddings for 2 texts"): embedder.embed(["hello world", "goodbye world"]) assert route.called