mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-24 17:30:38 +00:00
feat: Add OpenAIError to retry mechanism (#4178)
* Add OpenAIError to retry mechanism. Use env variable for timeout for OpenAI request in PromptNode. * Updated retry in OpenAI embedding encoder as well. * Empty commit
This commit is contained in:
parent
7eeb3e07bf
commit
44509cd6a1
@ -186,7 +186,9 @@ class OpenAIAnswerGenerator(BaseGenerator):
|
||||
logger.debug("Using GPT2TokenizerFast")
|
||||
self._hf_tokenizer: PreTrainedTokenizerFast = GPT2TokenizerFast.from_pretrained(tokenizer)
|
||||
|
||||
@retry_with_exponential_backoff(backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES)
|
||||
@retry_with_exponential_backoff(
|
||||
backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError)
|
||||
)
|
||||
def predict(
|
||||
self,
|
||||
query: str,
|
||||
|
@ -20,7 +20,11 @@ from transformers import (
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
||||
|
||||
from haystack import MultiLabel
|
||||
from haystack.environment import HAYSTACK_REMOTE_API_BACKOFF_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES
|
||||
from haystack.environment import (
|
||||
HAYSTACK_REMOTE_API_BACKOFF_SEC,
|
||||
HAYSTACK_REMOTE_API_MAX_RETRIES,
|
||||
HAYSTACK_REMOTE_API_TIMEOUT_SEC,
|
||||
)
|
||||
from haystack.errors import OpenAIError, OpenAIRateLimitError
|
||||
from haystack.modeling.utils import initialize_device_settings
|
||||
from haystack.nodes.base import BaseComponent
|
||||
@ -435,8 +439,9 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
}
|
||||
|
||||
@retry_with_exponential_backoff(
|
||||
backoff_in_seconds=int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 5)),
|
||||
backoff_in_seconds=float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 5)),
|
||||
max_retries=int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5)),
|
||||
errors=(OpenAIRateLimitError, OpenAIError),
|
||||
)
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
@ -478,7 +483,13 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
|
||||
"logit_bias": kwargs_with_defaults.get("logit_bias", {}),
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
response = requests.request("POST", self.url, headers=headers, data=json.dumps(payload), timeout=30)
|
||||
response = requests.request(
|
||||
"POST",
|
||||
self.url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
timeout=float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30)),
|
||||
)
|
||||
res = json.loads(response.text)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
@ -104,7 +104,9 @@ class _OpenAIEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
|
||||
return decoded_string
|
||||
|
||||
@retry_with_exponential_backoff(backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES)
|
||||
@retry_with_exponential_backoff(
|
||||
backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError)
|
||||
)
|
||||
def embed(self, model: str, text: List[str]) -> np.ndarray:
|
||||
payload = {"model": model, "input": text}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
Loading…
x
Reference in New Issue
Block a user