feat: Add Toolset support in ChatGenerator(s) (#9177)

* Add Toolset support in ChatGenerator(s)

* Add reno note

* Update azure test

* Updates

* Minor fix

* Add more tests

* Remove some integration tests

* PR feedback

* rm unused fixture

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2025-04-07 16:12:09 +02:00 committed by GitHub
parent 63781afd8f
commit e1e797206d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 233 additions and 36 deletions

View File

@ -11,7 +11,9 @@ from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.chat import OpenAIChatGenerator from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses.streaming_chunk import StreamingCallbackT from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.tools.toolset import Toolset
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.misc import serialize_tools_or_toolset
@component @component
@ -73,7 +75,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
default_headers: Optional[Dict[str, str]] = None, default_headers: Optional[Dict[str, str]] = None,
tools: Optional[List[Tool]] = None, tools: Optional[Union[List[Tool], Toolset]] = None,
tools_strict: bool = False, tools_strict: bool = False,
*, *,
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None, azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
@ -115,7 +117,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
values are the bias to add to that token. values are the bias to add to that token.
:param default_headers: Default headers to use for the AzureOpenAI client. :param default_headers: Default headers to use for the AzureOpenAI client.
:param tools: :param tools:
A list of tools for which the model can prepare calls. A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
list of `Tool` objects or a `Toolset` instance.
:param tools_strict: :param tools_strict:
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
the schema provided in the `parameters` field of the tool definition, but this may increase latency. the schema provided in the `parameters` field of the tool definition, but this may increase latency.
@ -152,7 +155,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
self.default_headers = default_headers or {} self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider self.azure_ad_token_provider = azure_ad_token_provider
_check_duplicate_tool_names(tools) _check_duplicate_tool_names(list(tools or []))
self.tools = tools self.tools = tools
self.tools_strict = tools_strict self.tools_strict = tools_strict
@ -196,7 +199,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
api_key=self.api_key.to_dict() if self.api_key is not None else None, api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None, azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
default_headers=self.default_headers, default_headers=self.default_headers,
tools=[tool.to_dict() for tool in self.tools] if self.tools else None, tools=serialize_tools_or_toolset(self.tools),
tools_strict=self.tools_strict, tools_strict=self.tools_strict,
azure_ad_token_provider=azure_ad_token_provider_name, azure_ad_token_provider=azure_ad_token_provider_name,
) )

View File

@ -10,8 +10,10 @@ from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_s
from haystack.dataclasses.streaming_chunk import StreamingCallbackT from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.tools.toolset import Toolset
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
from haystack.utils.misc import serialize_tools_or_toolset
from haystack.utils.url_validation import is_valid_http_url from haystack.utils.url_validation import is_valid_http_url
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import: with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
@ -103,7 +105,7 @@ class HuggingFaceAPIChatGenerator:
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None, stop_words: Optional[List[str]] = None,
streaming_callback: Optional[StreamingCallbackT] = None, streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[List[Tool]] = None, tools: Optional[Union[List[Tool], Toolset]] = None,
): ):
""" """
Initialize the HuggingFaceAPIChatGenerator instance. Initialize the HuggingFaceAPIChatGenerator instance.
@ -130,10 +132,10 @@ class HuggingFaceAPIChatGenerator:
:param streaming_callback: :param streaming_callback:
An optional callable for handling streaming responses. An optional callable for handling streaming responses.
:param tools: :param tools:
A list of tools for which the model can prepare calls. A list of tools or a Toolset for which the model can prepare calls.
The chosen model should support tool/function calling, according to the model card. The chosen model should support tool/function calling, according to the model card.
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
unexpected behavior. unexpected behavior. This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
""" """
huggingface_hub_import.check() huggingface_hub_import.check()
@ -166,7 +168,7 @@ class HuggingFaceAPIChatGenerator:
if tools and streaming_callback is not None: if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools) _check_duplicate_tool_names(list(tools or []))
# handle generation kwargs setup # handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
@ -191,7 +193,6 @@ class HuggingFaceAPIChatGenerator:
A dictionary containing the serialized component. A dictionary containing the serialized component.
""" """
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
return default_to_dict( return default_to_dict(
self, self,
api_type=str(self.api_type), api_type=str(self.api_type),
@ -199,7 +200,7 @@ class HuggingFaceAPIChatGenerator:
token=self.token.to_dict() if self.token else None, token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name, streaming_callback=callback_name,
tools=serialized_tools, tools=serialize_tools_or_toolset(self.tools),
) )
@classmethod @classmethod
@ -220,7 +221,7 @@ class HuggingFaceAPIChatGenerator:
self, self,
messages: List[ChatMessage], messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None, tools: Optional[Union[List[Tool], Toolset]] = None,
streaming_callback: Optional[StreamingCallbackT] = None, streaming_callback: Optional[StreamingCallbackT] = None,
): ):
""" """
@ -231,8 +232,9 @@ class HuggingFaceAPIChatGenerator:
:param generation_kwargs: :param generation_kwargs:
Additional keyword arguments for text generation. Additional keyword arguments for text generation.
:param tools: :param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set A list of tools or a Toolset for which the model can prepare calls. If set, it will override
during component initialization. the `tools` parameter set during component initialization. This parameter can accept either a
list of `Tool` objects or a `Toolset` instance.
:param streaming_callback: :param streaming_callback:
An optional callable for handling streaming responses. If set, it will override the `streaming_callback` An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
parameter set during component initialization. parameter set during component initialization.
@ -248,7 +250,7 @@ class HuggingFaceAPIChatGenerator:
tools = tools or self.tools tools = tools or self.tools
if tools and self.streaming_callback: if tools and self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools) _check_duplicate_tool_names(list(tools or []))
# validate and select the streaming callback # validate and select the streaming callback
streaming_callback = select_streaming_callback( streaming_callback = select_streaming_callback(
@ -260,6 +262,8 @@ class HuggingFaceAPIChatGenerator:
hf_tools = None hf_tools = None
if tools: if tools:
if isinstance(tools, Toolset):
tools = list(tools)
hf_tools = [ hf_tools = [
ChatCompletionInputTool( ChatCompletionInputTool(
function=ChatCompletionInputFunctionDefinition( function=ChatCompletionInputFunctionDefinition(
@ -276,7 +280,7 @@ class HuggingFaceAPIChatGenerator:
self, self,
messages: List[ChatMessage], messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None, tools: Optional[Union[List[Tool], Toolset]] = None,
streaming_callback: Optional[StreamingCallbackT] = None, streaming_callback: Optional[StreamingCallbackT] = None,
): ):
""" """
@ -290,8 +294,9 @@ class HuggingFaceAPIChatGenerator:
:param generation_kwargs: :param generation_kwargs:
Additional keyword arguments for text generation. Additional keyword arguments for text generation.
:param tools: :param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
during component initialization. parameter set during component initialization. This parameter can accept either a list of `Tool` objects
or a `Toolset` instance.
:param streaming_callback: :param streaming_callback:
An optional callable for handling streaming responses. If set, it will override the `streaming_callback` An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
parameter set during component initialization. parameter set during component initialization.
@ -307,7 +312,7 @@ class HuggingFaceAPIChatGenerator:
tools = tools or self.tools tools = tools or self.tools
if tools and self.streaming_callback: if tools and self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools) _check_duplicate_tool_names(list(tools or []))
# validate and select the streaming callback # validate and select the streaming callback
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True) streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
@ -317,6 +322,8 @@ class HuggingFaceAPIChatGenerator:
hf_tools = None hf_tools = None
if tools: if tools:
if isinstance(tools, Toolset):
tools = list(tools)
hf_tools = [ hf_tools = [
ChatCompletionInputTool( ChatCompletionInputTool(
function=ChatCompletionInputFunctionDefinition( function=ChatCompletionInputFunctionDefinition(

View File

@ -13,6 +13,7 @@ from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.tools.toolset import Toolset
from haystack.utils import ( from haystack.utils import (
ComponentDevice, ComponentDevice,
Secret, Secret,
@ -20,6 +21,7 @@ from haystack.utils import (
deserialize_secrets_inplace, deserialize_secrets_inplace,
serialize_callable, serialize_callable,
) )
from haystack.utils.misc import serialize_tools_or_toolset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -123,7 +125,7 @@ class HuggingFaceLocalChatGenerator:
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None, stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None, tools: Optional[Union[List[Tool], Toolset]] = None,
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None, tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
async_executor: Optional[ThreadPoolExecutor] = None, async_executor: Optional[ThreadPoolExecutor] = None,
): ):
@ -164,7 +166,8 @@ class HuggingFaceLocalChatGenerator:
For some chat models, the output includes both the new text and the original prompt. For some chat models, the output includes both the new text and the original prompt.
In these cases, make sure your prompt has no stop words. In these cases, make sure your prompt has no stop words.
:param streaming_callback: An optional callable for handling streaming responses. :param streaming_callback: An optional callable for handling streaming responses.
:param tools: A list of tools for which the model can prepare calls. :param tools: A list of tools or a Toolset for which the model can prepare calls.
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
:param tool_parsing_function: :param tool_parsing_function:
A callable that takes a string and returns a list of ToolCall objects or None. A callable that takes a string and returns a list of ToolCall objects or None.
If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern. If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
@ -176,7 +179,7 @@ class HuggingFaceLocalChatGenerator:
if tools and streaming_callback is not None: if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools) _check_duplicate_tool_names(list(tools or []))
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {} huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
generation_kwargs = generation_kwargs or {} generation_kwargs = generation_kwargs or {}
@ -273,7 +276,6 @@ class HuggingFaceLocalChatGenerator:
Dictionary with serialized data. Dictionary with serialized data.
""" """
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
serialization_dict = default_to_dict( serialization_dict = default_to_dict(
self, self,
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
@ -281,7 +283,7 @@ class HuggingFaceLocalChatGenerator:
streaming_callback=callback_name, streaming_callback=callback_name,
token=self.token.to_dict() if self.token else None, token=self.token.to_dict() if self.token else None,
chat_template=self.chat_template, chat_template=self.chat_template,
tools=serialized_tools, tools=serialize_tools_or_toolset(self.tools),
tool_parsing_function=serialize_callable(self.tool_parsing_function), tool_parsing_function=serialize_callable(self.tool_parsing_function),
) )
@ -323,7 +325,7 @@ class HuggingFaceLocalChatGenerator:
messages: List[ChatMessage], messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None, tools: Optional[Union[List[Tool], Toolset]] = None,
): ):
""" """
Invoke text generation inference based on the provided messages and generation parameters. Invoke text generation inference based on the provided messages and generation parameters.
@ -332,8 +334,9 @@ class HuggingFaceLocalChatGenerator:
:param generation_kwargs: Additional keyword arguments for text generation. :param generation_kwargs: Additional keyword arguments for text generation.
:param streaming_callback: An optional callable for handling streaming responses. :param streaming_callback: An optional callable for handling streaming responses.
:param tools: :param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter A list of tools or a Toolset for which the model can prepare calls. If set, it will override
provided during initialization. the `tools` parameter provided during initialization. This parameter can accept either a list
of `Tool` objects or a `Toolset` instance.
:returns: :returns:
A list containing the generated responses as ChatMessage instances. A list containing the generated responses as ChatMessage instances.
""" """
@ -377,6 +380,10 @@ class HuggingFaceLocalChatGenerator:
# convert messages to HF format # convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages] hf_messages = [convert_message_to_hf_format(message) for message in messages]
if isinstance(tools, Toolset):
tools = list(tools)
prepared_prompt = tokenizer.apply_chat_template( prepared_prompt = tokenizer.apply_chat_template(
hf_messages, hf_messages,
tokenize=False, tokenize=False,
@ -480,7 +487,7 @@ class HuggingFaceLocalChatGenerator:
messages: List[ChatMessage], messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None, tools: Optional[Union[List[Tool], Toolset]] = None,
): ):
""" """
Asynchronously invokes text generation inference based on the provided messages and generation parameters. Asynchronously invokes text generation inference based on the provided messages and generation parameters.
@ -491,7 +498,8 @@ class HuggingFaceLocalChatGenerator:
:param messages: A list of ChatMessage objects representing the input messages. :param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation. :param generation_kwargs: Additional keyword arguments for text generation.
:param streaming_callback: An optional callable for handling streaming responses. :param streaming_callback: An optional callable for handling streaming responses.
:param tools: A list of tools for which the model can prepare calls. :param tools: A list of tools or a Toolset for which the model can prepare calls.
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
:returns: A dictionary with the following keys: :returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage instances. - `replies`: A list containing the generated responses as ChatMessage instances.
""" """
@ -576,13 +584,17 @@ class HuggingFaceLocalChatGenerator:
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"], tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
generation_kwargs: Dict[str, Any], generation_kwargs: Dict[str, Any],
stop_words: Optional[List[str]], stop_words: Optional[List[str]],
tools: Optional[List[Tool]] = None, tools: Optional[Union[List[Tool], Toolset]] = None,
): ):
""" """
Handles async non-streaming generation of responses. Handles async non-streaming generation of responses.
""" """
# convert messages to HF format # convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages] hf_messages = [convert_message_to_hf_format(message) for message in messages]
if isinstance(tools, Toolset):
tools = list(tools)
prepared_prompt = tokenizer.apply_chat_template( prepared_prompt = tokenizer.apply_chat_template(
hf_messages, hf_messages,
tokenize=False, tokenize=False,

View File

@ -0,0 +1,4 @@
---
features:
- |
Add support for initializing chat generators with a Toolset, allowing for more flexible tool management. The tools parameter can now accept either a list of Tool objects or a Toolset instance.

View File

@ -11,10 +11,16 @@ from haystack.components.generators.chat import AzureOpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, ToolCall from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools.tool import Tool from haystack.tools.tool import Tool
from haystack.tools.toolset import Toolset
from haystack.utils.auth import Secret from haystack.utils.auth import Secret
from haystack.utils.azure import default_azure_ad_token_provider from haystack.utils.azure import default_azure_ad_token_provider
def get_weather(city: str) -> str:
"""Get weather information for a city."""
return f"Weather info for {city}"
@pytest.fixture @pytest.fixture
def tools(): def tools():
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
@ -22,7 +28,7 @@ def tools():
name="weather", name="weather",
description="useful to determine the weather in a given location", description="useful to determine the weather in a given location",
parameters=tool_parameters, parameters=tool_parameters,
function=lambda x: x, function=get_weather,
) )
return [tool] return [tool]
@ -228,6 +234,26 @@ class TestAzureOpenAIChatGenerator:
q = Pipeline.loads(p_str) q = Pipeline.loads(p_str)
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed." assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."
def test_azure_chat_generator_with_toolset_initialization(self, tools, monkeypatch):
"""Test that the AzureOpenAIChatGenerator can be initialized with a Toolset."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
toolset = Toolset(tools)
generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
assert generator.tools == toolset
def test_from_dict_with_toolset(self, tools, monkeypatch):
"""Test that the AzureOpenAIChatGenerator can be deserialized from a dictionary with a Toolset."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
toolset = Toolset(tools)
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
data = component.to_dict()
deserialized_component = AzureOpenAIChatGenerator.from_dict(data)
assert isinstance(deserialized_component.tools, Toolset)
assert len(deserialized_component.tools) == len(tools)
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.skipif( @pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None), not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
@ -272,7 +298,37 @@ class TestAzureOpenAIChatGenerator:
assert tool_call.arguments == {"city": "Paris"} assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls" assert message.meta["finish_reason"] == "tool_calls"
# additional tests intentionally omitted as they are covered by test_openai.py def test_to_dict_with_toolset(self, tools, monkeypatch):
"""Test that the AzureOpenAIChatGenerator can be serialized to a dictionary with a Toolset."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
toolset = Toolset(tools)
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
data = component.to_dict()
expected_tools_data = {
"type": "haystack.tools.toolset.Toolset",
"data": {
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather",
"description": "useful to determine the weather in a given location",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"function": "generators.chat.test_azure.get_weather",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
}
]
},
}
assert data["init_parameters"]["tools"] == expected_tools_data
class TestAzureOpenAIChatGeneratorAsync: class TestAzureOpenAIChatGeneratorAsync:

View File

@ -26,6 +26,7 @@ from huggingface_hub.utils import RepositoryNotFoundError
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator
from haystack.tools import Tool from haystack.tools import Tool
from haystack.dataclasses import ChatMessage, ToolCall from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools.toolset import Toolset
@pytest.fixture @pytest.fixture
@ -36,6 +37,11 @@ def chat_messages():
] ]
def get_weather(city: str) -> str:
"""Get weather information for a city."""
return f"Weather info for {city}"
@pytest.fixture @pytest.fixture
def tools(): def tools():
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
@ -43,7 +49,7 @@ def tools():
name="weather", name="weather",
description="useful to determine the weather in a given location", description="useful to determine the weather in a given location",
parameters=tool_parameters, parameters=tool_parameters,
function=lambda x: x, function=get_weather,
) )
return [tool] return [tool]
@ -851,3 +857,58 @@ class TestHuggingFaceAPIChatGenerator:
assert "completion_tokens" in response["replies"][0].meta["usage"] assert "completion_tokens" in response["replies"][0].meta["usage"]
finally: finally:
await generator._async_client.close() await generator._async_client.close()
def test_hugging_face_api_generator_with_toolset_initialization(self, mock_check_valid_model, tools):
"""Test that the HuggingFaceAPIChatGenerator can be initialized with a Toolset."""
toolset = Toolset(tools)
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
)
assert generator.tools == toolset
def test_from_dict_with_toolset(self, mock_check_valid_model, tools):
"""Test that the HuggingFaceAPIChatGenerator can be deserialized from a dictionary with a Toolset."""
toolset = Toolset(tools)
component = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
)
data = component.to_dict()
deserialized_component = HuggingFaceAPIChatGenerator.from_dict(data)
assert isinstance(deserialized_component.tools, Toolset)
assert len(deserialized_component.tools) == len(tools)
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
def test_to_dict_with_toolset(self, mock_check_valid_model, tools):
"""Test that the HuggingFaceAPIChatGenerator can be serialized to a dictionary with a Toolset."""
toolset = Toolset(tools)
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
)
data = generator.to_dict()
expected_tools_data = {
"type": "haystack.tools.toolset.Toolset",
"data": {
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather",
"description": "useful to determine the weather in a given location",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"function": "generators.chat.test_hugging_face_api.get_weather",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
}
]
},
}
assert data["init_parameters"]["tools"] == expected_tools_data

View File

@ -16,6 +16,7 @@ from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import Tool from haystack.tools import Tool
from haystack.utils import ComponentDevice from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret from haystack.utils.auth import Secret
from haystack.tools.toolset import Toolset
# used to test serialization of streaming_callback # used to test serialization of streaming_callback
@ -94,7 +95,7 @@ class TestHuggingFaceLocalChatGenerator:
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}} assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
assert generator.streaming_callback == streaming_callback assert generator.streaming_callback == streaming_callback
def test_init_custom_token(self): def test_init_custom_token(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator( generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2", model="mistralai/Mistral-7B-Instruct-v0.2",
task="text2text-generation", task="text2text-generation",
@ -109,7 +110,7 @@ class TestHuggingFaceLocalChatGenerator:
"device": "cpu", "device": "cpu",
} }
def test_init_custom_device(self): def test_init_custom_device(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator( generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2", model="mistralai/Mistral-7B-Instruct-v0.2",
task="text2text-generation", task="text2text-generation",
@ -124,7 +125,7 @@ class TestHuggingFaceLocalChatGenerator:
"device": "cpu", "device": "cpu",
} }
def test_init_task_parameter(self): def test_init_task_parameter(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator( generator = HuggingFaceLocalChatGenerator(
task="text2text-generation", device=ComponentDevice.from_str("cpu"), token=None task="text2text-generation", device=ComponentDevice.from_str("cpu"), token=None
) )
@ -136,7 +137,7 @@ class TestHuggingFaceLocalChatGenerator:
"device": "cpu", "device": "cpu",
} }
def test_init_task_in_huggingface_pipeline_kwargs(self): def test_init_task_in_huggingface_pipeline_kwargs(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator( generator = HuggingFaceLocalChatGenerator(
huggingface_pipeline_kwargs={"task": "text2text-generation"}, huggingface_pipeline_kwargs={"task": "text2text-generation"},
device=ComponentDevice.from_str("cpu"), device=ComponentDevice.from_str("cpu"),
@ -553,3 +554,56 @@ class TestHuggingFaceLocalChatGenerator:
del generator del generator
gc.collect() gc.collect()
mock_shutdown.assert_called_once_with(wait=True) mock_shutdown.assert_called_once_with(wait=True)
def test_hugging_face_local_generator_with_toolset_initialization(
self, model_info_mock, mock_pipeline_tokenizer, tools
):
"""Test that the HuggingFaceLocalChatGenerator can be initialized with a Toolset."""
toolset = Toolset(tools)
generator = HuggingFaceLocalChatGenerator(model="irrelevant", tools=toolset)
generator.pipeline = mock_pipeline_tokenizer
assert generator.tools == toolset
def test_from_dict_with_toolset(self, model_info_mock, tools):
"""Test that the HuggingFaceLocalChatGenerator can be deserialized from a dictionary with a Toolset."""
toolset = Toolset(tools)
component = HuggingFaceLocalChatGenerator(model="irrelevant", tools=toolset)
data = component.to_dict()
deserialized_component = HuggingFaceLocalChatGenerator.from_dict(data)
assert isinstance(deserialized_component.tools, Toolset)
assert len(deserialized_component.tools) == len(tools)
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_tokenizer, tools):
"""Test that the HuggingFaceLocalChatGenerator can be serialized to a dictionary with a Toolset."""
toolset = Toolset(tools)
generator = HuggingFaceLocalChatGenerator(huggingface_pipeline_kwargs={"model": "irrelevant"}, tools=toolset)
generator.pipeline = mock_pipeline_tokenizer
data = generator.to_dict()
expected_tools_data = {
"type": "haystack.tools.toolset.Toolset",
"data": {
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather",
"description": "useful to determine the weather in a given location",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"function": "generators.chat.test_hugging_face_local.get_weather",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
}
]
},
}
assert data["init_parameters"]["tools"] == expected_tools_data