mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-17 12:13:35 +00:00
Add truncate and normalize parameters to TEI Embedders (#7460)
This commit is contained in:
parent
1ce12c7a6a
commit
1c7d1618d8
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@ -50,6 +51,8 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
truncate: bool = True,
|
||||
normalize: bool = False,
|
||||
batch_size: int = 32,
|
||||
progress_bar: bool = True,
|
||||
meta_fields_to_embed: Optional[List[str]] = None,
|
||||
@ -70,6 +73,15 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
A string to add at the beginning of each text.
|
||||
:param suffix:
|
||||
A string to add at the end of each text.
|
||||
:param truncate:
|
||||
Truncate input text from the end to the maximum length supported by the model. This option is only available
|
||||
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
|
||||
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
|
||||
without TEI.
|
||||
:param normalize:
|
||||
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
|
||||
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
|
||||
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
|
||||
:param batch_size:
|
||||
Number of Documents to encode at once.
|
||||
:param progress_bar:
|
||||
@ -95,6 +107,8 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.truncate = truncate
|
||||
self.normalize = normalize
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||
@ -113,6 +127,8 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
url=self.url,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
truncate=self.truncate,
|
||||
normalize=self.normalize,
|
||||
batch_size=self.batch_size,
|
||||
progress_bar=self.progress_bar,
|
||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||
@ -167,8 +183,12 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
||||
):
|
||||
batch = texts_to_embed[i : i + batch_size]
|
||||
embeddings = self.client.feature_extraction(text=batch)
|
||||
all_embeddings.extend(embeddings.tolist())
|
||||
response = self.client.post(
|
||||
json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize},
|
||||
task="feature-extraction",
|
||||
)
|
||||
embeddings = json.loads(response.decode())
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@ -45,6 +46,8 @@ class HuggingFaceTEITextEmbedder:
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
truncate: bool = True,
|
||||
normalize: bool = False,
|
||||
):
|
||||
"""
|
||||
Create an HuggingFaceTEITextEmbedder component.
|
||||
@ -61,6 +64,15 @@ class HuggingFaceTEITextEmbedder:
|
||||
A string to add at the beginning of each text.
|
||||
:param suffix:
|
||||
A string to add at the end of each text.
|
||||
:param truncate:
|
||||
Truncate input text from the end to the maximum length supported by the model. This option is only available
|
||||
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
|
||||
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
|
||||
without TEI.
|
||||
:param normalize:
|
||||
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
|
||||
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
|
||||
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
|
||||
"""
|
||||
huggingface_hub_import.check()
|
||||
|
||||
@ -78,6 +90,8 @@ class HuggingFaceTEITextEmbedder:
|
||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.truncate = truncate
|
||||
self.normalize = normalize
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -93,6 +107,8 @@ class HuggingFaceTEITextEmbedder:
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
truncate=self.truncate,
|
||||
normalize=self.normalize,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -135,8 +151,10 @@ class HuggingFaceTEITextEmbedder:
|
||||
|
||||
text_to_embed = self.prefix + text + self.suffix
|
||||
|
||||
embeddings = self.client.feature_extraction(text=[text_to_embed])
|
||||
# The client returns a numpy array
|
||||
embedding = embeddings.tolist()[0]
|
||||
response = self.client.post(
|
||||
json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize},
|
||||
task="feature-extraction",
|
||||
)
|
||||
embedding = json.loads(response.decode())[0]
|
||||
|
||||
return {"embedding": embedding}
|
||||
|
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Adds `truncate` and `normalize` parameters to `HuggingFaceTEITextEmbedder` and `HuggingFaceTEITextEmbedder` for allowing truncation and normalization of embeddings.
|
@ -18,8 +18,8 @@ def mock_check_valid_model():
|
||||
yield mock
|
||||
|
||||
|
||||
def mock_embedding_generation(text, **kwargs):
|
||||
response = np.array([np.random.rand(384) for i in range(len(text))])
|
||||
def mock_embedding_generation(json, **kwargs):
|
||||
response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
|
||||
return response
|
||||
|
||||
|
||||
@ -33,6 +33,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate is True
|
||||
assert embedder.normalize is False
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar is True
|
||||
assert embedder.meta_fields_to_embed == []
|
||||
@ -45,6 +47,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
truncate=False,
|
||||
normalize=True,
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
meta_fields_to_embed=["test_field"],
|
||||
@ -56,6 +60,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
assert embedder.token == Secret.from_token("fake-api-token")
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.truncate is False
|
||||
assert embedder.normalize is True
|
||||
assert embedder.batch_size == 64
|
||||
assert embedder.progress_bar is False
|
||||
assert embedder.meta_fields_to_embed == ["test_field"]
|
||||
@ -83,6 +89,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
"url": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"truncate": True,
|
||||
"normalize": False,
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"meta_fields_to_embed": [],
|
||||
@ -90,6 +98,38 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self, mock_check_valid_model):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"url": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"truncate": True,
|
||||
"normalize": False,
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"meta_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
},
|
||||
}
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
|
||||
|
||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
||||
assert embedder.url is None
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate is True
|
||||
assert embedder.normalize is False
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar is True
|
||||
assert embedder.meta_fields_to_embed == []
|
||||
assert embedder.embedding_separator == "\n"
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||
component = HuggingFaceTEIDocumentEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
@ -97,6 +137,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
truncate=False,
|
||||
normalize=True,
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
meta_fields_to_embed=["test_field"],
|
||||
@ -113,6 +155,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
"url": "https://some_embedding_model.com",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"meta_fields_to_embed": ["test_field"],
|
||||
@ -120,6 +164,38 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
||||
"url": "https://some_embedding_model.com",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"meta_fields_to_embed": ["test_field"],
|
||||
"embedding_separator": " | ",
|
||||
},
|
||||
}
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
|
||||
|
||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||
assert embedder.url == "https://some_embedding_model.com"
|
||||
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.truncate is False
|
||||
assert embedder.normalize is True
|
||||
assert embedder.batch_size == 64
|
||||
assert embedder.progress_bar is False
|
||||
assert embedder.meta_fields_to_embed == ["test_field"]
|
||||
assert embedder.embedding_separator == " | "
|
||||
|
||||
def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model):
|
||||
documents = [
|
||||
Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
|
||||
@ -167,7 +243,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
def test_embed_batch(self, mock_check_valid_model):
|
||||
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
|
||||
|
||||
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
@ -192,7 +268,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
|
||||
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
@ -207,10 +283,15 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
result = embedder.run(documents=docs)
|
||||
|
||||
mock_embedding_patch.assert_called_once_with(
|
||||
text=[
|
||||
"prefix Cuisine | I love cheese suffix",
|
||||
"prefix ML | A transformer is a deep learning architecture suffix",
|
||||
]
|
||||
json={
|
||||
"inputs": [
|
||||
"prefix Cuisine | I love cheese suffix",
|
||||
"prefix ML | A transformer is a deep learning architecture suffix",
|
||||
],
|
||||
"truncate": True,
|
||||
"normalize": False,
|
||||
},
|
||||
task="feature-extraction",
|
||||
)
|
||||
documents_with_embeddings = result["documents"]
|
||||
|
||||
@ -251,7 +332,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
|
||||
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
|
@ -16,8 +16,8 @@ def mock_check_valid_model():
|
||||
yield mock
|
||||
|
||||
|
||||
def mock_embedding_generation(text, **kwargs):
|
||||
response = np.array([np.random.rand(384) for i in range(len(text))])
|
||||
def mock_embedding_generation(json, **kwargs):
|
||||
response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
|
||||
return response
|
||||
|
||||
|
||||
@ -31,6 +31,8 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate is True
|
||||
assert embedder.normalize is False
|
||||
|
||||
def test_init_with_parameters(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceTEITextEmbedder(
|
||||
@ -39,6 +41,8 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
truncate=False,
|
||||
normalize=True,
|
||||
)
|
||||
|
||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||
@ -46,6 +50,8 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
assert embedder.token == Secret.from_token("fake-api-token")
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.truncate is False
|
||||
assert embedder.normalize is True
|
||||
|
||||
def test_initialize_with_invalid_url(self, mock_check_valid_model):
|
||||
with pytest.raises(ValueError):
|
||||
@ -69,9 +75,35 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
"url": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"truncate": True,
|
||||
"normalize": False,
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self, mock_check_valid_model):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"url": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"truncate": True,
|
||||
"normalize": False,
|
||||
},
|
||||
}
|
||||
|
||||
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
|
||||
|
||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
||||
assert embedder.url is None
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate is True
|
||||
assert embedder.normalize is False
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||
component = HuggingFaceTEITextEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
@ -79,6 +111,8 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
truncate=False,
|
||||
normalize=True,
|
||||
)
|
||||
|
||||
data = component.to_dict()
|
||||
@ -91,11 +125,37 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
"url": "https://some_embedding_model.com",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
||||
"url": "https://some_embedding_model.com",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
},
|
||||
}
|
||||
|
||||
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
|
||||
|
||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||
assert embedder.url == "https://some_embedding_model.com"
|
||||
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.truncate is False
|
||||
assert embedder.normalize is True
|
||||
|
||||
def test_run(self, mock_check_valid_model):
|
||||
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceTEITextEmbedder(
|
||||
@ -107,7 +167,10 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
|
||||
result = embedder.run(text="The food was delicious")
|
||||
|
||||
mock_embedding_patch.assert_called_once_with(text=["prefix The food was delicious suffix"])
|
||||
mock_embedding_patch.assert_called_once_with(
|
||||
json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False},
|
||||
task="feature-extraction",
|
||||
)
|
||||
|
||||
assert len(result["embedding"]) == 384
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
|
Loading…
x
Reference in New Issue
Block a user