mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 07:29:06 +00:00
feat: Allow OpenAI client config in OpenAIChatGenerator and AzureOpenAIChatGenerator (#9215)
* Allow OpenAI client config in chat generator * Add init_http_client as a util method * Update azure chat gen * Fix linting
This commit is contained in:
parent
e5dc4ef94d
commit
498637788a
@ -5,12 +5,12 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from openai.lib.azure import AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.components.embedders import OpenAIDocumentEmbedder
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.http_client import init_http_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -106,7 +106,9 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
|
||||
:param default_headers: Default headers to send to the AzureOpenAI client.
|
||||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
|
||||
every request.
|
||||
:param http_client_kwargs: A dictionary of keyword arguments to configure a custom httpx.Client.
|
||||
:param http_client_kwargs:
|
||||
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
|
||||
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
|
||||
"""
|
||||
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
|
||||
# with the API.
|
||||
@ -152,18 +154,12 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
|
||||
"default_headers": self.default_headers,
|
||||
}
|
||||
|
||||
self.client = AzureOpenAI(http_client=self._init_http_client(async_client=False), **client_args)
|
||||
self.async_client = AsyncAzureOpenAI(http_client=self._init_http_client(async_client=True), **client_args)
|
||||
|
||||
def _init_http_client(self, async_client: bool = False):
|
||||
"""Internal method to initialize the httpx.Client."""
|
||||
if not self.http_client_kwargs:
|
||||
return None
|
||||
if not isinstance(self.http_client_kwargs, dict):
|
||||
raise TypeError("The parameter 'http_client_kwargs' must be a dictionary.")
|
||||
if async_client:
|
||||
return httpx.AsyncClient(**self.http_client_kwargs)
|
||||
return httpx.Client(**self.http_client_kwargs)
|
||||
self.client = AzureOpenAI(
|
||||
http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args
|
||||
)
|
||||
self.async_client = AsyncAzureOpenAI(
|
||||
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@ -5,12 +5,12 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from openai.lib.azure import AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.components.embedders import OpenAITextEmbedder
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.http_client import init_http_client
|
||||
|
||||
|
||||
@component
|
||||
@ -92,7 +92,9 @@ class AzureOpenAITextEmbedder(OpenAITextEmbedder):
|
||||
:param default_headers: Default headers to send to the AzureOpenAI client.
|
||||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
|
||||
every request.
|
||||
:param http_client_kwargs: A dictionary of keyword arguments to configure a custom httpx.Client.
|
||||
:param http_client_kwargs:
|
||||
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
|
||||
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
|
||||
|
||||
"""
|
||||
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
|
||||
@ -138,17 +140,12 @@ class AzureOpenAITextEmbedder(OpenAITextEmbedder):
|
||||
"default_headers": self.default_headers,
|
||||
}
|
||||
|
||||
self.client = AzureOpenAI(http_client=self._init_http_client(async_client=False), **client_kwargs)
|
||||
self.async_client = AsyncAzureOpenAI(http_client=self._init_http_client(async_client=True), **client_kwargs)
|
||||
|
||||
def _init_http_client(self, async_client: bool = False):
|
||||
if not self.http_client_kwargs:
|
||||
return None
|
||||
if not isinstance(self.http_client_kwargs, dict):
|
||||
raise TypeError("The parameter 'http_client_kwargs' must be a dictionary.")
|
||||
if async_client:
|
||||
return httpx.AsyncClient(**self.http_client_kwargs)
|
||||
return httpx.Client(**self.http_client_kwargs)
|
||||
self.client = AzureOpenAI(
|
||||
http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs
|
||||
)
|
||||
self.async_client = AsyncAzureOpenAI(
|
||||
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@ -18,6 +18,7 @@ from haystack.tools import (
|
||||
serialize_tools_or_toolset,
|
||||
)
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.http_client import init_http_client
|
||||
|
||||
|
||||
@component
|
||||
@ -66,6 +67,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
"""
|
||||
|
||||
# pylint: disable=super-init-not-called
|
||||
# ruff: noqa: PLR0913
|
||||
def __init__( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
azure_endpoint: Optional[str] = None,
|
||||
@ -83,6 +85,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
tools_strict: bool = False,
|
||||
*,
|
||||
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
|
||||
http_client_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Azure OpenAI Chat Generator component.
|
||||
@ -128,6 +131,9 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
|
||||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
|
||||
every request.
|
||||
:param http_client_kwargs:
|
||||
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
|
||||
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
|
||||
"""
|
||||
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
|
||||
# with the API.
|
||||
@ -158,7 +164,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
self.default_headers = default_headers or {}
|
||||
self.azure_ad_token_provider = azure_ad_token_provider
|
||||
|
||||
self.http_client_kwargs = http_client_kwargs
|
||||
_check_duplicate_tool_names(list(tools or []))
|
||||
self.tools = tools
|
||||
self.tools_strict = tools_strict
|
||||
@ -176,8 +182,12 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
}
|
||||
|
||||
self.client = AzureOpenAI(**client_args)
|
||||
self.async_client = AsyncAzureOpenAI(**client_args)
|
||||
self.client = AzureOpenAI(
|
||||
http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args
|
||||
)
|
||||
self.async_client = AsyncAzureOpenAI(
|
||||
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -206,6 +216,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
tools=serialize_tools_or_toolset(self.tools),
|
||||
tools_strict=self.tools_strict,
|
||||
azure_ad_token_provider=azure_ad_token_provider_name,
|
||||
http_client_kwargs=self.http_client_kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -30,6 +30,7 @@ from haystack.tools import (
|
||||
serialize_tools_or_toolset,
|
||||
)
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.http_client import init_http_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -89,6 +90,7 @@ class OpenAIChatGenerator:
|
||||
max_retries: Optional[int] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
tools_strict: bool = False,
|
||||
http_client_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in `model`, uses OpenAI's gpt-4o-mini
|
||||
@ -138,6 +140,9 @@ class OpenAIChatGenerator:
|
||||
:param tools_strict:
|
||||
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
|
||||
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
|
||||
:param http_client_kwargs:
|
||||
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
|
||||
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
@ -149,7 +154,7 @@ class OpenAIChatGenerator:
|
||||
self.max_retries = max_retries
|
||||
self.tools = tools # Store tools as-is, whether it's a list or a Toolset
|
||||
self.tools_strict = tools_strict
|
||||
|
||||
self.http_client_kwargs = http_client_kwargs
|
||||
# Check for duplicate tool names
|
||||
_check_duplicate_tool_names(list(self.tools or []))
|
||||
|
||||
@ -158,7 +163,7 @@ class OpenAIChatGenerator:
|
||||
if max_retries is None:
|
||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
|
||||
client_args: Dict[str, Any] = {
|
||||
client_kwargs: Dict[str, Any] = {
|
||||
"api_key": api_key.resolve_value(),
|
||||
"organization": organization,
|
||||
"base_url": api_base_url,
|
||||
@ -166,8 +171,10 @@ class OpenAIChatGenerator:
|
||||
"max_retries": max_retries,
|
||||
}
|
||||
|
||||
self.client = OpenAI(**client_args)
|
||||
self.async_client = AsyncOpenAI(**client_args)
|
||||
self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs)
|
||||
self.async_client = AsyncOpenAI(
|
||||
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
|
||||
)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -195,6 +202,7 @@ class OpenAIChatGenerator:
|
||||
max_retries=self.max_retries,
|
||||
tools=serialize_tools_or_toolset(self.tools),
|
||||
tools_strict=self.tools_strict,
|
||||
http_client_kwargs=self.http_client_kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
28
haystack/utils/http_client.py
Normal file
28
haystack/utils/http_client.py
Normal file
@ -0,0 +1,28 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
def init_http_client(http_client_kwargs: Optional[Dict[str, Any]] = None, async_client: bool = False):
|
||||
"""
|
||||
Initialize an httpx client based on the http_client_kwargs.
|
||||
|
||||
:param http_client_kwargs:
|
||||
The kwargs to pass to the httpx client.
|
||||
:param async_client:
|
||||
Whether to initialize an async client.
|
||||
|
||||
:returns:
|
||||
A httpx client or an async httpx client.
|
||||
"""
|
||||
if not http_client_kwargs:
|
||||
return None
|
||||
if not isinstance(http_client_kwargs, dict):
|
||||
raise TypeError("The parameter 'http_client_kwargs' must be a dictionary.")
|
||||
if async_client:
|
||||
return httpx.AsyncClient(**http_client_kwargs)
|
||||
return httpx.Client(**http_client_kwargs)
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
`OpenAIChatGenerator` and `AzureOpenAIChatGenerator` now support custom HTTP client config via `http_client_kwargs`, enabling proxy and SSL setup.
|
||||
@ -190,6 +190,7 @@ class TestAgent:
|
||||
"max_retries": None,
|
||||
"tools": None,
|
||||
"tools_strict": False,
|
||||
"http_client_kwargs": None,
|
||||
},
|
||||
},
|
||||
"tools": [
|
||||
@ -256,6 +257,7 @@ class TestAgent:
|
||||
"max_retries": None,
|
||||
"tools": None,
|
||||
"tools_strict": False,
|
||||
"http_client_kwargs": None,
|
||||
},
|
||||
},
|
||||
"tools": [
|
||||
|
||||
@ -7,7 +7,6 @@ from openai import APIError
|
||||
|
||||
from haystack.utils.auth import Secret
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
from haystack import Document
|
||||
from haystack.components.embedders import AzureOpenAIDocumentEmbedder
|
||||
@ -215,33 +214,6 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
assert len(caplog.records) == 1
|
||||
assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.text
|
||||
|
||||
def test_init_http_client(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
|
||||
|
||||
embedder = AzureOpenAIDocumentEmbedder()
|
||||
client = embedder._init_http_client()
|
||||
assert client is None
|
||||
|
||||
embedder.http_client_kwargs = {"proxy": "http://example.com:3128"}
|
||||
client = embedder._init_http_client(async_client=False)
|
||||
assert isinstance(client, httpx.Client)
|
||||
|
||||
client = embedder._init_http_client(async_client=True)
|
||||
assert isinstance(client, httpx.AsyncClient)
|
||||
|
||||
def test_http_client_kwargs_type_validation(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
|
||||
with pytest.raises(TypeError, match="The parameter 'http_client_kwargs' must be a dictionary."):
|
||||
AzureOpenAIDocumentEmbedder(http_client_kwargs="invalid_argument")
|
||||
|
||||
def test_http_client_kwargs_with_invalid_params(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
|
||||
with pytest.raises(TypeError, match="unexpected keyword argument"):
|
||||
AzureOpenAIDocumentEmbedder(http_client_kwargs={"invalid_key": "invalid_value"})
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from haystack.components.embedders import AzureOpenAITextEmbedder
|
||||
@ -169,33 +168,6 @@ class TestAzureOpenAITextEmbedder:
|
||||
assert component.azure_ad_token_provider is not None
|
||||
assert component.http_client_kwargs == {"proxy": "http://example.com:3128", "verify": False}
|
||||
|
||||
def test_init_http_client(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
|
||||
|
||||
embedder = AzureOpenAITextEmbedder()
|
||||
client = embedder._init_http_client()
|
||||
assert client is None
|
||||
|
||||
embedder.http_client_kwargs = {"proxy": "http://example.com:3128"}
|
||||
client = embedder._init_http_client(async_client=False)
|
||||
assert isinstance(client, httpx.Client)
|
||||
|
||||
client = embedder._init_http_client(async_client=True)
|
||||
assert isinstance(client, httpx.AsyncClient)
|
||||
|
||||
def test_http_client_kwargs_type_validation(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
|
||||
with pytest.raises(TypeError, match="The parameter 'http_client_kwargs' must be a dictionary."):
|
||||
AzureOpenAITextEmbedder(http_client_kwargs="invalid_argument")
|
||||
|
||||
def test_http_client_kwargs_with_invalid_params(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com")
|
||||
with pytest.raises(TypeError, match="unexpected keyword argument"):
|
||||
AzureOpenAITextEmbedder(http_client_kwargs={"invalid_key": "invalid_value"})
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
|
||||
@ -110,6 +110,7 @@ class TestAzureOpenAIChatGenerator:
|
||||
"tools": None,
|
||||
"tools_strict": False,
|
||||
"azure_ad_token_provider": None,
|
||||
"http_client_kwargs": None,
|
||||
},
|
||||
}
|
||||
|
||||
@ -124,6 +125,7 @@ class TestAzureOpenAIChatGenerator:
|
||||
max_retries=10,
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
azure_ad_token_provider=default_azure_ad_token_provider,
|
||||
http_client_kwargs={"proxy": "http://localhost:8080"},
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
@ -143,6 +145,7 @@ class TestAzureOpenAIChatGenerator:
|
||||
"tools_strict": False,
|
||||
"default_headers": {},
|
||||
"azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider",
|
||||
"http_client_kwargs": {"proxy": "http://localhost:8080"},
|
||||
},
|
||||
}
|
||||
|
||||
@ -175,6 +178,7 @@ class TestAzureOpenAIChatGenerator:
|
||||
}
|
||||
],
|
||||
"tools_strict": False,
|
||||
"http_client_kwargs": None,
|
||||
},
|
||||
}
|
||||
|
||||
@ -196,6 +200,7 @@ class TestAzureOpenAIChatGenerator:
|
||||
Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)
|
||||
]
|
||||
assert generator.tools_strict == False
|
||||
assert generator.http_client_kwargs is None
|
||||
|
||||
def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
@ -225,6 +230,7 @@ class TestAzureOpenAIChatGenerator:
|
||||
"tools": None,
|
||||
"tools_strict": False,
|
||||
"azure_ad_token_provider": None,
|
||||
"http_client_kwargs": None,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
@ -101,6 +101,7 @@ class TestOpenAIChatGenerator:
|
||||
assert component.client.max_retries == 5
|
||||
assert component.tools is None
|
||||
assert not component.tools_strict
|
||||
assert component.http_client_kwargs is None
|
||||
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
@ -129,6 +130,7 @@ class TestOpenAIChatGenerator:
|
||||
max_retries=1,
|
||||
tools=[tool],
|
||||
tools_strict=True,
|
||||
http_client_kwargs={"proxy": "http://example.com:8080", "verify": False},
|
||||
)
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.model == "gpt-4o-mini"
|
||||
@ -138,6 +140,7 @@ class TestOpenAIChatGenerator:
|
||||
assert component.client.max_retries == 1
|
||||
assert component.tools == [tool]
|
||||
assert component.tools_strict
|
||||
assert component.http_client_kwargs == {"proxy": "http://example.com:8080", "verify": False}
|
||||
|
||||
def test_init_with_parameters_and_env_vars(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
|
||||
@ -173,6 +176,7 @@ class TestOpenAIChatGenerator:
|
||||
"tools_strict": False,
|
||||
"max_retries": None,
|
||||
"timeout": None,
|
||||
"http_client_kwargs": None,
|
||||
},
|
||||
}
|
||||
|
||||
@ -190,6 +194,7 @@ class TestOpenAIChatGenerator:
|
||||
tools_strict=True,
|
||||
max_retries=10,
|
||||
timeout=100.0,
|
||||
http_client_kwargs={"proxy": "http://example.com:8080", "verify": False},
|
||||
)
|
||||
data = component.to_dict()
|
||||
|
||||
@ -219,6 +224,7 @@ class TestOpenAIChatGenerator:
|
||||
}
|
||||
],
|
||||
"tools_strict": True,
|
||||
"http_client_kwargs": {"proxy": "http://example.com:8080", "verify": False},
|
||||
},
|
||||
}
|
||||
|
||||
@ -246,6 +252,7 @@ class TestOpenAIChatGenerator:
|
||||
}
|
||||
],
|
||||
"tools_strict": True,
|
||||
"http_client_kwargs": {"proxy": "http://example.com:8080", "verify": False},
|
||||
},
|
||||
}
|
||||
component = OpenAIChatGenerator.from_dict(data)
|
||||
@ -262,6 +269,7 @@ class TestOpenAIChatGenerator:
|
||||
assert component.tools_strict
|
||||
assert component.client.timeout == 100.0
|
||||
assert component.client.max_retries == 10
|
||||
assert component.http_client_kwargs == {"proxy": "http://example.com:8080", "verify": False}
|
||||
|
||||
def test_from_dict_fail_wo_env_var(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
|
||||
@ -320,6 +320,7 @@ class TestToolInvoker:
|
||||
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
|
||||
"tools": None,
|
||||
"tools_strict": False,
|
||||
"http_client_kwargs": None,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
43
test/utils/test_http_client.py
Normal file
43
test/utils/test_http_client.py
Normal file
@ -0,0 +1,43 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from haystack.utils.http_client import init_http_client
|
||||
import httpx
|
||||
|
||||
|
||||
def test_init_http_client():
|
||||
# test without any params
|
||||
http_client = init_http_client()
|
||||
assert http_client is None
|
||||
|
||||
# test client is initialized with http_client_kwargs
|
||||
http_client = init_http_client(http_client_kwargs={"base_url": "https://example.com"})
|
||||
assert http_client is not None
|
||||
assert isinstance(http_client, httpx.Client)
|
||||
assert http_client.base_url == "https://example.com"
|
||||
|
||||
|
||||
def test_init_http_client_async():
|
||||
# test without any params
|
||||
http_async_client = init_http_client(async_client=True)
|
||||
assert http_async_client is None
|
||||
|
||||
# test async client is initialized with http_client_kwargs
|
||||
http_async_client = init_http_client(http_client_kwargs={"base_url": "https://example.com"}, async_client=True)
|
||||
assert http_async_client is not None
|
||||
assert isinstance(http_async_client, httpx.AsyncClient)
|
||||
assert http_async_client.base_url == "https://example.com"
|
||||
|
||||
|
||||
def test_http_client_kwargs_type_validation():
|
||||
# test http_client_kwargs is not a dictionary
|
||||
with pytest.raises(TypeError, match="The parameter 'http_client_kwargs' must be a dictionary."):
|
||||
init_http_client(http_client_kwargs="invalid")
|
||||
|
||||
|
||||
def test_http_client_kwargs_with_invalid_params():
|
||||
# test http_client_kwargs with invalid keys
|
||||
with pytest.raises(TypeError, match="unexpected keyword argument"):
|
||||
init_http_client(http_client_kwargs={"invalid_key": "invalid"})
|
||||
Loading…
x
Reference in New Issue
Block a user