mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
feat: Expose default_headers and add kwargs for Azure Client (#8244)
* default_headers and azure_kwargs added * update docstrings * dont forget about chat generator * Remove azure_kwargs argument --------- Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
parent
b126c14e51
commit
145ca89a3f
@ -68,6 +68,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
timeout: Optional[float] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Azure OpenAI Generator.
|
||||
@ -107,6 +108,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
Higher values make the model less likely to repeat the token.
|
||||
- `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the
|
||||
values are the bias to add to that token.
|
||||
:param default_headers: Default headers to use for the AzureOpenAI client.
|
||||
"""
|
||||
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
|
||||
# with the API.
|
||||
@ -136,6 +138,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
self.model: str = azure_deployment or "gpt-35-turbo"
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
self.default_headers = default_headers or {}
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
api_version=api_version,
|
||||
@ -146,6 +149,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
organization=organization,
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
default_headers=self.default_headers,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
@ -169,6 +173,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
default_headers=self.default_headers,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -74,6 +74,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
timeout: Optional[float] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Azure OpenAI Chat Generator component.
|
||||
@ -110,6 +111,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
Higher values make the model less likely to repeat the token.
|
||||
- `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the
|
||||
values are the bias to add to that token.
|
||||
:param default_headers: Default headers to use for the AzureOpenAI client.
|
||||
"""
|
||||
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
|
||||
# with the API.
|
||||
@ -138,6 +140,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
self.model = azure_deployment or "gpt-35-turbo"
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
self.default_headers = default_headers or {}
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
api_version=api_version,
|
||||
@ -148,6 +151,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
organization=organization,
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
default_headers=self.default_headers,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
@ -170,6 +174,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
max_retries=self.max_retries,
|
||||
api_key=self.api_key.to_dict() if self.api_key is not None else None,
|
||||
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
||||
default_headers=self.default_headers,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Expose default_headers to pass custom headers to Azure API including APIM subscription key.
|
||||
- |
|
||||
Add optional azure_kwargs dictionary parameter to pass in parameters undefined in Haystack but supported by AzureOpenAI.
|
||||
@ -57,6 +57,7 @@ class TestOpenAIChatGenerator:
|
||||
"generation_kwargs": {},
|
||||
"timeout": 30.0,
|
||||
"max_retries": 5,
|
||||
"default_headers": {},
|
||||
},
|
||||
}
|
||||
|
||||
@ -84,6 +85,7 @@ class TestOpenAIChatGenerator:
|
||||
"timeout": 2.5,
|
||||
"max_retries": 10,
|
||||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||
"default_headers": {},
|
||||
},
|
||||
}
|
||||
|
||||
@ -98,7 +100,7 @@ class TestOpenAIChatGenerator:
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
reason=(
|
||||
"Please export env variables called AZURE_OPENAI_API_KEY containing "
|
||||
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
|
||||
|
||||
@ -60,6 +60,7 @@ class TestAzureOpenAIGenerator:
|
||||
"timeout": 30.0,
|
||||
"max_retries": 5,
|
||||
"generation_kwargs": {},
|
||||
"default_headers": {},
|
||||
},
|
||||
}
|
||||
|
||||
@ -89,6 +90,7 @@ class TestAzureOpenAIGenerator:
|
||||
"timeout": 3.5,
|
||||
"max_retries": 10,
|
||||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||
"default_headers": {},
|
||||
},
|
||||
}
|
||||
|
||||
@ -103,7 +105,7 @@ class TestAzureOpenAIGenerator:
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
reason=(
|
||||
"Please export env variables called AZURE_OPENAI_API_KEY containing "
|
||||
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user