mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
feat: Configure max_retries & timeout for AzureOpenAIGenerator (#7983)
max_retries: if not set is read from the OPENAI_MAX_RETRIES env variable or set to 5. timeout: if not set is read from the OPENAI_TIMEOUT env variable or set to 30. Signed-off-by: Nitanshu Vashistha <nitanshu.vzard@gmail.com>
This commit is contained in:
parent
e92a0e4beb
commit
167e886f2c
@ -63,6 +63,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
@ -78,7 +79,10 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||
The callback function accepts StreamingChunk as an argument.
|
||||
:param system_prompt: The prompt to use for the system. If not provided, the system prompt will be
|
||||
:param timeout: The timeout to be passed to the underlying `AzureOpenAI` client.
|
||||
:param timeout: The timeout to be passed to the underlying `AzureOpenAI` client, if not set it is
|
||||
inferred from the `OPENAI_TIMEOUT` environment variable or set to 30.
|
||||
:param max_retries: Maximum retries to establish contact with AzureOpenAI if it returns an internal error,
|
||||
if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
|
||||
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
|
||||
the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for
|
||||
more details.
|
||||
@ -125,7 +129,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
self.azure_deployment = azure_deployment
|
||||
self.organization = organization
|
||||
self.model: str = azure_deployment or "gpt-35-turbo"
|
||||
self.timeout = timeout
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
api_version=api_version,
|
||||
@ -134,7 +139,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
api_key=api_key.resolve_value() if api_key is not None else None,
|
||||
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
|
||||
organization=organization,
|
||||
timeout=timeout,
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
@ -157,6 +163,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
api_key=self.api_key.to_dict() if self.api_key is not None else None,
|
||||
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Add max_retries to AzureOpenAIGenerator.
|
||||
AzureOpenAIGenerator can now be initialised by setting max_retries.
|
||||
If not set, it is inferred from the `OPENAI_MAX_RETRIES`
|
||||
environment variable or set to 5.
|
||||
The timeout for AzureOpenAIGenerator, if not set, it is inferred
|
||||
from the `OPENAI_TIMEOUT` environment variable or set to 30.
|
||||
@ -39,7 +39,7 @@ class TestAzureOpenAIGenerator:
|
||||
assert component.client.api_key == "fake-api-key"
|
||||
assert component.azure_deployment == "gpt-35-turbo"
|
||||
assert component.streaming_callback is print_streaming_chunk
|
||||
assert component.timeout is None
|
||||
assert component.timeout == 30.0
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
def test_to_dict_default(self, monkeypatch):
|
||||
@ -57,7 +57,8 @@ class TestAzureOpenAIGenerator:
|
||||
"azure_endpoint": "some-non-existing-endpoint",
|
||||
"organization": None,
|
||||
"system_prompt": None,
|
||||
"timeout": None,
|
||||
"timeout": 30.0,
|
||||
"max_retries": 5,
|
||||
"generation_kwargs": {},
|
||||
},
|
||||
}
|
||||
@ -69,6 +70,7 @@ class TestAzureOpenAIGenerator:
|
||||
azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False),
|
||||
azure_endpoint="some-non-existing-endpoint",
|
||||
timeout=3.5,
|
||||
max_retries=10,
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
)
|
||||
|
||||
@ -85,6 +87,7 @@ class TestAzureOpenAIGenerator:
|
||||
"organization": None,
|
||||
"system_prompt": None,
|
||||
"timeout": 3.5,
|
||||
"max_retries": 10,
|
||||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||
},
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user