mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
feat: HuggingFaceAPIDocumentEmbedder (#7485)
* add HuggingFaceAPITextEmbedder * add HuggingFaceAPITextEmbedder * rm unneeded else * wip * small fixes * deprecation; reno * Apply suggestions from code review Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * make params mandatory * changes requested * fix test * fix test --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
parent
c91bd49cae
commit
eff53a9131
@ -7,6 +7,7 @@ loaders:
|
||||
"azure_text_embedder",
|
||||
"hugging_face_tei_document_embedder",
|
||||
"hugging_face_tei_text_embedder",
|
||||
"hugging_face_api_document_embedder",
|
||||
"hugging_face_api_text_embedder",
|
||||
"openai_document_embedder",
|
||||
"openai_text_embedder",
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from haystack.components.embedders.azure_document_embedder import AzureOpenAIDocumentEmbedder
|
||||
from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder
|
||||
from haystack.components.embedders.hugging_face_api_document_embedder import HuggingFaceAPIDocumentEmbedder
|
||||
from haystack.components.embedders.hugging_face_api_text_embedder import HuggingFaceAPITextEmbedder
|
||||
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
|
||||
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
|
||||
@ -12,6 +13,7 @@ __all__ = [
|
||||
"HuggingFaceTEITextEmbedder",
|
||||
"HuggingFaceTEIDocumentEmbedder",
|
||||
"HuggingFaceAPITextEmbedder",
|
||||
"HuggingFaceAPIDocumentEmbedder",
|
||||
"SentenceTransformersTextEmbedder",
|
||||
"SentenceTransformersDocumentEmbedder",
|
||||
"OpenAITextEmbedder",
|
||||
|
||||
@ -0,0 +1,263 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
|
||||
from haystack.utils.url_validation import is_valid_http_url
|
||||
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class HuggingFaceAPIDocumentEmbedder:
|
||||
"""
|
||||
This component can be used to compute Document embeddings using different Hugging Face APIs:
|
||||
- [Free Serverless Inference API]((https://huggingface.co/inference-api)
|
||||
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
|
||||
- [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
|
||||
|
||||
|
||||
Example usage with the free Serverless Inference API:
|
||||
```python
|
||||
from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
|
||||
from haystack.utils import Secret
|
||||
from haystack.dataclasses import Document
|
||||
|
||||
doc = Document(content="I love pizza!")
|
||||
|
||||
doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="serverless_inference_api",
|
||||
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
||||
token=Secret.from_token("<your-api-key>"))
|
||||
|
||||
result = document_embedder.run([doc])
|
||||
print(result["documents"][0].embedding)
|
||||
|
||||
# [0.017020374536514282, -0.023255806416273117, ...]
|
||||
```
|
||||
|
||||
Example usage with paid Inference Endpoints:
|
||||
```python
|
||||
from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
|
||||
from haystack.utils import Secret
|
||||
from haystack.dataclasses import Document
|
||||
|
||||
doc = Document(content="I love pizza!")
|
||||
|
||||
doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="inference_endpoints",
|
||||
api_params={"url": "<your-inference-endpoint-url>"},
|
||||
token=Secret.from_token("<your-api-key>"))
|
||||
|
||||
result = document_embedder.run([doc])
|
||||
print(result["documents"][0].embedding)
|
||||
|
||||
# [0.017020374536514282, -0.023255806416273117, ...]
|
||||
```
|
||||
|
||||
Example usage with self-hosted Text Embeddings Inference:
|
||||
```python
|
||||
from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
|
||||
from haystack.dataclasses import Document
|
||||
|
||||
doc = Document(content="I love pizza!")
|
||||
|
||||
doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="text_embeddings_inference",
|
||||
api_params={"url": "http://localhost:8080"})
|
||||
|
||||
result = document_embedder.run([doc])
|
||||
print(result["documents"][0].embedding)
|
||||
|
||||
# [0.017020374536514282, -0.023255806416273117, ...]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_type: Union[HFEmbeddingAPIType, str],
|
||||
api_params: Dict[str, str],
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
truncate: bool = True,
|
||||
normalize: bool = False,
|
||||
batch_size: int = 32,
|
||||
progress_bar: bool = True,
|
||||
meta_fields_to_embed: Optional[List[str]] = None,
|
||||
embedding_separator: str = "\n",
|
||||
):
|
||||
"""
|
||||
Create an HuggingFaceAPITextEmbedder component.
|
||||
|
||||
:param api_type:
|
||||
The type of Hugging Face API to use.
|
||||
:param api_params:
|
||||
A dictionary containing the following keys:
|
||||
- `model`: model ID on the Hugging Face Hub. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
|
||||
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_EMBEDDINGS_INFERENCE`.
|
||||
:param token: The HuggingFace token to use as HTTP bearer authorization.
|
||||
You can find your HF token in your [account settings](https://huggingface.co/settings/tokens).
|
||||
:param prefix:
|
||||
A string to add at the beginning of each text.
|
||||
:param suffix:
|
||||
A string to add at the end of each text.
|
||||
:param truncate:
|
||||
Truncate input text from the end to the maximum length supported by the model.
|
||||
This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`.
|
||||
It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference.
|
||||
This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `True` and cannot be changed).
|
||||
:param normalize:
|
||||
Normalize the embeddings to unit length.
|
||||
This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`.
|
||||
It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference.
|
||||
This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `False` and cannot be changed).
|
||||
:param batch_size:
|
||||
Number of Documents to process at once.
|
||||
:param progress_bar:
|
||||
If `True` shows a progress bar when running.
|
||||
:param meta_fields_to_embed:
|
||||
List of meta fields that will be embedded along with the Document text.
|
||||
:param embedding_separator:
|
||||
Separator used to concatenate the meta fields to the Document text.
|
||||
"""
|
||||
huggingface_hub_import.check()
|
||||
|
||||
if isinstance(api_type, str):
|
||||
api_type = HFEmbeddingAPIType.from_str(api_type)
|
||||
|
||||
api_params = api_params or {}
|
||||
|
||||
if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
|
||||
model = api_params.get("model")
|
||||
if model is None:
|
||||
raise ValueError(
|
||||
"To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
|
||||
)
|
||||
check_valid_model(model, HFModelType.EMBEDDING, token)
|
||||
model_or_url = model
|
||||
elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
|
||||
url = api_params.get("url")
|
||||
if url is None:
|
||||
raise ValueError(
|
||||
"To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` parameter in `api_params`."
|
||||
)
|
||||
if not is_valid_http_url(url):
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
model_or_url = url
|
||||
|
||||
self.api_type = api_type
|
||||
self.api_params = api_params
|
||||
self.token = token
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.truncate = truncate
|
||||
self.normalize = normalize
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
api_type=self.api_type,
|
||||
api_params=self.api_params,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
truncate=self.truncate,
|
||||
normalize=self.normalize,
|
||||
batch_size=self.batch_size,
|
||||
progress_bar=self.progress_bar,
|
||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIDocumentEmbedder":
|
||||
"""
|
||||
Deserializes the component from a dictionary.
|
||||
|
||||
:param data:
|
||||
Dictionary to deserialize from.
|
||||
:returns:
|
||||
Deserialized component.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
|
||||
"""
|
||||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
|
||||
"""
|
||||
texts_to_embed = []
|
||||
for doc in documents:
|
||||
meta_values_to_embed = [
|
||||
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
|
||||
]
|
||||
|
||||
text_to_embed = (
|
||||
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
|
||||
)
|
||||
|
||||
texts_to_embed.append(text_to_embed)
|
||||
return texts_to_embed
|
||||
|
||||
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[float]]:
|
||||
"""
|
||||
Embed a list of texts in batches.
|
||||
"""
|
||||
|
||||
all_embeddings = []
|
||||
for i in tqdm(
|
||||
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
||||
):
|
||||
batch = texts_to_embed[i : i + batch_size]
|
||||
response = self._client.post(
|
||||
json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize},
|
||||
task="feature-extraction",
|
||||
)
|
||||
embeddings = json.loads(response.decode())
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: List[Document]):
|
||||
"""
|
||||
Embed a list of Documents.
|
||||
|
||||
:param documents:
|
||||
Documents to embed.
|
||||
|
||||
:returns:
|
||||
A dictionary with the following keys:
|
||||
- `documents`: Documents with embeddings
|
||||
"""
|
||||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
|
||||
raise TypeError(
|
||||
"HuggingFaceAPIDocumentEmbedder expects a list of Documents as input."
|
||||
" In case you want to embed a string, please use the HuggingFaceAPITextEmbedder."
|
||||
)
|
||||
|
||||
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
|
||||
|
||||
embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
|
||||
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
doc.embedding = emb
|
||||
|
||||
return {"documents": documents}
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@ -91,6 +92,12 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
:param embedding_separator:
|
||||
Separator used to concatenate the meta fields to the Document text.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`HuggingFaceTEIDocumentEmbedder` is deprecated and will be removed in Haystack 2.3.0."
|
||||
"Use `HuggingFaceAPIDocumentEmbedder` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
huggingface_hub_import.check()
|
||||
|
||||
if url:
|
||||
|
||||
13
releasenotes/notes/hfapidocembedder-4c3970d002275edb.yaml
Normal file
13
releasenotes/notes/hfapidocembedder-4c3970d002275edb.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Introduce `HuggingFaceAPIDocumentEmbedder`.
|
||||
This component can be used to compute Document embeddings using different Hugging Face APIs:
|
||||
- free Serverless Inference API
|
||||
- paid Inference Endpoints
|
||||
- self-hosted Text Embeddings Inference.
|
||||
This embedder will replace the `HuggingFaceTEIDocumentEmbedder` in the future.
|
||||
deprecations:
|
||||
- |
|
||||
Deprecate `HuggingFaceTEIDocumentEmbedder`. This component will be removed in Haystack 2.3.0.
|
||||
Use `HuggingFaceAPIDocumentEmbedder` instead.
|
||||
@ -0,0 +1,344 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
from numpy import array, random
|
||||
|
||||
from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.utils.hf import HFEmbeddingAPIType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_check_valid_model():
|
||||
with patch(
|
||||
"haystack.components.embedders.hugging_face_api_document_embedder.check_valid_model",
|
||||
MagicMock(return_value=None),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def mock_embedding_generation(json, **kwargs):
|
||||
response = str(array([random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
|
||||
return response
|
||||
|
||||
|
||||
class TestHuggingFaceAPIDocumentEmbedder:
|
||||
def test_init_invalid_api_type(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPIDocumentEmbedder(api_type="invalid_api_type", api_params={})
|
||||
|
||||
def test_init_serverless(self, mock_check_valid_model):
|
||||
model = "BAAI/bge-small-en-v1.5"
|
||||
embedder = HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model}
|
||||
)
|
||||
|
||||
assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API
|
||||
assert embedder.api_params == {"model": model}
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate
|
||||
assert not embedder.normalize
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar
|
||||
assert embedder.meta_fields_to_embed == []
|
||||
assert embedder.embedding_separator == "\n"
|
||||
|
||||
def test_init_serverless_invalid_model(self, mock_check_valid_model):
|
||||
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
||||
with pytest.raises(RepositoryNotFoundError):
|
||||
HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"}
|
||||
)
|
||||
|
||||
def test_init_serverless_no_model(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"}
|
||||
)
|
||||
|
||||
def test_init_tei(self):
|
||||
url = "https://some_model.com"
|
||||
|
||||
embedder = HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": url}
|
||||
)
|
||||
|
||||
assert embedder.api_type == HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE
|
||||
assert embedder.api_params == {"url": url}
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate
|
||||
assert not embedder.normalize
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar
|
||||
assert embedder.meta_fields_to_embed == []
|
||||
assert embedder.embedding_separator == "\n"
|
||||
|
||||
def test_init_tei_invalid_url(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": "invalid_url"}
|
||||
)
|
||||
|
||||
def test_init_tei_no_url(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"param": "irrelevant"}
|
||||
)
|
||||
|
||||
def test_to_dict(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
truncate=False,
|
||||
normalize=True,
|
||||
batch_size=128,
|
||||
progress_bar=False,
|
||||
meta_fields_to_embed=["meta_field"],
|
||||
embedding_separator=" ",
|
||||
)
|
||||
|
||||
data = embedder.to_dict()
|
||||
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.hugging_face_api_document_embedder.HuggingFaceAPIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
"api_params": {"model": "BAAI/bge-small-en-v1.5"},
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
"batch_size": 128,
|
||||
"progress_bar": False,
|
||||
"meta_fields_to_embed": ["meta_field"],
|
||||
"embedding_separator": " ",
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self, mock_check_valid_model):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.hugging_face_api_document_embedder.HuggingFaceAPIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
"api_params": {"model": "BAAI/bge-small-en-v1.5"},
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
"batch_size": 128,
|
||||
"progress_bar": False,
|
||||
"meta_fields_to_embed": ["meta_field"],
|
||||
"embedding_separator": " ",
|
||||
},
|
||||
}
|
||||
|
||||
embedder = HuggingFaceAPIDocumentEmbedder.from_dict(data)
|
||||
|
||||
assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API
|
||||
assert embedder.api_params == {"model": "BAAI/bge-small-en-v1.5"}
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert not embedder.truncate
|
||||
assert embedder.normalize
|
||||
assert embedder.batch_size == 128
|
||||
assert not embedder.progress_bar
|
||||
assert embedder.meta_fields_to_embed == ["meta_field"]
|
||||
assert embedder.embedding_separator == " "
|
||||
|
||||
def test_prepare_texts_to_embed_w_metadata(self):
|
||||
documents = [
|
||||
Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
|
||||
]
|
||||
|
||||
embedder = HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE,
|
||||
api_params={"url": "https://some_model.com"},
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
meta_fields_to_embed=["meta_field"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
|
||||
prepared_texts = embedder._prepare_texts_to_embed(documents)
|
||||
|
||||
assert prepared_texts == [
|
||||
"meta_value 0 | document number 0: content",
|
||||
"meta_value 1 | document number 1: content",
|
||||
"meta_value 2 | document number 2: content",
|
||||
"meta_value 3 | document number 3: content",
|
||||
"meta_value 4 | document number 4: content",
|
||||
]
|
||||
|
||||
def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model):
|
||||
documents = [Document(content=f"document number {i}") for i in range(5)]
|
||||
|
||||
embedder = HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE,
|
||||
api_params={"url": "https://some_model.com"},
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="my_prefix ",
|
||||
suffix=" my_suffix",
|
||||
)
|
||||
|
||||
prepared_texts = embedder._prepare_texts_to_embed(documents)
|
||||
|
||||
assert prepared_texts == [
|
||||
"my_prefix document number 0 my_suffix",
|
||||
"my_prefix document number 1 my_suffix",
|
||||
"my_prefix document number 2 my_suffix",
|
||||
"my_prefix document number 3 my_suffix",
|
||||
"my_prefix document number 4 my_suffix",
|
||||
]
|
||||
|
||||
def test_embed_batch(self, mock_check_valid_model):
|
||||
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
|
||||
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
)
|
||||
embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2)
|
||||
|
||||
assert mock_embedding_patch.call_count == 3
|
||||
|
||||
assert isinstance(embeddings, list)
|
||||
assert len(embeddings) == len(texts)
|
||||
for embedding in embeddings:
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) == 384
|
||||
assert all(isinstance(x, float) for x in embedding)
|
||||
|
||||
def test_run_wrong_input_format(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
|
||||
)
|
||||
|
||||
list_integers_input = [1, 2, 3]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
embedder.run(text=list_integers_input)
|
||||
|
||||
def test_run_on_empty_list(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
)
|
||||
|
||||
empty_list_input = []
|
||||
result = embedder.run(documents=empty_list_input)
|
||||
|
||||
assert result["documents"] is not None
|
||||
assert not result["documents"] # empty list
|
||||
|
||||
def test_run(self, mock_check_valid_model):
|
||||
docs = [
|
||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix ",
|
||||
suffix=" suffix",
|
||||
meta_fields_to_embed=["topic"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
|
||||
result = embedder.run(documents=docs)
|
||||
|
||||
mock_embedding_patch.assert_called_once_with(
|
||||
json={
|
||||
"inputs": [
|
||||
"prefix Cuisine | I love cheese suffix",
|
||||
"prefix ML | A transformer is a deep learning architecture suffix",
|
||||
],
|
||||
"truncate": True,
|
||||
"normalize": False,
|
||||
},
|
||||
task="feature-extraction",
|
||||
)
|
||||
documents_with_embeddings = result["documents"]
|
||||
|
||||
assert isinstance(documents_with_embeddings, list)
|
||||
assert len(documents_with_embeddings) == len(docs)
|
||||
for doc in documents_with_embeddings:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert len(doc.embedding) == 384
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
|
||||
def test_run_custom_batch_size(self, mock_check_valid_model):
|
||||
docs = [
|
||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix ",
|
||||
suffix=" suffix",
|
||||
meta_fields_to_embed=["topic"],
|
||||
embedding_separator=" | ",
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
result = embedder.run(documents=docs)
|
||||
|
||||
assert mock_embedding_patch.call_count == 2
|
||||
|
||||
documents_with_embeddings = result["documents"]
|
||||
|
||||
assert isinstance(documents_with_embeddings, list)
|
||||
assert len(documents_with_embeddings) == len(docs)
|
||||
for doc in documents_with_embeddings:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert len(doc.embedding) == 384
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
|
||||
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_serverless(self):
|
||||
docs = [
|
||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
|
||||
embedder = HuggingFaceAPIDocumentEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "sentence-transformers/all-MiniLM-L6-v2"},
|
||||
meta_fields_to_embed=["topic"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
result = embedder.run(documents=docs)
|
||||
documents_with_embeddings = result["documents"]
|
||||
|
||||
assert isinstance(documents_with_embeddings, list)
|
||||
assert len(documents_with_embeddings) == len(docs)
|
||||
for doc in documents_with_embeddings:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert len(doc.embedding) == 384
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
Loading…
x
Reference in New Issue
Block a user