From 1c7d1618d8fc5e6435cd8c0ee3fabf3c970f54e9 Mon Sep 17 00:00:00 2001 From: Ashwin Mathur <97467100+awinml@users.noreply.github.com> Date: Wed, 3 Apr 2024 20:11:30 +0530 Subject: [PATCH] Add truncate and normalize parameters to TEI Embedders (#7460) --- .../hugging_face_tei_document_embedder.py | 24 ++++- .../hugging_face_tei_text_embedder.py | 24 ++++- ...uncate-normalize-tei-6c998b14154267bb.yaml | 4 + ...test_hugging_face_tei_document_embedder.py | 99 +++++++++++++++++-- .../test_hugging_face_tei_text_embedder.py | 71 ++++++++++++- 5 files changed, 204 insertions(+), 18 deletions(-) create mode 100644 releasenotes/notes/add-truncate-normalize-tei-6c998b14154267bb.yaml diff --git a/haystack/components/embedders/hugging_face_tei_document_embedder.py b/haystack/components/embedders/hugging_face_tei_document_embedder.py index a9e32f88d..9a9803e45 100644 --- a/haystack/components/embedders/hugging_face_tei_document_embedder.py +++ b/haystack/components/embedders/hugging_face_tei_document_embedder.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict, List, Optional from urllib.parse import urlparse @@ -50,6 +51,8 @@ class HuggingFaceTEIDocumentEmbedder: token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), prefix: str = "", suffix: str = "", + truncate: bool = True, + normalize: bool = False, batch_size: int = 32, progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, @@ -70,6 +73,15 @@ class HuggingFaceTEIDocumentEmbedder: A string to add at the beginning of each text. :param suffix: A string to add at the end of each text. + :param truncate: + Truncate input text from the end to the maximum length supported by the model. This option is only available + for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with + TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed + without TEI. + :param normalize: + Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding + Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used + with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI. :param batch_size: Number of Documents to encode at once. :param progress_bar: @@ -95,6 +107,8 @@ class HuggingFaceTEIDocumentEmbedder: self.client = InferenceClient(url or model, token=token.resolve_value() if token else None) self.prefix = prefix self.suffix = suffix + self.truncate = truncate + self.normalize = normalize self.batch_size = batch_size self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed or [] @@ -113,6 +127,8 @@ class HuggingFaceTEIDocumentEmbedder: url=self.url, prefix=self.prefix, suffix=self.suffix, + truncate=self.truncate, + normalize=self.normalize, batch_size=self.batch_size, progress_bar=self.progress_bar, meta_fields_to_embed=self.meta_fields_to_embed, @@ -167,8 +183,12 @@ class HuggingFaceTEIDocumentEmbedder: range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): batch = texts_to_embed[i : i + batch_size] - embeddings = self.client.feature_extraction(text=batch) - all_embeddings.extend(embeddings.tolist()) + response = self.client.post( + json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize}, + task="feature-extraction", + ) + embeddings = json.loads(response.decode()) + all_embeddings.extend(embeddings) return all_embeddings diff --git a/haystack/components/embedders/hugging_face_tei_text_embedder.py b/haystack/components/embedders/hugging_face_tei_text_embedder.py index 28757ded8..f618214e3 100644 --- a/haystack/components/embedders/hugging_face_tei_text_embedder.py +++ b/haystack/components/embedders/hugging_face_tei_text_embedder.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict, List, Optional from urllib.parse import urlparse @@ -45,6 +46,8 @@ class HuggingFaceTEITextEmbedder: token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), prefix: str = "", suffix: str = "", + truncate: bool = True, + normalize: bool = False, ): """ Create an HuggingFaceTEITextEmbedder component. @@ -61,6 +64,15 @@ class HuggingFaceTEITextEmbedder: A string to add at the beginning of each text. :param suffix: A string to add at the end of each text. + :param truncate: + Truncate input text from the end to the maximum length supported by the model. This option is only available + for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with + TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed + without TEI. + :param normalize: + Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding + Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used + with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI. """ huggingface_hub_import.check() @@ -78,6 +90,8 @@ class HuggingFaceTEITextEmbedder: self.client = InferenceClient(url or model, token=token.resolve_value() if token else None) self.prefix = prefix self.suffix = suffix + self.truncate = truncate + self.normalize = normalize def to_dict(self) -> Dict[str, Any]: """ @@ -93,6 +107,8 @@ class HuggingFaceTEITextEmbedder: prefix=self.prefix, suffix=self.suffix, token=self.token.to_dict() if self.token else None, + truncate=self.truncate, + normalize=self.normalize, ) @classmethod @@ -135,8 +151,10 @@ class HuggingFaceTEITextEmbedder: text_to_embed = self.prefix + text + self.suffix - embeddings = self.client.feature_extraction(text=[text_to_embed]) - # The client returns a numpy array - embedding = embeddings.tolist()[0] + response = self.client.post( + json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize}, + task="feature-extraction", + ) + embedding = json.loads(response.decode())[0] return {"embedding": embedding} diff --git a/releasenotes/notes/add-truncate-normalize-tei-6c998b14154267bb.yaml b/releasenotes/notes/add-truncate-normalize-tei-6c998b14154267bb.yaml new file mode 100644 index 000000000..a76a4b61b --- /dev/null +++ b/releasenotes/notes/add-truncate-normalize-tei-6c998b14154267bb.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Adds `truncate` and `normalize` parameters to `HuggingFaceTEITextEmbedder` and `HuggingFaceTEITextEmbedder` for allowing truncation and normalization of embeddings. diff --git a/test/components/embedders/test_hugging_face_tei_document_embedder.py b/test/components/embedders/test_hugging_face_tei_document_embedder.py index e4b75615e..06d9673cd 100644 --- a/test/components/embedders/test_hugging_face_tei_document_embedder.py +++ b/test/components/embedders/test_hugging_face_tei_document_embedder.py @@ -18,8 +18,8 @@ def mock_check_valid_model(): yield mock -def mock_embedding_generation(text, **kwargs): - response = np.array([np.random.rand(384) for i in range(len(text))]) +def mock_embedding_generation(json, **kwargs): + response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode() return response @@ -33,6 +33,8 @@ class TestHuggingFaceTEIDocumentEmbedder: assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) assert embedder.prefix == "" assert embedder.suffix == "" + assert embedder.truncate is True + assert embedder.normalize is False assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] @@ -45,6 +47,8 @@ class TestHuggingFaceTEIDocumentEmbedder: token=Secret.from_token("fake-api-token"), prefix="prefix", suffix="suffix", + truncate=False, + normalize=True, batch_size=64, progress_bar=False, meta_fields_to_embed=["test_field"], @@ -56,6 +60,8 @@ class TestHuggingFaceTEIDocumentEmbedder: assert embedder.token == Secret.from_token("fake-api-token") assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" + assert embedder.truncate is False + assert embedder.normalize is True assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] @@ -83,6 +89,8 @@ class TestHuggingFaceTEIDocumentEmbedder: "url": None, "prefix": "", "suffix": "", + "truncate": True, + "normalize": False, "batch_size": 32, "progress_bar": True, "meta_fields_to_embed": [], @@ -90,6 +98,38 @@ class TestHuggingFaceTEIDocumentEmbedder: }, } + def test_from_dict(self, mock_check_valid_model): + data = { + "type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder", + "init_parameters": { + "model": "BAAI/bge-small-en-v1.5", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "url": None, + "prefix": "", + "suffix": "", + "truncate": True, + "normalize": False, + "batch_size": 32, + "progress_bar": True, + "meta_fields_to_embed": [], + "embedding_separator": "\n", + }, + } + + embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data) + + assert embedder.model == "BAAI/bge-small-en-v1.5" + assert embedder.url is None + assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.truncate is True + assert embedder.normalize is False + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): component = HuggingFaceTEIDocumentEmbedder( model="sentence-transformers/all-mpnet-base-v2", @@ -97,6 +137,8 @@ class TestHuggingFaceTEIDocumentEmbedder: token=Secret.from_env_var("ENV_VAR", strict=False), prefix="prefix", suffix="suffix", + truncate=False, + normalize=True, batch_size=64, progress_bar=False, meta_fields_to_embed=["test_field"], @@ -113,6 +155,8 @@ class TestHuggingFaceTEIDocumentEmbedder: "url": "https://some_embedding_model.com", "prefix": "prefix", "suffix": "suffix", + "truncate": False, + "normalize": True, "batch_size": 64, "progress_bar": False, "meta_fields_to_embed": ["test_field"], @@ -120,6 +164,38 @@ class TestHuggingFaceTEIDocumentEmbedder: }, } + def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model): + data = { + "type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder", + "init_parameters": { + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "model": "sentence-transformers/all-mpnet-base-v2", + "url": "https://some_embedding_model.com", + "prefix": "prefix", + "suffix": "suffix", + "truncate": False, + "normalize": True, + "batch_size": 64, + "progress_bar": False, + "meta_fields_to_embed": ["test_field"], + "embedding_separator": " | ", + }, + } + + embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data) + + assert embedder.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder.url == "https://some_embedding_model.com" + assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False) + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert embedder.truncate is False + assert embedder.normalize is True + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.meta_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == " | " + def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model): documents = [ Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5) @@ -167,7 +243,7 @@ class TestHuggingFaceTEIDocumentEmbedder: def test_embed_batch(self, mock_check_valid_model): texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] - with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceTEIDocumentEmbedder( @@ -192,7 +268,7 @@ class TestHuggingFaceTEIDocumentEmbedder: Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] - with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceTEIDocumentEmbedder( @@ -207,10 +283,15 @@ class TestHuggingFaceTEIDocumentEmbedder: result = embedder.run(documents=docs) mock_embedding_patch.assert_called_once_with( - text=[ - "prefix Cuisine | I love cheese suffix", - "prefix ML | A transformer is a deep learning architecture suffix", - ] + json={ + "inputs": [ + "prefix Cuisine | I love cheese suffix", + "prefix ML | A transformer is a deep learning architecture suffix", + ], + "truncate": True, + "normalize": False, + }, + task="feature-extraction", ) documents_with_embeddings = result["documents"] @@ -251,7 +332,7 @@ class TestHuggingFaceTEIDocumentEmbedder: Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] - with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceTEIDocumentEmbedder( diff --git a/test/components/embedders/test_hugging_face_tei_text_embedder.py b/test/components/embedders/test_hugging_face_tei_text_embedder.py index 5efed23d4..1e50cd6f2 100644 --- a/test/components/embedders/test_hugging_face_tei_text_embedder.py +++ b/test/components/embedders/test_hugging_face_tei_text_embedder.py @@ -16,8 +16,8 @@ def mock_check_valid_model(): yield mock -def mock_embedding_generation(text, **kwargs): - response = np.array([np.random.rand(384) for i in range(len(text))]) +def mock_embedding_generation(json, **kwargs): + response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode() return response @@ -31,6 +31,8 @@ class TestHuggingFaceTEITextEmbedder: assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) assert embedder.prefix == "" assert embedder.suffix == "" + assert embedder.truncate is True + assert embedder.normalize is False def test_init_with_parameters(self, mock_check_valid_model): embedder = HuggingFaceTEITextEmbedder( @@ -39,6 +41,8 @@ class TestHuggingFaceTEITextEmbedder: token=Secret.from_token("fake-api-token"), prefix="prefix", suffix="suffix", + truncate=False, + normalize=True, ) assert embedder.model == "sentence-transformers/all-mpnet-base-v2" @@ -46,6 +50,8 @@ class TestHuggingFaceTEITextEmbedder: assert embedder.token == Secret.from_token("fake-api-token") assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" + assert embedder.truncate is False + assert embedder.normalize is True def test_initialize_with_invalid_url(self, mock_check_valid_model): with pytest.raises(ValueError): @@ -69,9 +75,35 @@ class TestHuggingFaceTEITextEmbedder: "url": None, "prefix": "", "suffix": "", + "truncate": True, + "normalize": False, }, } + def test_from_dict(self, mock_check_valid_model): + data = { + "type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder", + "init_parameters": { + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "model": "BAAI/bge-small-en-v1.5", + "url": None, + "prefix": "", + "suffix": "", + "truncate": True, + "normalize": False, + }, + } + + embedder = HuggingFaceTEITextEmbedder.from_dict(data) + + assert embedder.model == "BAAI/bge-small-en-v1.5" + assert embedder.url is None + assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.truncate is True + assert embedder.normalize is False + def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): component = HuggingFaceTEITextEmbedder( model="sentence-transformers/all-mpnet-base-v2", @@ -79,6 +111,8 @@ class TestHuggingFaceTEITextEmbedder: token=Secret.from_env_var("ENV_VAR", strict=False), prefix="prefix", suffix="suffix", + truncate=False, + normalize=True, ) data = component.to_dict() @@ -91,11 +125,37 @@ class TestHuggingFaceTEITextEmbedder: "url": "https://some_embedding_model.com", "prefix": "prefix", "suffix": "suffix", + "truncate": False, + "normalize": True, }, } + def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model): + data = { + "type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder", + "init_parameters": { + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "model": "sentence-transformers/all-mpnet-base-v2", + "url": "https://some_embedding_model.com", + "prefix": "prefix", + "suffix": "suffix", + "truncate": False, + "normalize": True, + }, + } + + embedder = HuggingFaceTEITextEmbedder.from_dict(data) + + assert embedder.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder.url == "https://some_embedding_model.com" + assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False) + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert embedder.truncate is False + assert embedder.normalize is True + def test_run(self, mock_check_valid_model): - with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: + with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: mock_embedding_patch.side_effect = mock_embedding_generation embedder = HuggingFaceTEITextEmbedder( @@ -107,7 +167,10 @@ class TestHuggingFaceTEITextEmbedder: result = embedder.run(text="The food was delicious") - mock_embedding_patch.assert_called_once_with(text=["prefix The food was delicious suffix"]) + mock_embedding_patch.assert_called_once_with( + json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False}, + task="feature-extraction", + ) assert len(result["embedding"]) == 384 assert all(isinstance(x, float) for x in result["embedding"])