mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
feat: Add HuggingFace API (text-embeddings-inference for rerank model) for component.rankers (#9414)
* feat(component.rankers): Add HuggingFace API (text-embeddings-inference for rerank) ranker component * update test flow & doc loaders * Support run_async for HuggingFaceAPIRanker * Add release note for HuggingFace API support in component.rankers * Add release note for HuggingFace API support in component.rankers * Add release note for HuggingFace API support in component.rankers * Add release note for HuggingFace API support in component.rankers * fix: 1. `hugging_face_api.HuggingFaceAPIRanker` rename to `hugging_face_tei.HuggingFaceAPIRanker` 2. HuggingFaceAPIRanker: use our Secret API for token 3. add the missing modules for `docs/pydoc/config/rankers_api.yml` 4. added function `async_request_with_retry` for `haystack/utils/requests_utils.py` and added unittest on `test/utils/test_requests_utils.py` 4. HuggingFaceAPIRanker: refactor the retry function to support configuration based on attempts and status code. 5. HuggingFaceAPIRanker: refactor the test into unit tests using mocks * fix(HuggingFaceTEIRanker): change the token check logic to use the resolve_value method. * fix(format): run `hatch run format` * fix: - Force keyword-only arguments in __init__ method by adding *, - Clarify token docstring that it's not always required - Copy documents to avoid modifying original objects - Remove test file from slow workflow - Add monkeypatch eånvironment variable cleanup in tests - Fix missing module in rankers_api.yml and sort modules alphabetically - Remove unnecessary test info from release notes * fix HuggingFaceTEIRanker: - "None" of "Optional[Secret]" has no attribute "resolve_value" - run/run_async: too many parameters * fix(HuggingFaceTEIRanker) :Revise the docstring of the HuggingFaceTEIRanker, improve the parameter descriptions, ensure consistency and clarity. Add error handling information to enhance the readability of the API response. * fix:unit test for HuggingFaceTEIRanker raise message * fix fmt * minor refinements * refine release note --------- Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
parent
db3d95b12a
commit
3deaa20cb6
3
.gitignore
vendored
3
.gitignore
vendored
@ -163,3 +163,6 @@ haystack/json-schemas
|
||||
|
||||
# Zed configs
|
||||
.zed/*
|
||||
|
||||
# uv
|
||||
uv.lock
|
||||
|
@ -1,8 +1,14 @@
|
||||
loaders:
|
||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||
search_path: [../../../haystack/components/rankers]
|
||||
modules: ["lost_in_the_middle", "meta_field", "meta_field_grouping_ranker", "transformers_similarity",
|
||||
"sentence_transformers_diversity", "sentence_transformers_similarity"]
|
||||
modules: [
|
||||
"hugging_face_tei",
|
||||
"lost_in_the_middle",
|
||||
"meta_field",
|
||||
"meta_field_grouping_ranker",
|
||||
"sentence_transformers_diversity",
|
||||
"sentence_transformers_similarity",
|
||||
"transformers_similarity"]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
- type: filter
|
||||
|
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
|
||||
from lazy_imports import LazyImporter
|
||||
|
||||
_import_structure = {
|
||||
"hugging_face_tei": ["HuggingFaceTEIRanker"],
|
||||
"lost_in_the_middle": ["LostInTheMiddleRanker"],
|
||||
"meta_field": ["MetaFieldRanker"],
|
||||
"meta_field_grouping_ranker": ["MetaFieldGroupingRanker"],
|
||||
@ -17,6 +18,7 @@ _import_structure = {
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .hugging_face_tei import HuggingFaceTEIRanker
|
||||
from .lost_in_the_middle import LostInTheMiddleRanker
|
||||
from .meta_field import MetaFieldRanker
|
||||
from .meta_field_grouping_ranker import MetaFieldGroupingRanker
|
||||
|
270
haystack/components/rankers/hugging_face_tei.py
Normal file
270
haystack/components/rankers/hugging_face_tei.py
Normal file
@ -0,0 +1,270 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.requests_utils import async_request_with_retry, request_with_retry
|
||||
|
||||
|
||||
class TruncationDirection(str, Enum):
|
||||
"""
|
||||
Defines the direction to truncate text when input length exceeds the model's limit.
|
||||
|
||||
Attributes:
|
||||
LEFT: Truncate text from the left side (start of text).
|
||||
RIGHT: Truncate text from the right side (end of text).
|
||||
"""
|
||||
|
||||
LEFT = "Left"
|
||||
RIGHT = "Right"
|
||||
|
||||
|
||||
@component
|
||||
class HuggingFaceTEIRanker:
|
||||
"""
|
||||
Ranks documents based on their semantic similarity to the query.
|
||||
|
||||
It can be used with a Text Embeddings Inference (TEI) API endpoint:
|
||||
- [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
|
||||
- [Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints)
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack import Document
|
||||
from haystack.components.rankers import HuggingFaceTEIRanker
|
||||
from haystack.utils import Secret
|
||||
|
||||
reranker = HuggingFaceTEIRanker(
|
||||
url="http://localhost:8080",
|
||||
top_k=5,
|
||||
timeout=30,
|
||||
token=Secret.from_token("my_api_token")
|
||||
)
|
||||
|
||||
docs = [Document(content="The capital of France is Paris"), Document(content="The capital of Germany is Berlin")]
|
||||
|
||||
result = reranker.run(query="What is the capital of France?", documents=docs)
|
||||
|
||||
ranked_docs = result["documents"]
|
||||
print(ranked_docs)
|
||||
>> {'documents': [Document(id=..., content: 'the capital of France is Paris', score: 0.9979767),
|
||||
>> Document(id=..., content: 'the capital of Germany is Berlin', score: 0.13982213)]}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
top_k: int = 10,
|
||||
raw_scores: bool = False,
|
||||
timeout: Optional[int] = 30,
|
||||
max_retries: int = 3,
|
||||
retry_status_codes: Optional[List[int]] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the TEI reranker component.
|
||||
|
||||
:param url: Base URL of the TEI reranking service (for example, "https://api.example.com").
|
||||
:param top_k: Maximum number of top documents to return.
|
||||
:param raw_scores: If True, include raw relevance scores in the API payload.
|
||||
:param timeout: Request timeout in seconds.
|
||||
:param max_retries: Maximum number of retry attempts for failed requests.
|
||||
:param retry_status_codes: List of HTTP status codes that will trigger a retry.
|
||||
When None, HTTP 408, 418, 429 and 503 will be retried (default: None).
|
||||
:param token: The Hugging Face token to use as HTTP bearer authorization. Not always required
|
||||
depending on your TEI server configuration.
|
||||
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
|
||||
"""
|
||||
self.url = url
|
||||
self.top_k = top_k
|
||||
self.timeout = timeout
|
||||
self.token = token
|
||||
self.max_retries = max_retries
|
||||
self.retry_status_codes = retry_status_codes
|
||||
self.raw_scores = raw_scores
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
url=self.url,
|
||||
top_k=self.top_k,
|
||||
timeout=self.timeout,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
max_retries=self.max_retries,
|
||||
retry_status_codes=self.retry_status_codes,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEIRanker":
|
||||
"""
|
||||
Deserializes the component from a dictionary.
|
||||
|
||||
:param data:
|
||||
Dictionary to deserialize from.
|
||||
:returns:
|
||||
Deserialized component.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _compose_response(
|
||||
self, result: Union[Dict[str, str], List[Dict[str, Any]]], top_k: Optional[int], documents: List[Document]
|
||||
) -> Dict[str, List[Document]]:
|
||||
"""
|
||||
Processes the API response into a structured format.
|
||||
|
||||
:param result: The raw response from the API.
|
||||
|
||||
:returns: A dictionary with the following keys:
|
||||
- `documents`: A list of reranked documents.
|
||||
|
||||
:raises requests.exceptions.RequestException:
|
||||
- If the API request fails.
|
||||
|
||||
:raises RuntimeError:
|
||||
- If the API returns an error response.
|
||||
"""
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
error_type = result.get("error_type", "UnknownError")
|
||||
error_msg = result.get("error", "No additional information.")
|
||||
raise RuntimeError(f"HuggingFaceTEIRanker API call failed ({error_type}): {error_msg}")
|
||||
|
||||
# Ensure we have a list of score dicts
|
||||
if not isinstance(result, list):
|
||||
# Expected list or dict, but encountered an unknown response format.
|
||||
error_msg = f"Expected a list of score dictionaries, but got `{type(result).__name__}`. "
|
||||
error_msg += f"Response content: {result}"
|
||||
raise RuntimeError(f"Unexpected response format from text-embeddings-inference rerank API: {error_msg}")
|
||||
|
||||
# Determine number of docs to return
|
||||
final_k = min(top_k or self.top_k, len(result))
|
||||
|
||||
# Select and return the top_k documents
|
||||
ranked_docs = []
|
||||
for item in result[:final_k]:
|
||||
index: int = item["index"]
|
||||
doc_copy = copy.copy(documents[index])
|
||||
doc_copy.score = item["score"]
|
||||
ranked_docs.append(doc_copy)
|
||||
return {"documents": ranked_docs}
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[Document],
|
||||
top_k: Optional[int] = None,
|
||||
truncation_direction: Optional[TruncationDirection] = None,
|
||||
) -> Dict[str, List[Document]]:
|
||||
"""
|
||||
Reranks the provided documents by relevance to the query using the TEI API.
|
||||
|
||||
:param query: The user query string to guide reranking.
|
||||
:param documents: List of `Document` objects to rerank.
|
||||
:param top_k: Optional override for the maximum number of documents to return.
|
||||
:param truncation_direction: If set, enables text truncation in the specified direction.
|
||||
|
||||
:returns: A dictionary with the following keys:
|
||||
- `documents`: A list of reranked documents.
|
||||
|
||||
:raises requests.exceptions.RequestException:
|
||||
- If the API request fails.
|
||||
|
||||
:raises RuntimeError:
|
||||
- If the API returns an error response.
|
||||
"""
|
||||
# Return empty if no documents provided
|
||||
if not documents:
|
||||
return {"documents": []}
|
||||
|
||||
# Prepare the payload
|
||||
texts = [doc.content for doc in documents]
|
||||
payload: Dict[str, Any] = {"query": query, "texts": texts, "raw_scores": self.raw_scores}
|
||||
if truncation_direction:
|
||||
payload.update({"truncate": True, "truncation_direction": truncation_direction.value})
|
||||
|
||||
headers = {}
|
||||
if self.token and self.token.resolve_value():
|
||||
headers["Authorization"] = f"Bearer {self.token.resolve_value()}"
|
||||
|
||||
# Call the external service with retry
|
||||
response = request_with_retry(
|
||||
method="POST",
|
||||
url=urljoin(self.url, "/rerank"),
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
headers=headers,
|
||||
attempts=self.max_retries,
|
||||
status_codes_to_retry=self.retry_status_codes,
|
||||
)
|
||||
|
||||
result: Union[Dict[str, str], List[Dict[str, Any]]] = response.json()
|
||||
|
||||
return self._compose_response(result, top_k, documents)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
async def run_async(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[Document],
|
||||
top_k: Optional[int] = None,
|
||||
truncation_direction: Optional[TruncationDirection] = None,
|
||||
) -> Dict[str, List[Document]]:
|
||||
"""
|
||||
Asynchronously reranks the provided documents by relevance to the query using the TEI API.
|
||||
|
||||
:param query: The user query string to guide reranking.
|
||||
:param documents: List of `Document` objects to rerank.
|
||||
:param top_k: Optional override for the maximum number of documents to return.
|
||||
:param truncation_direction: If set, enables text truncation in the specified direction.
|
||||
|
||||
:returns: A dictionary with the following keys:
|
||||
- `documents`: A list of reranked documents.
|
||||
|
||||
:raises httpx.RequestError:
|
||||
- If the API request fails.
|
||||
:raises RuntimeError:
|
||||
- If the API returns an error response.
|
||||
"""
|
||||
# Return empty if no documents provided
|
||||
if not documents:
|
||||
return {"documents": []}
|
||||
|
||||
# Prepare the payload
|
||||
texts = [doc.content for doc in documents]
|
||||
payload: Dict[str, Any] = {"query": query, "texts": texts, "raw_scores": self.raw_scores}
|
||||
if truncation_direction:
|
||||
payload.update({"truncate": True, "truncation_direction": truncation_direction.value})
|
||||
|
||||
headers = {}
|
||||
if self.token and self.token.resolve_value():
|
||||
headers["Authorization"] = f"Bearer {self.token.resolve_value()}"
|
||||
|
||||
# Call the external service with retry
|
||||
response = await async_request_with_retry(
|
||||
method="POST",
|
||||
url=urljoin(self.url, "/rerank"),
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
headers=headers,
|
||||
attempts=self.max_retries,
|
||||
status_codes_to_retry=self.retry_status_codes,
|
||||
)
|
||||
|
||||
result: Union[Dict[str, str], List[Dict[str, Any]]] = response.json()
|
||||
|
||||
return self._compose_response(result, top_k, documents)
|
@ -18,7 +18,7 @@ _import_structure = {
|
||||
"jinja2_extensions": ["Jinja2TimeExtension"],
|
||||
"jupyter": ["is_in_jupyter"],
|
||||
"misc": ["expit", "expand_page_range"],
|
||||
"requests_utils": ["request_with_retry"],
|
||||
"requests_utils": ["request_with_retry", "async_request_with_retry"],
|
||||
"type_serialization": ["deserialize_type", "serialize_type"],
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||
from .jinja2_extensions import Jinja2TimeExtension
|
||||
from .jupyter import is_in_jupyter
|
||||
from .misc import expand_page_range, expit
|
||||
from .requests_utils import request_with_retry
|
||||
from .requests_utils import async_request_with_retry, request_with_retry
|
||||
from .type_serialization import deserialize_type, serialize_type
|
||||
else:
|
||||
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
|
||||
|
@ -5,6 +5,7 @@
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from tenacity import after_log, before_log, retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
@ -95,3 +96,113 @@ def request_with_retry(
|
||||
# won't trigger a retry, this way the call will still cause an explicit exception
|
||||
res.raise_for_status()
|
||||
return res
|
||||
|
||||
|
||||
async def async_request_with_retry(
|
||||
attempts: int = 3, status_codes_to_retry: Optional[List[int]] = None, **kwargs: Any
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Executes an asynchronous HTTP request with a configurable exponential backoff retry on failures.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
import asyncio
|
||||
from haystack.utils import async_request_with_retry
|
||||
|
||||
# Sending an async HTTP request with default retry configs
|
||||
async def example():
|
||||
res = await async_request_with_retry(method="GET", url="https://example.com")
|
||||
return res
|
||||
|
||||
# Sending an async HTTP request with custom number of attempts
|
||||
async def example_with_attempts():
|
||||
res = await async_request_with_retry(method="GET", url="https://example.com", attempts=10)
|
||||
return res
|
||||
|
||||
# Sending an async HTTP request with custom HTTP codes to retry
|
||||
async def example_with_status_codes():
|
||||
res = await async_request_with_retry(method="GET", url="https://example.com", status_codes_to_retry=[408, 503])
|
||||
return res
|
||||
|
||||
# Sending an async HTTP request with custom timeout in seconds
|
||||
async def example_with_timeout():
|
||||
res = await async_request_with_retry(method="GET", url="https://example.com", timeout=5)
|
||||
return res
|
||||
|
||||
# Sending an async HTTP request with custom headers
|
||||
async def example_with_headers():
|
||||
headers = {"Authorization": "Bearer <my_token_here>"}
|
||||
res = await async_request_with_retry(method="GET", url="https://example.com", headers=headers)
|
||||
return res
|
||||
|
||||
# All of the above combined
|
||||
async def example_combined():
|
||||
headers = {"Authorization": "Bearer <my_token_here>"}
|
||||
res = await async_request_with_retry(
|
||||
method="GET",
|
||||
url="https://example.com",
|
||||
headers=headers,
|
||||
attempts=10,
|
||||
status_codes_to_retry=[408, 503],
|
||||
timeout=5
|
||||
)
|
||||
return res
|
||||
|
||||
# Sending an async POST request
|
||||
async def example_post():
|
||||
res = await async_request_with_retry(
|
||||
method="POST",
|
||||
url="https://example.com",
|
||||
json={"key": "value"},
|
||||
attempts=10
|
||||
)
|
||||
return res
|
||||
|
||||
# Retry all 5xx status codes
|
||||
async def example_5xx():
|
||||
res = await async_request_with_retry(
|
||||
method="GET",
|
||||
url="https://example.com",
|
||||
status_codes_to_retry=list(range(500, 600))
|
||||
)
|
||||
return res
|
||||
```
|
||||
|
||||
:param attempts:
|
||||
Maximum number of attempts to retry the request.
|
||||
:param status_codes_to_retry:
|
||||
List of HTTP status codes that will trigger a retry.
|
||||
When param is `None`, HTTP 408, 418, 429 and 503 will be retried.
|
||||
:param kwargs:
|
||||
Optional arguments that `httpx.AsyncClient.request` accepts.
|
||||
:returns:
|
||||
The `httpx.Response` object.
|
||||
"""
|
||||
|
||||
if status_codes_to_retry is None:
|
||||
status_codes_to_retry = [408, 418, 429, 503]
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
wait=wait_exponential(),
|
||||
retry=retry_if_exception_type((httpx.HTTPError, TimeoutError)),
|
||||
stop=stop_after_attempt(attempts),
|
||||
before=before_log(logger, logging.DEBUG),
|
||||
after=after_log(logger, logging.DEBUG),
|
||||
)
|
||||
async def run():
|
||||
timeout = kwargs.pop("timeout", 10)
|
||||
async with httpx.AsyncClient() as client:
|
||||
res = await client.request(**kwargs, timeout=timeout)
|
||||
|
||||
if res.status_code in status_codes_to_retry:
|
||||
# We raise only for the status codes that must trigger a retry
|
||||
res.raise_for_status()
|
||||
|
||||
return res
|
||||
|
||||
res = await run()
|
||||
# We raise here too in case the request failed with a status code that
|
||||
# won't trigger a retry, this way the call will still cause an explicit exception
|
||||
res.raise_for_status()
|
||||
return res
|
||||
|
@ -0,0 +1,6 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Added new `HuggingFaceTEIRanker` component to enable reranking with Text Embeddings Inference (TEI) API.
|
||||
This component supports both both self-hosted Text Embeddings Inference services and Hugging Face Inference
|
||||
Endpoints.
|
331
test/components/rankers/test_hugging_face_tei.py
Normal file
331
test/components/rankers/test_hugging_face_tei.py
Normal file
@ -0,0 +1,331 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from haystack import Document
|
||||
from haystack.components.rankers.hugging_face_tei import HuggingFaceTEIRanker, TruncationDirection
|
||||
from haystack.utils import Secret
|
||||
|
||||
|
||||
class TestHuggingFaceTEIRanker:
|
||||
def test_init(self, monkeypatch):
|
||||
"""Test initialization with default and custom parameters"""
|
||||
# Ensure we're not using system environment variables
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
|
||||
# Default parameters
|
||||
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com")
|
||||
assert ranker.url == "https://api.my-tei-service.com"
|
||||
assert ranker.top_k == 10
|
||||
assert ranker.timeout == 30
|
||||
assert not ranker.token.resolve_value()
|
||||
assert ranker.max_retries == 3
|
||||
assert ranker.retry_status_codes is None
|
||||
|
||||
# Custom parameters
|
||||
token = Secret.from_token("my_api_token")
|
||||
ranker = HuggingFaceTEIRanker(
|
||||
url="https://api.my-tei-service.com",
|
||||
top_k=5,
|
||||
timeout=60,
|
||||
token=token,
|
||||
max_retries=5,
|
||||
retry_status_codes=[500, 502, 503],
|
||||
)
|
||||
assert ranker.url == "https://api.my-tei-service.com"
|
||||
assert ranker.top_k == 5
|
||||
assert ranker.timeout == 60
|
||||
assert ranker.token == token
|
||||
assert ranker.max_retries == 5
|
||||
assert ranker.retry_status_codes == [500, 502, 503]
|
||||
|
||||
def test_to_dict(self, monkeypatch):
|
||||
"""Test serialization to dict with Secret token"""
|
||||
# Ensure we're not using system environment variables
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
|
||||
component = HuggingFaceTEIRanker(
|
||||
url="https://api.my-tei-service.com", top_k=5, timeout=30, max_retries=4, retry_status_codes=[500, 502]
|
||||
)
|
||||
data = component.to_dict()
|
||||
|
||||
assert data["type"] == "haystack.components.rankers.hugging_face_tei.HuggingFaceTEIRanker"
|
||||
assert data["init_parameters"]["url"] == "https://api.my-tei-service.com"
|
||||
assert data["init_parameters"]["top_k"] == 5
|
||||
assert data["init_parameters"]["timeout"] == 30
|
||||
assert data["init_parameters"]["token"] == {
|
||||
"env_vars": ["HF_API_TOKEN", "HF_TOKEN"],
|
||||
"strict": False,
|
||||
"type": "env_var",
|
||||
}
|
||||
assert data["init_parameters"]["max_retries"] == 4
|
||||
assert data["init_parameters"]["retry_status_codes"] == [500, 502]
|
||||
|
||||
def test_from_dict(self, monkeypatch):
|
||||
"""Test deserialization from dict with environment variable token"""
|
||||
# Ensure we're not using system environment variables
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
|
||||
data = {
|
||||
"type": "haystack.components.rankers.hugging_face_tei.HuggingFaceTEIRanker",
|
||||
"init_parameters": {
|
||||
"url": "https://api.my-tei-service.com",
|
||||
"top_k": 5,
|
||||
"timeout": 30,
|
||||
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
|
||||
"max_retries": 4,
|
||||
"retry_status_codes": [500, 502],
|
||||
},
|
||||
}
|
||||
|
||||
component = HuggingFaceTEIRanker.from_dict(data)
|
||||
|
||||
assert component.url == "https://api.my-tei-service.com"
|
||||
assert component.top_k == 5
|
||||
assert component.timeout == 30
|
||||
assert component.max_retries == 4
|
||||
assert component.retry_status_codes == [500, 502]
|
||||
|
||||
def test_empty_documents(self, monkeypatch):
|
||||
"""Test that empty documents list returns empty result"""
|
||||
# Ensure we're not using system environment variables
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
|
||||
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com")
|
||||
result = ranker.run(query="test query", documents=[])
|
||||
assert result == {"documents": []}
|
||||
|
||||
@patch("haystack.components.rankers.hugging_face_tei.request_with_retry")
|
||||
def test_run_with_mock(self, mock_request, monkeypatch):
|
||||
"""Test run method with mocked API response"""
|
||||
# Ensure we're not using system environment variables
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
|
||||
# Setup mock response
|
||||
mock_response = MagicMock(spec=requests.Response)
|
||||
mock_response.json.return_value = [
|
||||
{"index": 2, "score": 0.95},
|
||||
{"index": 1, "score": 0.85},
|
||||
{"index": 0, "score": 0.75},
|
||||
]
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
# Create ranker and test documents
|
||||
token = Secret.from_token("test_token")
|
||||
ranker = HuggingFaceTEIRanker(
|
||||
url="https://api.my-tei-service.com",
|
||||
top_k=3,
|
||||
timeout=30,
|
||||
token=token,
|
||||
max_retries=4,
|
||||
retry_status_codes=[500, 502],
|
||||
)
|
||||
|
||||
docs = [Document(content="Document A"), Document(content="Document B"), Document(content="Document C")]
|
||||
|
||||
# Run the ranker
|
||||
result = ranker.run(query="test query", documents=docs)
|
||||
|
||||
# Check that request_with_retry was called with correct parameters
|
||||
mock_request.assert_called_once_with(
|
||||
method="POST",
|
||||
url="https://api.my-tei-service.com/rerank",
|
||||
json={"query": "test query", "texts": ["Document A", "Document B", "Document C"], "raw_scores": False},
|
||||
timeout=30,
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
attempts=4,
|
||||
status_codes_to_retry=[500, 502],
|
||||
)
|
||||
|
||||
# Check that documents are ranked correctly
|
||||
assert len(result["documents"]) == 3
|
||||
assert result["documents"][0].content == "Document C"
|
||||
assert result["documents"][0].score == 0.95
|
||||
assert result["documents"][1].content == "Document B"
|
||||
assert result["documents"][1].score == 0.85
|
||||
assert result["documents"][2].content == "Document A"
|
||||
assert result["documents"][2].score == 0.75
|
||||
|
||||
@patch("haystack.components.rankers.hugging_face_tei.request_with_retry")
|
||||
def test_run_with_truncation_direction(self, mock_request, monkeypatch):
|
||||
"""Test run method with truncation direction parameter"""
|
||||
# Ensure we're not using system environment variables
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
|
||||
# Setup mock response
|
||||
mock_response = MagicMock(spec=requests.Response)
|
||||
mock_response.json.return_value = [{"index": 0, "score": 0.95}]
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
# Create ranker and test documents
|
||||
token = Secret.from_token("test_token")
|
||||
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com", token=token)
|
||||
docs = [Document(content="Document A")]
|
||||
|
||||
# Run the ranker with truncation direction
|
||||
ranker.run(query="test query", documents=docs, truncation_direction=TruncationDirection.LEFT)
|
||||
|
||||
# Check that request includes truncation parameters
|
||||
mock_request.assert_called_once_with(
|
||||
method="POST",
|
||||
url="https://api.my-tei-service.com/rerank",
|
||||
json={
|
||||
"query": "test query",
|
||||
"texts": ["Document A"],
|
||||
"raw_scores": False,
|
||||
"truncate": True,
|
||||
"truncation_direction": "Left",
|
||||
},
|
||||
timeout=30,
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
attempts=3,
|
||||
status_codes_to_retry=None,
|
||||
)
|
||||
|
||||
@patch("haystack.components.rankers.hugging_face_tei.request_with_retry")
|
||||
def test_run_with_custom_top_k(self, mock_request, monkeypatch):
|
||||
"""Test run method with custom top_k parameter"""
|
||||
# Ensure we're not using system environment variables
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
|
||||
# Setup mock response with 5 documents
|
||||
mock_response = MagicMock(spec=requests.Response)
|
||||
mock_response.json.return_value = [
|
||||
{"index": 4, "score": 0.95},
|
||||
{"index": 3, "score": 0.90},
|
||||
{"index": 2, "score": 0.85},
|
||||
{"index": 1, "score": 0.80},
|
||||
{"index": 0, "score": 0.75},
|
||||
]
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
# Create ranker with top_k=3
|
||||
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com", top_k=3)
|
||||
|
||||
# Create 5 test documents
|
||||
docs = [Document(content=f"Document {i}") for i in range(5)]
|
||||
|
||||
# Run the ranker
|
||||
result = ranker.run(query="test query", documents=docs)
|
||||
|
||||
# Check that only top 3 documents are returned
|
||||
assert len(result["documents"]) == 3
|
||||
assert result["documents"][0].content == "Document 4"
|
||||
assert result["documents"][1].content == "Document 3"
|
||||
assert result["documents"][2].content == "Document 2"
|
||||
|
||||
# Test with run-time top_k override
|
||||
result = ranker.run(query="test query", documents=docs, top_k=2)
|
||||
|
||||
# Check that only top 2 documents are returned
|
||||
assert len(result["documents"]) == 2
|
||||
assert result["documents"][0].content == "Document 4"
|
||||
assert result["documents"][1].content == "Document 3"
|
||||
|
||||
@patch("haystack.components.rankers.hugging_face_tei.request_with_retry")
|
||||
def test_error_handling(self, mock_request, monkeypatch):
|
||||
"""Test error handling in the ranker"""
|
||||
# Ensure we're not using system environment variables
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
|
||||
# Setup mock response with error
|
||||
mock_response = MagicMock(spec=requests.Response)
|
||||
mock_response.json.return_value = {"error": "Some error occurred", "error_type": "TestError"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
# Create ranker and test documents
|
||||
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com")
|
||||
docs = [Document(content="Document A")]
|
||||
|
||||
# Test that RuntimeError is raised with the correct message
|
||||
with pytest.raises(
|
||||
RuntimeError, match=r"HuggingFaceTEIRanker API call failed \(TestError\): Some error occurred"
|
||||
):
|
||||
ranker.run(query="test query", documents=docs)
|
||||
|
||||
# Test unexpected response format
|
||||
mock_response.json.return_value = {"unexpected": "format"}
|
||||
with pytest.raises(RuntimeError, match="Unexpected response format from text-embeddings-inference rerank API"):
|
||||
ranker.run(query="test query", documents=docs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("haystack.components.rankers.hugging_face_tei.async_request_with_retry")
|
||||
async def test_run_async_with_mock(self, mock_request, monkeypatch):
|
||||
"""Test run_async method with mocked API response"""
|
||||
# Ensure we're not using system environment variables
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
|
||||
# Setup mock response
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.json.return_value = [
|
||||
{"index": 2, "score": 0.95},
|
||||
{"index": 1, "score": 0.85},
|
||||
{"index": 0, "score": 0.75},
|
||||
]
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
# Create ranker and test documents
|
||||
token = Secret.from_token("test_token")
|
||||
ranker = HuggingFaceTEIRanker(
|
||||
url="https://api.my-tei-service.com",
|
||||
top_k=3,
|
||||
timeout=30,
|
||||
token=token,
|
||||
max_retries=4,
|
||||
retry_status_codes=[500, 502],
|
||||
)
|
||||
|
||||
docs = [Document(content="Document A"), Document(content="Document B"), Document(content="Document C")]
|
||||
|
||||
# Run the ranker asynchronously
|
||||
result = await ranker.run_async(query="test query", documents=docs)
|
||||
|
||||
# Check that async_request_with_retry was called with correct parameters
|
||||
mock_request.assert_called_once_with(
|
||||
method="POST",
|
||||
url="https://api.my-tei-service.com/rerank",
|
||||
json={"query": "test query", "texts": ["Document A", "Document B", "Document C"], "raw_scores": False},
|
||||
timeout=30,
|
||||
headers={"Authorization": "Bearer test_token"},
|
||||
attempts=4,
|
||||
status_codes_to_retry=[500, 502],
|
||||
)
|
||||
|
||||
# Check that documents are ranked correctly
|
||||
assert len(result["documents"]) == 3
|
||||
assert result["documents"][0].content == "Document C"
|
||||
assert result["documents"][0].score == 0.95
|
||||
assert result["documents"][1].content == "Document B"
|
||||
assert result["documents"][1].score == 0.85
|
||||
assert result["documents"][2].content == "Document A"
|
||||
assert result["documents"][2].score == 0.75
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("haystack.components.rankers.hugging_face_tei.async_request_with_retry")
|
||||
async def test_run_async_empty_documents(self, mock_request, monkeypatch):
|
||||
"""Test run_async with empty documents list"""
|
||||
# Ensure we're not using system environment variables
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
|
||||
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com")
|
||||
result = await ranker.run_async(query="test query", documents=[])
|
||||
|
||||
# Check that no API call was made
|
||||
mock_request.assert_not_called()
|
||||
assert result == {"documents": []}
|
226
test/utils/test_requests_utils.py
Normal file
226
test/utils/test_requests_utils.py
Normal file
@ -0,0 +1,226 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
import requests
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from haystack.utils.requests_utils import request_with_retry, async_request_with_retry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_requests_response():
|
||||
response = MagicMock(spec=requests.Response)
|
||||
response.status_code = 200
|
||||
response.raise_for_status.return_value = None
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_response():
|
||||
response = MagicMock(spec=httpx.Response)
|
||||
response.status_code = 200
|
||||
response.raise_for_status.return_value = None
|
||||
return response
|
||||
|
||||
|
||||
class TestRequestWithRetry:
|
||||
def test_request_with_retry_success(self, mock_requests_response):
|
||||
"""Test that request_with_retry works with default parameters"""
|
||||
with patch("requests.request", return_value=mock_requests_response) as mock_request:
|
||||
response = request_with_retry(method="GET", url="https://example.com")
|
||||
|
||||
assert response == mock_requests_response
|
||||
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
|
||||
|
||||
def test_request_with_retry_custom_attempts(self, mock_requests_response):
|
||||
"""Test that request_with_retry respects custom attempts parameter"""
|
||||
with patch("requests.request", return_value=mock_requests_response) as mock_request:
|
||||
response = request_with_retry(method="GET", url="https://example.com", attempts=5)
|
||||
|
||||
assert response == mock_requests_response
|
||||
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
|
||||
|
||||
def test_request_with_retry_custom_status_codes(self, mock_requests_response):
|
||||
"""Test that request_with_retry respects custom status_codes_to_retry parameter"""
|
||||
with patch("requests.request", return_value=mock_requests_response) as mock_request:
|
||||
response = request_with_retry(method="GET", url="https://example.com", status_codes_to_retry=[500, 502])
|
||||
|
||||
assert response == mock_requests_response
|
||||
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
|
||||
|
||||
def test_request_with_retry_custom_timeout(self, mock_requests_response):
|
||||
"""Test that request_with_retry respects custom timeout parameter"""
|
||||
with patch("requests.request", return_value=mock_requests_response) as mock_request:
|
||||
response = request_with_retry(method="GET", url="https://example.com", timeout=30)
|
||||
|
||||
assert response == mock_requests_response
|
||||
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=30)
|
||||
|
||||
def test_request_with_retry_with_headers(self, mock_requests_response):
|
||||
"""Test that request_with_retry passes headers correctly"""
|
||||
headers = {"Authorization": "Bearer token123"}
|
||||
with patch("requests.request", return_value=mock_requests_response) as mock_request:
|
||||
response = request_with_retry(method="GET", url="https://example.com", headers=headers)
|
||||
|
||||
assert response == mock_requests_response
|
||||
mock_request.assert_called_once_with(method="GET", url="https://example.com", headers=headers, timeout=10)
|
||||
|
||||
def test_request_with_retry_with_json(self, mock_requests_response):
|
||||
"""Test that request_with_retry passes JSON data correctly"""
|
||||
json_data = {"key": "value"}
|
||||
with patch("requests.request", return_value=mock_requests_response) as mock_request:
|
||||
response = request_with_retry(method="POST", url="https://example.com", json=json_data)
|
||||
|
||||
assert response == mock_requests_response
|
||||
mock_request.assert_called_once_with(method="POST", url="https://example.com", json=json_data, timeout=10)
|
||||
|
||||
def test_request_with_retry_retries_on_error(self):
|
||||
"""Test that request_with_retry retries on HTTP errors"""
|
||||
error_response = requests.Response()
|
||||
error_response.status_code = 503
|
||||
|
||||
success_response = requests.Response()
|
||||
success_response.status_code = 200
|
||||
|
||||
with patch("requests.request") as mock_request:
|
||||
# First call raises an error, second call succeeds
|
||||
mock_request.side_effect = [requests.exceptions.HTTPError("Server error"), success_response]
|
||||
|
||||
response = request_with_retry(method="GET", url="https://example.com", attempts=2)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
|
||||
def test_request_with_retry_retries_on_status_code(self):
|
||||
"""Test that request_with_retry retries on specified status codes"""
|
||||
error_response = requests.Response()
|
||||
error_response.status_code = 503
|
||||
|
||||
def raise_for_status():
|
||||
if error_response.status_code in [503]:
|
||||
raise requests.exceptions.HTTPError("Service Unavailable")
|
||||
|
||||
error_response.raise_for_status = raise_for_status
|
||||
|
||||
success_response = requests.Response()
|
||||
success_response.status_code = 200
|
||||
success_response.raise_for_status = lambda: None
|
||||
|
||||
with patch("requests.request") as mock_request:
|
||||
# First call returns error status code, second call succeeds
|
||||
mock_request.side_effect = [error_response, success_response]
|
||||
|
||||
response = request_with_retry(
|
||||
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
|
||||
)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncRequestWithRetry:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_success(self, mock_httpx_response):
|
||||
"""Test that async_request_with_retry works with default parameters"""
|
||||
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
|
||||
response = await async_request_with_retry(method="GET", url="https://example.com")
|
||||
|
||||
assert response == mock_httpx_response
|
||||
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_custom_attempts(self, mock_httpx_response):
|
||||
"""Test that async_request_with_retry respects custom attempts parameter"""
|
||||
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
|
||||
response = await async_request_with_retry(method="GET", url="https://example.com", attempts=5)
|
||||
|
||||
assert response == mock_httpx_response
|
||||
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_custom_status_codes(self, mock_httpx_response):
|
||||
"""Test that async_request_with_retry respects custom status_codes_to_retry parameter"""
|
||||
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
|
||||
response = await async_request_with_retry(
|
||||
method="GET", url="https://example.com", status_codes_to_retry=[500, 502]
|
||||
)
|
||||
|
||||
assert response == mock_httpx_response
|
||||
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_custom_timeout(self, mock_httpx_response):
|
||||
"""Test that async_request_with_retry respects custom timeout parameter"""
|
||||
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
|
||||
response = await async_request_with_retry(method="GET", url="https://example.com", timeout=30)
|
||||
|
||||
assert response == mock_httpx_response
|
||||
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=30)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_with_headers(self, mock_httpx_response):
|
||||
"""Test that async_request_with_retry passes headers correctly"""
|
||||
headers = {"Authorization": "Bearer token123"}
|
||||
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
|
||||
response = await async_request_with_retry(method="GET", url="https://example.com", headers=headers)
|
||||
|
||||
assert response == mock_httpx_response
|
||||
mock_request.assert_called_once_with(method="GET", url="https://example.com", headers=headers, timeout=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_with_json(self, mock_httpx_response):
|
||||
"""Test that async_request_with_retry passes JSON data correctly"""
|
||||
json_data = {"key": "value"}
|
||||
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
|
||||
response = await async_request_with_retry(method="POST", url="https://example.com", json=json_data)
|
||||
|
||||
assert response == mock_httpx_response
|
||||
mock_request.assert_called_once_with(method="POST", url="https://example.com", json=json_data, timeout=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_retries_on_error(self):
|
||||
"""Test that async_request_with_retry retries on HTTP errors"""
|
||||
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
|
||||
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
|
||||
|
||||
with patch("httpx.AsyncClient.request") as mock_request:
|
||||
# First call raises an error, second call succeeds
|
||||
mock_request.side_effect = [
|
||||
httpx.RequestError("Server error", request=httpx.Request("GET", "https://example.com")),
|
||||
success_response,
|
||||
]
|
||||
|
||||
response = await async_request_with_retry(method="GET", url="https://example.com", attempts=2)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_retry_retries_on_status_code(self):
|
||||
"""Test that async_request_with_retry retries on specified status codes"""
|
||||
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
|
||||
|
||||
def raise_for_status():
|
||||
if error_response.status_code in [503]:
|
||||
raise httpx.HTTPStatusError(
|
||||
"Service Unavailable", request=error_response.request, response=error_response
|
||||
)
|
||||
|
||||
error_response.raise_for_status = raise_for_status
|
||||
|
||||
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
|
||||
success_response.raise_for_status = lambda: None
|
||||
|
||||
with patch("httpx.AsyncClient.request") as mock_request:
|
||||
# First call returns error status code, second call succeeds
|
||||
mock_request.side_effect = [error_response, success_response]
|
||||
|
||||
response = await async_request_with_retry(
|
||||
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
|
||||
)
|
||||
|
||||
assert response == success_response
|
||||
assert mock_request.call_count == 2
|
Loading…
x
Reference in New Issue
Block a user