feat: Configure max_retries & timeout for AzureOpenAITextEmbedder (#7993)

max_retries: if not set is read from the OPENAI_MAX_RETRIES
env variable or set to 5.

timeout: if not set is read from the OPENAI_TIMEOUT
env variable or set to 30.

Signed-off-by: Nitanshu Vashistha <nitanshu.vzard@gmail.com>
This commit is contained in:
Nitanshu Vashistha 2024-07-09 13:26:46 +05:30 committed by GitHub
parent f9d53c5ca8
commit cd8a5b98fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 0 deletions

View File

@ -41,6 +41,8 @@ class AzureOpenAITextEmbedder:
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
organization: Optional[str] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
prefix: str = "",
suffix: str = "",
):
@ -67,6 +69,10 @@ class AzureOpenAITextEmbedder:
The Organization ID. See OpenAI's
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization)
for more information.
:param timeout: The timeout in seconds to be passed to the underlying `AzureOpenAI` client, if not set it is
inferred from the `OPENAI_TIMEOUT` environment variable or set to 30.
:param max_retries: Maximum retries to establish a connection with AzureOpenAI if it returns an internal error,
if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
:param prefix:
A string to add at the beginning of each text.
:param suffix:
@ -90,6 +96,8 @@ class AzureOpenAITextEmbedder:
self.azure_deployment = azure_deployment
self.dimensions = dimensions
self.organization = organization
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.prefix = prefix
self.suffix = suffix
@ -100,6 +108,8 @@ class AzureOpenAITextEmbedder:
api_key=api_key.resolve_value() if api_key is not None else None,
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
organization=organization,
timeout=self.timeout,
max_retries=self.max_retries,
)
def _get_telemetry_data(self) -> Dict[str, Any]:
@ -126,6 +136,8 @@ class AzureOpenAITextEmbedder:
suffix=self.suffix,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
timeout=self.timeout,
max_retries=self.max_retries,
)
@classmethod

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Add `max_retries` and `timeout` parameters to the AzureOpenAITextEmbedder initializations.

View File

@ -34,6 +34,8 @@ class TestAzureOpenAITextEmbedder:
"organization": None,
"azure_endpoint": "https://example-resource.azure.openai.com/",
"api_version": "2023-05-15",
"max_retries": 5,
"timeout": 30.0,
"prefix": "",
"suffix": "",
},