feat: Add HuggingFace API (text-embeddings-inference for rerank model) for component.rankers (#9414)

* feat(component.rankers): Add HuggingFace API (text-embeddings-inference for rerank) ranker component

* update test flow & doc loaders

* Support run_async for HuggingFaceAPIRanker

* Add release note for HuggingFace API support in component.rankers

* Add release note for HuggingFace API support in component.rankers

* Add release note for HuggingFace API support in component.rankers

* Add release note for HuggingFace API support in component.rankers

* fix:
1. `hugging_face_api.HuggingFaceAPIRanker` rename to `hugging_face_tei.HuggingFaceAPIRanker`
2. HuggingFaceAPIRanker: use our Secret API for token
3. add the missing modules for `docs/pydoc/config/rankers_api.yml`
4. added function `async_request_with_retry` for `haystack/utils/requests_utils.py` and added unittest on `test/utils/test_requests_utils.py`
4. HuggingFaceAPIRanker: refactor the retry function to support configuration based on attempts and status code.
5. HuggingFaceAPIRanker: refactor the test into unit tests using mocks

* fix(HuggingFaceTEIRanker): change the token check logic to use the resolve_value method.

* fix(format): run `hatch run format`

* fix:
- Force keyword-only arguments in __init__ method by adding *,
- Clarify token docstring that it's not always required
- Copy documents to avoid modifying original objects
- Remove test file from slow workflow
- Add monkeypatch eånvironment variable cleanup in tests
- Fix missing module in rankers_api.yml and sort modules alphabetically
- Remove unnecessary test info from release notes

* fix HuggingFaceTEIRanker:
- "None" of "Optional[Secret]" has no attribute "resolve_value"
- run/run_async: too many parameters

* fix(HuggingFaceTEIRanker) :Revise the docstring of the HuggingFaceTEIRanker, improve the parameter descriptions, ensure consistency and clarity. Add error handling information to enhance the readability of the API response.

* fix:unit test for HuggingFaceTEIRanker raise message

* fix fmt

* minor refinements

* refine release note

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
atopx 2025-05-27 18:44:54 +08:00 committed by GitHub
parent db3d95b12a
commit 3deaa20cb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 959 additions and 4 deletions

3
.gitignore vendored
View File

@ -163,3 +163,6 @@ haystack/json-schemas
# Zed configs
.zed/*
# uv
uv.lock

View File

@ -1,8 +1,14 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/rankers]
modules: ["lost_in_the_middle", "meta_field", "meta_field_grouping_ranker", "transformers_similarity",
"sentence_transformers_diversity", "sentence_transformers_similarity"]
modules: [
"hugging_face_tei",
"lost_in_the_middle",
"meta_field",
"meta_field_grouping_ranker",
"sentence_transformers_diversity",
"sentence_transformers_similarity",
"transformers_similarity"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
from lazy_imports import LazyImporter
_import_structure = {
"hugging_face_tei": ["HuggingFaceTEIRanker"],
"lost_in_the_middle": ["LostInTheMiddleRanker"],
"meta_field": ["MetaFieldRanker"],
"meta_field_grouping_ranker": ["MetaFieldGroupingRanker"],
@ -17,6 +18,7 @@ _import_structure = {
}
if TYPE_CHECKING:
from .hugging_face_tei import HuggingFaceTEIRanker
from .lost_in_the_middle import LostInTheMiddleRanker
from .meta_field import MetaFieldRanker
from .meta_field_grouping_ranker import MetaFieldGroupingRanker

View File

@ -0,0 +1,270 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import copy
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urljoin
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.requests_utils import async_request_with_retry, request_with_retry
class TruncationDirection(str, Enum):
"""
Defines the direction to truncate text when input length exceeds the model's limit.
Attributes:
LEFT: Truncate text from the left side (start of text).
RIGHT: Truncate text from the right side (end of text).
"""
LEFT = "Left"
RIGHT = "Right"
@component
class HuggingFaceTEIRanker:
"""
Ranks documents based on their semantic similarity to the query.
It can be used with a Text Embeddings Inference (TEI) API endpoint:
- [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
- [Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints)
Usage example:
```python
from haystack import Document
from haystack.components.rankers import HuggingFaceTEIRanker
from haystack.utils import Secret
reranker = HuggingFaceTEIRanker(
url="http://localhost:8080",
top_k=5,
timeout=30,
token=Secret.from_token("my_api_token")
)
docs = [Document(content="The capital of France is Paris"), Document(content="The capital of Germany is Berlin")]
result = reranker.run(query="What is the capital of France?", documents=docs)
ranked_docs = result["documents"]
print(ranked_docs)
>> {'documents': [Document(id=..., content: 'the capital of France is Paris', score: 0.9979767),
>> Document(id=..., content: 'the capital of Germany is Berlin', score: 0.13982213)]}
```
"""
def __init__(
self,
*,
url: str,
top_k: int = 10,
raw_scores: bool = False,
timeout: Optional[int] = 30,
max_retries: int = 3,
retry_status_codes: Optional[List[int]] = None,
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
) -> None:
"""
Initializes the TEI reranker component.
:param url: Base URL of the TEI reranking service (for example, "https://api.example.com").
:param top_k: Maximum number of top documents to return.
:param raw_scores: If True, include raw relevance scores in the API payload.
:param timeout: Request timeout in seconds.
:param max_retries: Maximum number of retry attempts for failed requests.
:param retry_status_codes: List of HTTP status codes that will trigger a retry.
When None, HTTP 408, 418, 429 and 503 will be retried (default: None).
:param token: The Hugging Face token to use as HTTP bearer authorization. Not always required
depending on your TEI server configuration.
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
"""
self.url = url
self.top_k = top_k
self.timeout = timeout
self.token = token
self.max_retries = max_retries
self.retry_status_codes = retry_status_codes
self.raw_scores = raw_scores
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
url=self.url,
top_k=self.top_k,
timeout=self.timeout,
token=self.token.to_dict() if self.token else None,
max_retries=self.max_retries,
retry_status_codes=self.retry_status_codes,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEIRanker":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)
def _compose_response(
self, result: Union[Dict[str, str], List[Dict[str, Any]]], top_k: Optional[int], documents: List[Document]
) -> Dict[str, List[Document]]:
"""
Processes the API response into a structured format.
:param result: The raw response from the API.
:returns: A dictionary with the following keys:
- `documents`: A list of reranked documents.
:raises requests.exceptions.RequestException:
- If the API request fails.
:raises RuntimeError:
- If the API returns an error response.
"""
if isinstance(result, dict) and "error" in result:
error_type = result.get("error_type", "UnknownError")
error_msg = result.get("error", "No additional information.")
raise RuntimeError(f"HuggingFaceTEIRanker API call failed ({error_type}): {error_msg}")
# Ensure we have a list of score dicts
if not isinstance(result, list):
# Expected list or dict, but encountered an unknown response format.
error_msg = f"Expected a list of score dictionaries, but got `{type(result).__name__}`. "
error_msg += f"Response content: {result}"
raise RuntimeError(f"Unexpected response format from text-embeddings-inference rerank API: {error_msg}")
# Determine number of docs to return
final_k = min(top_k or self.top_k, len(result))
# Select and return the top_k documents
ranked_docs = []
for item in result[:final_k]:
index: int = item["index"]
doc_copy = copy.copy(documents[index])
doc_copy.score = item["score"]
ranked_docs.append(doc_copy)
return {"documents": ranked_docs}
@component.output_types(documents=List[Document])
def run(
self,
query: str,
documents: List[Document],
top_k: Optional[int] = None,
truncation_direction: Optional[TruncationDirection] = None,
) -> Dict[str, List[Document]]:
"""
Reranks the provided documents by relevance to the query using the TEI API.
:param query: The user query string to guide reranking.
:param documents: List of `Document` objects to rerank.
:param top_k: Optional override for the maximum number of documents to return.
:param truncation_direction: If set, enables text truncation in the specified direction.
:returns: A dictionary with the following keys:
- `documents`: A list of reranked documents.
:raises requests.exceptions.RequestException:
- If the API request fails.
:raises RuntimeError:
- If the API returns an error response.
"""
# Return empty if no documents provided
if not documents:
return {"documents": []}
# Prepare the payload
texts = [doc.content for doc in documents]
payload: Dict[str, Any] = {"query": query, "texts": texts, "raw_scores": self.raw_scores}
if truncation_direction:
payload.update({"truncate": True, "truncation_direction": truncation_direction.value})
headers = {}
if self.token and self.token.resolve_value():
headers["Authorization"] = f"Bearer {self.token.resolve_value()}"
# Call the external service with retry
response = request_with_retry(
method="POST",
url=urljoin(self.url, "/rerank"),
json=payload,
timeout=self.timeout,
headers=headers,
attempts=self.max_retries,
status_codes_to_retry=self.retry_status_codes,
)
result: Union[Dict[str, str], List[Dict[str, Any]]] = response.json()
return self._compose_response(result, top_k, documents)
@component.output_types(documents=List[Document])
async def run_async(
self,
query: str,
documents: List[Document],
top_k: Optional[int] = None,
truncation_direction: Optional[TruncationDirection] = None,
) -> Dict[str, List[Document]]:
"""
Asynchronously reranks the provided documents by relevance to the query using the TEI API.
:param query: The user query string to guide reranking.
:param documents: List of `Document` objects to rerank.
:param top_k: Optional override for the maximum number of documents to return.
:param truncation_direction: If set, enables text truncation in the specified direction.
:returns: A dictionary with the following keys:
- `documents`: A list of reranked documents.
:raises httpx.RequestError:
- If the API request fails.
:raises RuntimeError:
- If the API returns an error response.
"""
# Return empty if no documents provided
if not documents:
return {"documents": []}
# Prepare the payload
texts = [doc.content for doc in documents]
payload: Dict[str, Any] = {"query": query, "texts": texts, "raw_scores": self.raw_scores}
if truncation_direction:
payload.update({"truncate": True, "truncation_direction": truncation_direction.value})
headers = {}
if self.token and self.token.resolve_value():
headers["Authorization"] = f"Bearer {self.token.resolve_value()}"
# Call the external service with retry
response = await async_request_with_retry(
method="POST",
url=urljoin(self.url, "/rerank"),
json=payload,
timeout=self.timeout,
headers=headers,
attempts=self.max_retries,
status_codes_to_retry=self.retry_status_codes,
)
result: Union[Dict[str, str], List[Dict[str, Any]]] = response.json()
return self._compose_response(result, top_k, documents)

View File

@ -18,7 +18,7 @@ _import_structure = {
"jinja2_extensions": ["Jinja2TimeExtension"],
"jupyter": ["is_in_jupyter"],
"misc": ["expit", "expand_page_range"],
"requests_utils": ["request_with_retry"],
"requests_utils": ["request_with_retry", "async_request_with_retry"],
"type_serialization": ["deserialize_type", "serialize_type"],
}
@ -33,7 +33,7 @@ if TYPE_CHECKING:
from .jinja2_extensions import Jinja2TimeExtension
from .jupyter import is_in_jupyter
from .misc import expand_page_range, expit
from .requests_utils import request_with_retry
from .requests_utils import async_request_with_retry, request_with_retry
from .type_serialization import deserialize_type, serialize_type
else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)

View File

@ -5,6 +5,7 @@
import logging
from typing import Any, List, Optional
import httpx
import requests
from tenacity import after_log, before_log, retry, retry_if_exception_type, stop_after_attempt, wait_exponential
@ -95,3 +96,113 @@ def request_with_retry(
# won't trigger a retry, this way the call will still cause an explicit exception
res.raise_for_status()
return res
async def async_request_with_retry(
attempts: int = 3, status_codes_to_retry: Optional[List[int]] = None, **kwargs: Any
) -> httpx.Response:
"""
Executes an asynchronous HTTP request with a configurable exponential backoff retry on failures.
Usage example:
```python
import asyncio
from haystack.utils import async_request_with_retry
# Sending an async HTTP request with default retry configs
async def example():
res = await async_request_with_retry(method="GET", url="https://example.com")
return res
# Sending an async HTTP request with custom number of attempts
async def example_with_attempts():
res = await async_request_with_retry(method="GET", url="https://example.com", attempts=10)
return res
# Sending an async HTTP request with custom HTTP codes to retry
async def example_with_status_codes():
res = await async_request_with_retry(method="GET", url="https://example.com", status_codes_to_retry=[408, 503])
return res
# Sending an async HTTP request with custom timeout in seconds
async def example_with_timeout():
res = await async_request_with_retry(method="GET", url="https://example.com", timeout=5)
return res
# Sending an async HTTP request with custom headers
async def example_with_headers():
headers = {"Authorization": "Bearer <my_token_here>"}
res = await async_request_with_retry(method="GET", url="https://example.com", headers=headers)
return res
# All of the above combined
async def example_combined():
headers = {"Authorization": "Bearer <my_token_here>"}
res = await async_request_with_retry(
method="GET",
url="https://example.com",
headers=headers,
attempts=10,
status_codes_to_retry=[408, 503],
timeout=5
)
return res
# Sending an async POST request
async def example_post():
res = await async_request_with_retry(
method="POST",
url="https://example.com",
json={"key": "value"},
attempts=10
)
return res
# Retry all 5xx status codes
async def example_5xx():
res = await async_request_with_retry(
method="GET",
url="https://example.com",
status_codes_to_retry=list(range(500, 600))
)
return res
```
:param attempts:
Maximum number of attempts to retry the request.
:param status_codes_to_retry:
List of HTTP status codes that will trigger a retry.
When param is `None`, HTTP 408, 418, 429 and 503 will be retried.
:param kwargs:
Optional arguments that `httpx.AsyncClient.request` accepts.
:returns:
The `httpx.Response` object.
"""
if status_codes_to_retry is None:
status_codes_to_retry = [408, 418, 429, 503]
@retry(
reraise=True,
wait=wait_exponential(),
retry=retry_if_exception_type((httpx.HTTPError, TimeoutError)),
stop=stop_after_attempt(attempts),
before=before_log(logger, logging.DEBUG),
after=after_log(logger, logging.DEBUG),
)
async def run():
timeout = kwargs.pop("timeout", 10)
async with httpx.AsyncClient() as client:
res = await client.request(**kwargs, timeout=timeout)
if res.status_code in status_codes_to_retry:
# We raise only for the status codes that must trigger a retry
res.raise_for_status()
return res
res = await run()
# We raise here too in case the request failed with a status code that
# won't trigger a retry, this way the call will still cause an explicit exception
res.raise_for_status()
return res

View File

@ -0,0 +1,6 @@
---
features:
- |
Added new `HuggingFaceTEIRanker` component to enable reranking with Text Embeddings Inference (TEI) API.
This component supports both both self-hosted Text Embeddings Inference services and Hugging Face Inference
Endpoints.

View File

@ -0,0 +1,331 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from unittest.mock import patch, MagicMock
import requests
import httpx
from haystack import Document
from haystack.components.rankers.hugging_face_tei import HuggingFaceTEIRanker, TruncationDirection
from haystack.utils import Secret
class TestHuggingFaceTEIRanker:
def test_init(self, monkeypatch):
"""Test initialization with default and custom parameters"""
# Ensure we're not using system environment variables
monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False)
# Default parameters
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com")
assert ranker.url == "https://api.my-tei-service.com"
assert ranker.top_k == 10
assert ranker.timeout == 30
assert not ranker.token.resolve_value()
assert ranker.max_retries == 3
assert ranker.retry_status_codes is None
# Custom parameters
token = Secret.from_token("my_api_token")
ranker = HuggingFaceTEIRanker(
url="https://api.my-tei-service.com",
top_k=5,
timeout=60,
token=token,
max_retries=5,
retry_status_codes=[500, 502, 503],
)
assert ranker.url == "https://api.my-tei-service.com"
assert ranker.top_k == 5
assert ranker.timeout == 60
assert ranker.token == token
assert ranker.max_retries == 5
assert ranker.retry_status_codes == [500, 502, 503]
def test_to_dict(self, monkeypatch):
"""Test serialization to dict with Secret token"""
# Ensure we're not using system environment variables
monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False)
component = HuggingFaceTEIRanker(
url="https://api.my-tei-service.com", top_k=5, timeout=30, max_retries=4, retry_status_codes=[500, 502]
)
data = component.to_dict()
assert data["type"] == "haystack.components.rankers.hugging_face_tei.HuggingFaceTEIRanker"
assert data["init_parameters"]["url"] == "https://api.my-tei-service.com"
assert data["init_parameters"]["top_k"] == 5
assert data["init_parameters"]["timeout"] == 30
assert data["init_parameters"]["token"] == {
"env_vars": ["HF_API_TOKEN", "HF_TOKEN"],
"strict": False,
"type": "env_var",
}
assert data["init_parameters"]["max_retries"] == 4
assert data["init_parameters"]["retry_status_codes"] == [500, 502]
def test_from_dict(self, monkeypatch):
"""Test deserialization from dict with environment variable token"""
# Ensure we're not using system environment variables
monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False)
data = {
"type": "haystack.components.rankers.hugging_face_tei.HuggingFaceTEIRanker",
"init_parameters": {
"url": "https://api.my-tei-service.com",
"top_k": 5,
"timeout": 30,
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
"max_retries": 4,
"retry_status_codes": [500, 502],
},
}
component = HuggingFaceTEIRanker.from_dict(data)
assert component.url == "https://api.my-tei-service.com"
assert component.top_k == 5
assert component.timeout == 30
assert component.max_retries == 4
assert component.retry_status_codes == [500, 502]
def test_empty_documents(self, monkeypatch):
"""Test that empty documents list returns empty result"""
# Ensure we're not using system environment variables
monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False)
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com")
result = ranker.run(query="test query", documents=[])
assert result == {"documents": []}
@patch("haystack.components.rankers.hugging_face_tei.request_with_retry")
def test_run_with_mock(self, mock_request, monkeypatch):
"""Test run method with mocked API response"""
# Ensure we're not using system environment variables
monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False)
# Setup mock response
mock_response = MagicMock(spec=requests.Response)
mock_response.json.return_value = [
{"index": 2, "score": 0.95},
{"index": 1, "score": 0.85},
{"index": 0, "score": 0.75},
]
mock_request.return_value = mock_response
# Create ranker and test documents
token = Secret.from_token("test_token")
ranker = HuggingFaceTEIRanker(
url="https://api.my-tei-service.com",
top_k=3,
timeout=30,
token=token,
max_retries=4,
retry_status_codes=[500, 502],
)
docs = [Document(content="Document A"), Document(content="Document B"), Document(content="Document C")]
# Run the ranker
result = ranker.run(query="test query", documents=docs)
# Check that request_with_retry was called with correct parameters
mock_request.assert_called_once_with(
method="POST",
url="https://api.my-tei-service.com/rerank",
json={"query": "test query", "texts": ["Document A", "Document B", "Document C"], "raw_scores": False},
timeout=30,
headers={"Authorization": "Bearer test_token"},
attempts=4,
status_codes_to_retry=[500, 502],
)
# Check that documents are ranked correctly
assert len(result["documents"]) == 3
assert result["documents"][0].content == "Document C"
assert result["documents"][0].score == 0.95
assert result["documents"][1].content == "Document B"
assert result["documents"][1].score == 0.85
assert result["documents"][2].content == "Document A"
assert result["documents"][2].score == 0.75
@patch("haystack.components.rankers.hugging_face_tei.request_with_retry")
def test_run_with_truncation_direction(self, mock_request, monkeypatch):
"""Test run method with truncation direction parameter"""
# Ensure we're not using system environment variables
monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False)
# Setup mock response
mock_response = MagicMock(spec=requests.Response)
mock_response.json.return_value = [{"index": 0, "score": 0.95}]
mock_request.return_value = mock_response
# Create ranker and test documents
token = Secret.from_token("test_token")
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com", token=token)
docs = [Document(content="Document A")]
# Run the ranker with truncation direction
ranker.run(query="test query", documents=docs, truncation_direction=TruncationDirection.LEFT)
# Check that request includes truncation parameters
mock_request.assert_called_once_with(
method="POST",
url="https://api.my-tei-service.com/rerank",
json={
"query": "test query",
"texts": ["Document A"],
"raw_scores": False,
"truncate": True,
"truncation_direction": "Left",
},
timeout=30,
headers={"Authorization": "Bearer test_token"},
attempts=3,
status_codes_to_retry=None,
)
@patch("haystack.components.rankers.hugging_face_tei.request_with_retry")
def test_run_with_custom_top_k(self, mock_request, monkeypatch):
"""Test run method with custom top_k parameter"""
# Ensure we're not using system environment variables
monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False)
# Setup mock response with 5 documents
mock_response = MagicMock(spec=requests.Response)
mock_response.json.return_value = [
{"index": 4, "score": 0.95},
{"index": 3, "score": 0.90},
{"index": 2, "score": 0.85},
{"index": 1, "score": 0.80},
{"index": 0, "score": 0.75},
]
mock_request.return_value = mock_response
# Create ranker with top_k=3
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com", top_k=3)
# Create 5 test documents
docs = [Document(content=f"Document {i}") for i in range(5)]
# Run the ranker
result = ranker.run(query="test query", documents=docs)
# Check that only top 3 documents are returned
assert len(result["documents"]) == 3
assert result["documents"][0].content == "Document 4"
assert result["documents"][1].content == "Document 3"
assert result["documents"][2].content == "Document 2"
# Test with run-time top_k override
result = ranker.run(query="test query", documents=docs, top_k=2)
# Check that only top 2 documents are returned
assert len(result["documents"]) == 2
assert result["documents"][0].content == "Document 4"
assert result["documents"][1].content == "Document 3"
@patch("haystack.components.rankers.hugging_face_tei.request_with_retry")
def test_error_handling(self, mock_request, monkeypatch):
"""Test error handling in the ranker"""
# Ensure we're not using system environment variables
monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False)
# Setup mock response with error
mock_response = MagicMock(spec=requests.Response)
mock_response.json.return_value = {"error": "Some error occurred", "error_type": "TestError"}
mock_request.return_value = mock_response
# Create ranker and test documents
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com")
docs = [Document(content="Document A")]
# Test that RuntimeError is raised with the correct message
with pytest.raises(
RuntimeError, match=r"HuggingFaceTEIRanker API call failed \(TestError\): Some error occurred"
):
ranker.run(query="test query", documents=docs)
# Test unexpected response format
mock_response.json.return_value = {"unexpected": "format"}
with pytest.raises(RuntimeError, match="Unexpected response format from text-embeddings-inference rerank API"):
ranker.run(query="test query", documents=docs)
@pytest.mark.asyncio
@patch("haystack.components.rankers.hugging_face_tei.async_request_with_retry")
async def test_run_async_with_mock(self, mock_request, monkeypatch):
"""Test run_async method with mocked API response"""
# Ensure we're not using system environment variables
monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False)
# Setup mock response
mock_response = MagicMock(spec=httpx.Response)
mock_response.json.return_value = [
{"index": 2, "score": 0.95},
{"index": 1, "score": 0.85},
{"index": 0, "score": 0.75},
]
mock_request.return_value = mock_response
# Create ranker and test documents
token = Secret.from_token("test_token")
ranker = HuggingFaceTEIRanker(
url="https://api.my-tei-service.com",
top_k=3,
timeout=30,
token=token,
max_retries=4,
retry_status_codes=[500, 502],
)
docs = [Document(content="Document A"), Document(content="Document B"), Document(content="Document C")]
# Run the ranker asynchronously
result = await ranker.run_async(query="test query", documents=docs)
# Check that async_request_with_retry was called with correct parameters
mock_request.assert_called_once_with(
method="POST",
url="https://api.my-tei-service.com/rerank",
json={"query": "test query", "texts": ["Document A", "Document B", "Document C"], "raw_scores": False},
timeout=30,
headers={"Authorization": "Bearer test_token"},
attempts=4,
status_codes_to_retry=[500, 502],
)
# Check that documents are ranked correctly
assert len(result["documents"]) == 3
assert result["documents"][0].content == "Document C"
assert result["documents"][0].score == 0.95
assert result["documents"][1].content == "Document B"
assert result["documents"][1].score == 0.85
assert result["documents"][2].content == "Document A"
assert result["documents"][2].score == 0.75
@pytest.mark.asyncio
@patch("haystack.components.rankers.hugging_face_tei.async_request_with_retry")
async def test_run_async_empty_documents(self, mock_request, monkeypatch):
"""Test run_async with empty documents list"""
# Ensure we're not using system environment variables
monkeypatch.delenv("HF_API_TOKEN", raising=False)
monkeypatch.delenv("HF_TOKEN", raising=False)
ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com")
result = await ranker.run_async(query="test query", documents=[])
# Check that no API call was made
mock_request.assert_not_called()
assert result == {"documents": []}

View File

@ -0,0 +1,226 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import httpx
import requests
from unittest.mock import patch, MagicMock
from haystack.utils.requests_utils import request_with_retry, async_request_with_retry
@pytest.fixture
def mock_requests_response():
response = MagicMock(spec=requests.Response)
response.status_code = 200
response.raise_for_status.return_value = None
return response
@pytest.fixture
def mock_httpx_response():
response = MagicMock(spec=httpx.Response)
response.status_code = 200
response.raise_for_status.return_value = None
return response
class TestRequestWithRetry:
def test_request_with_retry_success(self, mock_requests_response):
"""Test that request_with_retry works with default parameters"""
with patch("requests.request", return_value=mock_requests_response) as mock_request:
response = request_with_retry(method="GET", url="https://example.com")
assert response == mock_requests_response
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
def test_request_with_retry_custom_attempts(self, mock_requests_response):
"""Test that request_with_retry respects custom attempts parameter"""
with patch("requests.request", return_value=mock_requests_response) as mock_request:
response = request_with_retry(method="GET", url="https://example.com", attempts=5)
assert response == mock_requests_response
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
def test_request_with_retry_custom_status_codes(self, mock_requests_response):
"""Test that request_with_retry respects custom status_codes_to_retry parameter"""
with patch("requests.request", return_value=mock_requests_response) as mock_request:
response = request_with_retry(method="GET", url="https://example.com", status_codes_to_retry=[500, 502])
assert response == mock_requests_response
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
def test_request_with_retry_custom_timeout(self, mock_requests_response):
"""Test that request_with_retry respects custom timeout parameter"""
with patch("requests.request", return_value=mock_requests_response) as mock_request:
response = request_with_retry(method="GET", url="https://example.com", timeout=30)
assert response == mock_requests_response
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=30)
def test_request_with_retry_with_headers(self, mock_requests_response):
"""Test that request_with_retry passes headers correctly"""
headers = {"Authorization": "Bearer token123"}
with patch("requests.request", return_value=mock_requests_response) as mock_request:
response = request_with_retry(method="GET", url="https://example.com", headers=headers)
assert response == mock_requests_response
mock_request.assert_called_once_with(method="GET", url="https://example.com", headers=headers, timeout=10)
def test_request_with_retry_with_json(self, mock_requests_response):
"""Test that request_with_retry passes JSON data correctly"""
json_data = {"key": "value"}
with patch("requests.request", return_value=mock_requests_response) as mock_request:
response = request_with_retry(method="POST", url="https://example.com", json=json_data)
assert response == mock_requests_response
mock_request.assert_called_once_with(method="POST", url="https://example.com", json=json_data, timeout=10)
def test_request_with_retry_retries_on_error(self):
"""Test that request_with_retry retries on HTTP errors"""
error_response = requests.Response()
error_response.status_code = 503
success_response = requests.Response()
success_response.status_code = 200
with patch("requests.request") as mock_request:
# First call raises an error, second call succeeds
mock_request.side_effect = [requests.exceptions.HTTPError("Server error"), success_response]
response = request_with_retry(method="GET", url="https://example.com", attempts=2)
assert response == success_response
assert mock_request.call_count == 2
def test_request_with_retry_retries_on_status_code(self):
"""Test that request_with_retry retries on specified status codes"""
error_response = requests.Response()
error_response.status_code = 503
def raise_for_status():
if error_response.status_code in [503]:
raise requests.exceptions.HTTPError("Service Unavailable")
error_response.raise_for_status = raise_for_status
success_response = requests.Response()
success_response.status_code = 200
success_response.raise_for_status = lambda: None
with patch("requests.request") as mock_request:
# First call returns error status code, second call succeeds
mock_request.side_effect = [error_response, success_response]
response = request_with_retry(
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
)
assert response == success_response
assert mock_request.call_count == 2
class TestAsyncRequestWithRetry:
@pytest.mark.asyncio
async def test_async_request_with_retry_success(self, mock_httpx_response):
"""Test that async_request_with_retry works with default parameters"""
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
response = await async_request_with_retry(method="GET", url="https://example.com")
assert response == mock_httpx_response
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
@pytest.mark.asyncio
async def test_async_request_with_retry_custom_attempts(self, mock_httpx_response):
"""Test that async_request_with_retry respects custom attempts parameter"""
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
response = await async_request_with_retry(method="GET", url="https://example.com", attempts=5)
assert response == mock_httpx_response
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
@pytest.mark.asyncio
async def test_async_request_with_retry_custom_status_codes(self, mock_httpx_response):
"""Test that async_request_with_retry respects custom status_codes_to_retry parameter"""
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
response = await async_request_with_retry(
method="GET", url="https://example.com", status_codes_to_retry=[500, 502]
)
assert response == mock_httpx_response
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10)
@pytest.mark.asyncio
async def test_async_request_with_retry_custom_timeout(self, mock_httpx_response):
"""Test that async_request_with_retry respects custom timeout parameter"""
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
response = await async_request_with_retry(method="GET", url="https://example.com", timeout=30)
assert response == mock_httpx_response
mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=30)
@pytest.mark.asyncio
async def test_async_request_with_retry_with_headers(self, mock_httpx_response):
"""Test that async_request_with_retry passes headers correctly"""
headers = {"Authorization": "Bearer token123"}
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
response = await async_request_with_retry(method="GET", url="https://example.com", headers=headers)
assert response == mock_httpx_response
mock_request.assert_called_once_with(method="GET", url="https://example.com", headers=headers, timeout=10)
@pytest.mark.asyncio
async def test_async_request_with_retry_with_json(self, mock_httpx_response):
"""Test that async_request_with_retry passes JSON data correctly"""
json_data = {"key": "value"}
with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request:
response = await async_request_with_retry(method="POST", url="https://example.com", json=json_data)
assert response == mock_httpx_response
mock_request.assert_called_once_with(method="POST", url="https://example.com", json=json_data, timeout=10)
@pytest.mark.asyncio
async def test_async_request_with_retry_retries_on_error(self):
"""Test that async_request_with_retry retries on HTTP errors"""
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
with patch("httpx.AsyncClient.request") as mock_request:
# First call raises an error, second call succeeds
mock_request.side_effect = [
httpx.RequestError("Server error", request=httpx.Request("GET", "https://example.com")),
success_response,
]
response = await async_request_with_retry(method="GET", url="https://example.com", attempts=2)
assert response == success_response
assert mock_request.call_count == 2
@pytest.mark.asyncio
async def test_async_request_with_retry_retries_on_status_code(self):
"""Test that async_request_with_retry retries on specified status codes"""
error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com"))
def raise_for_status():
if error_response.status_code in [503]:
raise httpx.HTTPStatusError(
"Service Unavailable", request=error_response.request, response=error_response
)
error_response.raise_for_status = raise_for_status
success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com"))
success_response.raise_for_status = lambda: None
with patch("httpx.AsyncClient.request") as mock_request:
# First call returns error status code, second call succeeds
mock_request.side_effect = [error_response, success_response]
response = await async_request_with_retry(
method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503]
)
assert response == success_response
assert mock_request.call_count == 2