From f12e5a012700058e18097fac1d2a4cea8843b4c6 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Wed, 10 May 2023 10:31:07 +0200 Subject: [PATCH] fix: Fix missing error in openai_request retry strategy (#4802) * Fix missing error in openai_request retry strategy * Correctly handle OpenAIUnauthorizedError Co-authored-by: bogdankostic --------- Co-authored-by: bogdankostic --- haystack/utils/openai_utils.py | 4 ++- test/utils/test_openai_utils.py | 57 +++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 test/utils/test_openai_utils.py diff --git a/haystack/utils/openai_utils.py b/haystack/utils/openai_utils.py index b3d0d552c..ad6638f71 100644 --- a/haystack/utils/openai_utils.py +++ b/haystack/utils/openai_utils.py @@ -128,7 +128,9 @@ def _openai_text_completion_tokenization_details(model_name: str): @tenacity.retry( - retry=tenacity.retry_if_exception_type(OpenAIRateLimitError), + reraise=True, + retry=tenacity.retry_if_exception_type(OpenAIError) + and tenacity.retry_if_not_exception_type(OpenAIUnauthorizedError), wait=tenacity.wait_exponential(multiplier=OPENAI_BACKOFF), stop=tenacity.stop_after_attempt(OPENAI_MAX_RETRIES), ) diff --git a/test/utils/test_openai_utils.py b/test/utils/test_openai_utils.py new file mode 100644 index 000000000..cc2594fbc --- /dev/null +++ b/test/utils/test_openai_utils.py @@ -0,0 +1,57 @@ +from unittest.mock import patch + +import pytest +from tenacity import wait_none + +from haystack.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError +from haystack.utils.openai_utils import openai_request + + +@pytest.mark.unit +@patch("haystack.utils.openai_utils.requests") +def test_openai_request_retries_generic_error(mock_requests): + mock_requests.request.return_value.status_code = 418 + + with pytest.raises(OpenAIError): + # We need to use a custom wait amount otherwise the test would take forever to run + # as the original wait time is exponential + openai_request.retry_with(wait=wait_none())(url="some_url", headers={}, payload={}, read_response=False) + + assert mock_requests.request.call_count == 5 + + +@pytest.mark.unit +@patch("haystack.utils.openai_utils.requests") +def test_openai_request_retries_on_rate_limit_error(mock_requests): + mock_requests.request.return_value.status_code = 429 + + with pytest.raises(OpenAIRateLimitError): + # We need to use a custom wait amount otherwise the test would take forever to run + # as the original wait time is exponential + openai_request.retry_with(wait=wait_none())(url="some_url", headers={}, payload={}, read_response=False) + + assert mock_requests.request.call_count == 5 + + +@pytest.mark.unit +@patch("haystack.utils.openai_utils.requests") +def test_openai_request_does_not_retry_on_unauthorized_error(mock_requests): + mock_requests.request.return_value.status_code = 401 + + with pytest.raises(OpenAIUnauthorizedError): + # We need to use a custom wait amount otherwise the test would take forever to run + # as the original wait time is exponential + openai_request.retry_with(wait=wait_none())(url="some_url", headers={}, payload={}, read_response=False) + + assert mock_requests.request.call_count == 1 + + +@pytest.mark.unit +@patch("haystack.utils.openai_utils.requests") +def test_openai_request_does_not_retry_on_success(mock_requests): + mock_requests.request.return_value.status_code = 200 + # We need to use a custom wait amount otherwise the test would take forever to run + # as the original wait time is exponential + openai_request.retry_with(wait=wait_none())(url="some_url", headers={}, payload={}, read_response=False) + + assert mock_requests.request.call_count == 1