feat: Configure max_retries & timeout for AzureOpenAIGenerator (#7983)

max_retries: if not set is read from the OPENAI_MAX_RETRIES
env variable or set to 5.

timeout: if not set is read from the OPENAI_TIMEOUT
env variable or set to 30.

Signed-off-by: Nitanshu Vashistha <nitanshu.vzard@gmail.com>
This commit is contained in:
Nitanshu Vashistha 2024-07-08 14:46:26 +05:30 committed by GitHub
parent e92a0e4beb
commit 167e886f2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 5 deletions

View File

@ -63,6 +63,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
system_prompt: Optional[str] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
@ -78,7 +79,10 @@ class AzureOpenAIGenerator(OpenAIGenerator):
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param system_prompt: The prompt to use for the system. If not provided, the system prompt will be
:param timeout: The timeout to be passed to the underlying `AzureOpenAI` client.
:param timeout: The timeout to be passed to the underlying `AzureOpenAI` client, if not set it is
inferred from the `OPENAI_TIMEOUT` environment variable or set to 30.
:param max_retries: Maximum retries to establish contact with AzureOpenAI if it returns an internal error,
if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for
more details.
@ -125,7 +129,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
self.azure_deployment = azure_deployment
self.organization = organization
self.model: str = azure_deployment or "gpt-35-turbo"
self.timeout = timeout
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.client = AzureOpenAI(
api_version=api_version,
@ -134,7 +139,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
api_key=api_key.resolve_value() if api_key is not None else None,
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
organization=organization,
timeout=timeout,
timeout=self.timeout,
max_retries=self.max_retries,
)
def to_dict(self) -> Dict[str, Any]:
@ -157,6 +163,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
timeout=self.timeout,
max_retries=self.max_retries,
)
@classmethod

View File

@ -0,0 +1,9 @@
---
enhancements:
- |
Add max_retries to AzureOpenAIGenerator.
AzureOpenAIGenerator can now be initialised by setting max_retries.
If not set, it is inferred from the `OPENAI_MAX_RETRIES`
environment variable or set to 5.
The timeout for AzureOpenAIGenerator, if not set, it is inferred
from the `OPENAI_TIMEOUT` environment variable or set to 30.

View File

@ -39,7 +39,7 @@ class TestAzureOpenAIGenerator:
assert component.client.api_key == "fake-api-key"
assert component.azure_deployment == "gpt-35-turbo"
assert component.streaming_callback is print_streaming_chunk
assert component.timeout is None
assert component.timeout == 30.0
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
def test_to_dict_default(self, monkeypatch):
@ -57,7 +57,8 @@ class TestAzureOpenAIGenerator:
"azure_endpoint": "some-non-existing-endpoint",
"organization": None,
"system_prompt": None,
"timeout": None,
"timeout": 30.0,
"max_retries": 5,
"generation_kwargs": {},
},
}
@ -69,6 +70,7 @@ class TestAzureOpenAIGenerator:
azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False),
azure_endpoint="some-non-existing-endpoint",
timeout=3.5,
max_retries=10,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
@ -85,6 +87,7 @@ class TestAzureOpenAIGenerator:
"organization": None,
"system_prompt": None,
"timeout": 3.5,
"max_retries": 10,
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}