diff --git a/haystack/components/generators/azure.py b/haystack/components/generators/azure.py index 0e20dc3c3..f0c5da7be 100644 --- a/haystack/components/generators/azure.py +++ b/haystack/components/generators/azure.py @@ -68,6 +68,7 @@ class AzureOpenAIGenerator(OpenAIGenerator): timeout: Optional[float] = None, max_retries: Optional[int] = None, generation_kwargs: Optional[Dict[str, Any]] = None, + default_headers: Optional[Dict[str, str]] = None, ): """ Initialize the Azure OpenAI Generator. @@ -107,6 +108,7 @@ class AzureOpenAIGenerator(OpenAIGenerator): Higher values make the model less likely to repeat the token. - `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the values are the bias to add to that token. + :param default_headers: Default headers to use for the AzureOpenAI client. """ # We intentionally do not call super().__init__ here because we only need to instantiate the client to interact # with the API. @@ -136,6 +138,7 @@ class AzureOpenAIGenerator(OpenAIGenerator): self.model: str = azure_deployment or "gpt-35-turbo" self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0)) self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) + self.default_headers = default_headers or {} self.client = AzureOpenAI( api_version=api_version, @@ -146,6 +149,7 @@ class AzureOpenAIGenerator(OpenAIGenerator): organization=organization, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) def to_dict(self) -> Dict[str, Any]: @@ -169,6 +173,7 @@ class AzureOpenAIGenerator(OpenAIGenerator): azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) @classmethod diff --git a/haystack/components/generators/chat/azure.py b/haystack/components/generators/chat/azure.py index 5fcf5ebf7..d8dc555cd 100644 --- a/haystack/components/generators/chat/azure.py +++ b/haystack/components/generators/chat/azure.py @@ -74,6 +74,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator): timeout: Optional[float] = None, max_retries: Optional[int] = None, generation_kwargs: Optional[Dict[str, Any]] = None, + default_headers: Optional[Dict[str, str]] = None, ): """ Initialize the Azure OpenAI Chat Generator component. @@ -110,6 +111,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator): Higher values make the model less likely to repeat the token. - `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the values are the bias to add to that token. + :param default_headers: Default headers to use for the AzureOpenAI client. """ # We intentionally do not call super().__init__ here because we only need to instantiate the client to interact # with the API. @@ -138,6 +140,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator): self.model = azure_deployment or "gpt-35-turbo" self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0)) self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) + self.default_headers = default_headers or {} self.client = AzureOpenAI( api_version=api_version, @@ -148,6 +151,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator): organization=organization, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) def to_dict(self) -> Dict[str, Any]: @@ -170,6 +174,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator): max_retries=self.max_retries, api_key=self.api_key.to_dict() if self.api_key is not None else None, azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, + default_headers=self.default_headers, ) @classmethod diff --git a/releasenotes/notes/add-azure-kwargs-6a5ab1358ef7f44c.yaml b/releasenotes/notes/add-azure-kwargs-6a5ab1358ef7f44c.yaml new file mode 100644 index 000000000..19e3bb0b6 --- /dev/null +++ b/releasenotes/notes/add-azure-kwargs-6a5ab1358ef7f44c.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Expose default_headers to pass custom headers to Azure API including APIM subscription key. + - | + Add optional azure_kwargs dictionary parameter to pass in parameters undefined in Haystack but supported by AzureOpenAI. diff --git a/test/components/generators/chat/test_azure.py b/test/components/generators/chat/test_azure.py index 80a290a65..0741c9849 100644 --- a/test/components/generators/chat/test_azure.py +++ b/test/components/generators/chat/test_azure.py @@ -57,6 +57,7 @@ class TestOpenAIChatGenerator: "generation_kwargs": {}, "timeout": 30.0, "max_retries": 5, + "default_headers": {}, }, } @@ -84,6 +85,7 @@ class TestOpenAIChatGenerator: "timeout": 2.5, "max_retries": 10, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "default_headers": {}, }, } @@ -98,7 +100,7 @@ class TestOpenAIChatGenerator: @pytest.mark.integration @pytest.mark.skipif( - not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None), + not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None), reason=( "Please export env variables called AZURE_OPENAI_API_KEY containing " "the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing " diff --git a/test/components/generators/test_azure.py b/test/components/generators/test_azure.py index b4638a167..bca22e214 100644 --- a/test/components/generators/test_azure.py +++ b/test/components/generators/test_azure.py @@ -60,6 +60,7 @@ class TestAzureOpenAIGenerator: "timeout": 30.0, "max_retries": 5, "generation_kwargs": {}, + "default_headers": {}, }, } @@ -89,6 +90,7 @@ class TestAzureOpenAIGenerator: "timeout": 3.5, "max_retries": 10, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "default_headers": {}, }, } @@ -103,7 +105,7 @@ class TestAzureOpenAIGenerator: @pytest.mark.integration @pytest.mark.skipif( - not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None), + not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None), reason=( "Please export env variables called AZURE_OPENAI_API_KEY containing " "the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "