mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-04 02:57:34 +00:00
feat: HuggingFaceAPITextEmbedder (#7484)
* add HuggingFaceAPITextEmbedder * add HuggingFaceAPITextEmbedder * rm unneeded else * small fixes * changes requested * fix test
This commit is contained in:
parent
3777f4342f
commit
c91bd49cae
@ -7,6 +7,7 @@ loaders:
|
||||
"azure_text_embedder",
|
||||
"hugging_face_tei_document_embedder",
|
||||
"hugging_face_tei_text_embedder",
|
||||
"hugging_face_api_text_embedder",
|
||||
"openai_document_embedder",
|
||||
"openai_text_embedder",
|
||||
"sentence_transformers_document_embedder",
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from haystack.components.embedders.azure_document_embedder import AzureOpenAIDocumentEmbedder
|
||||
from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder
|
||||
from haystack.components.embedders.hugging_face_api_text_embedder import HuggingFaceAPITextEmbedder
|
||||
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
|
||||
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
|
||||
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
|
||||
@ -10,6 +11,7 @@ from haystack.components.embedders.sentence_transformers_text_embedder import Se
|
||||
__all__ = [
|
||||
"HuggingFaceTEITextEmbedder",
|
||||
"HuggingFaceTEIDocumentEmbedder",
|
||||
"HuggingFaceAPITextEmbedder",
|
||||
"SentenceTransformersTextEmbedder",
|
||||
"SentenceTransformersDocumentEmbedder",
|
||||
"OpenAITextEmbedder",
|
||||
|
||||
191
haystack/components/embedders/hugging_face_api_text_embedder.py
Normal file
191
haystack/components/embedders/hugging_face_api_text_embedder.py
Normal file
@ -0,0 +1,191 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
|
||||
from haystack.utils.url_validation import is_valid_http_url
|
||||
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class HuggingFaceAPITextEmbedder:
|
||||
"""
|
||||
This component can be used to embed strings using different Hugging Face APIs:
|
||||
- [Free Serverless Inference API]((https://huggingface.co/inference-api)
|
||||
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
|
||||
- [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
|
||||
|
||||
|
||||
Example usage with the free Serverless Inference API:
|
||||
```python
|
||||
from haystack.components.embedders import HuggingFaceAPITextEmbedder
|
||||
from haystack.utils import Secret
|
||||
|
||||
text_embedder = HuggingFaceAPITextEmbedder(api_type="serverless_inference_api",
|
||||
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
||||
token=Secret.from_token("<your-api-key>"))
|
||||
|
||||
print(text_embedder.run("I love pizza!"))
|
||||
|
||||
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
|
||||
```
|
||||
|
||||
Example usage with paid Inference Endpoints:
|
||||
```python
|
||||
from haystack.components.embedders import HuggingFaceAPITextEmbedder
|
||||
from haystack.utils import Secret
|
||||
text_embedder = HuggingFaceAPITextEmbedder(api_type="inference_endpoints",
|
||||
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
||||
token=Secret.from_token("<your-api-key>"))
|
||||
|
||||
print(text_embedder.run("I love pizza!"))
|
||||
|
||||
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
|
||||
```
|
||||
|
||||
Example usage with self-hosted Text Embeddings Inference:
|
||||
```python
|
||||
from haystack.components.embedders import HuggingFaceAPITextEmbedder
|
||||
from haystack.utils import Secret
|
||||
|
||||
text_embedder = HuggingFaceAPITextEmbedder(api_type="text_embeddings_inference",
|
||||
api_params={"url": "http://localhost:8080"})
|
||||
|
||||
print(text_embedder.run("I love pizza!"))
|
||||
|
||||
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_type: Union[HFEmbeddingAPIType, str],
|
||||
api_params: Dict[str, str],
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
truncate: bool = True,
|
||||
normalize: bool = False,
|
||||
):
|
||||
"""
|
||||
Create an HuggingFaceAPITextEmbedder component.
|
||||
|
||||
:param api_type:
|
||||
The type of Hugging Face API to use.
|
||||
:param api_params:
|
||||
A dictionary containing the following keys:
|
||||
- `model`: model ID on the Hugging Face Hub. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
|
||||
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_EMBEDDINGS_INFERENCE`.
|
||||
:param token: The HuggingFace token to use as HTTP bearer authorization
|
||||
You can find your HF token in your [account settings](https://huggingface.co/settings/tokens)
|
||||
:param prefix:
|
||||
A string to add at the beginning of each text.
|
||||
:param suffix:
|
||||
A string to add at the end of each text.
|
||||
:param truncate:
|
||||
Truncate input text from the end to the maximum length supported by the model.
|
||||
This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`.
|
||||
It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference.
|
||||
This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `True` and cannot be changed).
|
||||
:param normalize:
|
||||
Normalize the embeddings to unit length.
|
||||
This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`.
|
||||
It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference.
|
||||
This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `False` and cannot be changed).
|
||||
"""
|
||||
huggingface_hub_import.check()
|
||||
|
||||
if isinstance(api_type, str):
|
||||
api_type = HFEmbeddingAPIType.from_str(api_type)
|
||||
|
||||
if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
|
||||
model = api_params.get("model")
|
||||
if model is None:
|
||||
raise ValueError(
|
||||
"To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
|
||||
)
|
||||
check_valid_model(model, HFModelType.EMBEDDING, token)
|
||||
model_or_url = model
|
||||
elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
|
||||
url = api_params.get("url")
|
||||
if url is None:
|
||||
raise ValueError(
|
||||
"To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` parameter in `api_params`."
|
||||
)
|
||||
if not is_valid_http_url(url):
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
model_or_url = url
|
||||
|
||||
self.api_type = api_type
|
||||
self.api_params = api_params
|
||||
self.token = token
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.truncate = truncate
|
||||
self.normalize = normalize
|
||||
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
api_type=self.api_type,
|
||||
api_params=self.api_params,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
truncate=self.truncate,
|
||||
normalize=self.normalize,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPITextEmbedder":
|
||||
"""
|
||||
Deserializes the component from a dictionary.
|
||||
|
||||
:param data:
|
||||
Dictionary to deserialize from.
|
||||
:returns:
|
||||
Deserialized component.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(embedding=List[float])
|
||||
def run(self, text: str):
|
||||
"""
|
||||
Embed a single string.
|
||||
|
||||
:param text:
|
||||
Text to embed.
|
||||
|
||||
:returns:
|
||||
A dictionary with the following keys:
|
||||
- `embedding`: The embedding of the input text.
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(
|
||||
"HuggingFaceAPITextEmbedder expects a string as an input."
|
||||
"In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder."
|
||||
)
|
||||
|
||||
text_to_embed = self.prefix + text + self.suffix
|
||||
|
||||
response = self._client.post(
|
||||
json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize},
|
||||
task="feature-extraction",
|
||||
)
|
||||
embedding = json.loads(response.decode())[0]
|
||||
|
||||
return {"embedding": embedding}
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@ -74,6 +75,12 @@ class HuggingFaceTEITextEmbedder:
|
||||
Inference (TEI) endpoints and paid HF Inference Endpoints deployed with TEI. It will be ignored when used
|
||||
with free HF Inference endpoints or paid HF Inference endpoints deployed without TEI.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`HuggingFaceTEITextEmbedder` is deprecated and will be removed in Haystack 2.3.0."
|
||||
"Use `HuggingFaceAPITextEmbedder` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
huggingface_hub_import.check()
|
||||
|
||||
if url:
|
||||
|
||||
@ -55,6 +55,33 @@ class HFGenerationAPIType(Enum):
|
||||
return mode
|
||||
|
||||
|
||||
class HFEmbeddingAPIType(Enum):
|
||||
"""
|
||||
API type to use for Hugging Face API Embedders.
|
||||
"""
|
||||
|
||||
# HF [Text Embeddings Inference (TEI)](https://github.com/huggingface/text-embeddings-inference).
|
||||
TEXT_EMBEDDINGS_INFERENCE = "text_embeddings_inference"
|
||||
|
||||
# HF [Inference Endpoints](https://huggingface.co/inference-endpoints).
|
||||
INFERENCE_ENDPOINTS = "inference_endpoints"
|
||||
|
||||
# HF [Serverless Inference API](https://huggingface.co/inference-api).
|
||||
SERVERLESS_INFERENCE_API = "serverless_inference_api"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
@staticmethod
|
||||
def from_str(string: str) -> "HFEmbeddingAPIType":
|
||||
enum_map = {e.value: e for e in HFEmbeddingAPIType}
|
||||
mode = enum_map.get(string)
|
||||
if mode is None:
|
||||
msg = f"Unknown Hugging Face API type '{string}'. Supported types are: {list(enum_map.keys())}"
|
||||
raise ValueError(msg)
|
||||
return mode
|
||||
|
||||
|
||||
class HFModelType(Enum):
|
||||
EMBEDDING = 1
|
||||
GENERATION = 2
|
||||
|
||||
13
releasenotes/notes/hfapitextembedder-97bf5f739f413f3e.yaml
Normal file
13
releasenotes/notes/hfapitextembedder-97bf5f739f413f3e.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Introduce `HuggingFaceAPITextEmbedder`.
|
||||
This component can be used to embed strings using different Hugging Face APIs:
|
||||
- free Serverless Inference API
|
||||
- paid Inference Endpoints
|
||||
- self-hosted Text Embeddings Inference.
|
||||
This embedder will replace the `HuggingFaceTEITextEmbedder` in the future.
|
||||
deprecations:
|
||||
- |
|
||||
Deprecate `HuggingFaceTEITextEmbedder`. This component will be removed in Haystack 2.3.0.
|
||||
Use `HuggingFaceAPITextEmbedder` instead.
|
||||
172
test/components/embedders/test_hugging_face_api_text_embedder.py
Normal file
172
test/components/embedders/test_hugging_face_api_text_embedder.py
Normal file
@ -0,0 +1,172 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
from numpy import array, random
|
||||
|
||||
from haystack.components.embedders import HuggingFaceAPITextEmbedder
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.utils.hf import HFEmbeddingAPIType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_check_valid_model():
|
||||
with patch(
|
||||
"haystack.components.embedders.hugging_face_api_text_embedder.check_valid_model", MagicMock(return_value=None)
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def mock_embedding_generation(json, **kwargs):
|
||||
response = str(array([random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode()
|
||||
return response
|
||||
|
||||
|
||||
class TestHuggingFaceAPITextEmbedder:
|
||||
def test_init_invalid_api_type(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPITextEmbedder(api_type="invalid_api_type", api_params={})
|
||||
|
||||
def test_init_serverless(self, mock_check_valid_model):
|
||||
model = "BAAI/bge-small-en-v1.5"
|
||||
embedder = HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model}
|
||||
)
|
||||
|
||||
assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API
|
||||
assert embedder.api_params == {"model": model}
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate
|
||||
assert not embedder.normalize
|
||||
|
||||
def test_init_serverless_invalid_model(self, mock_check_valid_model):
|
||||
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
|
||||
with pytest.raises(RepositoryNotFoundError):
|
||||
HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"}
|
||||
)
|
||||
|
||||
def test_init_serverless_no_model(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"}
|
||||
)
|
||||
|
||||
def test_init_tei(self):
|
||||
url = "https://some_model.com"
|
||||
|
||||
embedder = HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": url}
|
||||
)
|
||||
|
||||
assert embedder.api_type == HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE
|
||||
assert embedder.api_params == {"url": url}
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.truncate
|
||||
assert not embedder.normalize
|
||||
|
||||
def test_init_tei_invalid_url(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": "invalid_url"}
|
||||
)
|
||||
|
||||
def test_init_tei_no_url(self):
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"param": "irrelevant"}
|
||||
)
|
||||
|
||||
def test_to_dict(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
truncate=False,
|
||||
normalize=True,
|
||||
)
|
||||
|
||||
data = embedder.to_dict()
|
||||
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder",
|
||||
"init_parameters": {
|
||||
"api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
"api_params": {"model": "BAAI/bge-small-en-v1.5"},
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self, mock_check_valid_model):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder",
|
||||
"init_parameters": {
|
||||
"api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
"api_params": {"model": "BAAI/bge-small-en-v1.5"},
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"truncate": False,
|
||||
"normalize": True,
|
||||
},
|
||||
}
|
||||
|
||||
embedder = HuggingFaceAPITextEmbedder.from_dict(data)
|
||||
|
||||
assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API
|
||||
assert embedder.api_params == {"model": "BAAI/bge-small-en-v1.5"}
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert not embedder.truncate
|
||||
assert embedder.normalize
|
||||
|
||||
def test_run_wrong_input_format(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
|
||||
)
|
||||
|
||||
list_integers_input = [1, 2, 3]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
embedder.run(text=list_integers_input)
|
||||
|
||||
def test_run(self, mock_check_valid_model):
|
||||
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix ",
|
||||
suffix=" suffix",
|
||||
)
|
||||
|
||||
result = embedder.run(text="The food was delicious")
|
||||
|
||||
mock_embedding_patch.assert_called_once_with(
|
||||
json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False},
|
||||
task="feature-extraction",
|
||||
)
|
||||
|
||||
assert len(result["embedding"]) == 384
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
|
||||
@pytest.mark.flaky(reruns=5, reruns_delay=5)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_serverless(self):
|
||||
embedder = HuggingFaceAPITextEmbedder(
|
||||
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
|
||||
api_params={"model": "sentence-transformers/all-MiniLM-L6-v2"},
|
||||
)
|
||||
result = embedder.run(text="The food was delicious")
|
||||
|
||||
assert len(result["embedding"]) == 384
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
Loading…
x
Reference in New Issue
Block a user