mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 05:58:57 +00:00
fix: HuggingFaceTEITextEmbedder returning embedding of incorrect shape when used with Docker endpoint (#7319)
* Fix HuggingFaceTEITextEmbedder * Update haystack/components/embedders/hugging_face_tei_text_embedder.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * Improve imports; Add additional tests --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
parent
95837ab6b5
commit
8d7a58347d
@ -9,7 +9,7 @@ from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import HFModelType, check_valid_model
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -79,7 +79,7 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
:param embedding_separator:
|
||||
Separator used to concatenate the meta fields to the Document text.
|
||||
"""
|
||||
transformers_import.check()
|
||||
huggingface_hub_import.check()
|
||||
|
||||
if url:
|
||||
r = urlparse(url)
|
||||
|
||||
@ -6,7 +6,7 @@ from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import HFModelType, check_valid_model
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -62,7 +62,7 @@ class HuggingFaceTEITextEmbedder:
|
||||
:param suffix:
|
||||
A string to add at the end of each text.
|
||||
"""
|
||||
transformers_import.check()
|
||||
huggingface_hub_import.check()
|
||||
|
||||
if url:
|
||||
r = urlparse(url)
|
||||
@ -135,8 +135,8 @@ class HuggingFaceTEITextEmbedder:
|
||||
|
||||
text_to_embed = self.prefix + text + self.suffix
|
||||
|
||||
embedding = self.client.feature_extraction(text=text_to_embed)
|
||||
embeddings = self.client.feature_extraction(text=[text_to_embed])
|
||||
# The client returns a numpy array
|
||||
embedding = embedding.tolist()
|
||||
embedding = embeddings.tolist()[0]
|
||||
|
||||
return {"embedding": embedding}
|
||||
|
||||
5
releasenotes/notes/hf-tei-bug-fix-07732c672600aadd.yaml
Normal file
5
releasenotes/notes/hf-tei-bug-fix-07732c672600aadd.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Fixes `HuggingFaceTEITextEmbedder` returning an embedding of incorrect shape when used with a
|
||||
Text-Embedding-Inference endpoint deployed using Docker.
|
||||
@ -3,10 +3,10 @@ from unittest.mock import MagicMock, patch
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -222,6 +222,29 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
assert len(doc.embedding) == 384
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
|
||||
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
||||
@pytest.mark.integration
|
||||
def test_run_inference_api_endpoint(self):
|
||||
docs = [
|
||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", meta_fields_to_embed=["topic"], embedding_separator=" | "
|
||||
)
|
||||
|
||||
result = embedder.run(documents=docs)
|
||||
documents_with_embeddings = result["documents"]
|
||||
|
||||
assert isinstance(documents_with_embeddings, list)
|
||||
assert len(documents_with_embeddings) == len(docs)
|
||||
for doc in documents_with_embeddings:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert len(doc.embedding) == 384
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
|
||||
def test_run_custom_batch_size(self, mock_check_valid_model):
|
||||
docs = [
|
||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||
|
||||
@ -3,9 +3,9 @@ from unittest.mock import MagicMock, patch
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -17,7 +17,7 @@ def mock_check_valid_model():
|
||||
|
||||
|
||||
def mock_embedding_generation(text, **kwargs):
|
||||
response = np.random.rand(384)
|
||||
response = np.array([np.random.rand(384) for i in range(len(text))])
|
||||
return response
|
||||
|
||||
|
||||
@ -107,7 +107,16 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
|
||||
result = embedder.run(text="The food was delicious")
|
||||
|
||||
mock_embedding_patch.assert_called_once_with(text="prefix The food was delicious suffix")
|
||||
mock_embedding_patch.assert_called_once_with(text=["prefix The food was delicious suffix"])
|
||||
|
||||
assert len(result["embedding"]) == 384
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
|
||||
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
||||
@pytest.mark.integration
|
||||
def test_run_inference_api_endpoint(self):
|
||||
embedder = HuggingFaceTEITextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
||||
result = embedder.run(text="The food was delicious")
|
||||
|
||||
assert len(result["embedding"]) == 384
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user