mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-17 20:23:29 +00:00
Add truncate and normalize parameters to TEI Embedders (#7460)
This commit is contained in:
parent
1ce12c7a6a
commit
1c7d1618d8
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@ -50,6 +51,8 @@ class HuggingFaceTEIDocumentEmbedder:
|
|||||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
suffix: str = "",
|
suffix: str = "",
|
||||||
|
truncate: bool = True,
|
||||||
|
normalize: bool = False,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
progress_bar: bool = True,
|
progress_bar: bool = True,
|
||||||
meta_fields_to_embed: Optional[List[str]] = None,
|
meta_fields_to_embed: Optional[List[str]] = None,
|
||||||
@ -70,6 +73,15 @@ class HuggingFaceTEIDocumentEmbedder:
|
|||||||
A string to add at the beginning of each text.
|
A string to add at the beginning of each text.
|
||||||
:param suffix:
|
:param suffix:
|
||||||
A string to add at the end of each text.
|
A string to add at the end of each text.
|
||||||
|
:param truncate:
|
||||||
|
Truncate input text from the end to the maximum length supported by the model. This option is only available
|
||||||
|
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
|
||||||
|
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
|
||||||
|
without TEI.
|
||||||
|
:param normalize:
|
||||||
|
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
|
||||||
|
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
|
||||||
|
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
|
||||||
:param batch_size:
|
:param batch_size:
|
||||||
Number of Documents to encode at once.
|
Number of Documents to encode at once.
|
||||||
:param progress_bar:
|
:param progress_bar:
|
||||||
@ -95,6 +107,8 @@ class HuggingFaceTEIDocumentEmbedder:
|
|||||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.suffix = suffix
|
self.suffix = suffix
|
||||||
|
self.truncate = truncate
|
||||||
|
self.normalize = normalize
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.progress_bar = progress_bar
|
self.progress_bar = progress_bar
|
||||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||||
@ -113,6 +127,8 @@ class HuggingFaceTEIDocumentEmbedder:
|
|||||||
url=self.url,
|
url=self.url,
|
||||||
prefix=self.prefix,
|
prefix=self.prefix,
|
||||||
suffix=self.suffix,
|
suffix=self.suffix,
|
||||||
|
truncate=self.truncate,
|
||||||
|
normalize=self.normalize,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
progress_bar=self.progress_bar,
|
progress_bar=self.progress_bar,
|
||||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||||
@ -167,8 +183,12 @@ class HuggingFaceTEIDocumentEmbedder:
|
|||||||
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
||||||
):
|
):
|
||||||
batch = texts_to_embed[i : i + batch_size]
|
batch = texts_to_embed[i : i + batch_size]
|
||||||
embeddings = self.client.feature_extraction(text=batch)
|
response = self.client.post(
|
||||||
all_embeddings.extend(embeddings.tolist())
|
json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize},
|
||||||
|
task="feature-extraction",
|
||||||
|
)
|
||||||
|
embeddings = json.loads(response.decode())
|
||||||
|
all_embeddings.extend(embeddings)
|
||||||
|
|
||||||
return all_embeddings
|
return all_embeddings
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@ -45,6 +46,8 @@ class HuggingFaceTEITextEmbedder:
|
|||||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
suffix: str = "",
|
suffix: str = "",
|
||||||
|
truncate: bool = True,
|
||||||
|
normalize: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create an HuggingFaceTEITextEmbedder component.
|
Create an HuggingFaceTEITextEmbedder component.
|
||||||
@ -61,6 +64,15 @@ class HuggingFaceTEITextEmbedder:
|
|||||||
A string to add at the beginning of each text.
|
A string to add at the beginning of each text.
|
||||||
:param suffix:
|
:param suffix:
|
||||||
A string to add at the end of each text.
|
A string to add at the end of each text.
|
||||||
|
:param truncate:
|
||||||
|
Truncate input text from the end to the maximum length supported by the model. This option is only available
|
||||||
|
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
|
||||||
|
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
|
||||||
|
without TEI.
|
||||||
|
:param normalize:
|
||||||
|
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
|
||||||
|
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
|
||||||
|
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
|
||||||
"""
|
"""
|
||||||
huggingface_hub_import.check()
|
huggingface_hub_import.check()
|
||||||
|
|
||||||
@ -78,6 +90,8 @@ class HuggingFaceTEITextEmbedder:
|
|||||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.suffix = suffix
|
self.suffix = suffix
|
||||||
|
self.truncate = truncate
|
||||||
|
self.normalize = normalize
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -93,6 +107,8 @@ class HuggingFaceTEITextEmbedder:
|
|||||||
prefix=self.prefix,
|
prefix=self.prefix,
|
||||||
suffix=self.suffix,
|
suffix=self.suffix,
|
||||||
token=self.token.to_dict() if self.token else None,
|
token=self.token.to_dict() if self.token else None,
|
||||||
|
truncate=self.truncate,
|
||||||
|
normalize=self.normalize,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -135,8 +151,10 @@ class HuggingFaceTEITextEmbedder:
|
|||||||
|
|
||||||
text_to_embed = self.prefix + text + self.suffix
|
text_to_embed = self.prefix + text + self.suffix
|
||||||
|
|
||||||
embeddings = self.client.feature_extraction(text=[text_to_embed])
|
response = self.client.post(
|
||||||
# The client returns a numpy array
|
json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize},
|
||||||
embedding = embeddings.tolist()[0]
|
task="feature-extraction",
|
||||||
|
)
|
||||||
|
embedding = json.loads(response.decode())[0]
|
||||||
|
|
||||||
return {"embedding": embedding}
|
return {"embedding": embedding}
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Adds `truncate` and `normalize` parameters to `HuggingFaceTEITextEmbedder` and `HuggingFaceTEITextEmbedder` for allowing truncation and normalization of embeddings.
|
@ -18,8 +18,8 @@ def mock_check_valid_model():
|
|||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
def mock_embedding_generation(text, **kwargs):
|
def mock_embedding_generation(json, **kwargs):
|
||||||
response = np.array([np.random.rand(384) for i in range(len(text))])
|
response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@ -33,6 +33,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||||
assert embedder.prefix == ""
|
assert embedder.prefix == ""
|
||||||
assert embedder.suffix == ""
|
assert embedder.suffix == ""
|
||||||
|
assert embedder.truncate is True
|
||||||
|
assert embedder.normalize is False
|
||||||
assert embedder.batch_size == 32
|
assert embedder.batch_size == 32
|
||||||
assert embedder.progress_bar is True
|
assert embedder.progress_bar is True
|
||||||
assert embedder.meta_fields_to_embed == []
|
assert embedder.meta_fields_to_embed == []
|
||||||
@ -45,6 +47,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
token=Secret.from_token("fake-api-token"),
|
token=Secret.from_token("fake-api-token"),
|
||||||
prefix="prefix",
|
prefix="prefix",
|
||||||
suffix="suffix",
|
suffix="suffix",
|
||||||
|
truncate=False,
|
||||||
|
normalize=True,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
progress_bar=False,
|
progress_bar=False,
|
||||||
meta_fields_to_embed=["test_field"],
|
meta_fields_to_embed=["test_field"],
|
||||||
@ -56,6 +60,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
assert embedder.token == Secret.from_token("fake-api-token")
|
assert embedder.token == Secret.from_token("fake-api-token")
|
||||||
assert embedder.prefix == "prefix"
|
assert embedder.prefix == "prefix"
|
||||||
assert embedder.suffix == "suffix"
|
assert embedder.suffix == "suffix"
|
||||||
|
assert embedder.truncate is False
|
||||||
|
assert embedder.normalize is True
|
||||||
assert embedder.batch_size == 64
|
assert embedder.batch_size == 64
|
||||||
assert embedder.progress_bar is False
|
assert embedder.progress_bar is False
|
||||||
assert embedder.meta_fields_to_embed == ["test_field"]
|
assert embedder.meta_fields_to_embed == ["test_field"]
|
||||||
@ -83,6 +89,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
"url": None,
|
"url": None,
|
||||||
"prefix": "",
|
"prefix": "",
|
||||||
"suffix": "",
|
"suffix": "",
|
||||||
|
"truncate": True,
|
||||||
|
"normalize": False,
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"progress_bar": True,
|
"progress_bar": True,
|
||||||
"meta_fields_to_embed": [],
|
"meta_fields_to_embed": [],
|
||||||
@ -90,6 +98,38 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def test_from_dict(self, mock_check_valid_model):
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
||||||
|
"init_parameters": {
|
||||||
|
"model": "BAAI/bge-small-en-v1.5",
|
||||||
|
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||||
|
"url": None,
|
||||||
|
"prefix": "",
|
||||||
|
"suffix": "",
|
||||||
|
"truncate": True,
|
||||||
|
"normalize": False,
|
||||||
|
"batch_size": 32,
|
||||||
|
"progress_bar": True,
|
||||||
|
"meta_fields_to_embed": [],
|
||||||
|
"embedding_separator": "\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
|
||||||
|
|
||||||
|
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
||||||
|
assert embedder.url is None
|
||||||
|
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||||
|
assert embedder.prefix == ""
|
||||||
|
assert embedder.suffix == ""
|
||||||
|
assert embedder.truncate is True
|
||||||
|
assert embedder.normalize is False
|
||||||
|
assert embedder.batch_size == 32
|
||||||
|
assert embedder.progress_bar is True
|
||||||
|
assert embedder.meta_fields_to_embed == []
|
||||||
|
assert embedder.embedding_separator == "\n"
|
||||||
|
|
||||||
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||||
component = HuggingFaceTEIDocumentEmbedder(
|
component = HuggingFaceTEIDocumentEmbedder(
|
||||||
model="sentence-transformers/all-mpnet-base-v2",
|
model="sentence-transformers/all-mpnet-base-v2",
|
||||||
@ -97,6 +137,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||||
prefix="prefix",
|
prefix="prefix",
|
||||||
suffix="suffix",
|
suffix="suffix",
|
||||||
|
truncate=False,
|
||||||
|
normalize=True,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
progress_bar=False,
|
progress_bar=False,
|
||||||
meta_fields_to_embed=["test_field"],
|
meta_fields_to_embed=["test_field"],
|
||||||
@ -113,6 +155,8 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
"url": "https://some_embedding_model.com",
|
"url": "https://some_embedding_model.com",
|
||||||
"prefix": "prefix",
|
"prefix": "prefix",
|
||||||
"suffix": "suffix",
|
"suffix": "suffix",
|
||||||
|
"truncate": False,
|
||||||
|
"normalize": True,
|
||||||
"batch_size": 64,
|
"batch_size": 64,
|
||||||
"progress_bar": False,
|
"progress_bar": False,
|
||||||
"meta_fields_to_embed": ["test_field"],
|
"meta_fields_to_embed": ["test_field"],
|
||||||
@ -120,6 +164,38 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
||||||
|
"init_parameters": {
|
||||||
|
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||||
|
"model": "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
"url": "https://some_embedding_model.com",
|
||||||
|
"prefix": "prefix",
|
||||||
|
"suffix": "suffix",
|
||||||
|
"truncate": False,
|
||||||
|
"normalize": True,
|
||||||
|
"batch_size": 64,
|
||||||
|
"progress_bar": False,
|
||||||
|
"meta_fields_to_embed": ["test_field"],
|
||||||
|
"embedding_separator": " | ",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
|
||||||
|
|
||||||
|
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||||
|
assert embedder.url == "https://some_embedding_model.com"
|
||||||
|
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
|
||||||
|
assert embedder.prefix == "prefix"
|
||||||
|
assert embedder.suffix == "suffix"
|
||||||
|
assert embedder.truncate is False
|
||||||
|
assert embedder.normalize is True
|
||||||
|
assert embedder.batch_size == 64
|
||||||
|
assert embedder.progress_bar is False
|
||||||
|
assert embedder.meta_fields_to_embed == ["test_field"]
|
||||||
|
assert embedder.embedding_separator == " | "
|
||||||
|
|
||||||
def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model):
|
def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model):
|
||||||
documents = [
|
documents = [
|
||||||
Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
|
Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
|
||||||
@ -167,7 +243,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
def test_embed_batch(self, mock_check_valid_model):
|
def test_embed_batch(self, mock_check_valid_model):
|
||||||
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
|
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
|
||||||
|
|
||||||
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
|
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||||
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||||
@ -192,7 +268,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
|
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||||
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||||
@ -207,10 +283,15 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
result = embedder.run(documents=docs)
|
result = embedder.run(documents=docs)
|
||||||
|
|
||||||
mock_embedding_patch.assert_called_once_with(
|
mock_embedding_patch.assert_called_once_with(
|
||||||
text=[
|
json={
|
||||||
"prefix Cuisine | I love cheese suffix",
|
"inputs": [
|
||||||
"prefix ML | A transformer is a deep learning architecture suffix",
|
"prefix Cuisine | I love cheese suffix",
|
||||||
]
|
"prefix ML | A transformer is a deep learning architecture suffix",
|
||||||
|
],
|
||||||
|
"truncate": True,
|
||||||
|
"normalize": False,
|
||||||
|
},
|
||||||
|
task="feature-extraction",
|
||||||
)
|
)
|
||||||
documents_with_embeddings = result["documents"]
|
documents_with_embeddings = result["documents"]
|
||||||
|
|
||||||
@ -251,7 +332,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
|||||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
|
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||||
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||||
|
@ -16,8 +16,8 @@ def mock_check_valid_model():
|
|||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
def mock_embedding_generation(text, **kwargs):
|
def mock_embedding_generation(json, **kwargs):
|
||||||
response = np.array([np.random.rand(384) for i in range(len(text))])
|
response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +31,8 @@ class TestHuggingFaceTEITextEmbedder:
|
|||||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||||
assert embedder.prefix == ""
|
assert embedder.prefix == ""
|
||||||
assert embedder.suffix == ""
|
assert embedder.suffix == ""
|
||||||
|
assert embedder.truncate is True
|
||||||
|
assert embedder.normalize is False
|
||||||
|
|
||||||
def test_init_with_parameters(self, mock_check_valid_model):
|
def test_init_with_parameters(self, mock_check_valid_model):
|
||||||
embedder = HuggingFaceTEITextEmbedder(
|
embedder = HuggingFaceTEITextEmbedder(
|
||||||
@ -39,6 +41,8 @@ class TestHuggingFaceTEITextEmbedder:
|
|||||||
token=Secret.from_token("fake-api-token"),
|
token=Secret.from_token("fake-api-token"),
|
||||||
prefix="prefix",
|
prefix="prefix",
|
||||||
suffix="suffix",
|
suffix="suffix",
|
||||||
|
truncate=False,
|
||||||
|
normalize=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||||
@ -46,6 +50,8 @@ class TestHuggingFaceTEITextEmbedder:
|
|||||||
assert embedder.token == Secret.from_token("fake-api-token")
|
assert embedder.token == Secret.from_token("fake-api-token")
|
||||||
assert embedder.prefix == "prefix"
|
assert embedder.prefix == "prefix"
|
||||||
assert embedder.suffix == "suffix"
|
assert embedder.suffix == "suffix"
|
||||||
|
assert embedder.truncate is False
|
||||||
|
assert embedder.normalize is True
|
||||||
|
|
||||||
def test_initialize_with_invalid_url(self, mock_check_valid_model):
|
def test_initialize_with_invalid_url(self, mock_check_valid_model):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -69,9 +75,35 @@ class TestHuggingFaceTEITextEmbedder:
|
|||||||
"url": None,
|
"url": None,
|
||||||
"prefix": "",
|
"prefix": "",
|
||||||
"suffix": "",
|
"suffix": "",
|
||||||
|
"truncate": True,
|
||||||
|
"normalize": False,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def test_from_dict(self, mock_check_valid_model):
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
||||||
|
"init_parameters": {
|
||||||
|
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||||
|
"model": "BAAI/bge-small-en-v1.5",
|
||||||
|
"url": None,
|
||||||
|
"prefix": "",
|
||||||
|
"suffix": "",
|
||||||
|
"truncate": True,
|
||||||
|
"normalize": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
|
||||||
|
|
||||||
|
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
||||||
|
assert embedder.url is None
|
||||||
|
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||||
|
assert embedder.prefix == ""
|
||||||
|
assert embedder.suffix == ""
|
||||||
|
assert embedder.truncate is True
|
||||||
|
assert embedder.normalize is False
|
||||||
|
|
||||||
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||||
component = HuggingFaceTEITextEmbedder(
|
component = HuggingFaceTEITextEmbedder(
|
||||||
model="sentence-transformers/all-mpnet-base-v2",
|
model="sentence-transformers/all-mpnet-base-v2",
|
||||||
@ -79,6 +111,8 @@ class TestHuggingFaceTEITextEmbedder:
|
|||||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||||
prefix="prefix",
|
prefix="prefix",
|
||||||
suffix="suffix",
|
suffix="suffix",
|
||||||
|
truncate=False,
|
||||||
|
normalize=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
@ -91,11 +125,37 @@ class TestHuggingFaceTEITextEmbedder:
|
|||||||
"url": "https://some_embedding_model.com",
|
"url": "https://some_embedding_model.com",
|
||||||
"prefix": "prefix",
|
"prefix": "prefix",
|
||||||
"suffix": "suffix",
|
"suffix": "suffix",
|
||||||
|
"truncate": False,
|
||||||
|
"normalize": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
||||||
|
"init_parameters": {
|
||||||
|
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||||
|
"model": "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
"url": "https://some_embedding_model.com",
|
||||||
|
"prefix": "prefix",
|
||||||
|
"suffix": "suffix",
|
||||||
|
"truncate": False,
|
||||||
|
"normalize": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
|
||||||
|
|
||||||
|
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||||
|
assert embedder.url == "https://some_embedding_model.com"
|
||||||
|
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
|
||||||
|
assert embedder.prefix == "prefix"
|
||||||
|
assert embedder.suffix == "suffix"
|
||||||
|
assert embedder.truncate is False
|
||||||
|
assert embedder.normalize is True
|
||||||
|
|
||||||
def test_run(self, mock_check_valid_model):
|
def test_run(self, mock_check_valid_model):
|
||||||
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
|
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||||
|
|
||||||
embedder = HuggingFaceTEITextEmbedder(
|
embedder = HuggingFaceTEITextEmbedder(
|
||||||
@ -107,7 +167,10 @@ class TestHuggingFaceTEITextEmbedder:
|
|||||||
|
|
||||||
result = embedder.run(text="The food was delicious")
|
result = embedder.run(text="The food was delicious")
|
||||||
|
|
||||||
mock_embedding_patch.assert_called_once_with(text=["prefix The food was delicious suffix"])
|
mock_embedding_patch.assert_called_once_with(
|
||||||
|
json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False},
|
||||||
|
task="feature-extraction",
|
||||||
|
)
|
||||||
|
|
||||||
assert len(result["embedding"]) == 384
|
assert len(result["embedding"]) == 384
|
||||||
assert all(isinstance(x, float) for x in result["embedding"])
|
assert all(isinstance(x, float) for x in result["embedding"])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user