fix: DALLEImageGenerator - ensure max_retries is correctly set when 0 (#9131)

* fix: DALLEImageGenerator - ensure max_retries correctly set when 0

* other small fixes

* wording
This commit is contained in:
Stefano Fiorucci 2025-03-28 14:21:07 +01:00 committed by GitHub
parent 18367203a8
commit 637dcb4599
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 18 additions and 6 deletions

View File

@ -127,7 +127,7 @@ class AzureOpenAIDocumentEmbedder:
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider

View File

@ -107,7 +107,7 @@ class AzureOpenAITextEmbedder:
self.azure_deployment = azure_deployment
self.dimensions = dimensions
self.organization = organization
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.prefix = prefix
self.suffix = suffix

View File

@ -139,7 +139,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
self.azure_deployment = azure_deployment
self.organization = organization
self.model: str = azure_deployment or "gpt-4o-mini"
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider

View File

@ -147,7 +147,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
self.azure_deployment = azure_deployment
self.organization = organization
self.model = azure_deployment or "gpt-4o-mini"
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider

View File

@ -69,8 +69,8 @@ class DALLEImageGenerator:
self.api_base_url = api_base_url
self.organization = organization
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.client: Optional[OpenAI] = None

View File

@ -0,0 +1,5 @@
---
fixes:
- |
In `DALLEImageGenerator`, ensure that the `max_retries` initialization parameter is correctly set when it is equal
to 0.

View File

@ -54,6 +54,13 @@ class TestDALLEImageGenerator:
assert pytest.approx(component.timeout) == 60.0
assert component.max_retries == 10
def test_init_max_retries_0(self, monkeypatch):
"""
Test that the max_retries parameter is taken into account even if it is 0.
"""
component = DALLEImageGenerator(max_retries=0)
assert component.max_retries == 0
def test_warm_up(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = DALLEImageGenerator()