feat: Expose default_headers and add kwargs for Azure Client (#8244)

* default_headers and azure_kwargs added

* update docstrings

* dont forget about chat generator

* Remove azure_kwargs argument

---------

Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
Ulises M 2024-09-10 03:29:56 -07:00 committed by GitHub
parent b126c14e51
commit 145ca89a3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 22 additions and 2 deletions

View File

@ -68,6 +68,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
default_headers: Optional[Dict[str, str]] = None,
):
"""
Initialize the Azure OpenAI Generator.
@ -107,6 +108,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
Higher values make the model less likely to repeat the token.
- `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
:param default_headers: Default headers to use for the AzureOpenAI client.
"""
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
# with the API.
@ -136,6 +138,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
self.model: str = azure_deployment or "gpt-35-turbo"
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.default_headers = default_headers or {}
self.client = AzureOpenAI(
api_version=api_version,
@ -146,6 +149,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
organization=organization,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self.default_headers,
)
def to_dict(self) -> Dict[str, Any]:
@ -169,6 +173,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self.default_headers,
)
@classmethod

View File

@ -74,6 +74,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
default_headers: Optional[Dict[str, str]] = None,
):
"""
Initialize the Azure OpenAI Chat Generator component.
@ -110,6 +111,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
Higher values make the model less likely to repeat the token.
- `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
:param default_headers: Default headers to use for the AzureOpenAI client.
"""
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
# with the API.
@ -138,6 +140,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
self.model = azure_deployment or "gpt-35-turbo"
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.default_headers = default_headers or {}
self.client = AzureOpenAI(
api_version=api_version,
@ -148,6 +151,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
organization=organization,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self.default_headers,
)
def to_dict(self) -> Dict[str, Any]:
@ -170,6 +174,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
max_retries=self.max_retries,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
default_headers=self.default_headers,
)
@classmethod

View File

@ -0,0 +1,6 @@
---
enhancements:
- |
Expose default_headers to pass custom headers to Azure API including APIM subscription key.
- |
Add optional azure_kwargs dictionary parameter to pass in parameters undefined in Haystack but supported by AzureOpenAI.

View File

@ -57,6 +57,7 @@ class TestOpenAIChatGenerator:
"generation_kwargs": {},
"timeout": 30.0,
"max_retries": 5,
"default_headers": {},
},
}
@ -84,6 +85,7 @@ class TestOpenAIChatGenerator:
"timeout": 2.5,
"max_retries": 10,
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"default_headers": {},
},
}
@ -98,7 +100,7 @@ class TestOpenAIChatGenerator:
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
reason=(
"Please export env variables called AZURE_OPENAI_API_KEY containing "
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "

View File

@ -60,6 +60,7 @@ class TestAzureOpenAIGenerator:
"timeout": 30.0,
"max_retries": 5,
"generation_kwargs": {},
"default_headers": {},
},
}
@ -89,6 +90,7 @@ class TestAzureOpenAIGenerator:
"timeout": 3.5,
"max_retries": 10,
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"default_headers": {},
},
}
@ -103,7 +105,7 @@ class TestAzureOpenAIGenerator:
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
reason=(
"Please export env variables called AZURE_OPENAI_API_KEY containing "
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "