diff --git a/haystack/components/generators/azure.py b/haystack/components/generators/azure.py index caa3e0aae..ffe6f8a94 100644 --- a/haystack/components/generators/azure.py +++ b/haystack/components/generators/azure.py @@ -63,6 +63,7 @@ class AzureOpenAIGenerator(OpenAIGenerator): streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, system_prompt: Optional[str] = None, timeout: Optional[float] = None, + max_retries: Optional[int] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ): """ @@ -78,7 +79,10 @@ class AzureOpenAIGenerator(OpenAIGenerator): :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. :param system_prompt: The prompt to use for the system. If not provided, the system prompt will be - :param timeout: The timeout to be passed to the underlying `AzureOpenAI` client. + :param timeout: The timeout to be passed to the underlying `AzureOpenAI` client, if not set it is + inferred from the `OPENAI_TIMEOUT` environment variable or set to 30. + :param max_retries: Maximum retries to establish contact with AzureOpenAI if it returns an internal error, + if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5. :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for more details. @@ -125,7 +129,8 @@ class AzureOpenAIGenerator(OpenAIGenerator): self.azure_deployment = azure_deployment self.organization = organization self.model: str = azure_deployment or "gpt-35-turbo" - self.timeout = timeout + self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0)) + self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) self.client = AzureOpenAI( api_version=api_version, @@ -134,7 +139,8 @@ class AzureOpenAIGenerator(OpenAIGenerator): api_key=api_key.resolve_value() if api_key is not None else None, azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None, organization=organization, - timeout=timeout, + timeout=self.timeout, + max_retries=self.max_retries, ) def to_dict(self) -> Dict[str, Any]: @@ -157,6 +163,7 @@ class AzureOpenAIGenerator(OpenAIGenerator): api_key=self.api_key.to_dict() if self.api_key is not None else None, azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, timeout=self.timeout, + max_retries=self.max_retries, ) @classmethod diff --git a/releasenotes/notes/max-retries-for-AzureOpenAIGenerator-0f1a1807dd2af041.yaml b/releasenotes/notes/max-retries-for-AzureOpenAIGenerator-0f1a1807dd2af041.yaml new file mode 100644 index 000000000..7a00710c3 --- /dev/null +++ b/releasenotes/notes/max-retries-for-AzureOpenAIGenerator-0f1a1807dd2af041.yaml @@ -0,0 +1,9 @@ +--- +enhancements: + - | + Add max_retries to AzureOpenAIGenerator. + AzureOpenAIGenerator can now be initialised by setting max_retries. + If not set, it is inferred from the `OPENAI_MAX_RETRIES` + environment variable or set to 5. + The timeout for AzureOpenAIGenerator, if not set, it is inferred + from the `OPENAI_TIMEOUT` environment variable or set to 30. diff --git a/test/components/generators/test_azure.py b/test/components/generators/test_azure.py index 39bef98ae..b4638a167 100644 --- a/test/components/generators/test_azure.py +++ b/test/components/generators/test_azure.py @@ -39,7 +39,7 @@ class TestAzureOpenAIGenerator: assert component.client.api_key == "fake-api-key" assert component.azure_deployment == "gpt-35-turbo" assert component.streaming_callback is print_streaming_chunk - assert component.timeout is None + assert component.timeout == 30.0 assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default(self, monkeypatch): @@ -57,7 +57,8 @@ class TestAzureOpenAIGenerator: "azure_endpoint": "some-non-existing-endpoint", "organization": None, "system_prompt": None, - "timeout": None, + "timeout": 30.0, + "max_retries": 5, "generation_kwargs": {}, }, } @@ -69,6 +70,7 @@ class TestAzureOpenAIGenerator: azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False), azure_endpoint="some-non-existing-endpoint", timeout=3.5, + max_retries=10, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) @@ -85,6 +87,7 @@ class TestAzureOpenAIGenerator: "organization": None, "system_prompt": None, "timeout": 3.5, + "max_retries": 10, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, }