Add truncate and normalize parameters to TEI Embedders (#7460)

This commit is contained in:
Ashwin Mathur 2024-04-03 20:11:30 +05:30 committed by GitHub
parent 1ce12c7a6a
commit 1c7d1618d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 204 additions and 18 deletions

View File

@ -1,3 +1,4 @@
import json
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
@ -50,6 +51,8 @@ class HuggingFaceTEIDocumentEmbedder:
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "",
suffix: str = "",
truncate: bool = True,
normalize: bool = False,
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
@ -70,6 +73,15 @@ class HuggingFaceTEIDocumentEmbedder:
A string to add at the beginning of each text.
:param suffix:
A string to add at the end of each text.
:param truncate:
Truncate input text from the end to the maximum length supported by the model. This option is only available
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
without TEI.
:param normalize:
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
:param batch_size:
Number of Documents to encode at once.
:param progress_bar:
@ -95,6 +107,8 @@ class HuggingFaceTEIDocumentEmbedder:
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.prefix = prefix
self.suffix = suffix
self.truncate = truncate
self.normalize = normalize
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
@ -113,6 +127,8 @@ class HuggingFaceTEIDocumentEmbedder:
url=self.url,
prefix=self.prefix,
suffix=self.suffix,
truncate=self.truncate,
normalize=self.normalize,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
@ -167,8 +183,12 @@ class HuggingFaceTEIDocumentEmbedder:
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
embeddings = self.client.feature_extraction(text=batch)
all_embeddings.extend(embeddings.tolist())
response = self.client.post(
json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize},
task="feature-extraction",
)
embeddings = json.loads(response.decode())
all_embeddings.extend(embeddings)
return all_embeddings

View File

@ -1,3 +1,4 @@
import json
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
@ -45,6 +46,8 @@ class HuggingFaceTEITextEmbedder:
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "",
suffix: str = "",
truncate: bool = True,
normalize: bool = False,
):
"""
Create an HuggingFaceTEITextEmbedder component.
@ -61,6 +64,15 @@ class HuggingFaceTEITextEmbedder:
A string to add at the beginning of each text.
:param suffix:
A string to add at the end of each text.
:param truncate:
Truncate input text from the end to the maximum length supported by the model. This option is only available
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
without TEI.
:param normalize:
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
"""
huggingface_hub_import.check()
@ -78,6 +90,8 @@ class HuggingFaceTEITextEmbedder:
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.prefix = prefix
self.suffix = suffix
self.truncate = truncate
self.normalize = normalize
def to_dict(self) -> Dict[str, Any]:
"""
@ -93,6 +107,8 @@ class HuggingFaceTEITextEmbedder:
prefix=self.prefix,
suffix=self.suffix,
token=self.token.to_dict() if self.token else None,
truncate=self.truncate,
normalize=self.normalize,
)
@classmethod
@ -135,8 +151,10 @@ class HuggingFaceTEITextEmbedder:
text_to_embed = self.prefix + text + self.suffix
embeddings = self.client.feature_extraction(text=[text_to_embed])
# The client returns a numpy array
embedding = embeddings.tolist()[0]
response = self.client.post(
json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize},
task="feature-extraction",
)
embedding = json.loads(response.decode())[0]
return {"embedding": embedding}

View File

@ -0,0 +1,4 @@
---
features:
- |
Adds `truncate` and `normalize` parameters to `HuggingFaceTEITextEmbedder` and `HuggingFaceTEITextEmbedder` for allowing truncation and normalization of embeddings.

View File

@ -18,8 +18,8 @@ def mock_check_valid_model():
yield mock
def mock_embedding_generation(text, **kwargs):
response = np.array([np.random.rand(384) for i in range(len(text))])
def mock_embedding_generation(json, **kwargs):
response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
return response
@ -33,6 +33,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.truncate is True
assert embedder.normalize is False
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
@ -45,6 +47,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
token=Secret.from_token("fake-api-token"),
prefix="prefix",
suffix="suffix",
truncate=False,
normalize=True,
batch_size=64,
progress_bar=False,
meta_fields_to_embed=["test_field"],
@ -56,6 +60,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.truncate is False
assert embedder.normalize is True
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
@ -83,6 +89,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
"url": None,
"prefix": "",
"suffix": "",
"truncate": True,
"normalize": False,
"batch_size": 32,
"progress_bar": True,
"meta_fields_to_embed": [],
@ -90,6 +98,38 @@ class TestHuggingFaceTEIDocumentEmbedder:
},
}
def test_from_dict(self, mock_check_valid_model):
data = {
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"url": None,
"prefix": "",
"suffix": "",
"truncate": True,
"normalize": False,
"batch_size": 32,
"progress_bar": True,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
},
}
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
assert embedder.model == "BAAI/bge-small-en-v1.5"
assert embedder.url is None
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.truncate is True
assert embedder.normalize is False
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
component = HuggingFaceTEIDocumentEmbedder(
model="sentence-transformers/all-mpnet-base-v2",
@ -97,6 +137,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix",
suffix="suffix",
truncate=False,
normalize=True,
batch_size=64,
progress_bar=False,
meta_fields_to_embed=["test_field"],
@ -113,6 +155,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
"url": "https://some_embedding_model.com",
"prefix": "prefix",
"suffix": "suffix",
"truncate": False,
"normalize": True,
"batch_size": 64,
"progress_bar": False,
"meta_fields_to_embed": ["test_field"],
@ -120,6 +164,38 @@ class TestHuggingFaceTEIDocumentEmbedder:
},
}
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
data = {
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
"init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "sentence-transformers/all-mpnet-base-v2",
"url": "https://some_embedding_model.com",
"prefix": "prefix",
"suffix": "suffix",
"truncate": False,
"normalize": True,
"batch_size": 64,
"progress_bar": False,
"meta_fields_to_embed": ["test_field"],
"embedding_separator": " | ",
},
}
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
assert embedder.url == "https://some_embedding_model.com"
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.truncate is False
assert embedder.normalize is True
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model):
documents = [
Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
@ -167,7 +243,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
def test_embed_batch(self, mock_check_valid_model):
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEIDocumentEmbedder(
@ -192,7 +268,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEIDocumentEmbedder(
@ -207,10 +283,15 @@ class TestHuggingFaceTEIDocumentEmbedder:
result = embedder.run(documents=docs)
mock_embedding_patch.assert_called_once_with(
text=[
json={
"inputs": [
"prefix Cuisine | I love cheese suffix",
"prefix ML | A transformer is a deep learning architecture suffix",
]
],
"truncate": True,
"normalize": False,
},
task="feature-extraction",
)
documents_with_embeddings = result["documents"]
@ -251,7 +332,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEIDocumentEmbedder(

View File

@ -16,8 +16,8 @@ def mock_check_valid_model():
yield mock
def mock_embedding_generation(text, **kwargs):
response = np.array([np.random.rand(384) for i in range(len(text))])
def mock_embedding_generation(json, **kwargs):
response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
return response
@ -31,6 +31,8 @@ class TestHuggingFaceTEITextEmbedder:
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.truncate is True
assert embedder.normalize is False
def test_init_with_parameters(self, mock_check_valid_model):
embedder = HuggingFaceTEITextEmbedder(
@ -39,6 +41,8 @@ class TestHuggingFaceTEITextEmbedder:
token=Secret.from_token("fake-api-token"),
prefix="prefix",
suffix="suffix",
truncate=False,
normalize=True,
)
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
@ -46,6 +50,8 @@ class TestHuggingFaceTEITextEmbedder:
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.truncate is False
assert embedder.normalize is True
def test_initialize_with_invalid_url(self, mock_check_valid_model):
with pytest.raises(ValueError):
@ -69,9 +75,35 @@ class TestHuggingFaceTEITextEmbedder:
"url": None,
"prefix": "",
"suffix": "",
"truncate": True,
"normalize": False,
},
}
def test_from_dict(self, mock_check_valid_model):
data = {
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
"init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"model": "BAAI/bge-small-en-v1.5",
"url": None,
"prefix": "",
"suffix": "",
"truncate": True,
"normalize": False,
},
}
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
assert embedder.model == "BAAI/bge-small-en-v1.5"
assert embedder.url is None
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.truncate is True
assert embedder.normalize is False
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
component = HuggingFaceTEITextEmbedder(
model="sentence-transformers/all-mpnet-base-v2",
@ -79,6 +111,8 @@ class TestHuggingFaceTEITextEmbedder:
token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix",
suffix="suffix",
truncate=False,
normalize=True,
)
data = component.to_dict()
@ -91,11 +125,37 @@ class TestHuggingFaceTEITextEmbedder:
"url": "https://some_embedding_model.com",
"prefix": "prefix",
"suffix": "suffix",
"truncate": False,
"normalize": True,
},
}
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
data = {
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
"init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "sentence-transformers/all-mpnet-base-v2",
"url": "https://some_embedding_model.com",
"prefix": "prefix",
"suffix": "suffix",
"truncate": False,
"normalize": True,
},
}
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
assert embedder.url == "https://some_embedding_model.com"
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.truncate is False
assert embedder.normalize is True
def test_run(self, mock_check_valid_model):
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEITextEmbedder(
@ -107,7 +167,10 @@ class TestHuggingFaceTEITextEmbedder:
result = embedder.run(text="The food was delicious")
mock_embedding_patch.assert_called_once_with(text=["prefix The food was delicious suffix"])
mock_embedding_patch.assert_called_once_with(
json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False},
task="feature-extraction",
)
assert len(result["embedding"]) == 384
assert all(isinstance(x, float) for x in result["embedding"])