diff --git a/haystack/components/embedders/hugging_face_tei_document_embedder.py b/haystack/components/embedders/hugging_face_tei_document_embedder.py index 73e7e25b3..629246ee7 100644 --- a/haystack/components/embedders/hugging_face_tei_document_embedder.py +++ b/haystack/components/embedders/hugging_face_tei_document_embedder.py @@ -9,7 +9,7 @@ from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.hf import HFModelType, check_valid_model -with LazyImport(message="Run 'pip install transformers'") as transformers_import: +with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import: from huggingface_hub import InferenceClient logger = logging.getLogger(__name__) @@ -79,7 +79,7 @@ class HuggingFaceTEIDocumentEmbedder: :param embedding_separator: Separator used to concatenate the meta fields to the Document text. """ - transformers_import.check() + huggingface_hub_import.check() if url: r = urlparse(url) diff --git a/haystack/components/embedders/hugging_face_tei_text_embedder.py b/haystack/components/embedders/hugging_face_tei_text_embedder.py index 0c773929b..e8c258607 100644 --- a/haystack/components/embedders/hugging_face_tei_text_embedder.py +++ b/haystack/components/embedders/hugging_face_tei_text_embedder.py @@ -6,7 +6,7 @@ from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.hf import HFModelType, check_valid_model -with LazyImport(message="Run 'pip install transformers'") as transformers_import: +with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import: from huggingface_hub import InferenceClient logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ class HuggingFaceTEITextEmbedder: :param suffix: A string to add at the end of each text. """ - transformers_import.check() + huggingface_hub_import.check() if url: r = urlparse(url) @@ -135,8 +135,8 @@ class HuggingFaceTEITextEmbedder: text_to_embed = self.prefix + text + self.suffix - embedding = self.client.feature_extraction(text=text_to_embed) + embeddings = self.client.feature_extraction(text=[text_to_embed]) # The client returns a numpy array - embedding = embedding.tolist() + embedding = embeddings.tolist()[0] return {"embedding": embedding} diff --git a/releasenotes/notes/hf-tei-bug-fix-07732c672600aadd.yaml b/releasenotes/notes/hf-tei-bug-fix-07732c672600aadd.yaml new file mode 100644 index 000000000..7fd5fd178 --- /dev/null +++ b/releasenotes/notes/hf-tei-bug-fix-07732c672600aadd.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Fixes `HuggingFaceTEITextEmbedder` returning an embedding of incorrect shape when used with a + Text-Embedding-Inference endpoint deployed using Docker. diff --git a/test/components/embedders/test_hugging_face_tei_document_embedder.py b/test/components/embedders/test_hugging_face_tei_document_embedder.py index 98d884a67..e4b75615e 100644 --- a/test/components/embedders/test_hugging_face_tei_document_embedder.py +++ b/test/components/embedders/test_hugging_face_tei_document_embedder.py @@ -3,10 +3,10 @@ from unittest.mock import MagicMock, patch import numpy as np import pytest from huggingface_hub.utils import RepositoryNotFoundError -from haystack.utils.auth import Secret from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder from haystack.dataclasses import Document +from haystack.utils.auth import Secret @pytest.fixture @@ -222,6 +222,29 @@ class TestHuggingFaceTEIDocumentEmbedder: assert len(doc.embedding) == 384 assert all(isinstance(x, float) for x in doc.embedding) + @pytest.mark.flaky(reruns=5, reruns_delay=5) + @pytest.mark.integration + def test_run_inference_api_endpoint(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + embedder = HuggingFaceTEIDocumentEmbedder( + model="sentence-transformers/all-MiniLM-L6-v2", meta_fields_to_embed=["topic"], embedding_separator=" | " + ) + + result = embedder.run(documents=docs) + documents_with_embeddings = result["documents"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 384 + assert all(isinstance(x, float) for x in doc.embedding) + def test_run_custom_batch_size(self, mock_check_valid_model): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), diff --git a/test/components/embedders/test_hugging_face_tei_text_embedder.py b/test/components/embedders/test_hugging_face_tei_text_embedder.py index 569b3eea4..5efed23d4 100644 --- a/test/components/embedders/test_hugging_face_tei_text_embedder.py +++ b/test/components/embedders/test_hugging_face_tei_text_embedder.py @@ -3,9 +3,9 @@ from unittest.mock import MagicMock, patch import numpy as np import pytest from huggingface_hub.utils import RepositoryNotFoundError -from haystack.utils.auth import Secret from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder +from haystack.utils.auth import Secret @pytest.fixture @@ -17,7 +17,7 @@ def mock_check_valid_model(): def mock_embedding_generation(text, **kwargs): - response = np.random.rand(384) + response = np.array([np.random.rand(384) for i in range(len(text))]) return response @@ -107,7 +107,16 @@ class TestHuggingFaceTEITextEmbedder: result = embedder.run(text="The food was delicious") - mock_embedding_patch.assert_called_once_with(text="prefix The food was delicious suffix") + mock_embedding_patch.assert_called_once_with(text=["prefix The food was delicious suffix"]) + + assert len(result["embedding"]) == 384 + assert all(isinstance(x, float) for x in result["embedding"]) + + @pytest.mark.flaky(reruns=5, reruns_delay=5) + @pytest.mark.integration + def test_run_inference_api_endpoint(self): + embedder = HuggingFaceTEITextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") + result = embedder.run(text="The food was delicious") assert len(result["embedding"]) == 384 assert all(isinstance(x, float) for x in result["embedding"])