Add truncate and normalize parameters to TEI Embedders (#7460)

This commit is contained in:
Ashwin Mathur 2024-04-03 20:11:30 +05:30 committed by GitHub
parent 1ce12c7a6a
commit 1c7d1618d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 204 additions and 18 deletions

View File

@ -1,3 +1,4 @@
import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
@ -50,6 +51,8 @@ class HuggingFaceTEIDocumentEmbedder:
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "", prefix: str = "",
suffix: str = "", suffix: str = "",
truncate: bool = True,
normalize: bool = False,
batch_size: int = 32, batch_size: int = 32,
progress_bar: bool = True, progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None, meta_fields_to_embed: Optional[List[str]] = None,
@ -70,6 +73,15 @@ class HuggingFaceTEIDocumentEmbedder:
A string to add at the beginning of each text. A string to add at the beginning of each text.
:param suffix: :param suffix:
A string to add at the end of each text. A string to add at the end of each text.
:param truncate:
Truncate input text from the end to the maximum length supported by the model. This option is only available
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
without TEI.
:param normalize:
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
:param batch_size: :param batch_size:
Number of Documents to encode at once. Number of Documents to encode at once.
:param progress_bar: :param progress_bar:
@ -95,6 +107,8 @@ class HuggingFaceTEIDocumentEmbedder:
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None) self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.prefix = prefix self.prefix = prefix
self.suffix = suffix self.suffix = suffix
self.truncate = truncate
self.normalize = normalize
self.batch_size = batch_size self.batch_size = batch_size
self.progress_bar = progress_bar self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or [] self.meta_fields_to_embed = meta_fields_to_embed or []
@ -113,6 +127,8 @@ class HuggingFaceTEIDocumentEmbedder:
url=self.url, url=self.url,
prefix=self.prefix, prefix=self.prefix,
suffix=self.suffix, suffix=self.suffix,
truncate=self.truncate,
normalize=self.normalize,
batch_size=self.batch_size, batch_size=self.batch_size,
progress_bar=self.progress_bar, progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed, meta_fields_to_embed=self.meta_fields_to_embed,
@ -167,8 +183,12 @@ class HuggingFaceTEIDocumentEmbedder:
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
): ):
batch = texts_to_embed[i : i + batch_size] batch = texts_to_embed[i : i + batch_size]
embeddings = self.client.feature_extraction(text=batch) response = self.client.post(
all_embeddings.extend(embeddings.tolist()) json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize},
task="feature-extraction",
)
embeddings = json.loads(response.decode())
all_embeddings.extend(embeddings)
return all_embeddings return all_embeddings

View File

@ -1,3 +1,4 @@
import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
@ -45,6 +46,8 @@ class HuggingFaceTEITextEmbedder:
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "", prefix: str = "",
suffix: str = "", suffix: str = "",
truncate: bool = True,
normalize: bool = False,
): ):
""" """
Create an HuggingFaceTEITextEmbedder component. Create an HuggingFaceTEITextEmbedder component.
@ -61,6 +64,15 @@ class HuggingFaceTEITextEmbedder:
A string to add at the beginning of each text. A string to add at the beginning of each text.
:param suffix: :param suffix:
A string to add at the end of each text. A string to add at the end of each text.
:param truncate:
Truncate input text from the end to the maximum length supported by the model. This option is only available
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
without TEI.
:param normalize:
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
""" """
huggingface_hub_import.check() huggingface_hub_import.check()
@ -78,6 +90,8 @@ class HuggingFaceTEITextEmbedder:
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None) self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.prefix = prefix self.prefix = prefix
self.suffix = suffix self.suffix = suffix
self.truncate = truncate
self.normalize = normalize
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
@ -93,6 +107,8 @@ class HuggingFaceTEITextEmbedder:
prefix=self.prefix, prefix=self.prefix,
suffix=self.suffix, suffix=self.suffix,
token=self.token.to_dict() if self.token else None, token=self.token.to_dict() if self.token else None,
truncate=self.truncate,
normalize=self.normalize,
) )
@classmethod @classmethod
@ -135,8 +151,10 @@ class HuggingFaceTEITextEmbedder:
text_to_embed = self.prefix + text + self.suffix text_to_embed = self.prefix + text + self.suffix
embeddings = self.client.feature_extraction(text=[text_to_embed]) response = self.client.post(
# The client returns a numpy array json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize},
embedding = embeddings.tolist()[0] task="feature-extraction",
)
embedding = json.loads(response.decode())[0]
return {"embedding": embedding} return {"embedding": embedding}

View File

@ -0,0 +1,4 @@
---
features:
- |
Adds `truncate` and `normalize` parameters to `HuggingFaceTEITextEmbedder` and `HuggingFaceTEITextEmbedder` for allowing truncation and normalization of embeddings.

View File

@ -18,8 +18,8 @@ def mock_check_valid_model():
yield mock yield mock
def mock_embedding_generation(text, **kwargs): def mock_embedding_generation(json, **kwargs):
response = np.array([np.random.rand(384) for i in range(len(text))]) response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
return response return response
@ -33,6 +33,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == "" assert embedder.prefix == ""
assert embedder.suffix == "" assert embedder.suffix == ""
assert embedder.truncate is True
assert embedder.normalize is False
assert embedder.batch_size == 32 assert embedder.batch_size == 32
assert embedder.progress_bar is True assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == [] assert embedder.meta_fields_to_embed == []
@ -45,6 +47,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
token=Secret.from_token("fake-api-token"), token=Secret.from_token("fake-api-token"),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
truncate=False,
normalize=True,
batch_size=64, batch_size=64,
progress_bar=False, progress_bar=False,
meta_fields_to_embed=["test_field"], meta_fields_to_embed=["test_field"],
@ -56,6 +60,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
assert embedder.token == Secret.from_token("fake-api-token") assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix" assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix" assert embedder.suffix == "suffix"
assert embedder.truncate is False
assert embedder.normalize is True
assert embedder.batch_size == 64 assert embedder.batch_size == 64
assert embedder.progress_bar is False assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.meta_fields_to_embed == ["test_field"]
@ -83,6 +89,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
"url": None, "url": None,
"prefix": "", "prefix": "",
"suffix": "", "suffix": "",
"truncate": True,
"normalize": False,
"batch_size": 32, "batch_size": 32,
"progress_bar": True, "progress_bar": True,
"meta_fields_to_embed": [], "meta_fields_to_embed": [],
@ -90,6 +98,38 @@ class TestHuggingFaceTEIDocumentEmbedder:
}, },
} }
def test_from_dict(self, mock_check_valid_model):
data = {
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"url": None,
"prefix": "",
"suffix": "",
"truncate": True,
"normalize": False,
"batch_size": 32,
"progress_bar": True,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
},
}
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
assert embedder.model == "BAAI/bge-small-en-v1.5"
assert embedder.url is None
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.truncate is True
assert embedder.normalize is False
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
component = HuggingFaceTEIDocumentEmbedder( component = HuggingFaceTEIDocumentEmbedder(
model="sentence-transformers/all-mpnet-base-v2", model="sentence-transformers/all-mpnet-base-v2",
@ -97,6 +137,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
token=Secret.from_env_var("ENV_VAR", strict=False), token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
truncate=False,
normalize=True,
batch_size=64, batch_size=64,
progress_bar=False, progress_bar=False,
meta_fields_to_embed=["test_field"], meta_fields_to_embed=["test_field"],
@ -113,6 +155,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
"url": "https://some_embedding_model.com", "url": "https://some_embedding_model.com",
"prefix": "prefix", "prefix": "prefix",
"suffix": "suffix", "suffix": "suffix",
"truncate": False,
"normalize": True,
"batch_size": 64, "batch_size": 64,
"progress_bar": False, "progress_bar": False,
"meta_fields_to_embed": ["test_field"], "meta_fields_to_embed": ["test_field"],
@ -120,6 +164,38 @@ class TestHuggingFaceTEIDocumentEmbedder:
}, },
} }
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
data = {
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
"init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "sentence-transformers/all-mpnet-base-v2",
"url": "https://some_embedding_model.com",
"prefix": "prefix",
"suffix": "suffix",
"truncate": False,
"normalize": True,
"batch_size": 64,
"progress_bar": False,
"meta_fields_to_embed": ["test_field"],
"embedding_separator": " | ",
},
}
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
assert embedder.url == "https://some_embedding_model.com"
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.truncate is False
assert embedder.normalize is True
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model): def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model):
documents = [ documents = [
Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5) Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
@ -167,7 +243,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
def test_embed_batch(self, mock_check_valid_model): def test_embed_batch(self, mock_check_valid_model):
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEIDocumentEmbedder( embedder = HuggingFaceTEIDocumentEmbedder(
@ -192,7 +268,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
] ]
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEIDocumentEmbedder( embedder = HuggingFaceTEIDocumentEmbedder(
@ -207,10 +283,15 @@ class TestHuggingFaceTEIDocumentEmbedder:
result = embedder.run(documents=docs) result = embedder.run(documents=docs)
mock_embedding_patch.assert_called_once_with( mock_embedding_patch.assert_called_once_with(
text=[ json={
"prefix Cuisine | I love cheese suffix", "inputs": [
"prefix ML | A transformer is a deep learning architecture suffix", "prefix Cuisine | I love cheese suffix",
] "prefix ML | A transformer is a deep learning architecture suffix",
],
"truncate": True,
"normalize": False,
},
task="feature-extraction",
) )
documents_with_embeddings = result["documents"] documents_with_embeddings = result["documents"]
@ -251,7 +332,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
] ]
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEIDocumentEmbedder( embedder = HuggingFaceTEIDocumentEmbedder(

View File

@ -16,8 +16,8 @@ def mock_check_valid_model():
yield mock yield mock
def mock_embedding_generation(text, **kwargs): def mock_embedding_generation(json, **kwargs):
response = np.array([np.random.rand(384) for i in range(len(text))]) response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
return response return response
@ -31,6 +31,8 @@ class TestHuggingFaceTEITextEmbedder:
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == "" assert embedder.prefix == ""
assert embedder.suffix == "" assert embedder.suffix == ""
assert embedder.truncate is True
assert embedder.normalize is False
def test_init_with_parameters(self, mock_check_valid_model): def test_init_with_parameters(self, mock_check_valid_model):
embedder = HuggingFaceTEITextEmbedder( embedder = HuggingFaceTEITextEmbedder(
@ -39,6 +41,8 @@ class TestHuggingFaceTEITextEmbedder:
token=Secret.from_token("fake-api-token"), token=Secret.from_token("fake-api-token"),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
truncate=False,
normalize=True,
) )
assert embedder.model == "sentence-transformers/all-mpnet-base-v2" assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
@ -46,6 +50,8 @@ class TestHuggingFaceTEITextEmbedder:
assert embedder.token == Secret.from_token("fake-api-token") assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix" assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix" assert embedder.suffix == "suffix"
assert embedder.truncate is False
assert embedder.normalize is True
def test_initialize_with_invalid_url(self, mock_check_valid_model): def test_initialize_with_invalid_url(self, mock_check_valid_model):
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -69,9 +75,35 @@ class TestHuggingFaceTEITextEmbedder:
"url": None, "url": None,
"prefix": "", "prefix": "",
"suffix": "", "suffix": "",
"truncate": True,
"normalize": False,
}, },
} }
def test_from_dict(self, mock_check_valid_model):
data = {
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
"init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"model": "BAAI/bge-small-en-v1.5",
"url": None,
"prefix": "",
"suffix": "",
"truncate": True,
"normalize": False,
},
}
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
assert embedder.model == "BAAI/bge-small-en-v1.5"
assert embedder.url is None
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.truncate is True
assert embedder.normalize is False
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
component = HuggingFaceTEITextEmbedder( component = HuggingFaceTEITextEmbedder(
model="sentence-transformers/all-mpnet-base-v2", model="sentence-transformers/all-mpnet-base-v2",
@ -79,6 +111,8 @@ class TestHuggingFaceTEITextEmbedder:
token=Secret.from_env_var("ENV_VAR", strict=False), token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
truncate=False,
normalize=True,
) )
data = component.to_dict() data = component.to_dict()
@ -91,11 +125,37 @@ class TestHuggingFaceTEITextEmbedder:
"url": "https://some_embedding_model.com", "url": "https://some_embedding_model.com",
"prefix": "prefix", "prefix": "prefix",
"suffix": "suffix", "suffix": "suffix",
"truncate": False,
"normalize": True,
}, },
} }
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
data = {
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
"init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "sentence-transformers/all-mpnet-base-v2",
"url": "https://some_embedding_model.com",
"prefix": "prefix",
"suffix": "suffix",
"truncate": False,
"normalize": True,
},
}
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
assert embedder.url == "https://some_embedding_model.com"
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.truncate is False
assert embedder.normalize is True
def test_run(self, mock_check_valid_model): def test_run(self, mock_check_valid_model):
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEITextEmbedder( embedder = HuggingFaceTEITextEmbedder(
@ -107,7 +167,10 @@ class TestHuggingFaceTEITextEmbedder:
result = embedder.run(text="The food was delicious") result = embedder.run(text="The food was delicious")
mock_embedding_patch.assert_called_once_with(text=["prefix The food was delicious suffix"]) mock_embedding_patch.assert_called_once_with(
json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False},
task="feature-extraction",
)
assert len(result["embedding"]) == 384 assert len(result["embedding"]) == 384
assert all(isinstance(x, float) for x in result["embedding"]) assert all(isinstance(x, float) for x in result["embedding"])