mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-06 11:57:14 +00:00
chore: remove deprecated TEI embedders (#7907)
* remove deprecated TEI embedders * rm from the embedders init * rm related tests
This commit is contained in:
parent
d1f8c0dcd6
commit
75ad76a7ce
@ -5,8 +5,6 @@ loaders:
|
|||||||
[
|
[
|
||||||
"azure_document_embedder",
|
"azure_document_embedder",
|
||||||
"azure_text_embedder",
|
"azure_text_embedder",
|
||||||
"hugging_face_tei_document_embedder",
|
|
||||||
"hugging_face_tei_text_embedder",
|
|
||||||
"hugging_face_api_document_embedder",
|
"hugging_face_api_document_embedder",
|
||||||
"hugging_face_api_text_embedder",
|
"hugging_face_api_text_embedder",
|
||||||
"openai_document_embedder",
|
"openai_document_embedder",
|
||||||
|
|||||||
@ -6,16 +6,12 @@ from haystack.components.embedders.azure_document_embedder import AzureOpenAIDoc
|
|||||||
from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder
|
from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder
|
||||||
from haystack.components.embedders.hugging_face_api_document_embedder import HuggingFaceAPIDocumentEmbedder
|
from haystack.components.embedders.hugging_face_api_document_embedder import HuggingFaceAPIDocumentEmbedder
|
||||||
from haystack.components.embedders.hugging_face_api_text_embedder import HuggingFaceAPITextEmbedder
|
from haystack.components.embedders.hugging_face_api_text_embedder import HuggingFaceAPITextEmbedder
|
||||||
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
|
|
||||||
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
|
|
||||||
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
|
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
|
||||||
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
|
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
|
||||||
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
|
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
|
||||||
from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
|
from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"HuggingFaceTEITextEmbedder",
|
|
||||||
"HuggingFaceTEIDocumentEmbedder",
|
|
||||||
"HuggingFaceAPITextEmbedder",
|
"HuggingFaceAPITextEmbedder",
|
||||||
"HuggingFaceAPIDocumentEmbedder",
|
"HuggingFaceAPIDocumentEmbedder",
|
||||||
"SentenceTransformersTextEmbedder",
|
"SentenceTransformersTextEmbedder",
|
||||||
|
|||||||
@ -1,231 +0,0 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import json
|
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict, logging
|
|
||||||
from haystack.dataclasses import Document
|
|
||||||
from haystack.lazy_imports import LazyImport
|
|
||||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
|
||||||
from haystack.utils.hf import HFModelType, check_valid_model
|
|
||||||
|
|
||||||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
|
|
||||||
from huggingface_hub import InferenceClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@component
|
|
||||||
class HuggingFaceTEIDocumentEmbedder:
|
|
||||||
"""
|
|
||||||
A component for computing Document embeddings using HuggingFace Text-Embeddings-Inference endpoints.
|
|
||||||
|
|
||||||
This component can be used with embedding models hosted on Hugging Face Inference endpoints, the rate-limited
|
|
||||||
Inference API tier, for embedding models hosted on [the paid inference endpoint](https://huggingface.co/inference-endpoints)
|
|
||||||
and/or your own custom TEI endpoint.
|
|
||||||
|
|
||||||
Usage example:
|
|
||||||
```python
|
|
||||||
from haystack.dataclasses import Document
|
|
||||||
from haystack.components.embedders import HuggingFaceTEIDocumentEmbedder
|
|
||||||
from haystack.utils import Secret
|
|
||||||
|
|
||||||
doc = Document(content="I love pizza!")
|
|
||||||
|
|
||||||
document_embedder = HuggingFaceTEIDocumentEmbedder(
|
|
||||||
model="BAAI/bge-small-en-v1.5", token=Secret.from_token("<your-api-key>")
|
|
||||||
)
|
|
||||||
|
|
||||||
result = document_embedder.run([doc])
|
|
||||||
print(result["documents"][0].embedding)
|
|
||||||
|
|
||||||
# [0.017020374536514282, -0.023255806416273117, ...]
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: str = "BAAI/bge-small-en-v1.5",
|
|
||||||
url: Optional[str] = None,
|
|
||||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
|
||||||
prefix: str = "",
|
|
||||||
suffix: str = "",
|
|
||||||
truncate: bool = True,
|
|
||||||
normalize: bool = False,
|
|
||||||
batch_size: int = 32,
|
|
||||||
progress_bar: bool = True,
|
|
||||||
meta_fields_to_embed: Optional[List[str]] = None,
|
|
||||||
embedding_separator: str = "\n",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a HuggingFaceTEIDocumentEmbedder component.
|
|
||||||
|
|
||||||
:param model:
|
|
||||||
ID of the model on HuggingFace Hub.
|
|
||||||
:param url:
|
|
||||||
The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference
|
|
||||||
Endpoint.
|
|
||||||
:param token:
|
|
||||||
The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving
|
|
||||||
a private or gated model.
|
|
||||||
:param prefix:
|
|
||||||
A string to add at the beginning of each text.
|
|
||||||
:param suffix:
|
|
||||||
A string to add at the end of each text.
|
|
||||||
:param truncate:
|
|
||||||
Truncate input text from the end to the maximum length supported by the model. This option is only available
|
|
||||||
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
|
|
||||||
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
|
|
||||||
without TEI.
|
|
||||||
:param normalize:
|
|
||||||
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
|
|
||||||
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
|
|
||||||
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
|
|
||||||
:param batch_size:
|
|
||||||
Number of Documents to encode at once.
|
|
||||||
:param progress_bar:
|
|
||||||
If True shows a progress bar when running.
|
|
||||||
:param meta_fields_to_embed:
|
|
||||||
List of meta fields that will be embedded along with the Document text.
|
|
||||||
:param embedding_separator:
|
|
||||||
Separator used to concatenate the meta fields to the Document text.
|
|
||||||
"""
|
|
||||||
warnings.warn(
|
|
||||||
"`HuggingFaceTEIDocumentEmbedder` is deprecated and will be removed in Haystack 2.3.0."
|
|
||||||
"Use `HuggingFaceAPIDocumentEmbedder` instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
huggingface_hub_import.check()
|
|
||||||
|
|
||||||
if url:
|
|
||||||
r = urlparse(url)
|
|
||||||
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])
|
|
||||||
if not is_valid_url:
|
|
||||||
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
|
|
||||||
|
|
||||||
check_valid_model(model, HFModelType.EMBEDDING, token)
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.url = url
|
|
||||||
self.token = token
|
|
||||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
|
||||||
self.prefix = prefix
|
|
||||||
self.suffix = suffix
|
|
||||||
self.truncate = truncate
|
|
||||||
self.normalize = normalize
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.progress_bar = progress_bar
|
|
||||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
|
||||||
self.embedding_separator = embedding_separator
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Serializes the component to a dictionary.
|
|
||||||
|
|
||||||
:returns:
|
|
||||||
Dictionary with serialized data.
|
|
||||||
"""
|
|
||||||
return default_to_dict(
|
|
||||||
self,
|
|
||||||
model=self.model,
|
|
||||||
url=self.url,
|
|
||||||
prefix=self.prefix,
|
|
||||||
suffix=self.suffix,
|
|
||||||
truncate=self.truncate,
|
|
||||||
normalize=self.normalize,
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
progress_bar=self.progress_bar,
|
|
||||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
|
||||||
embedding_separator=self.embedding_separator,
|
|
||||||
token=self.token.to_dict() if self.token else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEIDocumentEmbedder":
|
|
||||||
"""
|
|
||||||
Deserializes the component from a dictionary.
|
|
||||||
|
|
||||||
:param data:
|
|
||||||
Dictionary to deserialize from.
|
|
||||||
:returns:
|
|
||||||
Deserialized component.
|
|
||||||
"""
|
|
||||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
|
||||||
return default_from_dict(cls, data)
|
|
||||||
|
|
||||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Data that is sent to Posthog for usage analytics.
|
|
||||||
"""
|
|
||||||
# Don't send URL as it is sensitive information
|
|
||||||
return {"model": self.model}
|
|
||||||
|
|
||||||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
|
|
||||||
"""
|
|
||||||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
|
|
||||||
"""
|
|
||||||
texts_to_embed = []
|
|
||||||
for doc in documents:
|
|
||||||
meta_values_to_embed = [
|
|
||||||
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
text_to_embed = (
|
|
||||||
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
|
|
||||||
)
|
|
||||||
|
|
||||||
texts_to_embed.append(text_to_embed)
|
|
||||||
return texts_to_embed
|
|
||||||
|
|
||||||
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[float]]:
|
|
||||||
"""
|
|
||||||
Embed a list of texts in batches.
|
|
||||||
"""
|
|
||||||
|
|
||||||
all_embeddings = []
|
|
||||||
for i in tqdm(
|
|
||||||
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
|
||||||
):
|
|
||||||
batch = texts_to_embed[i : i + batch_size]
|
|
||||||
response = self.client.post(
|
|
||||||
json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize},
|
|
||||||
task="feature-extraction",
|
|
||||||
)
|
|
||||||
embeddings = json.loads(response.decode())
|
|
||||||
all_embeddings.extend(embeddings)
|
|
||||||
|
|
||||||
return all_embeddings
|
|
||||||
|
|
||||||
@component.output_types(documents=List[Document])
|
|
||||||
def run(self, documents: List[Document]):
|
|
||||||
"""
|
|
||||||
Embed a list of Documents.
|
|
||||||
|
|
||||||
:param documents:
|
|
||||||
Documents to embed.
|
|
||||||
|
|
||||||
:returns:
|
|
||||||
A dictionary with the following keys:
|
|
||||||
- `documents`: Documents with embeddings
|
|
||||||
"""
|
|
||||||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
|
|
||||||
raise TypeError(
|
|
||||||
"HuggingFaceTEIDocumentEmbedder expects a list of Documents as input."
|
|
||||||
" In case you want to embed a string, please use the HuggingFaceTEITextEmbedder."
|
|
||||||
)
|
|
||||||
|
|
||||||
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
|
|
||||||
|
|
||||||
embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
|
|
||||||
|
|
||||||
for doc, emb in zip(documents, embeddings):
|
|
||||||
doc.embedding = emb
|
|
||||||
|
|
||||||
return {"documents": documents}
|
|
||||||
@ -1,171 +0,0 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import json
|
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict, logging
|
|
||||||
from haystack.lazy_imports import LazyImport
|
|
||||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
|
||||||
from haystack.utils.hf import HFModelType, check_valid_model
|
|
||||||
|
|
||||||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
|
|
||||||
from huggingface_hub import InferenceClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@component
|
|
||||||
class HuggingFaceTEITextEmbedder:
|
|
||||||
"""
|
|
||||||
A component for embedding strings using HuggingFace Text-Embeddings-Inference endpoints.
|
|
||||||
|
|
||||||
This component can be used with embedding models hosted on Hugging Face Inference endpoints, the rate-limited
|
|
||||||
Inference API tier, for embedding models hosted on [the paid inference endpoint](https://huggingface.co/inference-endpoints)
|
|
||||||
and/or your own custom TEI endpoint.
|
|
||||||
|
|
||||||
Usage example:
|
|
||||||
```python
|
|
||||||
from haystack.components.embedders import HuggingFaceTEITextEmbedder
|
|
||||||
from haystack.utils import Secret
|
|
||||||
|
|
||||||
text_to_embed = "I love pizza!"
|
|
||||||
|
|
||||||
text_embedder = HuggingFaceTEITextEmbedder(
|
|
||||||
model="BAAI/bge-small-en-v1.5", token=Secret.from_token("<your-api-key>")
|
|
||||||
)
|
|
||||||
|
|
||||||
print(text_embedder.run(text_to_embed))
|
|
||||||
|
|
||||||
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: str = "BAAI/bge-small-en-v1.5",
|
|
||||||
url: Optional[str] = None,
|
|
||||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
|
||||||
prefix: str = "",
|
|
||||||
suffix: str = "",
|
|
||||||
truncate: bool = True,
|
|
||||||
normalize: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create an HuggingFaceTEITextEmbedder component.
|
|
||||||
|
|
||||||
:param model:
|
|
||||||
ID of the model on HuggingFace Hub.
|
|
||||||
:param url:
|
|
||||||
The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference
|
|
||||||
Endpoint.
|
|
||||||
:param token:
|
|
||||||
The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving
|
|
||||||
a private or gated model.
|
|
||||||
:param prefix:
|
|
||||||
A string to add at the beginning of each text.
|
|
||||||
:param suffix:
|
|
||||||
A string to add at the end of each text.
|
|
||||||
:param truncate:
|
|
||||||
Truncate input text from the end to the maximum length supported by the model. This option is only available
|
|
||||||
for self-deployed Text Embedding Inference (TEI) endpoints and paid HF Inference Endpoints deployed with
|
|
||||||
TEI. It will be ignored when used with free HF Inference endpoints or paid HF Inference endpoints deployed
|
|
||||||
without TEI.
|
|
||||||
:param normalize:
|
|
||||||
Normalize the embeddings to unit length. This option is only available for self-deployed Text Embedding
|
|
||||||
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
|
|
||||||
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
|
|
||||||
"""
|
|
||||||
warnings.warn(
|
|
||||||
"`HuggingFaceTEITextEmbedder` is deprecated and will be removed in Haystack 2.3.0."
|
|
||||||
"Use `HuggingFaceAPITextEmbedder` instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
huggingface_hub_import.check()
|
|
||||||
|
|
||||||
if url:
|
|
||||||
r = urlparse(url)
|
|
||||||
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])
|
|
||||||
if not is_valid_url:
|
|
||||||
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
|
|
||||||
|
|
||||||
check_valid_model(model, HFModelType.EMBEDDING, token)
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.url = url
|
|
||||||
self.token = token
|
|
||||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
|
||||||
self.prefix = prefix
|
|
||||||
self.suffix = suffix
|
|
||||||
self.truncate = truncate
|
|
||||||
self.normalize = normalize
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Serializes the component to a dictionary.
|
|
||||||
|
|
||||||
:returns:
|
|
||||||
Dictionary with serialized data.
|
|
||||||
"""
|
|
||||||
return default_to_dict(
|
|
||||||
self,
|
|
||||||
model=self.model,
|
|
||||||
url=self.url,
|
|
||||||
prefix=self.prefix,
|
|
||||||
suffix=self.suffix,
|
|
||||||
token=self.token.to_dict() if self.token else None,
|
|
||||||
truncate=self.truncate,
|
|
||||||
normalize=self.normalize,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEITextEmbedder":
|
|
||||||
"""
|
|
||||||
Deserializes the component from a dictionary.
|
|
||||||
|
|
||||||
:param data:
|
|
||||||
Dictionary to deserialize from.
|
|
||||||
:returns:
|
|
||||||
Deserialized component.
|
|
||||||
"""
|
|
||||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
|
||||||
return default_from_dict(cls, data)
|
|
||||||
|
|
||||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Data that is sent to Posthog for usage analytics.
|
|
||||||
"""
|
|
||||||
# Don't send URL as it is sensitive information
|
|
||||||
return {"model": self.model}
|
|
||||||
|
|
||||||
@component.output_types(embedding=List[float])
|
|
||||||
def run(self, text: str):
|
|
||||||
"""
|
|
||||||
Embed a single string.
|
|
||||||
|
|
||||||
:param text:
|
|
||||||
Text to embed.
|
|
||||||
|
|
||||||
:returns:
|
|
||||||
A dictionary with the following keys:
|
|
||||||
- `embedding`: The embedding of the input text.
|
|
||||||
"""
|
|
||||||
if not isinstance(text, str):
|
|
||||||
raise TypeError(
|
|
||||||
"HuggingFaceTEITextEmbedder expects a string as an input."
|
|
||||||
"In case you want to embed a list of Documents, please use the HuggingFaceTEIDocumentEmbedder."
|
|
||||||
)
|
|
||||||
|
|
||||||
text_to_embed = self.prefix + text + self.suffix
|
|
||||||
|
|
||||||
response = self.client.post(
|
|
||||||
json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize},
|
|
||||||
task="feature-extraction",
|
|
||||||
)
|
|
||||||
embedding = json.loads(response.decode())[0]
|
|
||||||
|
|
||||||
return {"embedding": embedding}
|
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
upgrade:
|
||||||
|
- |
|
||||||
|
Deprecated `HuggingFaceTEITextEmbedder` and `HuggingFaceTEIDocumentEmbedder` have been removed.
|
||||||
|
Use `HuggingFaceAPITextEmbedder` and `HuggingFaceAPIDocumentEmbedder` instead.
|
||||||
@ -1,398 +0,0 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
import os
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
|
||||||
|
|
||||||
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
|
|
||||||
from haystack.dataclasses import Document
|
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_check_valid_model():
|
|
||||||
with patch(
|
|
||||||
"haystack.components.embedders.hugging_face_tei_document_embedder.check_valid_model",
|
|
||||||
MagicMock(return_value=None),
|
|
||||||
) as mock:
|
|
||||||
yield mock
|
|
||||||
|
|
||||||
|
|
||||||
def mock_embedding_generation(json, **kwargs):
|
|
||||||
response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class TestHuggingFaceTEIDocumentEmbedder:
|
|
||||||
def test_init_default(self, monkeypatch, mock_check_valid_model):
|
|
||||||
monkeypatch.setenv("HF_API_TOKEN", "fake-api-token")
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder()
|
|
||||||
|
|
||||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
|
||||||
assert embedder.url is None
|
|
||||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
|
||||||
assert embedder.prefix == ""
|
|
||||||
assert embedder.suffix == ""
|
|
||||||
assert embedder.truncate is True
|
|
||||||
assert embedder.normalize is False
|
|
||||||
assert embedder.batch_size == 32
|
|
||||||
assert embedder.progress_bar is True
|
|
||||||
assert embedder.meta_fields_to_embed == []
|
|
||||||
assert embedder.embedding_separator == "\n"
|
|
||||||
|
|
||||||
def test_init_with_parameters(self, mock_check_valid_model):
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
|
||||||
model="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
url="https://some_embedding_model.com",
|
|
||||||
token=Secret.from_token("fake-api-token"),
|
|
||||||
prefix="prefix",
|
|
||||||
suffix="suffix",
|
|
||||||
truncate=False,
|
|
||||||
normalize=True,
|
|
||||||
batch_size=64,
|
|
||||||
progress_bar=False,
|
|
||||||
meta_fields_to_embed=["test_field"],
|
|
||||||
embedding_separator=" | ",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
|
||||||
assert embedder.url == "https://some_embedding_model.com"
|
|
||||||
assert embedder.token == Secret.from_token("fake-api-token")
|
|
||||||
assert embedder.prefix == "prefix"
|
|
||||||
assert embedder.suffix == "suffix"
|
|
||||||
assert embedder.truncate is False
|
|
||||||
assert embedder.normalize is True
|
|
||||||
assert embedder.batch_size == 64
|
|
||||||
assert embedder.progress_bar is False
|
|
||||||
assert embedder.meta_fields_to_embed == ["test_field"]
|
|
||||||
assert embedder.embedding_separator == " | "
|
|
||||||
|
|
||||||
def test_initialize_with_invalid_url(self, mock_check_valid_model):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
HuggingFaceTEIDocumentEmbedder(model="sentence-transformers/all-mpnet-base-v2", url="invalid_url")
|
|
||||||
|
|
||||||
def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
|
|
||||||
# When custom TEI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
|
|
||||||
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
|
||||||
with pytest.raises(RepositoryNotFoundError):
|
|
||||||
HuggingFaceTEIDocumentEmbedder(model="invalid_model_id", url="https://some_embedding_model.com")
|
|
||||||
|
|
||||||
def test_to_dict(self, mock_check_valid_model):
|
|
||||||
component = HuggingFaceTEIDocumentEmbedder()
|
|
||||||
data = component.to_dict()
|
|
||||||
|
|
||||||
assert data == {
|
|
||||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
|
||||||
"init_parameters": {
|
|
||||||
"model": "BAAI/bge-small-en-v1.5",
|
|
||||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
|
||||||
"url": None,
|
|
||||||
"prefix": "",
|
|
||||||
"suffix": "",
|
|
||||||
"truncate": True,
|
|
||||||
"normalize": False,
|
|
||||||
"batch_size": 32,
|
|
||||||
"progress_bar": True,
|
|
||||||
"meta_fields_to_embed": [],
|
|
||||||
"embedding_separator": "\n",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_from_dict(self, mock_check_valid_model):
|
|
||||||
data = {
|
|
||||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
|
||||||
"init_parameters": {
|
|
||||||
"model": "BAAI/bge-small-en-v1.5",
|
|
||||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
|
||||||
"url": None,
|
|
||||||
"prefix": "",
|
|
||||||
"suffix": "",
|
|
||||||
"truncate": True,
|
|
||||||
"normalize": False,
|
|
||||||
"batch_size": 32,
|
|
||||||
"progress_bar": True,
|
|
||||||
"meta_fields_to_embed": [],
|
|
||||||
"embedding_separator": "\n",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
|
|
||||||
|
|
||||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
|
||||||
assert embedder.url is None
|
|
||||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
|
||||||
assert embedder.prefix == ""
|
|
||||||
assert embedder.suffix == ""
|
|
||||||
assert embedder.truncate is True
|
|
||||||
assert embedder.normalize is False
|
|
||||||
assert embedder.batch_size == 32
|
|
||||||
assert embedder.progress_bar is True
|
|
||||||
assert embedder.meta_fields_to_embed == []
|
|
||||||
assert embedder.embedding_separator == "\n"
|
|
||||||
|
|
||||||
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
|
||||||
component = HuggingFaceTEIDocumentEmbedder(
|
|
||||||
model="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
url="https://some_embedding_model.com",
|
|
||||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
|
||||||
prefix="prefix",
|
|
||||||
suffix="suffix",
|
|
||||||
truncate=False,
|
|
||||||
normalize=True,
|
|
||||||
batch_size=64,
|
|
||||||
progress_bar=False,
|
|
||||||
meta_fields_to_embed=["test_field"],
|
|
||||||
embedding_separator=" | ",
|
|
||||||
)
|
|
||||||
|
|
||||||
data = component.to_dict()
|
|
||||||
|
|
||||||
assert data == {
|
|
||||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
|
||||||
"init_parameters": {
|
|
||||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
|
||||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
|
||||||
"url": "https://some_embedding_model.com",
|
|
||||||
"prefix": "prefix",
|
|
||||||
"suffix": "suffix",
|
|
||||||
"truncate": False,
|
|
||||||
"normalize": True,
|
|
||||||
"batch_size": 64,
|
|
||||||
"progress_bar": False,
|
|
||||||
"meta_fields_to_embed": ["test_field"],
|
|
||||||
"embedding_separator": " | ",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
|
||||||
data = {
|
|
||||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
|
||||||
"init_parameters": {
|
|
||||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
|
||||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
|
||||||
"url": "https://some_embedding_model.com",
|
|
||||||
"prefix": "prefix",
|
|
||||||
"suffix": "suffix",
|
|
||||||
"truncate": False,
|
|
||||||
"normalize": True,
|
|
||||||
"batch_size": 64,
|
|
||||||
"progress_bar": False,
|
|
||||||
"meta_fields_to_embed": ["test_field"],
|
|
||||||
"embedding_separator": " | ",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder.from_dict(data)
|
|
||||||
|
|
||||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
|
||||||
assert embedder.url == "https://some_embedding_model.com"
|
|
||||||
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
|
|
||||||
assert embedder.prefix == "prefix"
|
|
||||||
assert embedder.suffix == "suffix"
|
|
||||||
assert embedder.truncate is False
|
|
||||||
assert embedder.normalize is True
|
|
||||||
assert embedder.batch_size == 64
|
|
||||||
assert embedder.progress_bar is False
|
|
||||||
assert embedder.meta_fields_to_embed == ["test_field"]
|
|
||||||
assert embedder.embedding_separator == " | "
|
|
||||||
|
|
||||||
def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model):
|
|
||||||
documents = [
|
|
||||||
Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
|
|
||||||
]
|
|
||||||
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
|
||||||
model="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
url="https://some_embedding_model.com",
|
|
||||||
token=Secret.from_token("fake-api-token"),
|
|
||||||
meta_fields_to_embed=["meta_field"],
|
|
||||||
embedding_separator=" | ",
|
|
||||||
)
|
|
||||||
|
|
||||||
prepared_texts = embedder._prepare_texts_to_embed(documents)
|
|
||||||
|
|
||||||
assert prepared_texts == [
|
|
||||||
"meta_value 0 | document number 0: content",
|
|
||||||
"meta_value 1 | document number 1: content",
|
|
||||||
"meta_value 2 | document number 2: content",
|
|
||||||
"meta_value 3 | document number 3: content",
|
|
||||||
"meta_value 4 | document number 4: content",
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model):
|
|
||||||
documents = [Document(content=f"document number {i}") for i in range(5)]
|
|
||||||
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
|
||||||
model="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
url="https://some_embedding_model.com",
|
|
||||||
token=Secret.from_token("fake-api-token"),
|
|
||||||
prefix="my_prefix ",
|
|
||||||
suffix=" my_suffix",
|
|
||||||
)
|
|
||||||
|
|
||||||
prepared_texts = embedder._prepare_texts_to_embed(documents)
|
|
||||||
|
|
||||||
assert prepared_texts == [
|
|
||||||
"my_prefix document number 0 my_suffix",
|
|
||||||
"my_prefix document number 1 my_suffix",
|
|
||||||
"my_prefix document number 2 my_suffix",
|
|
||||||
"my_prefix document number 3 my_suffix",
|
|
||||||
"my_prefix document number 4 my_suffix",
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_embed_batch(self, mock_check_valid_model):
|
|
||||||
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
|
|
||||||
|
|
||||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
|
||||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
|
||||||
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
|
||||||
model="BAAI/bge-small-en-v1.5",
|
|
||||||
url="https://some_embedding_model.com",
|
|
||||||
token=Secret.from_token("fake-api-token"),
|
|
||||||
)
|
|
||||||
embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2)
|
|
||||||
|
|
||||||
assert mock_embedding_patch.call_count == 3
|
|
||||||
|
|
||||||
assert isinstance(embeddings, list)
|
|
||||||
assert len(embeddings) == len(texts)
|
|
||||||
for embedding in embeddings:
|
|
||||||
assert isinstance(embedding, list)
|
|
||||||
assert len(embedding) == 384
|
|
||||||
assert all(isinstance(x, float) for x in embedding)
|
|
||||||
|
|
||||||
def test_run(self, mock_check_valid_model):
|
|
||||||
docs = [
|
|
||||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
|
||||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
|
||||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
|
||||||
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
|
||||||
model="BAAI/bge-small-en-v1.5",
|
|
||||||
token=Secret.from_token("fake-api-token"),
|
|
||||||
prefix="prefix ",
|
|
||||||
suffix=" suffix",
|
|
||||||
meta_fields_to_embed=["topic"],
|
|
||||||
embedding_separator=" | ",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = embedder.run(documents=docs)
|
|
||||||
|
|
||||||
mock_embedding_patch.assert_called_once_with(
|
|
||||||
json={
|
|
||||||
"inputs": [
|
|
||||||
"prefix Cuisine | I love cheese suffix",
|
|
||||||
"prefix ML | A transformer is a deep learning architecture suffix",
|
|
||||||
],
|
|
||||||
"truncate": True,
|
|
||||||
"normalize": False,
|
|
||||||
},
|
|
||||||
task="feature-extraction",
|
|
||||||
)
|
|
||||||
documents_with_embeddings = result["documents"]
|
|
||||||
|
|
||||||
assert isinstance(documents_with_embeddings, list)
|
|
||||||
assert len(documents_with_embeddings) == len(docs)
|
|
||||||
for doc in documents_with_embeddings:
|
|
||||||
assert isinstance(doc, Document)
|
|
||||||
assert isinstance(doc.embedding, list)
|
|
||||||
assert len(doc.embedding) == 384
|
|
||||||
assert all(isinstance(x, float) for x in doc.embedding)
|
|
||||||
|
|
||||||
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
|
||||||
@pytest.mark.integration
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not os.environ.get("HF_API_TOKEN", None),
|
|
||||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
|
||||||
)
|
|
||||||
def test_run_inference_api_endpoint(self):
|
|
||||||
docs = [
|
|
||||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
|
||||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
|
||||||
]
|
|
||||||
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
|
||||||
model="sentence-transformers/all-MiniLM-L6-v2", meta_fields_to_embed=["topic"], embedding_separator=" | "
|
|
||||||
)
|
|
||||||
|
|
||||||
result = embedder.run(documents=docs)
|
|
||||||
documents_with_embeddings = result["documents"]
|
|
||||||
|
|
||||||
assert isinstance(documents_with_embeddings, list)
|
|
||||||
assert len(documents_with_embeddings) == len(docs)
|
|
||||||
for doc in documents_with_embeddings:
|
|
||||||
assert isinstance(doc, Document)
|
|
||||||
assert isinstance(doc.embedding, list)
|
|
||||||
assert len(doc.embedding) == 384
|
|
||||||
assert all(isinstance(x, float) for x in doc.embedding)
|
|
||||||
|
|
||||||
def test_run_custom_batch_size(self, mock_check_valid_model):
|
|
||||||
docs = [
|
|
||||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
|
||||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
|
||||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
|
||||||
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
|
||||||
model="BAAI/bge-small-en-v1.5",
|
|
||||||
token=Secret.from_token("fake-api-token"),
|
|
||||||
prefix="prefix ",
|
|
||||||
suffix=" suffix",
|
|
||||||
meta_fields_to_embed=["topic"],
|
|
||||||
embedding_separator=" | ",
|
|
||||||
batch_size=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = embedder.run(documents=docs)
|
|
||||||
|
|
||||||
assert mock_embedding_patch.call_count == 2
|
|
||||||
|
|
||||||
documents_with_embeddings = result["documents"]
|
|
||||||
|
|
||||||
assert isinstance(documents_with_embeddings, list)
|
|
||||||
assert len(documents_with_embeddings) == len(docs)
|
|
||||||
for doc in documents_with_embeddings:
|
|
||||||
assert isinstance(doc, Document)
|
|
||||||
assert isinstance(doc.embedding, list)
|
|
||||||
assert len(doc.embedding) == 384
|
|
||||||
assert all(isinstance(x, float) for x in doc.embedding)
|
|
||||||
|
|
||||||
def test_run_wrong_input_format(self, mock_check_valid_model):
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
|
||||||
model="BAAI/bge-small-en-v1.5",
|
|
||||||
url="https://some_embedding_model.com",
|
|
||||||
token=Secret.from_token("fake-api-token"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# wrong formats
|
|
||||||
string_input = "text"
|
|
||||||
list_integers_input = [1, 2, 3]
|
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="HuggingFaceTEIDocumentEmbedder expects a list of Documents as input"):
|
|
||||||
embedder.run(documents=string_input)
|
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="HuggingFaceTEIDocumentEmbedder expects a list of Documents as input"):
|
|
||||||
embedder.run(documents=list_integers_input)
|
|
||||||
|
|
||||||
def test_run_on_empty_list(self, mock_check_valid_model):
|
|
||||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
|
||||||
model="BAAI/bge-small-en-v1.5",
|
|
||||||
url="https://some_embedding_model.com",
|
|
||||||
token=Secret.from_token("fake-api-token"),
|
|
||||||
)
|
|
||||||
|
|
||||||
empty_list_input = []
|
|
||||||
result = embedder.run(documents=empty_list_input)
|
|
||||||
|
|
||||||
assert result["documents"] is not None
|
|
||||||
assert not result["documents"] # empty list
|
|
||||||
@ -1,205 +0,0 @@
|
|||||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
import os
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
|
||||||
|
|
||||||
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
|
|
||||||
from haystack.utils.auth import Secret
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_check_valid_model():
|
|
||||||
with patch(
|
|
||||||
"haystack.components.embedders.hugging_face_tei_text_embedder.check_valid_model", MagicMock(return_value=None)
|
|
||||||
) as mock:
|
|
||||||
yield mock
|
|
||||||
|
|
||||||
|
|
||||||
def mock_embedding_generation(json, **kwargs):
|
|
||||||
response = str(np.array([np.random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class TestHuggingFaceTEITextEmbedder:
|
|
||||||
def test_init_default(self, monkeypatch, mock_check_valid_model):
|
|
||||||
monkeypatch.setenv("HF_API_TOKEN", "fake-api-token")
|
|
||||||
embedder = HuggingFaceTEITextEmbedder()
|
|
||||||
|
|
||||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
|
||||||
assert embedder.url is None
|
|
||||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
|
||||||
assert embedder.prefix == ""
|
|
||||||
assert embedder.suffix == ""
|
|
||||||
assert embedder.truncate is True
|
|
||||||
assert embedder.normalize is False
|
|
||||||
|
|
||||||
def test_init_with_parameters(self, mock_check_valid_model):
|
|
||||||
embedder = HuggingFaceTEITextEmbedder(
|
|
||||||
model="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
url="https://some_embedding_model.com",
|
|
||||||
token=Secret.from_token("fake-api-token"),
|
|
||||||
prefix="prefix",
|
|
||||||
suffix="suffix",
|
|
||||||
truncate=False,
|
|
||||||
normalize=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
|
||||||
assert embedder.url == "https://some_embedding_model.com"
|
|
||||||
assert embedder.token == Secret.from_token("fake-api-token")
|
|
||||||
assert embedder.prefix == "prefix"
|
|
||||||
assert embedder.suffix == "suffix"
|
|
||||||
assert embedder.truncate is False
|
|
||||||
assert embedder.normalize is True
|
|
||||||
|
|
||||||
def test_initialize_with_invalid_url(self, mock_check_valid_model):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
HuggingFaceTEITextEmbedder(model="sentence-transformers/all-mpnet-base-v2", url="invalid_url")
|
|
||||||
|
|
||||||
def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
|
|
||||||
# When custom TEI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
|
|
||||||
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
|
||||||
with pytest.raises(RepositoryNotFoundError):
|
|
||||||
HuggingFaceTEITextEmbedder(model="invalid_model_id", url="https://some_embedding_model.com")
|
|
||||||
|
|
||||||
def test_to_dict(self, mock_check_valid_model):
|
|
||||||
component = HuggingFaceTEITextEmbedder()
|
|
||||||
data = component.to_dict()
|
|
||||||
|
|
||||||
assert data == {
|
|
||||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
|
||||||
"init_parameters": {
|
|
||||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
|
||||||
"model": "BAAI/bge-small-en-v1.5",
|
|
||||||
"url": None,
|
|
||||||
"prefix": "",
|
|
||||||
"suffix": "",
|
|
||||||
"truncate": True,
|
|
||||||
"normalize": False,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_from_dict(self, mock_check_valid_model):
|
|
||||||
data = {
|
|
||||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
|
||||||
"init_parameters": {
|
|
||||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
|
||||||
"model": "BAAI/bge-small-en-v1.5",
|
|
||||||
"url": None,
|
|
||||||
"prefix": "",
|
|
||||||
"suffix": "",
|
|
||||||
"truncate": True,
|
|
||||||
"normalize": False,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
|
|
||||||
|
|
||||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
|
||||||
assert embedder.url is None
|
|
||||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
|
||||||
assert embedder.prefix == ""
|
|
||||||
assert embedder.suffix == ""
|
|
||||||
assert embedder.truncate is True
|
|
||||||
assert embedder.normalize is False
|
|
||||||
|
|
||||||
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
|
||||||
component = HuggingFaceTEITextEmbedder(
|
|
||||||
model="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
url="https://some_embedding_model.com",
|
|
||||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
|
||||||
prefix="prefix",
|
|
||||||
suffix="suffix",
|
|
||||||
truncate=False,
|
|
||||||
normalize=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = component.to_dict()
|
|
||||||
|
|
||||||
assert data == {
|
|
||||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
|
||||||
"init_parameters": {
|
|
||||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
|
||||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
|
||||||
"url": "https://some_embedding_model.com",
|
|
||||||
"prefix": "prefix",
|
|
||||||
"suffix": "suffix",
|
|
||||||
"truncate": False,
|
|
||||||
"normalize": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_from_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
|
||||||
data = {
|
|
||||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
|
||||||
"init_parameters": {
|
|
||||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
|
||||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
|
||||||
"url": "https://some_embedding_model.com",
|
|
||||||
"prefix": "prefix",
|
|
||||||
"suffix": "suffix",
|
|
||||||
"truncate": False,
|
|
||||||
"normalize": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
embedder = HuggingFaceTEITextEmbedder.from_dict(data)
|
|
||||||
|
|
||||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
|
||||||
assert embedder.url == "https://some_embedding_model.com"
|
|
||||||
assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False)
|
|
||||||
assert embedder.prefix == "prefix"
|
|
||||||
assert embedder.suffix == "suffix"
|
|
||||||
assert embedder.truncate is False
|
|
||||||
assert embedder.normalize is True
|
|
||||||
|
|
||||||
def test_run(self, mock_check_valid_model):
|
|
||||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
|
||||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
|
||||||
|
|
||||||
embedder = HuggingFaceTEITextEmbedder(
|
|
||||||
model="BAAI/bge-small-en-v1.5",
|
|
||||||
token=Secret.from_token("fake-api-token"),
|
|
||||||
prefix="prefix ",
|
|
||||||
suffix=" suffix",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = embedder.run(text="The food was delicious")
|
|
||||||
|
|
||||||
mock_embedding_patch.assert_called_once_with(
|
|
||||||
json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False},
|
|
||||||
task="feature-extraction",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result["embedding"]) == 384
|
|
||||||
assert all(isinstance(x, float) for x in result["embedding"])
|
|
||||||
|
|
||||||
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
|
||||||
@pytest.mark.integration
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not os.environ.get("HF_API_TOKEN", None),
|
|
||||||
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
|
|
||||||
)
|
|
||||||
def test_run_inference_api_endpoint(self):
|
|
||||||
embedder = HuggingFaceTEITextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
|
||||||
result = embedder.run(text="The food was delicious")
|
|
||||||
|
|
||||||
assert len(result["embedding"]) == 384
|
|
||||||
assert all(isinstance(x, float) for x in result["embedding"])
|
|
||||||
|
|
||||||
def test_run_wrong_input_format(self, mock_check_valid_model):
|
|
||||||
embedder = HuggingFaceTEITextEmbedder(
|
|
||||||
model="BAAI/bge-small-en-v1.5",
|
|
||||||
url="https://some_embedding_model.com",
|
|
||||||
token=Secret.from_token("fake-api-token"),
|
|
||||||
)
|
|
||||||
|
|
||||||
list_integers_input = [1, 2, 3]
|
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="HuggingFaceTEITextEmbedder expects a string as an input"):
|
|
||||||
embedder.run(text=list_integers_input)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user