mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-29 16:08:38 +00:00
feat: add default_headers for Azure embedders (#8699)
* Add default_headers param to azure embedders
This commit is contained in:
parent
4f73b192f8
commit
db76ae2847
@ -51,6 +51,8 @@ class AzureOpenAIDocumentEmbedder:
|
||||
embedding_separator: str = "\n",
|
||||
timeout: Optional[float] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
*,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Creates an AzureOpenAIDocumentEmbedder component.
|
||||
@ -95,6 +97,7 @@ class AzureOpenAIDocumentEmbedder:
|
||||
`OPENAI_TIMEOUT` environment variable, or 30 seconds.
|
||||
:param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error.
|
||||
If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries.
|
||||
:param default_headers: Default headers to send to the AzureOpenAI client.
|
||||
"""
|
||||
# if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
|
||||
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
@ -119,6 +122,7 @@ class AzureOpenAIDocumentEmbedder:
|
||||
self.embedding_separator = embedding_separator
|
||||
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
self.default_headers = default_headers or {}
|
||||
|
||||
self._client = AzureOpenAI(
|
||||
api_version=api_version,
|
||||
@ -129,6 +133,7 @@ class AzureOpenAIDocumentEmbedder:
|
||||
organization=organization,
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
default_headers=self.default_headers,
|
||||
)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
@ -161,6 +166,7 @@ class AzureOpenAIDocumentEmbedder:
|
||||
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
default_headers=self.default_headers,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -46,6 +46,8 @@ class AzureOpenAITextEmbedder:
|
||||
max_retries: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
*,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Creates an AzureOpenAITextEmbedder component.
|
||||
@ -82,6 +84,7 @@ class AzureOpenAITextEmbedder:
|
||||
A string to add at the beginning of each text.
|
||||
:param suffix:
|
||||
A string to add at the end of each text.
|
||||
:param default_headers: Default headers to send to the AzureOpenAI client.
|
||||
"""
|
||||
# Why is this here?
|
||||
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
|
||||
@ -105,6 +108,7 @@ class AzureOpenAITextEmbedder:
|
||||
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.default_headers = default_headers or {}
|
||||
|
||||
self._client = AzureOpenAI(
|
||||
api_version=api_version,
|
||||
@ -115,6 +119,7 @@ class AzureOpenAITextEmbedder:
|
||||
organization=organization,
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
default_headers=self.default_headers,
|
||||
)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
@ -143,6 +148,7 @@ class AzureOpenAITextEmbedder:
|
||||
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
default_headers=self.default_headers,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Added `default_headers` parameter to `AzureOpenAIDocumentEmbedder` and `AzureOpenAITextEmbedder`.
|
||||
@ -22,6 +22,7 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
assert embedder.progress_bar is True
|
||||
assert embedder.meta_fields_to_embed == []
|
||||
assert embedder.embedding_separator == "\n"
|
||||
assert embedder.default_headers == {}
|
||||
|
||||
def test_to_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
@ -45,9 +46,43 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
"embedding_separator": "\n",
|
||||
"max_retries": 5,
|
||||
"timeout": 30.0,
|
||||
"default_headers": {},
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
data = {
|
||||
"type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
|
||||
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"api_version": "2023-05-15",
|
||||
"azure_deployment": "text-embedding-ada-002",
|
||||
"dimensions": None,
|
||||
"azure_endpoint": "https://example-resource.azure.openai.com/",
|
||||
"organization": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"meta_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
"max_retries": 5,
|
||||
"timeout": 30.0,
|
||||
"default_headers": {},
|
||||
},
|
||||
}
|
||||
component = AzureOpenAIDocumentEmbedder.from_dict(data)
|
||||
assert component.azure_deployment == "text-embedding-ada-002"
|
||||
assert component.azure_endpoint == "https://example-resource.azure.openai.com/"
|
||||
assert component.api_version == "2023-05-15"
|
||||
assert component.max_retries == 5
|
||||
assert component.timeout == 30.0
|
||||
assert component.prefix == ""
|
||||
assert component.suffix == ""
|
||||
assert component.default_headers == {}
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
|
||||
@ -19,6 +19,7 @@ class TestAzureOpenAITextEmbedder:
|
||||
assert embedder.organization is None
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.default_headers == {}
|
||||
|
||||
def test_to_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
@ -38,9 +39,39 @@ class TestAzureOpenAITextEmbedder:
|
||||
"timeout": 30.0,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"default_headers": {},
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
data = {
|
||||
"type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
|
||||
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"azure_deployment": "text-embedding-ada-002",
|
||||
"dimensions": None,
|
||||
"organization": None,
|
||||
"azure_endpoint": "https://example-resource.azure.openai.com/",
|
||||
"api_version": "2023-05-15",
|
||||
"max_retries": 5,
|
||||
"timeout": 30.0,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"default_headers": {},
|
||||
},
|
||||
}
|
||||
component = AzureOpenAITextEmbedder.from_dict(data)
|
||||
assert component.azure_deployment == "text-embedding-ada-002"
|
||||
assert component.azure_endpoint == "https://example-resource.azure.openai.com/"
|
||||
assert component.api_version == "2023-05-15"
|
||||
assert component.max_retries == 5
|
||||
assert component.timeout == 30.0
|
||||
assert component.prefix == ""
|
||||
assert component.suffix == ""
|
||||
assert component.default_headers == {}
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user