fix: HuggingFaceTEITextEmbedder returning embedding of incorrect shape when used with Docker endpoint (#7319)

* Fix HuggingFaceTEITextEmbedder

* Update haystack/components/embedders/hugging_face_tei_text_embedder.py

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* Improve imports; Add additional tests

---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
Ashwin Mathur 2024-03-07 20:53:57 +05:30 committed by GitHub
parent 95837ab6b5
commit 8d7a58347d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 47 additions and 10 deletions

View File

@ -9,7 +9,7 @@ from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFModelType, check_valid_model
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
from huggingface_hub import InferenceClient
logger = logging.getLogger(__name__)
@ -79,7 +79,7 @@ class HuggingFaceTEIDocumentEmbedder:
:param embedding_separator:
Separator used to concatenate the meta fields to the Document text.
"""
transformers_import.check()
huggingface_hub_import.check()
if url:
r = urlparse(url)

View File

@ -6,7 +6,7 @@ from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFModelType, check_valid_model
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
from huggingface_hub import InferenceClient
logger = logging.getLogger(__name__)
@ -62,7 +62,7 @@ class HuggingFaceTEITextEmbedder:
:param suffix:
A string to add at the end of each text.
"""
transformers_import.check()
huggingface_hub_import.check()
if url:
r = urlparse(url)
@ -135,8 +135,8 @@ class HuggingFaceTEITextEmbedder:
text_to_embed = self.prefix + text + self.suffix
embedding = self.client.feature_extraction(text=text_to_embed)
embeddings = self.client.feature_extraction(text=[text_to_embed])
# The client returns a numpy array
embedding = embedding.tolist()
embedding = embeddings.tolist()[0]
return {"embedding": embedding}

View File

@ -0,0 +1,5 @@
---
fixes:
- |
Fixes `HuggingFaceTEITextEmbedder` returning an embedding of incorrect shape when used with a
Text-Embedding-Inference endpoint deployed using Docker.

View File

@ -3,10 +3,10 @@ from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from huggingface_hub.utils import RepositoryNotFoundError
from haystack.utils.auth import Secret
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
from haystack.dataclasses import Document
from haystack.utils.auth import Secret
@pytest.fixture
@ -222,6 +222,29 @@ class TestHuggingFaceTEIDocumentEmbedder:
assert len(doc.embedding) == 384
assert all(isinstance(x, float) for x in doc.embedding)
@pytest.mark.flaky(reruns=5, reruns_delay=5)
@pytest.mark.integration
def test_run_inference_api_endpoint(self):
docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}),
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]
embedder = HuggingFaceTEIDocumentEmbedder(
model="sentence-transformers/all-MiniLM-L6-v2", meta_fields_to_embed=["topic"], embedding_separator=" | "
)
result = embedder.run(documents=docs)
documents_with_embeddings = result["documents"]
assert isinstance(documents_with_embeddings, list)
assert len(documents_with_embeddings) == len(docs)
for doc in documents_with_embeddings:
assert isinstance(doc, Document)
assert isinstance(doc.embedding, list)
assert len(doc.embedding) == 384
assert all(isinstance(x, float) for x in doc.embedding)
def test_run_custom_batch_size(self, mock_check_valid_model):
docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}),

View File

@ -3,9 +3,9 @@ from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from huggingface_hub.utils import RepositoryNotFoundError
from haystack.utils.auth import Secret
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
from haystack.utils.auth import Secret
@pytest.fixture
@ -17,7 +17,7 @@ def mock_check_valid_model():
def mock_embedding_generation(text, **kwargs):
response = np.random.rand(384)
response = np.array([np.random.rand(384) for i in range(len(text))])
return response
@ -107,7 +107,16 @@ class TestHuggingFaceTEITextEmbedder:
result = embedder.run(text="The food was delicious")
mock_embedding_patch.assert_called_once_with(text="prefix The food was delicious suffix")
mock_embedding_patch.assert_called_once_with(text=["prefix The food was delicious suffix"])
assert len(result["embedding"]) == 384
assert all(isinstance(x, float) for x in result["embedding"])
@pytest.mark.flaky(reruns=5, reruns_delay=5)
@pytest.mark.integration
def test_run_inference_api_endpoint(self):
embedder = HuggingFaceTEITextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
result = embedder.run(text="The food was delicious")
assert len(result["embedding"]) == 384
assert all(isinstance(x, float) for x in result["embedding"])