feat: Allow OpenAI client config in OpenAIChatGenerator and AzureOpenAIChatGenerator (#9215)

* Allow OpenAI client config in chat generator

* Add init_http_client as a util method

* Update azure chat gen

* Fix linting
This commit is contained in:
Amna Mubashar 2025-04-16 21:32:13 +05:00 committed by GitHub
parent e5dc4ef94d
commit 498637788a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 138 additions and 90 deletions

View File

@ -5,12 +5,12 @@
import os
from typing import Any, Dict, List, Optional
import httpx
from openai.lib.azure import AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.embedders import OpenAIDocumentEmbedder
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client
logger = logging.getLogger(__name__)
@ -106,7 +106,9 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
:param default_headers: Default headers to send to the AzureOpenAI client.
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
every request.
:param http_client_kwargs: A dictionary of keyword arguments to configure a custom httpx.Client.
:param http_client_kwargs:
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
"""
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
# with the API.
@ -152,18 +154,12 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
"default_headers": self.default_headers,
}
self.client = AzureOpenAI(http_client=self._init_http_client(async_client=False), **client_args)
self.async_client = AsyncAzureOpenAI(http_client=self._init_http_client(async_client=True), **client_args)
def _init_http_client(self, async_client: bool = False):
"""Internal method to initialize the httpx.Client."""
if not self.http_client_kwargs:
return None
if not isinstance(self.http_client_kwargs, dict):
raise TypeError("The parameter 'http_client_kwargs' must be a dictionary.")
if async_client:
return httpx.AsyncClient(**self.http_client_kwargs)
return httpx.Client(**self.http_client_kwargs)
self.client = AzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args
)
self.async_client = AsyncAzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
)
def to_dict(self) -> Dict[str, Any]:
"""

View File

@ -5,12 +5,12 @@
import os
from typing import Any, Dict, Optional
import httpx
from openai.lib.azure import AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI
from haystack import component, default_from_dict, default_to_dict
from haystack.components.embedders import OpenAITextEmbedder
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client
@component
@ -92,7 +92,9 @@ class AzureOpenAITextEmbedder(OpenAITextEmbedder):
:param default_headers: Default headers to send to the AzureOpenAI client.
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
every request.
:param http_client_kwargs: A dictionary of keyword arguments to configure a custom httpx.Client.
:param http_client_kwargs:
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
"""
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
@ -138,17 +140,12 @@ class AzureOpenAITextEmbedder(OpenAITextEmbedder):
"default_headers": self.default_headers,
}
self.client = AzureOpenAI(http_client=self._init_http_client(async_client=False), **client_kwargs)
self.async_client = AsyncAzureOpenAI(http_client=self._init_http_client(async_client=True), **client_kwargs)
def _init_http_client(self, async_client: bool = False):
if not self.http_client_kwargs:
return None
if not isinstance(self.http_client_kwargs, dict):
raise TypeError("The parameter 'http_client_kwargs' must be a dictionary.")
if async_client:
return httpx.AsyncClient(**self.http_client_kwargs)
return httpx.Client(**self.http_client_kwargs)
self.client = AzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs
)
self.async_client = AsyncAzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
)
def to_dict(self) -> Dict[str, Any]:
"""

View File

@ -18,6 +18,7 @@ from haystack.tools import (
serialize_tools_or_toolset,
)
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client
@component
@ -66,6 +67,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
"""
# pylint: disable=super-init-not-called
# ruff: noqa: PLR0913
def __init__( # pylint: disable=too-many-positional-arguments
self,
azure_endpoint: Optional[str] = None,
@ -83,6 +85,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
tools_strict: bool = False,
*,
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
http_client_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Initialize the Azure OpenAI Chat Generator component.
@ -128,6 +131,9 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
every request.
:param http_client_kwargs:
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
"""
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
# with the API.
@ -158,7 +164,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider
self.http_client_kwargs = http_client_kwargs
_check_duplicate_tool_names(list(tools or []))
self.tools = tools
self.tools_strict = tools_strict
@ -176,8 +182,12 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
"azure_ad_token_provider": azure_ad_token_provider,
}
self.client = AzureOpenAI(**client_args)
self.async_client = AsyncAzureOpenAI(**client_args)
self.client = AzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args
)
self.async_client = AsyncAzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
)
def to_dict(self) -> Dict[str, Any]:
"""
@ -206,6 +216,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
tools=serialize_tools_or_toolset(self.tools),
tools_strict=self.tools_strict,
azure_ad_token_provider=azure_ad_token_provider_name,
http_client_kwargs=self.http_client_kwargs,
)
@classmethod

View File

@ -30,6 +30,7 @@ from haystack.tools import (
serialize_tools_or_toolset,
)
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client
logger = logging.getLogger(__name__)
@ -89,6 +90,7 @@ class OpenAIChatGenerator:
max_retries: Optional[int] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools_strict: bool = False,
http_client_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in `model`, uses OpenAI's gpt-4o-mini
@ -138,6 +140,9 @@ class OpenAIChatGenerator:
:param tools_strict:
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
:param http_client_kwargs:
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
"""
self.api_key = api_key
self.model = model
@ -149,7 +154,7 @@ class OpenAIChatGenerator:
self.max_retries = max_retries
self.tools = tools # Store tools as-is, whether it's a list or a Toolset
self.tools_strict = tools_strict
self.http_client_kwargs = http_client_kwargs
# Check for duplicate tool names
_check_duplicate_tool_names(list(self.tools or []))
@ -158,7 +163,7 @@ class OpenAIChatGenerator:
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
client_args: Dict[str, Any] = {
client_kwargs: Dict[str, Any] = {
"api_key": api_key.resolve_value(),
"organization": organization,
"base_url": api_base_url,
@ -166,8 +171,10 @@ class OpenAIChatGenerator:
"max_retries": max_retries,
}
self.client = OpenAI(**client_args)
self.async_client = AsyncOpenAI(**client_args)
self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs)
self.async_client = AsyncOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -195,6 +202,7 @@ class OpenAIChatGenerator:
max_retries=self.max_retries,
tools=serialize_tools_or_toolset(self.tools),
tools_strict=self.tools_strict,
http_client_kwargs=self.http_client_kwargs,
)
@classmethod

View File

@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Optional
import httpx
def init_http_client(http_client_kwargs: Optional[Dict[str, Any]] = None, async_client: bool = False):
"""
Initialize an httpx client based on the http_client_kwargs.
:param http_client_kwargs:
The kwargs to pass to the httpx client.
:param async_client:
Whether to initialize an async client.
:returns:
A httpx client or an async httpx client.
"""
if not http_client_kwargs:
return None
if not isinstance(http_client_kwargs, dict):
raise TypeError("The parameter 'http_client_kwargs' must be a dictionary.")
if async_client:
return httpx.AsyncClient(**http_client_kwargs)
return httpx.Client(**http_client_kwargs)

View File

@ -0,0 +1,4 @@
---
features:
- |
`OpenAIChatGenerator` and `AzureOpenAIChatGenerator` now support custom HTTP client config via `http_client_kwargs`, enabling proxy and SSL setup.

View File

@ -190,6 +190,7 @@ class TestAgent:
"max_retries": None,
"tools": None,
"tools_strict": False,
"http_client_kwargs": None,
},
},
"tools": [
@ -256,6 +257,7 @@ class TestAgent:
"max_retries": None,
"tools": None,
"tools_strict": False,
"http_client_kwargs": None,
},
},
"tools": [

View File

@ -7,7 +7,6 @@ from openai import APIError
from haystack.utils.auth import Secret
import pytest
import httpx
from haystack import Document
from haystack.components.embedders import AzureOpenAIDocumentEmbedder
@ -215,33 +214,6 @@ class TestAzureOpenAIDocumentEmbedder:
assert len(caplog.records) == 1
assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.text
def test_init_http_client(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
embedder = AzureOpenAIDocumentEmbedder()
client = embedder._init_http_client()
assert client is None
embedder.http_client_kwargs = {"proxy": "http://example.com:3128"}
client = embedder._init_http_client(async_client=False)
assert isinstance(client, httpx.Client)
client = embedder._init_http_client(async_client=True)
assert isinstance(client, httpx.AsyncClient)
def test_http_client_kwargs_type_validation(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
with pytest.raises(TypeError, match="The parameter 'http_client_kwargs' must be a dictionary."):
AzureOpenAIDocumentEmbedder(http_client_kwargs="invalid_argument")
def test_http_client_kwargs_with_invalid_params(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
with pytest.raises(TypeError, match="unexpected keyword argument"):
AzureOpenAIDocumentEmbedder(http_client_kwargs={"invalid_key": "invalid_value"})
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),

View File

@ -3,7 +3,6 @@
# SPDX-License-Identifier: Apache-2.0
import os
import httpx
import pytest
from haystack.components.embedders import AzureOpenAITextEmbedder
@ -169,33 +168,6 @@ class TestAzureOpenAITextEmbedder:
assert component.azure_ad_token_provider is not None
assert component.http_client_kwargs == {"proxy": "http://example.com:3128", "verify": False}
def test_init_http_client(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
embedder = AzureOpenAITextEmbedder()
client = embedder._init_http_client()
assert client is None
embedder.http_client_kwargs = {"proxy": "http://example.com:3128"}
client = embedder._init_http_client(async_client=False)
assert isinstance(client, httpx.Client)
client = embedder._init_http_client(async_client=True)
assert isinstance(client, httpx.AsyncClient)
def test_http_client_kwargs_type_validation(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
with pytest.raises(TypeError, match="The parameter 'http_client_kwargs' must be a dictionary."):
AzureOpenAITextEmbedder(http_client_kwargs="invalid_argument")
def test_http_client_kwargs_with_invalid_params(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
with pytest.raises(TypeError, match="unexpected keyword argument"):
AzureOpenAITextEmbedder(http_client_kwargs={"invalid_key": "invalid_value"})
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),

View File

@ -110,6 +110,7 @@ class TestAzureOpenAIChatGenerator:
"tools": None,
"tools_strict": False,
"azure_ad_token_provider": None,
"http_client_kwargs": None,
},
}
@ -124,6 +125,7 @@ class TestAzureOpenAIChatGenerator:
max_retries=10,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
azure_ad_token_provider=default_azure_ad_token_provider,
http_client_kwargs={"proxy": "http://localhost:8080"},
)
data = component.to_dict()
assert data == {
@ -143,6 +145,7 @@ class TestAzureOpenAIChatGenerator:
"tools_strict": False,
"default_headers": {},
"azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider",
"http_client_kwargs": {"proxy": "http://localhost:8080"},
},
}
@ -175,6 +178,7 @@ class TestAzureOpenAIChatGenerator:
}
],
"tools_strict": False,
"http_client_kwargs": None,
},
}
@ -196,6 +200,7 @@ class TestAzureOpenAIChatGenerator:
Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
]
assert generator.tools_strict == False
assert generator.http_client_kwargs is None
def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
@ -225,6 +230,7 @@ class TestAzureOpenAIChatGenerator:
"tools": None,
"tools_strict": False,
"azure_ad_token_provider": None,
"http_client_kwargs": None,
},
}
},

View File

@ -101,6 +101,7 @@ class TestOpenAIChatGenerator:
assert component.client.max_retries == 5
assert component.tools is None
assert not component.tools_strict
assert component.http_client_kwargs is None
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
@ -129,6 +130,7 @@ class TestOpenAIChatGenerator:
max_retries=1,
tools=[tool],
tools_strict=True,
http_client_kwargs={"proxy": "http://example.com:8080", "verify": False},
)
assert component.client.api_key == "test-api-key"
assert component.model == "gpt-4o-mini"
@ -138,6 +140,7 @@ class TestOpenAIChatGenerator:
assert component.client.max_retries == 1
assert component.tools == [tool]
assert component.tools_strict
assert component.http_client_kwargs == {"proxy": "http://example.com:8080", "verify": False}
def test_init_with_parameters_and_env_vars(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
@ -173,6 +176,7 @@ class TestOpenAIChatGenerator:
"tools_strict": False,
"max_retries": None,
"timeout": None,
"http_client_kwargs": None,
},
}
@ -190,6 +194,7 @@ class TestOpenAIChatGenerator:
tools_strict=True,
max_retries=10,
timeout=100.0,
http_client_kwargs={"proxy": "http://example.com:8080", "verify": False},
)
data = component.to_dict()
@ -219,6 +224,7 @@ class TestOpenAIChatGenerator:
}
],
"tools_strict": True,
"http_client_kwargs": {"proxy": "http://example.com:8080", "verify": False},
},
}
@ -246,6 +252,7 @@ class TestOpenAIChatGenerator:
}
],
"tools_strict": True,
"http_client_kwargs": {"proxy": "http://example.com:8080", "verify": False},
},
}
component = OpenAIChatGenerator.from_dict(data)
@ -262,6 +269,7 @@ class TestOpenAIChatGenerator:
assert component.tools_strict
assert component.client.timeout == 100.0
assert component.client.max_retries == 10
assert component.http_client_kwargs == {"proxy": "http://example.com:8080", "verify": False}
def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)

View File

@ -320,6 +320,7 @@ class TestToolInvoker:
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"tools": None,
"tools_strict": False,
"http_client_kwargs": None,
},
},
},

View File

@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from haystack.utils.http_client import init_http_client
import httpx
def test_init_http_client():
# test without any params
http_client = init_http_client()
assert http_client is None
# test client is initialized with http_client_kwargs
http_client = init_http_client(http_client_kwargs={"base_url": "https://example.com"})
assert http_client is not None
assert isinstance(http_client, httpx.Client)
assert http_client.base_url == "https://example.com"
def test_init_http_client_async():
# test without any params
http_async_client = init_http_client(async_client=True)
assert http_async_client is None
# test async client is initialized with http_client_kwargs
http_async_client = init_http_client(http_client_kwargs={"base_url": "https://example.com"}, async_client=True)
assert http_async_client is not None
assert isinstance(http_async_client, httpx.AsyncClient)
assert http_async_client.base_url == "https://example.com"
def test_http_client_kwargs_type_validation():
# test http_client_kwargs is not a dictionary
with pytest.raises(TypeError, match="The parameter 'http_client_kwargs' must be a dictionary."):
init_http_client(http_client_kwargs="invalid")
def test_http_client_kwargs_with_invalid_params():
# test http_client_kwargs with invalid keys
with pytest.raises(TypeError, match="unexpected keyword argument"):
init_http_client(http_client_kwargs={"invalid_key": "invalid"})