mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-03 18:36:04 +00:00
chore: remove deprecated TEI embedders (#7907)
* remove deprecated TEI embedders * rm from the embedders init * rm related tests
This commit is contained in:
parent
d1f8c0dcd6
commit
75ad76a7ce
@ -5,8 +5,6 @@ loaders:
|
||||
[
|
||||
"azure_document_embedder",
|
||||
"azure_text_embedder",
|
||||
"hugging_face_tei_document_embedder",
|
||||
"hugging_face_tei_text_embedder",
|
||||
"hugging_face_api_document_embedder",
|
||||
"hugging_face_api_text_embedder",
|
||||
"openai_document_embedder",
|
||||
|
||||
@ -6,16 +6,12 @@ from haystack.components.embedders.azure_document_embedder import AzureOpenAIDoc
|
||||
from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder
|
||||
from haystack.components.embedders.hugging_face_api_document_embedder import HuggingFaceAPIDocumentEmbedder
|
||||
from haystack.components.embedders.hugging_face_api_text_embedder import HuggingFaceAPITextEmbedder
|
||||
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
|
||||
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
|
||||
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
|
||||
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
|
||||
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
|
||||
from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceTEITextEmbedder",
|
||||
"HuggingFaceTEIDocumentEmbedder",
|
||||
"HuggingFaceAPITextEmbedder",
|
||||
"HuggingFaceAPIDocumentEmbedder",
|
||||
"SentenceTransformersTextEmbedder",
|
||||
|
||||
@ -1,231 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import HFModelType, check_valid_model
|
||||
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class HuggingFaceTEIDocumentEmbedder:
|
||||
"""
|
||||
A component for computing Document embeddings using HuggingFace Text-Embeddings-Inference endpoints.
|
||||
|
||||
This component can be used with embedding models hosted on Hugging Face Inference endpoints, the rate-limited
|
||||
Inference API tier, for embedding models hosted on [the paid inference endpoint](https://huggingface.co/inference-endpoints)
|
||||
and/or your own custom TEI endpoint.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.components.embedders import HuggingFaceTEIDocumentEmbedder
|
||||
from haystack.utils import Secret
|
||||
|
||||
doc = Document(content="I love pizza!")
|
||||
|
||||
document_embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5", token=Secret.from_token("<your-api-key>")
|
||||
)
|
||||
|
||||
result = document_embedder.run([doc])
|
||||
print(result["documents"][0].embedding)
|
||||
|
||||
# [0.017020374536514282, -0.023255806416273117, ...]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "BAAI/bge-small-en-v1.5",
|
||||
url: Optional[str] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
truncate: bool = True,
|
||||
normalize: bool = False,
|
||||
batch_size: int = 32,
|
||||
progress_bar: bool = True,
|
||||
meta_fields_to_embed: Optional[List[str]] = None,
|
||||
embedding_separator: str = "\n",
|
||||
):
|
||||
"""
|
||||
Create a HuggingFaceTEIDocumentEmbedder component.
|
||||
|
||||
:param model:
|
||||
ID of the model on HuggingFace Hub.
|
||||
:param url:
|
||||
The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference
|
||||
Endpoint.
|
||||
:param token:
|
||||
The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving
|
||||
a private or gated model.
|
||||
:param prefix:
|
||||
A string to add at the beginning of each text.
|
||||
:param suffix:
|
||||
A string to add at the end of each text.
|
||||
:param truncate:
|
||||
Truncate input text from the end to the maximum length supported by the model. This option is only available
|
||||
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
|
||||
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
|
||||
without TEI.
|
||||
:param normalize:
|
||||
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
|
||||
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
|
||||
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
|
||||
:param batch_size:
|
||||
Number of Documents to encode at once.
|
||||
:param progress_bar:
|
||||
If True shows a progress bar when running.
|
||||
:param meta_fields_to_embed:
|
||||
List of meta fields that will be embedded along with the Document text.
|
||||
:param embedding_separator:
|
||||
Separator used to concatenate the meta fields to the Document text.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`HuggingFaceTEIDocumentEmbedder` is deprecated and will be removed in Haystack 2.3.0."
|
||||
"Use `HuggingFaceAPIDocumentEmbedder` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
huggingface_hub_import.check()
|
||||
|
||||
if url:
|
||||
r = urlparse(url)
|
||||
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])
|
||||
if not is_valid_url:
|
||||
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
|
||||
|
||||
check_valid_model(model, HFModelType.EMBEDDING, token)
|
||||
|
||||
self.model = model
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.truncate = truncate
|
||||
self.normalize = normalize
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
model=self.model,
|
||||
url=self.url,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
truncate=self.truncate,
|
||||
normalize=self.normalize,
|
||||
batch_size=self.batch_size,
|
||||
progress_bar=self.progress_bar,
|
||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEIDocumentEmbedder":
|
||||
"""
|
||||
Deserializes the component from a dictionary.
|
||||
|
||||
:param data:
|
||||
Dictionary to deserialize from.
|
||||
:returns:
|
||||
Deserialized component.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
# Don't send URL as it is sensitive information
|
||||
return {"model": self.model}
|
||||
|
||||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
|
||||
"""
|
||||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
|
||||
"""
|
||||
texts_to_embed = []
|
||||
for doc in documents:
|
||||
meta_values_to_embed = [
|
||||
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
|
||||
]
|
||||
|
||||
text_to_embed = (
|
||||
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
|
||||
)
|
||||
|
||||
texts_to_embed.append(text_to_embed)
|
||||
return texts_to_embed
|
||||
|
||||
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[float]]:
|
||||
"""
|
||||
Embed a list of texts in batches.
|
||||
"""
|
||||
|
||||
all_embeddings = []
|
||||
for i in tqdm(
|
||||
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
||||
):
|
||||
batch = texts_to_embed[i : i + batch_size]
|
||||
response = self.client.post(
|
||||
json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize},
|
||||
task="feature-extraction",
|
||||
)
|
||||
embeddings = json.loads(response.decode())
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: List[Document]):
|
||||
"""
|
||||
Embed a list of Documents.
|
||||
|
||||
:param documents:
|
||||
Documents to embed.
|
||||
|
||||
:returns:
|
||||
A dictionary with the following keys:
|
||||
- `documents`: Documents with embeddings
|
||||
"""
|
||||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
|
||||
raise TypeError(
|
||||
"HuggingFaceTEIDocumentEmbedder expects a list of Documents as input."
|
||||
" In case you want to embed a string, please use the HuggingFaceTEITextEmbedder."
|
||||
)
|
||||
|
||||
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
|
||||
|
||||
embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
|
||||
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
doc.embedding = emb
|
||||
|
||||
return {"documents": documents}
|
||||
@ -1,171 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import HFModelType, check_valid_model
|
||||
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class HuggingFaceTEITextEmbedder:
|
||||
"""
|
||||
A component for embedding strings using HuggingFace Text-Embeddings-Inference endpoints.
|
||||
|
||||
This component can be used with embedding models hosted on Hugging Face Inference endpoints, the rate-limited
|
||||
Inference API tier, for embedding models hosted on [the paid inference endpoint](https://huggingface.co/inference-endpoints)
|
||||
and/or your own custom TEI endpoint.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack.components.embedders import HuggingFaceTEITextEmbedder
|
||||
from haystack.utils import Secret
|
||||
|
||||
text_to_embed = "I love pizza!"
|
||||
|
||||
text_embedder = HuggingFaceTEITextEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5", token=Secret.from_token("<your-api-key>")
|
||||
)
|
||||
|
||||
print(text_embedder.run(text_to_embed))
|
||||
|
||||
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "BAAI/bge-small-en-v1.5",
|
||||
url: Optional[str] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
truncate: bool = True,
|
||||
normalize: bool = False,
|
||||
):
|
||||
"""
|
||||
Create an HuggingFaceTEITextEmbedder component.
|
||||
|
||||
:param model:
|
||||
ID of the model on HuggingFace Hub.
|
||||
:param url:
|
||||
The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference
|
||||
Endpoint.
|
||||
:param token:
|
||||
The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving
|
||||
a private or gated model.
|
||||
:param prefix:
|
||||
A string to add at the beginning of each text.
|
||||
:param suffix:
|
||||
A string to add at the end of each text.
|
||||
:param truncate:
|
||||
Truncate input text from the end to the maximum length supported by the model. This option is only available
|
||||
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
|
||||
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
|
||||
without TEI.
|
||||
:param normalize:
|
||||
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
|
||||
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
|
||||
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`HuggingFaceTEITextEmbedder` is deprecated and will be removed in Haystack 2.3.0."
|
||||
"Use `HuggingFaceAPITextEmbedder` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
huggingface_hub_import.check()
|
||||
|
||||
if url:
|
||||
r = urlparse(url)
|
||||
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])
|
||||
if not is_valid_url:
|
||||
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
|
||||
|
||||
check_valid_model(model, HFModelType.EMBEDDING, token)
|
||||
|
||||
self.model = model
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.truncate = truncate
|
||||
self.normalize = normalize
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
model=self.model,
|
||||
url=self.url,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
truncate=self.truncate,
|
||||
normalize=self.normalize,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEITextEmbedder":
|
||||
"""
|
||||
Deserializes the component from a dictionary.
|
||||
|
||||
:param data:
|
||||
Dictionary to deserialize from.
|
||||
:returns:
|
||||
Deserialized component.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
# Don't send URL as it is sensitive information
|
||||
return {"model": self.model}
|
||||
|
||||
@component.output_types(embedding=List[float])
|
||||
def run(self, text: str):
|
||||
"""
|
||||
Embed a single string.
|
||||
|
||||
:param text:
|
||||
Text to embed.
|
||||
|
||||
:returns:
|
||||
A dictionary with the following keys:
|
||||
- `embedding`: The embedding of the input text.
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(
|
||||
"HuggingFaceTEITextEmbedder expects a string as an input."
|
||||
"In case you want to embed a list of Documents, please use the HuggingFaceTEIDocumentEmbedder."
|
||||
)
|
||||
|
||||
text_to_embed = self.prefix + text + self.suffix
|
||||
|
||||
response = self.client.post(
|
||||
json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize},
|
||||
task="feature-extraction",
|
||||
)
|
||||
embedding = json.loads(response.decode())[0]
|
||||
|
||||
return {"embedding": embedding}
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
upgrade:
|
||||
- |
|
||||
Deprecated `HuggingFaceTEITextEmbedder` and `HuggingFaceTEIDocumentEmbedder` have been removed.
|
||||
Use `HuggingFaceAPITextEmbedder` and `HuggingFaceAPIDocumentEmbedder` instead.
|
||||
@ -1,398 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_check_valid_model():
|
||||
with patch(
|
||||
"haystack.components.embedders.hugging_face_tei_document_embedder.check_valid_model",
|
||||
MagicMock(return_value=None),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def mock_embedding_generation(json, **kwargs):
|
||||
response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
|
||||
return response
|
||||
|
||||
|
||||
class TestHuggingFaceTEIDocumentEmbedder:
|
||||
def test_init_default(self, monkeypatch, mock_check_valid_model):
|
||||
monkeypatch.setenv("HF_API_TOKEN", "fake-api-token")
|
||||
embedder = HuggingFaceTEIDocumentEmbedder()
|
||||
|
||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
||||
assert embedder.url is None
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate is True
|
||||
assert embedder.normalize is False
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar is True
|
||||
assert embedder.meta_fields_to_embed == []
|
||||
assert embedder.embedding_separator == "\n"
|
||||
|
||||
def test_init_with_parameters(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
truncate=False,
|
||||
normalize=True,
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
meta_fields_to_embed=["test_field"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
|
||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||
assert embedder.url == "https://some_embedding_model.com"
|
||||
assert embedder.token == Secret.from_token("fake-api-token")
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.truncate is False
|
||||
assert embedder.normalize is True
|
||||
assert embedder.batch_size == 64
|
||||
assert embedder.progress_bar is False
|
||||
assert embedder.meta_fields_to_embed == ["test_field"]
|
||||
assert embedder.embedding_separator == " | "
|
||||
|
||||
def test_initialize_with_invalid_url(self, mock_check_valid_model):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceTEIDocumentEmbedder(model="sentence-transformers/all-mpnet-base-v2", url="invalid_url")
|
||||
|
||||
def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
|
||||
# When custom TEI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
|
||||
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
||||
with pytest.raises(RepositoryNotFoundError):
|
||||
HuggingFaceTEIDocumentEmbedder(model="invalid_model_id", url="https://some_embedding_model.com")
|
||||
|
||||
def test_to_dict(self, mock_check_valid_model):
|
||||
component = HuggingFaceTEIDocumentEmbedder()
|
||||
data = component.to_dict()
|
||||
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"url": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"truncate": True,
|
||||
"normalize": False,
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"meta_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self, mock_check_valid_model):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"url": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"truncate": True,
|
||||
"normalize": False,
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"meta_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
},
|
||||
}
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
|
||||
|
||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
||||
assert embedder.url is None
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate is True
|
||||
assert embedder.normalize is False
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar is True
|
||||
assert embedder.meta_fields_to_embed == []
|
||||
assert embedder.embedding_separator == "\n"
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||
component = HuggingFaceTEIDocumentEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
truncate=False,
|
||||
normalize=True,
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
meta_fields_to_embed=["test_field"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
|
||||
data = component.to_dict()
|
||||
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
||||
"url": "https://some_embedding_model.com",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"meta_fields_to_embed": ["test_field"],
|
||||
"embedding_separator": " | ",
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
||||
"url": "https://some_embedding_model.com",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"meta_fields_to_embed": ["test_field"],
|
||||
"embedding_separator": " | ",
|
||||
},
|
||||
}
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
|
||||
|
||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||
assert embedder.url == "https://some_embedding_model.com"
|
||||
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.truncate is False
|
||||
assert embedder.normalize is True
|
||||
assert embedder.batch_size == 64
|
||||
assert embedder.progress_bar is False
|
||||
assert embedder.meta_fields_to_embed == ["test_field"]
|
||||
assert embedder.embedding_separator == " | "
|
||||
|
||||
def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model):
|
||||
documents = [
|
||||
Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
|
||||
]
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
meta_fields_to_embed=["meta_field"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
|
||||
prepared_texts = embedder._prepare_texts_to_embed(documents)
|
||||
|
||||
assert prepared_texts == [
|
||||
"meta_value 0 | document number 0: content",
|
||||
"meta_value 1 | document number 1: content",
|
||||
"meta_value 2 | document number 2: content",
|
||||
"meta_value 3 | document number 3: content",
|
||||
"meta_value 4 | document number 4: content",
|
||||
]
|
||||
|
||||
def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model):
|
||||
documents = [Document(content=f"document number {i}") for i in range(5)]
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="my_prefix ",
|
||||
suffix=" my_suffix",
|
||||
)
|
||||
|
||||
prepared_texts = embedder._prepare_texts_to_embed(documents)
|
||||
|
||||
assert prepared_texts == [
|
||||
"my_prefix document number 0 my_suffix",
|
||||
"my_prefix document number 1 my_suffix",
|
||||
"my_prefix document number 2 my_suffix",
|
||||
"my_prefix document number 3 my_suffix",
|
||||
"my_prefix document number 4 my_suffix",
|
||||
]
|
||||
|
||||
def test_embed_batch(self, mock_check_valid_model):
|
||||
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
|
||||
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
)
|
||||
embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2)
|
||||
|
||||
assert mock_embedding_patch.call_count == 3
|
||||
|
||||
assert isinstance(embeddings, list)
|
||||
assert len(embeddings) == len(texts)
|
||||
for embedding in embeddings:
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) == 384
|
||||
assert all(isinstance(x, float) for x in embedding)
|
||||
|
||||
def test_run(self, mock_check_valid_model):
|
||||
docs = [
|
||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix ",
|
||||
suffix=" suffix",
|
||||
meta_fields_to_embed=["topic"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
|
||||
result = embedder.run(documents=docs)
|
||||
|
||||
mock_embedding_patch.assert_called_once_with(
|
||||
json={
|
||||
"inputs": [
|
||||
"prefix Cuisine | I love cheese suffix",
|
||||
"prefix ML | A transformer is a deep learning architecture suffix",
|
||||
],
|
||||
"truncate": True,
|
||||
"normalize": False,
|
||||
},
|
||||
task="feature-extraction",
|
||||
)
|
||||
documents_with_embeddings = result["documents"]
|
||||
|
||||
assert isinstance(documents_with_embeddings, list)
|
||||
assert len(documents_with_embeddings) == len(docs)
|
||||
for doc in documents_with_embeddings:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert len(doc.embedding) == 384
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
|
||||
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("HF_API_TOKEN", None),
|
||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||
)
|
||||
def test_run_inference_api_endpoint(self):
|
||||
docs = [
|
||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", meta_fields_to_embed=["topic"], embedding_separator=" | "
|
||||
)
|
||||
|
||||
result = embedder.run(documents=docs)
|
||||
documents_with_embeddings = result["documents"]
|
||||
|
||||
assert isinstance(documents_with_embeddings, list)
|
||||
assert len(documents_with_embeddings) == len(docs)
|
||||
for doc in documents_with_embeddings:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert len(doc.embedding) == 384
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
|
||||
def test_run_custom_batch_size(self, mock_check_valid_model):
|
||||
docs = [
|
||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix ",
|
||||
suffix=" suffix",
|
||||
meta_fields_to_embed=["topic"],
|
||||
embedding_separator=" | ",
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
result = embedder.run(documents=docs)
|
||||
|
||||
assert mock_embedding_patch.call_count == 2
|
||||
|
||||
documents_with_embeddings = result["documents"]
|
||||
|
||||
assert isinstance(documents_with_embeddings, list)
|
||||
assert len(documents_with_embeddings) == len(docs)
|
||||
for doc in documents_with_embeddings:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert len(doc.embedding) == 384
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
|
||||
def test_run_wrong_input_format(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
)
|
||||
|
||||
# wrong formats
|
||||
string_input = "text"
|
||||
list_integers_input = [1, 2, 3]
|
||||
|
||||
with pytest.raises(TypeError, match="HuggingFaceTEIDocumentEmbedder expects a list of Documents as input"):
|
||||
embedder.run(documents=string_input)
|
||||
|
||||
with pytest.raises(TypeError, match="HuggingFaceTEIDocumentEmbedder expects a list of Documents as input"):
|
||||
embedder.run(documents=list_integers_input)
|
||||
|
||||
def test_run_on_empty_list(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
)
|
||||
|
||||
empty_list_input = []
|
||||
result = embedder.run(documents=empty_list_input)
|
||||
|
||||
assert result["documents"] is not None
|
||||
assert not result["documents"] # empty list
|
||||
@ -1,205 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_check_valid_model():
|
||||
with patch(
|
||||
"haystack.components.embedders.hugging_face_tei_text_embedder.check_valid_model", MagicMock(return_value=None)
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def mock_embedding_generation(json, **kwargs):
|
||||
response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
|
||||
return response
|
||||
|
||||
|
||||
class TestHuggingFaceTEITextEmbedder:
|
||||
def test_init_default(self, monkeypatch, mock_check_valid_model):
|
||||
monkeypatch.setenv("HF_API_TOKEN", "fake-api-token")
|
||||
embedder = HuggingFaceTEITextEmbedder()
|
||||
|
||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
||||
assert embedder.url is None
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate is True
|
||||
assert embedder.normalize is False
|
||||
|
||||
def test_init_with_parameters(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceTEITextEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
truncate=False,
|
||||
normalize=True,
|
||||
)
|
||||
|
||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||
assert embedder.url == "https://some_embedding_model.com"
|
||||
assert embedder.token == Secret.from_token("fake-api-token")
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.truncate is False
|
||||
assert embedder.normalize is True
|
||||
|
||||
def test_initialize_with_invalid_url(self, mock_check_valid_model):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceTEITextEmbedder(model="sentence-transformers/all-mpnet-base-v2", url="invalid_url")
|
||||
|
||||
def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
|
||||
# When custom TEI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
|
||||
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
||||
with pytest.raises(RepositoryNotFoundError):
|
||||
HuggingFaceTEITextEmbedder(model="invalid_model_id", url="https://some_embedding_model.com")
|
||||
|
||||
def test_to_dict(self, mock_check_valid_model):
|
||||
component = HuggingFaceTEITextEmbedder()
|
||||
data = component.to_dict()
|
||||
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"url": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"truncate": True,
|
||||
"normalize": False,
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self, mock_check_valid_model):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"url": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"truncate": True,
|
||||
"normalize": False,
|
||||
},
|
||||
}
|
||||
|
||||
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
|
||||
|
||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
||||
assert embedder.url is None
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate is True
|
||||
assert embedder.normalize is False
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||
component = HuggingFaceTEITextEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
truncate=False,
|
||||
normalize=True,
|
||||
)
|
||||
|
||||
data = component.to_dict()
|
||||
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
||||
"url": "https://some_embedding_model.com",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
||||
"url": "https://some_embedding_model.com",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
},
|
||||
}
|
||||
|
||||
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
|
||||
|
||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||
assert embedder.url == "https://some_embedding_model.com"
|
||||
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.truncate is False
|
||||
assert embedder.normalize is True
|
||||
|
||||
def test_run(self, mock_check_valid_model):
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceTEITextEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix ",
|
||||
suffix=" suffix",
|
||||
)
|
||||
|
||||
result = embedder.run(text="The food was delicious")
|
||||
|
||||
mock_embedding_patch.assert_called_once_with(
|
||||
json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False},
|
||||
task="feature-extraction",
|
||||
)
|
||||
|
||||
assert len(result["embedding"]) == 384
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
|
||||
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("HF_API_TOKEN", None),
|
||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
||||
)
|
||||
def test_run_inference_api_endpoint(self):
|
||||
embedder = HuggingFaceTEITextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
||||
result = embedder.run(text="The food was delicious")
|
||||
|
||||
assert len(result["embedding"]) == 384
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
|
||||
def test_run_wrong_input_format(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceTEITextEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
)
|
||||
|
||||
list_integers_input = [1, 2, 3]
|
||||
|
||||
with pytest.raises(TypeError, match="HuggingFaceTEITextEmbedder expects a string as an input"):
|
||||
embedder.run(text=list_integers_input)
|
||||
Loading…
x
Reference in New Issue
Block a user