diff --git a/haystack/utils/requests.py b/haystack/utils/requests.py new file mode 100644 index 000000000..daf75ec3f --- /dev/null +++ b/haystack/utils/requests.py @@ -0,0 +1,89 @@ +from typing import Optional, List + +import logging + +from tenacity import retry, wait_exponential, retry_if_exception_type, stop_after_attempt, before_log, after_log +import requests + +logger = logging.getLogger(__file__) + + +def request_with_retry(attempts: int = 3, status_codes: Optional[List[int]] = None, **kwargs) -> requests.Response: + """ + request_with_retry is a simple wrapper function that executes an HTTP request + with a configurable exponential backoff retry on failures. + + All kwargs will be passed to ``requests.request``, so it accepts the same arguments. + + Example Usage: + -------------- + + # Sending an HTTP request with default retry configs + res = request_with_retry(method="GET", url="https://example.com") + + # Sending an HTTP request with custom number of attempts + res = request_with_retry(method="GET", url="https://example.com", attempts=10) + + # Sending an HTTP request with custom HTTP codes to retry + res = request_with_retry(method="GET", url="https://example.com", status_codes=[408, 503]) + + # Sending an HTTP request with custom timeout in seconds + res = request_with_retry(method="GET", url="https://example.com", timeout=5) + + # Sending an HTTP request with custom authorization handling + class CustomAuth(requests.auth.AuthBase): + def __call__(self, r): + r.headers["authorization"] = "Basic " + return r + + res = request_with_retry(method="GET", url="https://example.com", auth=CustomAuth()) + + # All of the above combined + res = request_with_retry( + method="GET", + url="https://example.com", + auth=CustomAuth(), + attempts=10, + status_codes[408, 503], + timeout=5 + ) + + # Sending a POST request + res = request_with_retry(method="POST", url="https://example.com", data={"key": "value"}, attempts=10) + + # Retry all 5xx status codes + res = request_with_retry(method="GET", url="https://example.com", status_codes=list(range(500, 600))) + + :param attempts: Maximum number of attempts to retry the request, defaults to 3 + :param status_codes: List of HTTP status codes that will trigger a retry, defaults to [408, 418, 429] + :param **kwargs: Optional arguments that ``request`` takes. + :return: :class:`Response ` object + """ + + if status_codes is None: + status_codes = [408, 418, 429] + + @retry( + reraise=True, + wait=wait_exponential(), + retry=retry_if_exception_type((requests.HTTPError, TimeoutError)), + stop=stop_after_attempt(attempts), + before=before_log(logger, logging.DEBUG), + after=after_log(logger, logging.DEBUG), + ) + def run(): + # We ignore the missing-timeout Pylint rule as we set a default + kwargs.setdefault("timeout", 10) + res = requests.request(**kwargs) # pylint: disable=missing-timeout + + if res.status_code in status_codes: + # We raise only for the status codes that must trigger a retry + res.raise_for_status() + + return res + + res = run() + # We raise here too in case the request failed with a status code that + # won't trigger a retry, this way the call will still cause an explicit exception + res.raise_for_status() + return res diff --git a/test/utils/__init__.py b/test/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/utils/test_requests.py b/test/utils/test_requests.py new file mode 100644 index 000000000..a2207d8f3 --- /dev/null +++ b/test/utils/test_requests.py @@ -0,0 +1,72 @@ +from unittest.mock import patch, Mock + +import pytest +import requests + +from haystack.utils.requests import request_with_retry + + +@pytest.mark.unit +@patch("haystack.utils.requests.requests.request") +def test_request_with_retry_defaults_successfully(mock_request): + # Make requests with default retry configuration + request_with_retry(method="GET", url="https://example.com") + + # Verifies request has not been retried + mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=10) + + +@pytest.mark.unit +@patch("haystack.utils.requests.requests.request") +def test_request_with_retry_custom_timeout(mock_request): + # Make requests with default retry configuration + request_with_retry(method="GET", url="https://example.com", timeout=5) + + # Verifies request has not been retried + mock_request.assert_called_once_with(method="GET", url="https://example.com", timeout=5) + + +@pytest.mark.unit +@patch("haystack.utils.requests.requests.request") +def test_request_with_retry_failing_request_and_expected_status_code(mock_request): + # Create fake failed response with status code that triggers retry + fake_response = requests.Response() + fake_response.status_code = 408 + mock_request.return_value = fake_response + + # Make request with expected status code and verify error is raised + with pytest.raises(requests.HTTPError): + request_with_retry(method="GET", url="https://example.com", timeout=1, attempts=2, status_codes=[408]) + + # Veries request has been retried the expected number of times + assert mock_request.call_count == 2 + + +@pytest.mark.unit +@patch("haystack.utils.requests.requests.request") +def test_request_with_retry_failing_request_and_ignored_status_code(mock_request): + # Create fake failed response with status code that doesn't trigger retry + fake_response = requests.Response() + fake_response.status_code = 500 + mock_request.return_value = fake_response + + # Make request with status code that won't trigger a retry and verify error is raised + with pytest.raises(requests.HTTPError): + request_with_retry(method="GET", url="https://example.com", timeout=1, status_codes=[404]) + + # Verify request has not been retried + mock_request.assert_called_once() + + +@pytest.mark.unit +@patch("haystack.utils.requests.requests.request") +def test_request_with_retry_timed_out_request(mock_request: Mock): + # Make request fail cause of a timeout + mock_request.side_effect = TimeoutError() + + # Make request and verifies it fails + with pytest.raises(TimeoutError): + request_with_retry(method="GET", url="https://example.com", timeout=1, attempts=2) + + # Verifies request has been retried the expected number of times + assert mock_request.call_count == 2