diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py index 2769e19ce..61be9f975 100644 --- a/haystack/nodes/retriever/_embedding_encoder.py +++ b/haystack/nodes/retriever/_embedding_encoder.py @@ -3,6 +3,7 @@ import logging import os from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from tenacity import retry, retry_if_exception_type, wait_exponential, stop_after_attempt try: from typing import Literal @@ -30,7 +31,6 @@ from haystack.modeling.infer import Inferencer from haystack.nodes.retriever._losses import _TRAINING_LOSSES from haystack.nodes.retriever._openai_encoder import _OpenAIEmbeddingEncoder from haystack.schema import Document -from haystack.utils.reflection import retry_with_exponential_backoff from haystack.telemetry_2 import send_event from ._base_embedding_encoder import _BaseEmbeddingEncoder @@ -40,7 +40,7 @@ if TYPE_CHECKING: COHERE_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30)) -COHERE_BACKOFF = float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10)) +COHERE_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10)) COHERE_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5)) @@ -380,8 +380,10 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder): "multilingual-22-12", ) - @retry_with_exponential_backoff( - backoff_in_seconds=COHERE_BACKOFF, max_retries=COHERE_MAX_RETRIES, errors=(CohereError,) + @retry( + retry=retry_if_exception_type(CohereError), + wait=wait_exponential(multiplier=COHERE_BACKOFF), + stop=stop_after_attempt(COHERE_MAX_RETRIES), ) def embed(self, model: str, text: List[str]) -> np.ndarray: payload = {"model": model, "texts": text, "truncate": "END"} diff --git a/haystack/utils/openai_utils.py b/haystack/utils/openai_utils.py index c658af661..66fef9c5c 100644 --- a/haystack/utils/openai_utils.py +++ b/haystack/utils/openai_utils.py @@ -6,11 +6,10 @@ import sys import json from typing import Dict, Union, Tuple, Optional, List import requests - +from tenacity import retry, retry_if_exception_type, wait_exponential, stop_after_attempt from transformers import GPT2TokenizerFast from haystack.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError -from haystack.utils.reflection import retry_with_exponential_backoff from haystack.environment import ( HAYSTACK_REMOTE_API_BACKOFF_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES, @@ -25,7 +24,7 @@ system = platform.system() OPENAI_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30)) -OPENAI_BACKOFF = float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10)) +OPENAI_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10)) OPENAI_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5)) @@ -124,8 +123,10 @@ def _openai_text_completion_tokenization_details(model_name: str): return tokenizer_name, max_tokens_limit -@retry_with_exponential_backoff( - backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError) +@retry( + retry=retry_if_exception_type(OpenAIRateLimitError), + wait=wait_exponential(multiplier=OPENAI_BACKOFF), + stop=stop_after_attempt(OPENAI_MAX_RETRIES), ) def openai_request( url: str, diff --git a/haystack/utils/reflection.py b/haystack/utils/reflection.py index ff37592ef..73d9de673 100644 --- a/haystack/utils/reflection.py +++ b/haystack/utils/reflection.py @@ -1,11 +1,7 @@ import inspect import logging -import time -from random import random from typing import Any, Dict, Tuple, Callable -from haystack.errors import OpenAIRateLimitError - logger = logging.getLogger(__name__) @@ -17,54 +13,3 @@ def args_to_kwargs(args: Tuple, func: Callable) -> Dict[str, Any]: arg_names = arg_names[1 : 1 + len(args)] args_as_kwargs = {arg_name: arg for arg, arg_name in zip(args, arg_names)} return args_as_kwargs - - -def retry_with_exponential_backoff( - backoff_in_seconds: float = 1, max_retries: int = 10, errors: tuple = (OpenAIRateLimitError,) -): - """ - Decorator to retry a function with exponential backoff. - :param backoff_in_seconds: The initial backoff in seconds. - :param max_retries: The maximum number of retries. - :param errors: The errors to catch retry on. - """ - - def decorator(function): - def wrapper(*args, **kwargs): - # Initialize variables - num_retries = 0 - - # Loop until a successful response or max_retries is hit or an exception is raised - while True: - try: - return function(*args, **kwargs) - - # Retry on specified errors - except errors as e: - # Check if max retries has been reached - if num_retries > max_retries: - raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") - - # Increment the delay - sleep_time = backoff_in_seconds * 2**num_retries + random() - - # Sleep for the delay - logger.warning( - "%s - %s, retry %s in %s seconds...", - e.__class__.__name__, - e, - function.__name__, - "{0:.2f}".format(sleep_time), - ) - time.sleep(sleep_time) - - # Increment retries - num_retries += 1 - - # Raise exceptions for any errors not specified - except Exception as e: - raise e - - return wrapper - - return decorator diff --git a/test/others/test_utils.py b/test/others/test_utils.py index 29abb771d..9bc756c4b 100644 --- a/test/others/test_utils.py +++ b/test/others/test_utils.py @@ -20,7 +20,6 @@ from haystack.utils.deepsetcloud import DeepsetCloud, DeepsetCloudExperiments from haystack.utils.labels import aggregate_labels from haystack.utils.preprocessing import convert_files_to_docs, tika_convert_files_to_docs from haystack.utils.cleaning import clean_wiki_text -from haystack.utils.reflection import retry_with_exponential_backoff from haystack.utils.context_matching import calculate_context_similarity, match_context, match_contexts from .. import conftest @@ -1276,27 +1275,6 @@ def test_get_eval_run_results(): assert first_result["answer"] == "This" -def test_exponential_backoff(): - # Test that the exponential backoff works as expected - # should raise exception, check the exception contains the correct message - with pytest.raises(Exception, match="retries \(2\)"): - - @retry_with_exponential_backoff(backoff_in_seconds=1, max_retries=2) - def greet(name: str): - if random() < 1.1: - raise OpenAIRateLimitError("Too many requests") - return f"Hello {name}" - - greet("John") - - # this should not raise exception and should print "Hello John" - @retry_with_exponential_backoff(backoff_in_seconds=1, max_retries=1) - def greet2(name: str): - return f"Hello {name}" - - assert greet2("John") == "Hello John" - - def test_secure_model_loading(monkeypatch, caplog): caplog.set_level(logging.INFO) monkeypatch.setenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0")