Remove retry_with_exponential_backoff in favor of tenacity (#4460)

This commit is contained in:
Silvano Cerza 2023-03-24 11:14:11 +01:00 committed by GitHub
parent dda350088b
commit b70715a74d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 86 deletions

View File

@ -3,6 +3,7 @@ import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from tenacity import retry, retry_if_exception_type, wait_exponential, stop_after_attempt
try:
from typing import Literal
@ -30,7 +31,6 @@ from haystack.modeling.infer import Inferencer
from haystack.nodes.retriever._losses import _TRAINING_LOSSES
from haystack.nodes.retriever._openai_encoder import _OpenAIEmbeddingEncoder
from haystack.schema import Document
from haystack.utils.reflection import retry_with_exponential_backoff
from haystack.telemetry_2 import send_event
from ._base_embedding_encoder import _BaseEmbeddingEncoder
@ -40,7 +40,7 @@ if TYPE_CHECKING:
COHERE_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
COHERE_BACKOFF = float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10))
COHERE_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10))
COHERE_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
@ -380,8 +380,10 @@ class _CohereEmbeddingEncoder(_BaseEmbeddingEncoder):
"multilingual-22-12",
)
@retry_with_exponential_backoff(
backoff_in_seconds=COHERE_BACKOFF, max_retries=COHERE_MAX_RETRIES, errors=(CohereError,)
@retry(
retry=retry_if_exception_type(CohereError),
wait=wait_exponential(multiplier=COHERE_BACKOFF),
stop=stop_after_attempt(COHERE_MAX_RETRIES),
)
def embed(self, model: str, text: List[str]) -> np.ndarray:
payload = {"model": model, "texts": text, "truncate": "END"}

View File

@ -6,11 +6,10 @@ import sys
import json
from typing import Dict, Union, Tuple, Optional, List
import requests
from tenacity import retry, retry_if_exception_type, wait_exponential, stop_after_attempt
from transformers import GPT2TokenizerFast
from haystack.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError
from haystack.utils.reflection import retry_with_exponential_backoff
from haystack.environment import (
HAYSTACK_REMOTE_API_BACKOFF_SEC,
HAYSTACK_REMOTE_API_MAX_RETRIES,
@ -25,7 +24,7 @@ system = platform.system()
OPENAI_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
OPENAI_BACKOFF = float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10))
OPENAI_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10))
OPENAI_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))
@ -124,8 +123,10 @@ def _openai_text_completion_tokenization_details(model_name: str):
return tokenizer_name, max_tokens_limit
@retry_with_exponential_backoff(
backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError)
@retry(
retry=retry_if_exception_type(OpenAIRateLimitError),
wait=wait_exponential(multiplier=OPENAI_BACKOFF),
stop=stop_after_attempt(OPENAI_MAX_RETRIES),
)
def openai_request(
url: str,

View File

@ -1,11 +1,7 @@
import inspect
import logging
import time
from random import random
from typing import Any, Dict, Tuple, Callable
from haystack.errors import OpenAIRateLimitError
logger = logging.getLogger(__name__)
@ -17,54 +13,3 @@ def args_to_kwargs(args: Tuple, func: Callable) -> Dict[str, Any]:
arg_names = arg_names[1 : 1 + len(args)]
args_as_kwargs = {arg_name: arg for arg, arg_name in zip(args, arg_names)}
return args_as_kwargs
def retry_with_exponential_backoff(
backoff_in_seconds: float = 1, max_retries: int = 10, errors: tuple = (OpenAIRateLimitError,)
):
"""
Decorator to retry a function with exponential backoff.
:param backoff_in_seconds: The initial backoff in seconds.
:param max_retries: The maximum number of retries.
:param errors: The errors to catch retry on.
"""
def decorator(function):
def wrapper(*args, **kwargs):
# Initialize variables
num_retries = 0
# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
return function(*args, **kwargs)
# Retry on specified errors
except errors as e:
# Check if max retries has been reached
if num_retries > max_retries:
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
# Increment the delay
sleep_time = backoff_in_seconds * 2**num_retries + random()
# Sleep for the delay
logger.warning(
"%s - %s, retry %s in %s seconds...",
e.__class__.__name__,
e,
function.__name__,
"{0:.2f}".format(sleep_time),
)
time.sleep(sleep_time)
# Increment retries
num_retries += 1
# Raise exceptions for any errors not specified
except Exception as e:
raise e
return wrapper
return decorator

View File

@ -20,7 +20,6 @@ from haystack.utils.deepsetcloud import DeepsetCloud, DeepsetCloudExperiments
from haystack.utils.labels import aggregate_labels
from haystack.utils.preprocessing import convert_files_to_docs, tika_convert_files_to_docs
from haystack.utils.cleaning import clean_wiki_text
from haystack.utils.reflection import retry_with_exponential_backoff
from haystack.utils.context_matching import calculate_context_similarity, match_context, match_contexts
from .. import conftest
@ -1276,27 +1275,6 @@ def test_get_eval_run_results():
assert first_result["answer"] == "This"
def test_exponential_backoff():
# Test that the exponential backoff works as expected
# should raise exception, check the exception contains the correct message
with pytest.raises(Exception, match="retries \(2\)"):
@retry_with_exponential_backoff(backoff_in_seconds=1, max_retries=2)
def greet(name: str):
if random() < 1.1:
raise OpenAIRateLimitError("Too many requests")
return f"Hello {name}"
greet("John")
# this should not raise exception and should print "Hello John"
@retry_with_exponential_backoff(backoff_in_seconds=1, max_retries=1)
def greet2(name: str):
return f"Hello {name}"
assert greet2("John") == "Hello John"
def test_secure_model_loading(monkeypatch, caplog):
caplog.set_level(logging.INFO)
monkeypatch.setenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0")