feat: Add Toolset support in ChatGenerator(s) (#9177)

* Add Toolset support in ChatGenerator(s)

* Add reno note

* Update azure test

* Updates

* Minor fix

* Add more tests

* Remove some integration tests

* PR feedback

* rm unused fixture

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2025-04-07 16:12:09 +02:00 committed by GitHub
parent 63781afd8f
commit e1e797206d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 233 additions and 36 deletions

View File

@ -11,7 +11,9 @@ from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.tools.toolset import Toolset
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.misc import serialize_tools_or_toolset
@component
@ -73,7 +75,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
max_retries: Optional[int] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
default_headers: Optional[Dict[str, str]] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools_strict: bool = False,
*,
azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
@ -115,7 +117,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
values are the bias to add to that token.
:param default_headers: Default headers to use for the AzureOpenAI client.
:param tools:
A list of tools for which the model can prepare calls.
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
list of `Tool` objects or a `Toolset` instance.
:param tools_strict:
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
@ -152,7 +155,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider
_check_duplicate_tool_names(tools)
_check_duplicate_tool_names(list(tools or []))
self.tools = tools
self.tools_strict = tools_strict
@ -196,7 +199,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
default_headers=self.default_headers,
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
tools=serialize_tools_or_toolset(self.tools),
tools_strict=self.tools_strict,
azure_ad_token_provider=azure_ad_token_provider_name,
)

View File

@ -10,8 +10,10 @@ from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_s
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.lazy_imports import LazyImport
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.tools.toolset import Toolset
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
from haystack.utils.misc import serialize_tools_or_toolset
from haystack.utils.url_validation import is_valid_http_url
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
@ -103,7 +105,7 @@ class HuggingFaceAPIChatGenerator:
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
):
"""
Initialize the HuggingFaceAPIChatGenerator instance.
@ -130,10 +132,10 @@ class HuggingFaceAPIChatGenerator:
:param streaming_callback:
An optional callable for handling streaming responses.
:param tools:
A list of tools for which the model can prepare calls.
A list of tools or a Toolset for which the model can prepare calls.
The chosen model should support tool/function calling, according to the model card.
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
unexpected behavior.
unexpected behavior. This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
"""
huggingface_hub_import.check()
@ -166,7 +168,7 @@ class HuggingFaceAPIChatGenerator:
if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
_check_duplicate_tool_names(list(tools or []))
# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
@ -191,7 +193,6 @@ class HuggingFaceAPIChatGenerator:
A dictionary containing the serialized component.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
return default_to_dict(
self,
api_type=str(self.api_type),
@ -199,7 +200,7 @@ class HuggingFaceAPIChatGenerator:
token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
tools=serialized_tools,
tools=serialize_tools_or_toolset(self.tools),
)
@classmethod
@ -220,7 +221,7 @@ class HuggingFaceAPIChatGenerator:
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
):
"""
@ -231,8 +232,9 @@ class HuggingFaceAPIChatGenerator:
:param generation_kwargs:
Additional keyword arguments for text generation.
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
A list of tools or a Toolset for which the model can prepare calls. If set, it will override
the `tools` parameter set during component initialization. This parameter can accept either a
list of `Tool` objects or a `Toolset` instance.
:param streaming_callback:
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
parameter set during component initialization.
@ -248,7 +250,7 @@ class HuggingFaceAPIChatGenerator:
tools = tools or self.tools
if tools and self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
_check_duplicate_tool_names(list(tools or []))
# validate and select the streaming callback
streaming_callback = select_streaming_callback(
@ -260,6 +262,8 @@ class HuggingFaceAPIChatGenerator:
hf_tools = None
if tools:
if isinstance(tools, Toolset):
tools = list(tools)
hf_tools = [
ChatCompletionInputTool(
function=ChatCompletionInputFunctionDefinition(
@ -276,7 +280,7 @@ class HuggingFaceAPIChatGenerator:
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
):
"""
@ -290,8 +294,9 @@ class HuggingFaceAPIChatGenerator:
:param generation_kwargs:
Additional keyword arguments for text generation.
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
parameter set during component initialization. This parameter can accept either a list of `Tool` objects
or a `Toolset` instance.
:param streaming_callback:
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
parameter set during component initialization.
@ -307,7 +312,7 @@ class HuggingFaceAPIChatGenerator:
tools = tools or self.tools
if tools and self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
_check_duplicate_tool_names(list(tools or []))
# validate and select the streaming callback
streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
@ -317,6 +322,8 @@ class HuggingFaceAPIChatGenerator:
hf_tools = None
if tools:
if isinstance(tools, Toolset):
tools = list(tools)
hf_tools = [
ChatCompletionInputTool(
function=ChatCompletionInputFunctionDefinition(

View File

@ -13,6 +13,7 @@ from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
from haystack.lazy_imports import LazyImport
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.tools.toolset import Toolset
from haystack.utils import (
ComponentDevice,
Secret,
@ -20,6 +21,7 @@ from haystack.utils import (
deserialize_secrets_inplace,
serialize_callable,
)
from haystack.utils.misc import serialize_tools_or_toolset
logger = logging.getLogger(__name__)
@ -123,7 +125,7 @@ class HuggingFaceLocalChatGenerator:
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
async_executor: Optional[ThreadPoolExecutor] = None,
):
@ -164,7 +166,8 @@ class HuggingFaceLocalChatGenerator:
For some chat models, the output includes both the new text and the original prompt.
In these cases, make sure your prompt has no stop words.
:param streaming_callback: An optional callable for handling streaming responses.
:param tools: A list of tools for which the model can prepare calls.
:param tools: A list of tools or a Toolset for which the model can prepare calls.
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
:param tool_parsing_function:
A callable that takes a string and returns a list of ToolCall objects or None.
If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
@ -176,7 +179,7 @@ class HuggingFaceLocalChatGenerator:
if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
_check_duplicate_tool_names(list(tools or []))
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
generation_kwargs = generation_kwargs or {}
@ -273,7 +276,6 @@ class HuggingFaceLocalChatGenerator:
Dictionary with serialized data.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
serialization_dict = default_to_dict(
self,
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
@ -281,7 +283,7 @@ class HuggingFaceLocalChatGenerator:
streaming_callback=callback_name,
token=self.token.to_dict() if self.token else None,
chat_template=self.chat_template,
tools=serialized_tools,
tools=serialize_tools_or_toolset(self.tools),
tool_parsing_function=serialize_callable(self.tool_parsing_function),
)
@ -323,7 +325,7 @@ class HuggingFaceLocalChatGenerator:
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
):
"""
Invoke text generation inference based on the provided messages and generation parameters.
@ -332,8 +334,9 @@ class HuggingFaceLocalChatGenerator:
:param generation_kwargs: Additional keyword arguments for text generation.
:param streaming_callback: An optional callable for handling streaming responses.
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter
provided during initialization.
A list of tools or a Toolset for which the model can prepare calls. If set, it will override
the `tools` parameter provided during initialization. This parameter can accept either a list
of `Tool` objects or a `Toolset` instance.
:returns:
A list containing the generated responses as ChatMessage instances.
"""
@ -377,6 +380,10 @@ class HuggingFaceLocalChatGenerator:
# convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages]
if isinstance(tools, Toolset):
tools = list(tools)
prepared_prompt = tokenizer.apply_chat_template(
hf_messages,
tokenize=False,
@ -480,7 +487,7 @@ class HuggingFaceLocalChatGenerator:
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
):
"""
Asynchronously invokes text generation inference based on the provided messages and generation parameters.
@ -491,7 +498,8 @@ class HuggingFaceLocalChatGenerator:
:param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:param streaming_callback: An optional callable for handling streaming responses.
:param tools: A list of tools for which the model can prepare calls.
:param tools: A list of tools or a Toolset for which the model can prepare calls.
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
@ -576,13 +584,17 @@ class HuggingFaceLocalChatGenerator:
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
generation_kwargs: Dict[str, Any],
stop_words: Optional[List[str]],
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
):
"""
Handles async non-streaming generation of responses.
"""
# convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages]
if isinstance(tools, Toolset):
tools = list(tools)
prepared_prompt = tokenizer.apply_chat_template(
hf_messages,
tokenize=False,

View File

@ -0,0 +1,4 @@
---
features:
- |
Add support for initializing chat generators with a Toolset, allowing for more flexible tool management. The tools parameter can now accept either a list of Tool objects or a Toolset instance.

View File

@ -11,10 +11,16 @@ from haystack.components.generators.chat import AzureOpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools.tool import Tool
from haystack.tools.toolset import Toolset
from haystack.utils.auth import Secret
from haystack.utils.azure import default_azure_ad_token_provider
def get_weather(city: str) -> str:
"""Get weather information for a city."""
return f"Weather info for {city}"
@pytest.fixture
def tools():
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
@ -22,7 +28,7 @@ def tools():
name="weather",
description="useful to determine the weather in a given location",
parameters=tool_parameters,
function=lambda x: x,
function=get_weather,
)
return [tool]
@ -228,6 +234,26 @@ class TestAzureOpenAIChatGenerator:
q = Pipeline.loads(p_str)
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."
def test_azure_chat_generator_with_toolset_initialization(self, tools, monkeypatch):
"""Test that the AzureOpenAIChatGenerator can be initialized with a Toolset."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
toolset = Toolset(tools)
generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
assert generator.tools == toolset
def test_from_dict_with_toolset(self, tools, monkeypatch):
"""Test that the AzureOpenAIChatGenerator can be deserialized from a dictionary with a Toolset."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
toolset = Toolset(tools)
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
data = component.to_dict()
deserialized_component = AzureOpenAIChatGenerator.from_dict(data)
assert isinstance(deserialized_component.tools, Toolset)
assert len(deserialized_component.tools) == len(tools)
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) or not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
@ -272,7 +298,37 @@ class TestAzureOpenAIChatGenerator:
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"
# additional tests intentionally omitted as they are covered by test_openai.py
def test_to_dict_with_toolset(self, tools, monkeypatch):
"""Test that the AzureOpenAIChatGenerator can be serialized to a dictionary with a Toolset."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
toolset = Toolset(tools)
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=toolset)
data = component.to_dict()
expected_tools_data = {
"type": "haystack.tools.toolset.Toolset",
"data": {
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather",
"description": "useful to determine the weather in a given location",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"function": "generators.chat.test_azure.get_weather",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
}
]
},
}
assert data["init_parameters"]["tools"] == expected_tools_data
class TestAzureOpenAIChatGeneratorAsync:

View File

@ -26,6 +26,7 @@ from huggingface_hub.utils import RepositoryNotFoundError
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator
from haystack.tools import Tool
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools.toolset import Toolset
@pytest.fixture
@ -36,6 +37,11 @@ def chat_messages():
]
def get_weather(city: str) -> str:
"""Get weather information for a city."""
return f"Weather info for {city}"
@pytest.fixture
def tools():
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
@ -43,7 +49,7 @@ def tools():
name="weather",
description="useful to determine the weather in a given location",
parameters=tool_parameters,
function=lambda x: x,
function=get_weather,
)
return [tool]
@ -851,3 +857,58 @@ class TestHuggingFaceAPIChatGenerator:
assert "completion_tokens" in response["replies"][0].meta["usage"]
finally:
await generator._async_client.close()
def test_hugging_face_api_generator_with_toolset_initialization(self, mock_check_valid_model, tools):
"""Test that the HuggingFaceAPIChatGenerator can be initialized with a Toolset."""
toolset = Toolset(tools)
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
)
assert generator.tools == toolset
def test_from_dict_with_toolset(self, mock_check_valid_model, tools):
"""Test that the HuggingFaceAPIChatGenerator can be deserialized from a dictionary with a Toolset."""
toolset = Toolset(tools)
component = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
)
data = component.to_dict()
deserialized_component = HuggingFaceAPIChatGenerator.from_dict(data)
assert isinstance(deserialized_component.tools, Toolset)
assert len(deserialized_component.tools) == len(tools)
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
def test_to_dict_with_toolset(self, mock_check_valid_model, tools):
"""Test that the HuggingFaceAPIChatGenerator can be serialized to a dictionary with a Toolset."""
toolset = Toolset(tools)
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "irrelevant"}, tools=toolset
)
data = generator.to_dict()
expected_tools_data = {
"type": "haystack.tools.toolset.Toolset",
"data": {
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather",
"description": "useful to determine the weather in a given location",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"function": "generators.chat.test_hugging_face_api.get_weather",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
}
]
},
}
assert data["init_parameters"]["tools"] == expected_tools_data

View File

@ -16,6 +16,7 @@ from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.tools import Tool
from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret
from haystack.tools.toolset import Toolset
# used to test serialization of streaming_callback
@ -94,7 +95,7 @@ class TestHuggingFaceLocalChatGenerator:
assert generator.generation_kwargs == {**generation_kwargs, **{"stop_sequences": ["stop"]}}
assert generator.streaming_callback == streaming_callback
def test_init_custom_token(self):
def test_init_custom_token(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2",
task="text2text-generation",
@ -109,7 +110,7 @@ class TestHuggingFaceLocalChatGenerator:
"device": "cpu",
}
def test_init_custom_device(self):
def test_init_custom_device(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2",
task="text2text-generation",
@ -124,7 +125,7 @@ class TestHuggingFaceLocalChatGenerator:
"device": "cpu",
}
def test_init_task_parameter(self):
def test_init_task_parameter(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator(
task="text2text-generation", device=ComponentDevice.from_str("cpu"), token=None
)
@ -136,7 +137,7 @@ class TestHuggingFaceLocalChatGenerator:
"device": "cpu",
}
def test_init_task_in_huggingface_pipeline_kwargs(self):
def test_init_task_in_huggingface_pipeline_kwargs(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator(
huggingface_pipeline_kwargs={"task": "text2text-generation"},
device=ComponentDevice.from_str("cpu"),
@ -553,3 +554,56 @@ class TestHuggingFaceLocalChatGenerator:
del generator
gc.collect()
mock_shutdown.assert_called_once_with(wait=True)
def test_hugging_face_local_generator_with_toolset_initialization(
self, model_info_mock, mock_pipeline_tokenizer, tools
):
"""Test that the HuggingFaceLocalChatGenerator can be initialized with a Toolset."""
toolset = Toolset(tools)
generator = HuggingFaceLocalChatGenerator(model="irrelevant", tools=toolset)
generator.pipeline = mock_pipeline_tokenizer
assert generator.tools == toolset
def test_from_dict_with_toolset(self, model_info_mock, tools):
"""Test that the HuggingFaceLocalChatGenerator can be deserialized from a dictionary with a Toolset."""
toolset = Toolset(tools)
component = HuggingFaceLocalChatGenerator(model="irrelevant", tools=toolset)
data = component.to_dict()
deserialized_component = HuggingFaceLocalChatGenerator.from_dict(data)
assert isinstance(deserialized_component.tools, Toolset)
assert len(deserialized_component.tools) == len(tools)
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
def test_to_dict_with_toolset(self, model_info_mock, mock_pipeline_tokenizer, tools):
"""Test that the HuggingFaceLocalChatGenerator can be serialized to a dictionary with a Toolset."""
toolset = Toolset(tools)
generator = HuggingFaceLocalChatGenerator(huggingface_pipeline_kwargs={"model": "irrelevant"}, tools=toolset)
generator.pipeline = mock_pipeline_tokenizer
data = generator.to_dict()
expected_tools_data = {
"type": "haystack.tools.toolset.Toolset",
"data": {
"tools": [
{
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather",
"description": "useful to determine the weather in a given location",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"function": "generators.chat.test_hugging_face_local.get_weather",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": None,
},
}
]
},
}
assert data["init_parameters"]["tools"] == expected_tools_data