From 637dcb4599541778a5f0f0a3eeba830cdbe8d6cd Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 28 Mar 2025 14:21:07 +0100 Subject: [PATCH] fix: `DALLEImageGenerator` - ensure `max_retries` is correctly set when 0 (#9131) * fix: DALLEImageGenerator - ensure max_retries correctly set when 0 * other small fixes * wording --- haystack/components/embedders/azure_document_embedder.py | 2 +- haystack/components/embedders/azure_text_embedder.py | 2 +- haystack/components/generators/azure.py | 2 +- haystack/components/generators/chat/azure.py | 2 +- haystack/components/generators/openai_dalle.py | 4 ++-- .../notes/dalle-fix-max-retries-703c36ce354a2e45.yaml | 5 +++++ test/components/generators/test_openai_dalle.py | 7 +++++++ 7 files changed, 18 insertions(+), 6 deletions(-) create mode 100644 releasenotes/notes/dalle-fix-max-retries-703c36ce354a2e45.yaml diff --git a/haystack/components/embedders/azure_document_embedder.py b/haystack/components/embedders/azure_document_embedder.py index c8834eeb8..6275f1241 100644 --- a/haystack/components/embedders/azure_document_embedder.py +++ b/haystack/components/embedders/azure_document_embedder.py @@ -127,7 +127,7 @@ class AzureOpenAIDocumentEmbedder: self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator - self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) self.default_headers = default_headers or {} self.azure_ad_token_provider = azure_ad_token_provider diff --git a/haystack/components/embedders/azure_text_embedder.py b/haystack/components/embedders/azure_text_embedder.py index cf1f7826e..7cd0049a9 100644 --- a/haystack/components/embedders/azure_text_embedder.py +++ b/haystack/components/embedders/azure_text_embedder.py @@ -107,7 +107,7 @@ class AzureOpenAITextEmbedder: self.azure_deployment = azure_deployment self.dimensions = dimensions self.organization = organization - self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) self.prefix = prefix self.suffix = suffix diff --git a/haystack/components/generators/azure.py b/haystack/components/generators/azure.py index 3c63709ab..5324907fb 100644 --- a/haystack/components/generators/azure.py +++ b/haystack/components/generators/azure.py @@ -139,7 +139,7 @@ class AzureOpenAIGenerator(OpenAIGenerator): self.azure_deployment = azure_deployment self.organization = organization self.model: str = azure_deployment or "gpt-4o-mini" - self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) self.default_headers = default_headers or {} self.azure_ad_token_provider = azure_ad_token_provider diff --git a/haystack/components/generators/chat/azure.py b/haystack/components/generators/chat/azure.py index 17a187fbd..b9daaffe8 100644 --- a/haystack/components/generators/chat/azure.py +++ b/haystack/components/generators/chat/azure.py @@ -147,7 +147,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator): self.azure_deployment = azure_deployment self.organization = organization self.model = azure_deployment or "gpt-4o-mini" - self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) self.default_headers = default_headers or {} self.azure_ad_token_provider = azure_ad_token_provider diff --git a/haystack/components/generators/openai_dalle.py b/haystack/components/generators/openai_dalle.py index 94b5d87f9..7b1198c6c 100644 --- a/haystack/components/generators/openai_dalle.py +++ b/haystack/components/generators/openai_dalle.py @@ -69,8 +69,8 @@ class DALLEImageGenerator: self.api_base_url = api_base_url self.organization = organization - self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0")) - self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) self.client: Optional[OpenAI] = None diff --git a/releasenotes/notes/dalle-fix-max-retries-703c36ce354a2e45.yaml b/releasenotes/notes/dalle-fix-max-retries-703c36ce354a2e45.yaml new file mode 100644 index 000000000..ffc97a8a3 --- /dev/null +++ b/releasenotes/notes/dalle-fix-max-retries-703c36ce354a2e45.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + In `DALLEImageGenerator`, ensure that the `max_retries` initialization parameter is correctly set when it is equal + to 0. diff --git a/test/components/generators/test_openai_dalle.py b/test/components/generators/test_openai_dalle.py index 0c319d020..79cfe9ce5 100644 --- a/test/components/generators/test_openai_dalle.py +++ b/test/components/generators/test_openai_dalle.py @@ -54,6 +54,13 @@ class TestDALLEImageGenerator: assert pytest.approx(component.timeout) == 60.0 assert component.max_retries == 10 + def test_init_max_retries_0(self, monkeypatch): + """ + Test that the max_retries parameter is taken into account even if it is 0. + """ + component = DALLEImageGenerator(max_retries=0) + assert component.max_retries == 0 + def test_warm_up(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = DALLEImageGenerator()