feat: add default_headers for Azure embedders (#8699)

* Add default_headers param to azure embedders
This commit is contained in:
Amna Mubashar 2025-01-12 17:41:38 +01:00 committed by GitHub
parent 4f73b192f8
commit db76ae2847
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 82 additions and 0 deletions

View File

@ -51,6 +51,8 @@ class AzureOpenAIDocumentEmbedder:
embedding_separator: str = "\n",
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
*,
default_headers: Optional[Dict[str, str]] = None,
):
"""
Creates an AzureOpenAIDocumentEmbedder component.
@ -95,6 +97,7 @@ class AzureOpenAIDocumentEmbedder:
`OPENAI_TIMEOUT` environment variable, or 30 seconds.
:param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error.
If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries.
:param default_headers: Default headers to send to the AzureOpenAI client.
"""
# if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
@ -119,6 +122,7 @@ class AzureOpenAIDocumentEmbedder:
self.embedding_separator = embedding_separator
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.default_headers = default_headers or {}
self._client = AzureOpenAI(
api_version=api_version,
@ -129,6 +133,7 @@ class AzureOpenAIDocumentEmbedder:
organization=organization,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self.default_headers,
)
def _get_telemetry_data(self) -> Dict[str, Any]:
@ -161,6 +166,7 @@ class AzureOpenAIDocumentEmbedder:
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self.default_headers,
)
@classmethod

View File

@ -46,6 +46,8 @@ class AzureOpenAITextEmbedder:
max_retries: Optional[int] = None,
prefix: str = "",
suffix: str = "",
*,
default_headers: Optional[Dict[str, str]] = None,
):
"""
Creates an AzureOpenAITextEmbedder component.
@ -82,6 +84,7 @@ class AzureOpenAITextEmbedder:
A string to add at the beginning of each text.
:param suffix:
A string to add at the end of each text.
:param default_headers: Default headers to send to the AzureOpenAI client.
"""
# Why is this here?
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
@ -105,6 +108,7 @@ class AzureOpenAITextEmbedder:
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.prefix = prefix
self.suffix = suffix
self.default_headers = default_headers or {}
self._client = AzureOpenAI(
api_version=api_version,
@ -115,6 +119,7 @@ class AzureOpenAITextEmbedder:
organization=organization,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self.default_headers,
)
def _get_telemetry_data(self) -> Dict[str, Any]:
@ -143,6 +148,7 @@ class AzureOpenAITextEmbedder:
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self.default_headers,
)
@classmethod

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Added `default_headers` parameter to `AzureOpenAIDocumentEmbedder` and `AzureOpenAITextEmbedder`.

View File

@ -22,6 +22,7 @@ class TestAzureOpenAIDocumentEmbedder:
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
assert embedder.default_headers == {}
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
@ -45,9 +46,43 @@ class TestAzureOpenAIDocumentEmbedder:
"embedding_separator": "\n",
"max_retries": 5,
"timeout": 30.0,
"default_headers": {},
},
}
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
data = {
"type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"api_version": "2023-05-15",
"azure_deployment": "text-embedding-ada-002",
"dimensions": None,
"azure_endpoint": "https://example-resource.azure.openai.com/",
"organization": None,
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
"max_retries": 5,
"timeout": 30.0,
"default_headers": {},
},
}
component = AzureOpenAIDocumentEmbedder.from_dict(data)
assert component.azure_deployment == "text-embedding-ada-002"
assert component.azure_endpoint == "https://example-resource.azure.openai.com/"
assert component.api_version == "2023-05-15"
assert component.max_retries == 5
assert component.timeout == 30.0
assert component.prefix == ""
assert component.suffix == ""
assert component.default_headers == {}
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),

View File

@ -19,6 +19,7 @@ class TestAzureOpenAITextEmbedder:
assert embedder.organization is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.default_headers == {}
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
@ -38,9 +39,39 @@ class TestAzureOpenAITextEmbedder:
"timeout": 30.0,
"prefix": "",
"suffix": "",
"default_headers": {},
},
}
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
data = {
"type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"azure_deployment": "text-embedding-ada-002",
"dimensions": None,
"organization": None,
"azure_endpoint": "https://example-resource.azure.openai.com/",
"api_version": "2023-05-15",
"max_retries": 5,
"timeout": 30.0,
"prefix": "",
"suffix": "",
"default_headers": {},
},
}
component = AzureOpenAITextEmbedder.from_dict(data)
assert component.azure_deployment == "text-embedding-ada-002"
assert component.azure_endpoint == "https://example-resource.azure.openai.com/"
assert component.api_version == "2023-05-15"
assert component.max_retries == 5
assert component.timeout == 30.0
assert component.prefix == ""
assert component.suffix == ""
assert component.default_headers == {}
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),