From 3deaa20cb61661700e212bb81b4f255237481f8b Mon Sep 17 00:00:00 2001 From: atopx <58352247+atopx@users.noreply.github.com> Date: Tue, 27 May 2025 18:44:54 +0800 Subject: [PATCH] feat: Add HuggingFace API (text-embeddings-inference for rerank model) for component.rankers (#9414) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(component.rankers): Add HuggingFace API (text-embeddings-inference for rerank) ranker component * update test flow & doc loaders * Support run_async for HuggingFaceAPIRanker * Add release note for HuggingFace API support in component.rankers * Add release note for HuggingFace API support in component.rankers * Add release note for HuggingFace API support in component.rankers * Add release note for HuggingFace API support in component.rankers * fix: 1. `hugging_face_api.HuggingFaceAPIRanker` rename to `hugging_face_tei.HuggingFaceAPIRanker` 2. HuggingFaceAPIRanker: use our Secret API for token 3. add the missing modules for `docs/pydoc/config/rankers_api.yml` 4. added function `async_request_with_retry` for `haystack/utils/requests_utils.py` and added unittest on `test/utils/test_requests_utils.py` 4. HuggingFaceAPIRanker: refactor the retry function to support configuration based on attempts and status code. 5. HuggingFaceAPIRanker: refactor the test into unit tests using mocks * fix(HuggingFaceTEIRanker): change the token check logic to use the resolve_value method. * fix(format): run `hatch run format` * fix: - Force keyword-only arguments in __init__ method by adding *, - Clarify token docstring that it's not always required - Copy documents to avoid modifying original objects - Remove test file from slow workflow - Add monkeypatch eånvironment variable cleanup in tests - Fix missing module in rankers_api.yml and sort modules alphabetically - Remove unnecessary test info from release notes * fix HuggingFaceTEIRanker: - "None" of "Optional[Secret]" has no attribute "resolve_value" - run/run_async: too many parameters * fix(HuggingFaceTEIRanker) :Revise the docstring of the HuggingFaceTEIRanker, improve the parameter descriptions, ensure consistency and clarity. Add error handling information to enhance the readability of the API response. * fix:unit test for HuggingFaceTEIRanker raise message * fix fmt * minor refinements * refine release note --------- Co-authored-by: anakin87 --- .gitignore | 3 + docs/pydoc/config/rankers_api.yml | 10 +- haystack/components/rankers/__init__.py | 2 + .../components/rankers/hugging_face_tei.py | 270 ++++++++++++++ haystack/utils/__init__.py | 4 +- haystack/utils/requests_utils.py | 111 ++++++ ...to-component-rankers-0e3f54e523e42141.yaml | 6 + .../rankers/test_hugging_face_tei.py | 331 ++++++++++++++++++ test/utils/test_requests_utils.py | 226 ++++++++++++ 9 files changed, 959 insertions(+), 4 deletions(-) create mode 100644 haystack/components/rankers/hugging_face_tei.py create mode 100644 releasenotes/notes/add-huggingface-api-text-embeddings-inference-to-component-rankers-0e3f54e523e42141.yaml create mode 100644 test/components/rankers/test_hugging_face_tei.py create mode 100644 test/utils/test_requests_utils.py diff --git a/.gitignore b/.gitignore index 12b220303..5be2a1fde 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,6 @@ haystack/json-schemas # Zed configs .zed/* + +# uv +uv.lock diff --git a/docs/pydoc/config/rankers_api.yml b/docs/pydoc/config/rankers_api.yml index 1806daa88..64601c0c3 100644 --- a/docs/pydoc/config/rankers_api.yml +++ b/docs/pydoc/config/rankers_api.yml @@ -1,8 +1,14 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/components/rankers] - modules: ["lost_in_the_middle", "meta_field", "meta_field_grouping_ranker", "transformers_similarity", - "sentence_transformers_diversity", "sentence_transformers_similarity"] + modules: [ + "hugging_face_tei", + "lost_in_the_middle", + "meta_field", + "meta_field_grouping_ranker", + "sentence_transformers_diversity", + "sentence_transformers_similarity", + "transformers_similarity"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/components/rankers/__init__.py b/haystack/components/rankers/__init__.py index 128408aaf..d6be5240f 100644 --- a/haystack/components/rankers/__init__.py +++ b/haystack/components/rankers/__init__.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING from lazy_imports import LazyImporter _import_structure = { + "hugging_face_tei": ["HuggingFaceTEIRanker"], "lost_in_the_middle": ["LostInTheMiddleRanker"], "meta_field": ["MetaFieldRanker"], "meta_field_grouping_ranker": ["MetaFieldGroupingRanker"], @@ -17,6 +18,7 @@ _import_structure = { } if TYPE_CHECKING: + from .hugging_face_tei import HuggingFaceTEIRanker from .lost_in_the_middle import LostInTheMiddleRanker from .meta_field import MetaFieldRanker from .meta_field_grouping_ranker import MetaFieldGroupingRanker diff --git a/haystack/components/rankers/hugging_face_tei.py b/haystack/components/rankers/hugging_face_tei.py new file mode 100644 index 000000000..ffea65162 --- /dev/null +++ b/haystack/components/rankers/hugging_face_tei.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import copy +from enum import Enum +from typing import Any, Dict, List, Optional, Union +from urllib.parse import urljoin + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils.requests_utils import async_request_with_retry, request_with_retry + + +class TruncationDirection(str, Enum): + """ + Defines the direction to truncate text when input length exceeds the model's limit. + + Attributes: + LEFT: Truncate text from the left side (start of text). + RIGHT: Truncate text from the right side (end of text). + """ + + LEFT = "Left" + RIGHT = "Right" + + +@component +class HuggingFaceTEIRanker: + """ + Ranks documents based on their semantic similarity to the query. + + It can be used with a Text Embeddings Inference (TEI) API endpoint: + - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference) + - [Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) + + Usage example: + ```python + from haystack import Document + from haystack.components.rankers import HuggingFaceTEIRanker + from haystack.utils import Secret + + reranker = HuggingFaceTEIRanker( + url="http://localhost:8080", + top_k=5, + timeout=30, + token=Secret.from_token("my_api_token") + ) + + docs = [Document(content="The capital of France is Paris"), Document(content="The capital of Germany is Berlin")] + + result = reranker.run(query="What is the capital of France?", documents=docs) + + ranked_docs = result["documents"] + print(ranked_docs) + >> {'documents': [Document(id=..., content: 'the capital of France is Paris', score: 0.9979767), + >> Document(id=..., content: 'the capital of Germany is Berlin', score: 0.13982213)]} + ``` + """ + + def __init__( + self, + *, + url: str, + top_k: int = 10, + raw_scores: bool = False, + timeout: Optional[int] = 30, + max_retries: int = 3, + retry_status_codes: Optional[List[int]] = None, + token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), + ) -> None: + """ + Initializes the TEI reranker component. + + :param url: Base URL of the TEI reranking service (for example, "https://api.example.com"). + :param top_k: Maximum number of top documents to return. + :param raw_scores: If True, include raw relevance scores in the API payload. + :param timeout: Request timeout in seconds. + :param max_retries: Maximum number of retry attempts for failed requests. + :param retry_status_codes: List of HTTP status codes that will trigger a retry. + When None, HTTP 408, 418, 429 and 503 will be retried (default: None). + :param token: The Hugging Face token to use as HTTP bearer authorization. Not always required + depending on your TEI server configuration. + Check your HF token in your [account settings](https://huggingface.co/settings/tokens). + """ + self.url = url + self.top_k = top_k + self.timeout = timeout + self.token = token + self.max_retries = max_retries + self.retry_status_codes = retry_status_codes + self.raw_scores = raw_scores + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + url=self.url, + top_k=self.top_k, + timeout=self.timeout, + token=self.token.to_dict() if self.token else None, + max_retries=self.max_retries, + retry_status_codes=self.retry_status_codes, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEIRanker": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + return default_from_dict(cls, data) + + def _compose_response( + self, result: Union[Dict[str, str], List[Dict[str, Any]]], top_k: Optional[int], documents: List[Document] + ) -> Dict[str, List[Document]]: + """ + Processes the API response into a structured format. + + :param result: The raw response from the API. + + :returns: A dictionary with the following keys: + - `documents`: A list of reranked documents. + + :raises requests.exceptions.RequestException: + - If the API request fails. + + :raises RuntimeError: + - If the API returns an error response. + """ + if isinstance(result, dict) and "error" in result: + error_type = result.get("error_type", "UnknownError") + error_msg = result.get("error", "No additional information.") + raise RuntimeError(f"HuggingFaceTEIRanker API call failed ({error_type}): {error_msg}") + + # Ensure we have a list of score dicts + if not isinstance(result, list): + # Expected list or dict, but encountered an unknown response format. + error_msg = f"Expected a list of score dictionaries, but got `{type(result).__name__}`. " + error_msg += f"Response content: {result}" + raise RuntimeError(f"Unexpected response format from text-embeddings-inference rerank API: {error_msg}") + + # Determine number of docs to return + final_k = min(top_k or self.top_k, len(result)) + + # Select and return the top_k documents + ranked_docs = [] + for item in result[:final_k]: + index: int = item["index"] + doc_copy = copy.copy(documents[index]) + doc_copy.score = item["score"] + ranked_docs.append(doc_copy) + return {"documents": ranked_docs} + + @component.output_types(documents=List[Document]) + def run( + self, + query: str, + documents: List[Document], + top_k: Optional[int] = None, + truncation_direction: Optional[TruncationDirection] = None, + ) -> Dict[str, List[Document]]: + """ + Reranks the provided documents by relevance to the query using the TEI API. + + :param query: The user query string to guide reranking. + :param documents: List of `Document` objects to rerank. + :param top_k: Optional override for the maximum number of documents to return. + :param truncation_direction: If set, enables text truncation in the specified direction. + + :returns: A dictionary with the following keys: + - `documents`: A list of reranked documents. + + :raises requests.exceptions.RequestException: + - If the API request fails. + + :raises RuntimeError: + - If the API returns an error response. + """ + # Return empty if no documents provided + if not documents: + return {"documents": []} + + # Prepare the payload + texts = [doc.content for doc in documents] + payload: Dict[str, Any] = {"query": query, "texts": texts, "raw_scores": self.raw_scores} + if truncation_direction: + payload.update({"truncate": True, "truncation_direction": truncation_direction.value}) + + headers = {} + if self.token and self.token.resolve_value(): + headers["Authorization"] = f"Bearer {self.token.resolve_value()}" + + # Call the external service with retry + response = request_with_retry( + method="POST", + url=urljoin(self.url, "/rerank"), + json=payload, + timeout=self.timeout, + headers=headers, + attempts=self.max_retries, + status_codes_to_retry=self.retry_status_codes, + ) + + result: Union[Dict[str, str], List[Dict[str, Any]]] = response.json() + + return self._compose_response(result, top_k, documents) + + @component.output_types(documents=List[Document]) + async def run_async( + self, + query: str, + documents: List[Document], + top_k: Optional[int] = None, + truncation_direction: Optional[TruncationDirection] = None, + ) -> Dict[str, List[Document]]: + """ + Asynchronously reranks the provided documents by relevance to the query using the TEI API. + + :param query: The user query string to guide reranking. + :param documents: List of `Document` objects to rerank. + :param top_k: Optional override for the maximum number of documents to return. + :param truncation_direction: If set, enables text truncation in the specified direction. + + :returns: A dictionary with the following keys: + - `documents`: A list of reranked documents. + + :raises httpx.RequestError: + - If the API request fails. + :raises RuntimeError: + - If the API returns an error response. + """ + # Return empty if no documents provided + if not documents: + return {"documents": []} + + # Prepare the payload + texts = [doc.content for doc in documents] + payload: Dict[str, Any] = {"query": query, "texts": texts, "raw_scores": self.raw_scores} + if truncation_direction: + payload.update({"truncate": True, "truncation_direction": truncation_direction.value}) + + headers = {} + if self.token and self.token.resolve_value(): + headers["Authorization"] = f"Bearer {self.token.resolve_value()}" + + # Call the external service with retry + response = await async_request_with_retry( + method="POST", + url=urljoin(self.url, "/rerank"), + json=payload, + timeout=self.timeout, + headers=headers, + attempts=self.max_retries, + status_codes_to_retry=self.retry_status_codes, + ) + + result: Union[Dict[str, str], List[Dict[str, Any]]] = response.json() + + return self._compose_response(result, top_k, documents) diff --git a/haystack/utils/__init__.py b/haystack/utils/__init__.py index 1b07ac487..62a67ff04 100644 --- a/haystack/utils/__init__.py +++ b/haystack/utils/__init__.py @@ -18,7 +18,7 @@ _import_structure = { "jinja2_extensions": ["Jinja2TimeExtension"], "jupyter": ["is_in_jupyter"], "misc": ["expit", "expand_page_range"], - "requests_utils": ["request_with_retry"], + "requests_utils": ["request_with_retry", "async_request_with_retry"], "type_serialization": ["deserialize_type", "serialize_type"], } @@ -33,7 +33,7 @@ if TYPE_CHECKING: from .jinja2_extensions import Jinja2TimeExtension from .jupyter import is_in_jupyter from .misc import expand_page_range, expit - from .requests_utils import request_with_retry + from .requests_utils import async_request_with_retry, request_with_retry from .type_serialization import deserialize_type, serialize_type else: sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure) diff --git a/haystack/utils/requests_utils.py b/haystack/utils/requests_utils.py index a684c3763..ec6ec4f4d 100644 --- a/haystack/utils/requests_utils.py +++ b/haystack/utils/requests_utils.py @@ -5,6 +5,7 @@ import logging from typing import Any, List, Optional +import httpx import requests from tenacity import after_log, before_log, retry, retry_if_exception_type, stop_after_attempt, wait_exponential @@ -95,3 +96,113 @@ def request_with_retry( # won't trigger a retry, this way the call will still cause an explicit exception res.raise_for_status() return res + + +async def async_request_with_retry( + attempts: int = 3, status_codes_to_retry: Optional[List[int]] = None, **kwargs: Any +) -> httpx.Response: + """ + Executes an asynchronous HTTP request with a configurable exponential backoff retry on failures. + + Usage example: + ```python + import asyncio + from haystack.utils import async_request_with_retry + + # Sending an async HTTP request with default retry configs + async def example(): + res = await async_request_with_retry(method="GET", url="https://example.com") + return res + + # Sending an async HTTP request with custom number of attempts + async def example_with_attempts(): + res = await async_request_with_retry(method="GET", url="https://example.com", attempts=10) + return res + + # Sending an async HTTP request with custom HTTP codes to retry + async def example_with_status_codes(): + res = await async_request_with_retry(method="GET", url="https://example.com", status_codes_to_retry=[408, 503]) + return res + + # Sending an async HTTP request with custom timeout in seconds + async def example_with_timeout(): + res = await async_request_with_retry(method="GET", url="https://example.com", timeout=5) + return res + + # Sending an async HTTP request with custom headers + async def example_with_headers(): + headers = {"Authorization": "Bearer "} + res = await async_request_with_retry(method="GET", url="https://example.com", headers=headers) + return res + + # All of the above combined + async def example_combined(): + headers = {"Authorization": "Bearer "} + res = await async_request_with_retry( + method="GET", + url="https://example.com", + headers=headers, + attempts=10, + status_codes_to_retry=[408, 503], + timeout=5 + ) + return res + + # Sending an async POST request + async def example_post(): + res = await async_request_with_retry( + method="POST", + url="https://example.com", + json={"key": "value"}, + attempts=10 + ) + return res + + # Retry all 5xx status codes + async def example_5xx(): + res = await async_request_with_retry( + method="GET", + url="https://example.com", + status_codes_to_retry=list(range(500, 600)) + ) + return res + ``` + + :param attempts: + Maximum number of attempts to retry the request. + :param status_codes_to_retry: + List of HTTP status codes that will trigger a retry. + When param is `None`, HTTP 408, 418, 429 and 503 will be retried. + :param kwargs: + Optional arguments that `httpx.AsyncClient.request` accepts. + :returns: + The `httpx.Response` object. + """ + + if status_codes_to_retry is None: + status_codes_to_retry = [408, 418, 429, 503] + + @retry( + reraise=True, + wait=wait_exponential(), + retry=retry_if_exception_type((httpx.HTTPError, TimeoutError)), + stop=stop_after_attempt(attempts), + before=before_log(logger, logging.DEBUG), + after=after_log(logger, logging.DEBUG), + ) + async def run(): + timeout = kwargs.pop("timeout", 10) + async with httpx.AsyncClient() as client: + res = await client.request(**kwargs, timeout=timeout) + + if res.status_code in status_codes_to_retry: + # We raise only for the status codes that must trigger a retry + res.raise_for_status() + + return res + + res = await run() + # We raise here too in case the request failed with a status code that + # won't trigger a retry, this way the call will still cause an explicit exception + res.raise_for_status() + return res diff --git a/releasenotes/notes/add-huggingface-api-text-embeddings-inference-to-component-rankers-0e3f54e523e42141.yaml b/releasenotes/notes/add-huggingface-api-text-embeddings-inference-to-component-rankers-0e3f54e523e42141.yaml new file mode 100644 index 000000000..aff1d2798 --- /dev/null +++ b/releasenotes/notes/add-huggingface-api-text-embeddings-inference-to-component-rankers-0e3f54e523e42141.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added new `HuggingFaceTEIRanker` component to enable reranking with Text Embeddings Inference (TEI) API. + This component supports both both self-hosted Text Embeddings Inference services and Hugging Face Inference + Endpoints. diff --git a/test/components/rankers/test_hugging_face_tei.py b/test/components/rankers/test_hugging_face_tei.py new file mode 100644 index 000000000..d42d33459 --- /dev/null +++ b/test/components/rankers/test_hugging_face_tei.py @@ -0,0 +1,331 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from unittest.mock import patch, MagicMock + +import requests +import httpx + +from haystack import Document +from haystack.components.rankers.hugging_face_tei import HuggingFaceTEIRanker, TruncationDirection +from haystack.utils import Secret + + +class TestHuggingFaceTEIRanker: + def test_init(self, monkeypatch): + """Test initialization with default and custom parameters""" + # Ensure we're not using system environment variables + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) + + # Default parameters + ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com") + assert ranker.url == "https://api.my-tei-service.com" + assert ranker.top_k == 10 + assert ranker.timeout == 30 + assert not ranker.token.resolve_value() + assert ranker.max_retries == 3 + assert ranker.retry_status_codes is None + + # Custom parameters + token = Secret.from_token("my_api_token") + ranker = HuggingFaceTEIRanker( + url="https://api.my-tei-service.com", + top_k=5, + timeout=60, + token=token, + max_retries=5, + retry_status_codes=[500, 502, 503], + ) + assert ranker.url == "https://api.my-tei-service.com" + assert ranker.top_k == 5 + assert ranker.timeout == 60 + assert ranker.token == token + assert ranker.max_retries == 5 + assert ranker.retry_status_codes == [500, 502, 503] + + def test_to_dict(self, monkeypatch): + """Test serialization to dict with Secret token""" + # Ensure we're not using system environment variables + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) + + component = HuggingFaceTEIRanker( + url="https://api.my-tei-service.com", top_k=5, timeout=30, max_retries=4, retry_status_codes=[500, 502] + ) + data = component.to_dict() + + assert data["type"] == "haystack.components.rankers.hugging_face_tei.HuggingFaceTEIRanker" + assert data["init_parameters"]["url"] == "https://api.my-tei-service.com" + assert data["init_parameters"]["top_k"] == 5 + assert data["init_parameters"]["timeout"] == 30 + assert data["init_parameters"]["token"] == { + "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], + "strict": False, + "type": "env_var", + } + assert data["init_parameters"]["max_retries"] == 4 + assert data["init_parameters"]["retry_status_codes"] == [500, 502] + + def test_from_dict(self, monkeypatch): + """Test deserialization from dict with environment variable token""" + # Ensure we're not using system environment variables + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) + + data = { + "type": "haystack.components.rankers.hugging_face_tei.HuggingFaceTEIRanker", + "init_parameters": { + "url": "https://api.my-tei-service.com", + "top_k": 5, + "timeout": 30, + "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False}, + "max_retries": 4, + "retry_status_codes": [500, 502], + }, + } + + component = HuggingFaceTEIRanker.from_dict(data) + + assert component.url == "https://api.my-tei-service.com" + assert component.top_k == 5 + assert component.timeout == 30 + assert component.max_retries == 4 + assert component.retry_status_codes == [500, 502] + + def test_empty_documents(self, monkeypatch): + """Test that empty documents list returns empty result""" + # Ensure we're not using system environment variables + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) + + ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com") + result = ranker.run(query="test query", documents=[]) + assert result == {"documents": []} + + @patch("haystack.components.rankers.hugging_face_tei.request_with_retry") + def test_run_with_mock(self, mock_request, monkeypatch): + """Test run method with mocked API response""" + # Ensure we're not using system environment variables + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) + + # Setup mock response + mock_response = MagicMock(spec=requests.Response) + mock_response.json.return_value = [ + {"index": 2, "score": 0.95}, + {"index": 1, "score": 0.85}, + {"index": 0, "score": 0.75}, + ] + mock_request.return_value = mock_response + + # Create ranker and test documents + token = Secret.from_token("test_token") + ranker = HuggingFaceTEIRanker( + url="https://api.my-tei-service.com", + top_k=3, + timeout=30, + token=token, + max_retries=4, + retry_status_codes=[500, 502], + ) + + docs = [Document(content="Document A"), Document(content="Document B"), Document(content="Document C")] + + # Run the ranker + result = ranker.run(query="test query", documents=docs) + + # Check that request_with_retry was called with correct parameters + mock_request.assert_called_once_with( + method="POST", + url="https://api.my-tei-service.com/rerank", + json={"query": "test query", "texts": ["Document A", "Document B", "Document C"], "raw_scores": False}, + timeout=30, + headers={"Authorization": "Bearer test_token"}, + attempts=4, + status_codes_to_retry=[500, 502], + ) + + # Check that documents are ranked correctly + assert len(result["documents"]) == 3 + assert result["documents"][0].content == "Document C" + assert result["documents"][0].score == 0.95 + assert result["documents"][1].content == "Document B" + assert result["documents"][1].score == 0.85 + assert result["documents"][2].content == "Document A" + assert result["documents"][2].score == 0.75 + + @patch("haystack.components.rankers.hugging_face_tei.request_with_retry") + def test_run_with_truncation_direction(self, mock_request, monkeypatch): + """Test run method with truncation direction parameter""" + # Ensure we're not using system environment variables + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) + + # Setup mock response + mock_response = MagicMock(spec=requests.Response) + mock_response.json.return_value = [{"index": 0, "score": 0.95}] + mock_request.return_value = mock_response + + # Create ranker and test documents + token = Secret.from_token("test_token") + ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com", token=token) + docs = [Document(content="Document A")] + + # Run the ranker with truncation direction + ranker.run(query="test query", documents=docs, truncation_direction=TruncationDirection.LEFT) + + # Check that request includes truncation parameters + mock_request.assert_called_once_with( + method="POST", + url="https://api.my-tei-service.com/rerank", + json={ + "query": "test query", + "texts": ["Document A"], + "raw_scores": False, + "truncate": True, + "truncation_direction": "Left", + }, + timeout=30, + headers={"Authorization": "Bearer test_token"}, + attempts=3, + status_codes_to_retry=None, + ) + + @patch("haystack.components.rankers.hugging_face_tei.request_with_retry") + def test_run_with_custom_top_k(self, mock_request, monkeypatch): + """Test run method with custom top_k parameter""" + # Ensure we're not using system environment variables + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) + + # Setup mock response with 5 documents + mock_response = MagicMock(spec=requests.Response) + mock_response.json.return_value = [ + {"index": 4, "score": 0.95}, + {"index": 3, "score": 0.90}, + {"index": 2, "score": 0.85}, + {"index": 1, "score": 0.80}, + {"index": 0, "score": 0.75}, + ] + mock_request.return_value = mock_response + + # Create ranker with top_k=3 + ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com", top_k=3) + + # Create 5 test documents + docs = [Document(content=f"Document {i}") for i in range(5)] + + # Run the ranker + result = ranker.run(query="test query", documents=docs) + + # Check that only top 3 documents are returned + assert len(result["documents"]) == 3 + assert result["documents"][0].content == "Document 4" + assert result["documents"][1].content == "Document 3" + assert result["documents"][2].content == "Document 2" + + # Test with run-time top_k override + result = ranker.run(query="test query", documents=docs, top_k=2) + + # Check that only top 2 documents are returned + assert len(result["documents"]) == 2 + assert result["documents"][0].content == "Document 4" + assert result["documents"][1].content == "Document 3" + + @patch("haystack.components.rankers.hugging_face_tei.request_with_retry") + def test_error_handling(self, mock_request, monkeypatch): + """Test error handling in the ranker""" + # Ensure we're not using system environment variables + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) + + # Setup mock response with error + mock_response = MagicMock(spec=requests.Response) + mock_response.json.return_value = {"error": "Some error occurred", "error_type": "TestError"} + mock_request.return_value = mock_response + + # Create ranker and test documents + ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com") + docs = [Document(content="Document A")] + + # Test that RuntimeError is raised with the correct message + with pytest.raises( + RuntimeError, match=r"HuggingFaceTEIRanker API call failed \(TestError\): Some error occurred" + ): + ranker.run(query="test query", documents=docs) + + # Test unexpected response format + mock_response.json.return_value = {"unexpected": "format"} + with pytest.raises(RuntimeError, match="Unexpected response format from text-embeddings-inference rerank API"): + ranker.run(query="test query", documents=docs) + + @pytest.mark.asyncio + @patch("haystack.components.rankers.hugging_face_tei.async_request_with_retry") + async def test_run_async_with_mock(self, mock_request, monkeypatch): + """Test run_async method with mocked API response""" + # Ensure we're not using system environment variables + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) + + # Setup mock response + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = [ + {"index": 2, "score": 0.95}, + {"index": 1, "score": 0.85}, + {"index": 0, "score": 0.75}, + ] + mock_request.return_value = mock_response + + # Create ranker and test documents + token = Secret.from_token("test_token") + ranker = HuggingFaceTEIRanker( + url="https://api.my-tei-service.com", + top_k=3, + timeout=30, + token=token, + max_retries=4, + retry_status_codes=[500, 502], + ) + + docs = [Document(content="Document A"), Document(content="Document B"), Document(content="Document C")] + + # Run the ranker asynchronously + result = await ranker.run_async(query="test query", documents=docs) + + # Check that async_request_with_retry was called with correct parameters + mock_request.assert_called_once_with( + method="POST", + url="https://api.my-tei-service.com/rerank", + json={"query": "test query", "texts": ["Document A", "Document B", "Document C"], "raw_scores": False}, + timeout=30, + headers={"Authorization": "Bearer test_token"}, + attempts=4, + status_codes_to_retry=[500, 502], + ) + + # Check that documents are ranked correctly + assert len(result["documents"]) == 3 + assert result["documents"][0].content == "Document C" + assert result["documents"][0].score == 0.95 + assert result["documents"][1].content == "Document B" + assert result["documents"][1].score == 0.85 + assert result["documents"][2].content == "Document A" + assert result["documents"][2].score == 0.75 + + @pytest.mark.asyncio + @patch("haystack.components.rankers.hugging_face_tei.async_request_with_retry") + async def test_run_async_empty_documents(self, mock_request, monkeypatch): + """Test run_async with empty documents list""" + # Ensure we're not using system environment variables + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) + + ranker = HuggingFaceTEIRanker(url="https://api.my-tei-service.com") + result = await ranker.run_async(query="test query", documents=[]) + + # Check that no API call was made + mock_request.assert_not_called() + assert result == {"documents": []} diff --git a/test/utils/test_requests_utils.py b/test/utils/test_requests_utils.py new file mode 100644 index 000000000..7b25e4788 --- /dev/null +++ b/test/utils/test_requests_utils.py @@ -0,0 +1,226 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import httpx +import requests +from unittest.mock import patch, MagicMock + +from haystack.utils.requests_utils import request_with_retry, async_request_with_retry + + +@pytest.fixture +def mock_requests_response(): + response = MagicMock(spec=requests.Response) + response.status_code = 200 + response.raise_for_status.return_value = None + return response + + +@pytest.fixture +def mock_httpx_response(): + response = MagicMock(spec=httpx.Response) + response.status_code = 200 + response.raise_for_status.return_value = None + return response + + +class TestRequestWithRetry: + def test_request_with_retry_success(self, mock_requests_response): + """Test that request_with_retry works with default parameters""" + with patch("requests.request", return_value=mock_requests_response) as mock_request: + response = request_with_retry(method="GET", url="https://example.com") + + assert response == mock_requests_response + mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10) + + def test_request_with_retry_custom_attempts(self, mock_requests_response): + """Test that request_with_retry respects custom attempts parameter""" + with patch("requests.request", return_value=mock_requests_response) as mock_request: + response = request_with_retry(method="GET", url="https://example.com", attempts=5) + + assert response == mock_requests_response + mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10) + + def test_request_with_retry_custom_status_codes(self, mock_requests_response): + """Test that request_with_retry respects custom status_codes_to_retry parameter""" + with patch("requests.request", return_value=mock_requests_response) as mock_request: + response = request_with_retry(method="GET", url="https://example.com", status_codes_to_retry=[500, 502]) + + assert response == mock_requests_response + mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10) + + def test_request_with_retry_custom_timeout(self, mock_requests_response): + """Test that request_with_retry respects custom timeout parameter""" + with patch("requests.request", return_value=mock_requests_response) as mock_request: + response = request_with_retry(method="GET", url="https://example.com", timeout=30) + + assert response == mock_requests_response + mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=30) + + def test_request_with_retry_with_headers(self, mock_requests_response): + """Test that request_with_retry passes headers correctly""" + headers = {"Authorization": "Bearer token123"} + with patch("requests.request", return_value=mock_requests_response) as mock_request: + response = request_with_retry(method="GET", url="https://example.com", headers=headers) + + assert response == mock_requests_response + mock_request.assert_called_once_with(method="GET", url="https://example.com", headers=headers, timeout=10) + + def test_request_with_retry_with_json(self, mock_requests_response): + """Test that request_with_retry passes JSON data correctly""" + json_data = {"key": "value"} + with patch("requests.request", return_value=mock_requests_response) as mock_request: + response = request_with_retry(method="POST", url="https://example.com", json=json_data) + + assert response == mock_requests_response + mock_request.assert_called_once_with(method="POST", url="https://example.com", json=json_data, timeout=10) + + def test_request_with_retry_retries_on_error(self): + """Test that request_with_retry retries on HTTP errors""" + error_response = requests.Response() + error_response.status_code = 503 + + success_response = requests.Response() + success_response.status_code = 200 + + with patch("requests.request") as mock_request: + # First call raises an error, second call succeeds + mock_request.side_effect = [requests.exceptions.HTTPError("Server error"), success_response] + + response = request_with_retry(method="GET", url="https://example.com", attempts=2) + + assert response == success_response + assert mock_request.call_count == 2 + + def test_request_with_retry_retries_on_status_code(self): + """Test that request_with_retry retries on specified status codes""" + error_response = requests.Response() + error_response.status_code = 503 + + def raise_for_status(): + if error_response.status_code in [503]: + raise requests.exceptions.HTTPError("Service Unavailable") + + error_response.raise_for_status = raise_for_status + + success_response = requests.Response() + success_response.status_code = 200 + success_response.raise_for_status = lambda: None + + with patch("requests.request") as mock_request: + # First call returns error status code, second call succeeds + mock_request.side_effect = [error_response, success_response] + + response = request_with_retry( + method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503] + ) + + assert response == success_response + assert mock_request.call_count == 2 + + +class TestAsyncRequestWithRetry: + @pytest.mark.asyncio + async def test_async_request_with_retry_success(self, mock_httpx_response): + """Test that async_request_with_retry works with default parameters""" + with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request: + response = await async_request_with_retry(method="GET", url="https://example.com") + + assert response == mock_httpx_response + mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10) + + @pytest.mark.asyncio + async def test_async_request_with_retry_custom_attempts(self, mock_httpx_response): + """Test that async_request_with_retry respects custom attempts parameter""" + with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request: + response = await async_request_with_retry(method="GET", url="https://example.com", attempts=5) + + assert response == mock_httpx_response + mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10) + + @pytest.mark.asyncio + async def test_async_request_with_retry_custom_status_codes(self, mock_httpx_response): + """Test that async_request_with_retry respects custom status_codes_to_retry parameter""" + with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request: + response = await async_request_with_retry( + method="GET", url="https://example.com", status_codes_to_retry=[500, 502] + ) + + assert response == mock_httpx_response + mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10) + + @pytest.mark.asyncio + async def test_async_request_with_retry_custom_timeout(self, mock_httpx_response): + """Test that async_request_with_retry respects custom timeout parameter""" + with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request: + response = await async_request_with_retry(method="GET", url="https://example.com", timeout=30) + + assert response == mock_httpx_response + mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=30) + + @pytest.mark.asyncio + async def test_async_request_with_retry_with_headers(self, mock_httpx_response): + """Test that async_request_with_retry passes headers correctly""" + headers = {"Authorization": "Bearer token123"} + with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request: + response = await async_request_with_retry(method="GET", url="https://example.com", headers=headers) + + assert response == mock_httpx_response + mock_request.assert_called_once_with(method="GET", url="https://example.com", headers=headers, timeout=10) + + @pytest.mark.asyncio + async def test_async_request_with_retry_with_json(self, mock_httpx_response): + """Test that async_request_with_retry passes JSON data correctly""" + json_data = {"key": "value"} + with patch("httpx.AsyncClient.request", return_value=mock_httpx_response) as mock_request: + response = await async_request_with_retry(method="POST", url="https://example.com", json=json_data) + + assert response == mock_httpx_response + mock_request.assert_called_once_with(method="POST", url="https://example.com", json=json_data, timeout=10) + + @pytest.mark.asyncio + async def test_async_request_with_retry_retries_on_error(self): + """Test that async_request_with_retry retries on HTTP errors""" + error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com")) + success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com")) + + with patch("httpx.AsyncClient.request") as mock_request: + # First call raises an error, second call succeeds + mock_request.side_effect = [ + httpx.RequestError("Server error", request=httpx.Request("GET", "https://example.com")), + success_response, + ] + + response = await async_request_with_retry(method="GET", url="https://example.com", attempts=2) + + assert response == success_response + assert mock_request.call_count == 2 + + @pytest.mark.asyncio + async def test_async_request_with_retry_retries_on_status_code(self): + """Test that async_request_with_retry retries on specified status codes""" + error_response = httpx.Response(status_code=503, request=httpx.Request("GET", "https://example.com")) + + def raise_for_status(): + if error_response.status_code in [503]: + raise httpx.HTTPStatusError( + "Service Unavailable", request=error_response.request, response=error_response + ) + + error_response.raise_for_status = raise_for_status + + success_response = httpx.Response(status_code=200, request=httpx.Request("GET", "https://example.com")) + success_response.raise_for_status = lambda: None + + with patch("httpx.AsyncClient.request") as mock_request: + # First call returns error status code, second call succeeds + mock_request.side_effect = [error_response, success_response] + + response = await async_request_with_retry( + method="GET", url="https://example.com", attempts=2, status_codes_to_retry=[503] + ) + + assert response == success_response + assert mock_request.call_count == 2