From db76ae28472da64ff05eba295c2fca72e5d7d3a0 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Sun, 12 Jan 2025 17:41:38 +0100 Subject: [PATCH] feat: add `default_headers` for Azure embedders (#8699) * Add default_headers param to azure embedders --- .../embedders/azure_document_embedder.py | 6 ++++ .../embedders/azure_text_embedder.py | 6 ++++ ...ders-azure-embedders-6ffd24ec1502c5e4.yaml | 4 +++ .../embedders/test_azure_document_embedder.py | 35 +++++++++++++++++++ .../embedders/test_azure_text_embedder.py | 31 ++++++++++++++++ 5 files changed, 82 insertions(+) create mode 100644 releasenotes/notes/add-default-headers-azure-embedders-6ffd24ec1502c5e4.yaml diff --git a/haystack/components/embedders/azure_document_embedder.py b/haystack/components/embedders/azure_document_embedder.py index b28fc1fda..2223c82bf 100644 --- a/haystack/components/embedders/azure_document_embedder.py +++ b/haystack/components/embedders/azure_document_embedder.py @@ -51,6 +51,8 @@ class AzureOpenAIDocumentEmbedder: embedding_separator: str = "\n", timeout: Optional[float] = None, max_retries: Optional[int] = None, + *, + default_headers: Optional[Dict[str, str]] = None, ): """ Creates an AzureOpenAIDocumentEmbedder component. @@ -95,6 +97,7 @@ class AzureOpenAIDocumentEmbedder: `OPENAI_TIMEOUT` environment variable, or 30 seconds. :param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error. If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries. + :param default_headers: Default headers to send to the AzureOpenAI client. """ # if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT") @@ -119,6 +122,7 @@ class AzureOpenAIDocumentEmbedder: self.embedding_separator = embedding_separator self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0)) self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) + self.default_headers = default_headers or {} self._client = AzureOpenAI( api_version=api_version, @@ -129,6 +133,7 @@ class AzureOpenAIDocumentEmbedder: organization=organization, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) def _get_telemetry_data(self) -> Dict[str, Any]: @@ -161,6 +166,7 @@ class AzureOpenAIDocumentEmbedder: azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) @classmethod diff --git a/haystack/components/embedders/azure_text_embedder.py b/haystack/components/embedders/azure_text_embedder.py index bef34d6c3..c37bf2338 100644 --- a/haystack/components/embedders/azure_text_embedder.py +++ b/haystack/components/embedders/azure_text_embedder.py @@ -46,6 +46,8 @@ class AzureOpenAITextEmbedder: max_retries: Optional[int] = None, prefix: str = "", suffix: str = "", + *, + default_headers: Optional[Dict[str, str]] = None, ): """ Creates an AzureOpenAITextEmbedder component. @@ -82,6 +84,7 @@ class AzureOpenAITextEmbedder: A string to add at the beginning of each text. :param suffix: A string to add at the end of each text. + :param default_headers: Default headers to send to the AzureOpenAI client. """ # Why is this here? # AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not @@ -105,6 +108,7 @@ class AzureOpenAITextEmbedder: self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5)) self.prefix = prefix self.suffix = suffix + self.default_headers = default_headers or {} self._client = AzureOpenAI( api_version=api_version, @@ -115,6 +119,7 @@ class AzureOpenAITextEmbedder: organization=organization, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) def _get_telemetry_data(self) -> Dict[str, Any]: @@ -143,6 +148,7 @@ class AzureOpenAITextEmbedder: azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, timeout=self.timeout, max_retries=self.max_retries, + default_headers=self.default_headers, ) @classmethod diff --git a/releasenotes/notes/add-default-headers-azure-embedders-6ffd24ec1502c5e4.yaml b/releasenotes/notes/add-default-headers-azure-embedders-6ffd24ec1502c5e4.yaml new file mode 100644 index 000000000..f8a7401c3 --- /dev/null +++ b/releasenotes/notes/add-default-headers-azure-embedders-6ffd24ec1502c5e4.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Added `default_headers` parameter to `AzureOpenAIDocumentEmbedder` and `AzureOpenAITextEmbedder`. diff --git a/test/components/embedders/test_azure_document_embedder.py b/test/components/embedders/test_azure_document_embedder.py index 354f35a0f..033ed3612 100644 --- a/test/components/embedders/test_azure_document_embedder.py +++ b/test/components/embedders/test_azure_document_embedder.py @@ -22,6 +22,7 @@ class TestAzureOpenAIDocumentEmbedder: assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" + assert embedder.default_headers == {} def test_to_dict(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") @@ -45,9 +46,43 @@ class TestAzureOpenAIDocumentEmbedder: "embedding_separator": "\n", "max_retries": 5, "timeout": 30.0, + "default_headers": {}, }, } + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + data = { + "type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"}, + "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"}, + "api_version": "2023-05-15", + "azure_deployment": "text-embedding-ada-002", + "dimensions": None, + "azure_endpoint": "https://example-resource.azure.openai.com/", + "organization": None, + "prefix": "", + "suffix": "", + "batch_size": 32, + "progress_bar": True, + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "max_retries": 5, + "timeout": 30.0, + "default_headers": {}, + }, + } + component = AzureOpenAIDocumentEmbedder.from_dict(data) + assert component.azure_deployment == "text-embedding-ada-002" + assert component.azure_endpoint == "https://example-resource.azure.openai.com/" + assert component.api_version == "2023-05-15" + assert component.max_retries == 5 + assert component.timeout == 30.0 + assert component.prefix == "" + assert component.suffix == "" + assert component.default_headers == {} + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None), diff --git a/test/components/embedders/test_azure_text_embedder.py b/test/components/embedders/test_azure_text_embedder.py index 5f1f82e3d..f4aab7edf 100644 --- a/test/components/embedders/test_azure_text_embedder.py +++ b/test/components/embedders/test_azure_text_embedder.py @@ -19,6 +19,7 @@ class TestAzureOpenAITextEmbedder: assert embedder.organization is None assert embedder.prefix == "" assert embedder.suffix == "" + assert embedder.default_headers == {} def test_to_dict(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") @@ -38,9 +39,39 @@ class TestAzureOpenAITextEmbedder: "timeout": 30.0, "prefix": "", "suffix": "", + "default_headers": {}, }, } + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + data = { + "type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"}, + "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"}, + "azure_deployment": "text-embedding-ada-002", + "dimensions": None, + "organization": None, + "azure_endpoint": "https://example-resource.azure.openai.com/", + "api_version": "2023-05-15", + "max_retries": 5, + "timeout": 30.0, + "prefix": "", + "suffix": "", + "default_headers": {}, + }, + } + component = AzureOpenAITextEmbedder.from_dict(data) + assert component.azure_deployment == "text-embedding-ada-002" + assert component.azure_endpoint == "https://example-resource.azure.openai.com/" + assert component.api_version == "2023-05-15" + assert component.max_retries == 5 + assert component.timeout == 30.0 + assert component.prefix == "" + assert component.suffix == "" + assert component.default_headers == {} + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),