feat: Add Toolset to tooling architecture (#9161)

* Add Toolset abstraction

* Add reno note

* More pydoc improvements

* Update test

* Simplify, Toolset is a dataclass

* Wrap toolset instance with list

* Add example

* Toolset pydoc serde enhancement

* Toolset as init param

* Fix types

* Linting

* Minor updates

* PR feedback

* Add to pydoc config, minor import fixes

* Improve pydoc example

* Improve coverage for test_toolset.py

* Improve test_toolset.py, test custom toolset serde properly

* Update haystack/utils/misc.py

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* Rework Toolset pydoc

* Another minor pydoc improvement

* Prevent single Tool instantiating Toolset

* Reduce number of integration tests

* Remove some toolset tests from openai

* Rework tests

---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2025-04-04 10:09:46 -04:00 committed by GitHub
parent a2f73d134d
commit c81d68402c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1046 additions and 23 deletions

View File

@ -2,7 +2,7 @@ loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/tools]
modules:
["tool", "from_function", "component_tool"]
["tool", "from_function", "component_tool", "toolset"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter

View File

@ -23,7 +23,9 @@ from haystack.dataclasses import (
select_streaming_callback,
)
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.tools.toolset import Toolset
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.misc import serialize_tools_or_toolset
logger = logging.getLogger(__name__)
@ -81,7 +83,7 @@ class OpenAIChatGenerator:
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools_strict: bool = False,
):
"""
@ -127,7 +129,8 @@ class OpenAIChatGenerator:
Maximum number of retries to contact OpenAI after an internal error.
If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5.
:param tools:
A list of tools for which the model can prepare calls.
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
list of `Tool` objects or a `Toolset` instance.
:param tools_strict:
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
@ -140,10 +143,11 @@ class OpenAIChatGenerator:
self.organization = organization
self.timeout = timeout
self.max_retries = max_retries
self.tools = tools
self.tools = tools # Store tools as-is, whether it's a list or a Toolset
self.tools_strict = tools_strict
_check_duplicate_tool_names(tools)
# Check for duplicate tool names
_check_duplicate_tool_names(list(self.tools or []))
if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
@ -185,7 +189,7 @@ class OpenAIChatGenerator:
api_key=self.api_key.to_dict(),
timeout=self.timeout,
max_retries=self.max_retries,
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
tools=serialize_tools_or_toolset(self.tools),
tools_strict=self.tools_strict,
)
@ -213,7 +217,7 @@ class OpenAIChatGenerator:
streaming_callback: Optional[StreamingCallbackT] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
*,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools_strict: Optional[bool] = None,
):
"""
@ -228,8 +232,9 @@ class OpenAIChatGenerator:
override the parameters passed during component initialization.
For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create).
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
A list of tools or a Toolset for which the model can prepare calls. If set, it will override the
`tools` parameter set during component initialization. This parameter can accept either a list of
`Tool` objects or a `Toolset` instance.
:param tools_strict:
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
@ -285,7 +290,7 @@ class OpenAIChatGenerator:
streaming_callback: Optional[StreamingCallbackT] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
*,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools_strict: Optional[bool] = None,
):
"""
@ -304,8 +309,9 @@ class OpenAIChatGenerator:
override the parameters passed during component initialization.
For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create).
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
A list of tools or a Toolset for which the model can prepare calls. If set, it will override the
`tools` parameter set during component initialization. This parameter can accept either a list of
`Tool` objects or a `Toolset` instance.
:param tools_strict:
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
@ -362,7 +368,7 @@ class OpenAIChatGenerator:
messages: List[ChatMessage],
streaming_callback: Optional[StreamingCallbackT] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools_strict: Optional[bool] = None,
) -> Dict[str, Any]:
# update generation kwargs by merging with the generation kwargs passed to the run method
@ -372,6 +378,8 @@ class OpenAIChatGenerator:
openai_formatted_messages = [message.to_openai_dict_format() for message in messages]
tools = tools or self.tools
if isinstance(tools, Toolset):
tools = list(tools)
tools_strict = tools_strict if tools_strict is not None else self.tools_strict
_check_duplicate_tool_names(tools)

View File

@ -4,13 +4,15 @@
import inspect
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.core.component.sockets import Sockets
from haystack.dataclasses import ChatMessage, State, ToolCall
from haystack.tools.component_tool import ComponentTool
from haystack.tools.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.tools.toolset import Toolset
from haystack.utils.misc import serialize_tools_or_toolset
logger = logging.getLogger(__name__)
@ -57,7 +59,8 @@ class ToolInvoker:
Usage example:
```python
from haystack.dataclasses import ChatMessage, ToolCall, Tool
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import Tool
from haystack.components.tools import ToolInvoker
# Tool definition
@ -108,14 +111,55 @@ class ToolInvoker:
>> ]
>> }
```
Usage example with a Toolset:
```python
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import Tool, Toolset
from haystack.components.tools import ToolInvoker
# Tool definition
def dummy_weather_function(city: str):
return f"The weather in {city} is 20 degrees."
parameters = {"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"]}
tool = Tool(name="weather_tool",
description="A tool to get the weather",
function=dummy_weather_function,
parameters=parameters)
# Create a Toolset
toolset = Toolset([tool])
# Usually, the ChatMessage with tool_calls is generated by a Language Model
# Here, we create it manually for demonstration purposes
tool_call = ToolCall(
tool_name="weather_tool",
arguments={"city": "Berlin"}
)
message = ChatMessage.from_assistant(tool_calls=[tool_call])
# ToolInvoker initialization and run with Toolset
invoker = ToolInvoker(tools=toolset)
result = invoker.run(messages=[message])
print(result)
"""
def __init__(self, tools: List[Tool], raise_on_failure: bool = True, convert_result_to_json_string: bool = False):
def __init__(
self,
tools: Union[List[Tool], Toolset],
raise_on_failure: bool = True,
convert_result_to_json_string: bool = False,
):
"""
Initialize the ToolInvoker component.
:param tools:
A list of tools that can be invoked.
A list of tools that can be invoked or a Toolset instance that can resolve tools.
:param raise_on_failure:
If True, the component will raise an exception in case of errors
(tool not found, tool invocation errors, tool result conversion errors).
@ -129,13 +173,20 @@ class ToolInvoker:
"""
if not tools:
raise ValueError("ToolInvoker requires at least one tool.")
# could be a Toolset instance or a list of Tools
self.tools = tools
# Convert Toolset to list for internal use
if isinstance(tools, Toolset):
tools = list(tools)
_check_duplicate_tool_names(tools)
tool_names = [tool.name for tool in tools]
duplicates = {name for name in tool_names if tool_names.count(name) > 1}
if duplicates:
raise ValueError(f"Duplicate tool names found: {duplicates}")
self.tools = tools
self._tools_with_names = dict(zip(tool_names, tools))
self.raise_on_failure = raise_on_failure
self.convert_result_to_json_string = convert_result_to_json_string
@ -385,10 +436,9 @@ class ToolInvoker:
:returns:
Dictionary with serialized data.
"""
serialized_tools = [tool.to_dict() for tool in self.tools]
return default_to_dict(
self,
tools=serialized_tools,
tools=serialize_tools_or_toolset(self.tools),
raise_on_failure=self.raise_on_failure,
convert_result_to_json_string=self.convert_result_to_json_string,
)

View File

@ -8,6 +8,7 @@
from .from_function import create_tool_from_function, tool
from .tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from .component_tool import ComponentTool
from .toolset import Toolset
__all__ = [
@ -17,4 +18,5 @@ __all__ = [
"create_tool_from_function",
"tool",
"ComponentTool",
"Toolset",
]

View File

@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional
from jsonschema import Draft202012Validator
from jsonschema.exceptions import SchemaError
from haystack.core.errors import DeserializationError
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
from haystack.tools.errors import ToolInvocationError
from haystack.utils import deserialize_callable, serialize_callable
@ -190,8 +191,23 @@ def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"):
if serialized_tools is None:
return
# Check if it's a serialized Toolset (a dict with "type" and "data" keys)
if isinstance(serialized_tools, dict) and all(k in serialized_tools for k in ["type", "data"]):
toolset_class_name = serialized_tools.get("type")
if not toolset_class_name:
raise DeserializationError("The 'type' key is missing or None in the serialized toolset data")
toolset_class = import_class_by_name(toolset_class_name)
from haystack.tools.toolset import Toolset # avoid circular import
if not issubclass(toolset_class, Toolset):
raise TypeError(f"Class '{toolset_class}' is not a subclass of Toolset")
data[key] = toolset_class.from_dict(serialized_tools)
return
if not isinstance(serialized_tools, list):
raise TypeError(f"The value of '{key}' is not a list")
raise TypeError(f"The value of '{key}' is not a list or a dictionary")
deserialized_tools = []
for tool in serialized_tools:

291
haystack/tools/toolset.py Normal file
View File

@ -0,0 +1,291 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Union
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
from haystack.tools.tool import Tool, _check_duplicate_tool_names
@dataclass
class Toolset:
"""
A collection of related Tools that can be used and managed as a cohesive unit.
Toolset serves two main purposes:
1. Group related tools together:
Toolset allows you to organize related tools into a single collection, making it easier
to manage and use them as a unit in Haystack pipelines.
Example:
```python
from haystack.tools import Tool, Toolset
from haystack.components.tools import ToolInvoker
# Define math functions
def add_numbers(a: int, b: int) -> int:
return a + b
def subtract_numbers(a: int, b: int) -> int:
return a - b
# Create tools with proper schemas
add_tool = Tool(
name="add",
description="Add two numbers",
parameters={
"type": "object",
"properties": {
"a": {"type": "integer"},
"b": {"type": "integer"}
},
"required": ["a", "b"]
},
function=add_numbers
)
subtract_tool = Tool(
name="subtract",
description="Subtract b from a",
parameters={
"type": "object",
"properties": {
"a": {"type": "integer"},
"b": {"type": "integer"}
},
"required": ["a", "b"]
},
function=subtract_numbers
)
# Create a toolset with the math tools
math_toolset = Toolset([add_tool, subtract_tool])
# Use the toolset with a ToolInvoker or ChatGenerator component
invoker = ToolInvoker(tools=math_toolset)
```
2. Base class for dynamic tool loading:
By subclassing Toolset, you can create implementations that dynamically load tools
from external sources like OpenAPI URLs, MCP servers, or other resources.
Example:
```python
from haystack.core.serialization import generate_qualified_class_name
from haystack.tools import Tool, Toolset
from haystack.components.tools import ToolInvoker
class CalculatorToolset(Toolset):
'''A toolset for calculator operations.'''
def __init__(self):
tools = self._create_tools()
super().__init__(tools)
def _create_tools(self):
# These Tool instances are obviously defined statically and for illustration purposes only.
# In a real-world scenario, you would dynamically load tools from an external source here.
tools = []
add_tool = Tool(
name="add",
description="Add two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=lambda a, b: a + b,
)
multiply_tool = Tool(
name="multiply",
description="Multiply two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=lambda a, b: a * b,
)
tools.append(add_tool)
tools.append(multiply_tool)
return tools
def to_dict(self):
return {
"type": generate_qualified_class_name(type(self)),
"data": {}, # no data to serialize as we define the tools dynamically
}
@classmethod
def from_dict(cls, data):
return cls() # Recreate the tools dynamically during deserialization
# Create the dynamic toolset and use it with ToolInvoker
calculator_toolset = CalculatorToolset()
invoker = ToolInvoker(tools=calculator_toolset)
```
Toolset implements the collection interface (__iter__, __contains__, __len__, __getitem__),
making it behave like a list of Tools. This makes it compatible with components that expect
iterable tools, such as ToolInvoker or Haystack chat generators.
When implementing a custom Toolset subclass for dynamic tool loading:
- Perform the dynamic loading in the __init__ method
- Override to_dict() and from_dict() methods if your tools are defined dynamically
- Serialize endpoint descriptors rather than tool instances if your tools
are loaded from external sources
"""
# Use field() with default_factory to initialize the list
tools: List[Tool] = field(default_factory=list)
def __post_init__(self):
"""
Validate and set up the toolset after initialization.
This handles the case when tools are provided during initialization.
"""
# If initialization was done a single Tool, raise an error
if isinstance(self.tools, Tool):
raise TypeError("A single Tool cannot be directly passed to Toolset. Please use a list: Toolset([tool])")
# Check for duplicate tool names in the initial set
_check_duplicate_tool_names(self.tools)
def __iter__(self) -> Iterator[Tool]:
"""
Return an iterator over the Tools in this Toolset.
This allows the Toolset to be used wherever a list of Tools is expected.
:returns: An iterator yielding Tool instances
"""
return iter(self.tools)
def __contains__(self, item: Any) -> bool:
"""
Check if a tool is in this Toolset.
Supports checking by:
- Tool instance: tool in toolset
- Tool name: "tool_name" in toolset
:param item: Tool instance or tool name string
:returns: True if contained, False otherwise
"""
if isinstance(item, str):
return any(tool.name == item for tool in self.tools)
if isinstance(item, Tool):
return item in self.tools
return False
def add(self, tool: Union[Tool, "Toolset"]) -> None:
"""
Add a new Tool or merge another Toolset.
:param tool: A Tool instance or another Toolset to add
:raises ValueError: If adding the tool would result in duplicate tool names
:raises TypeError: If the provided object is not a Tool or Toolset
"""
new_tools = []
if isinstance(tool, Tool):
new_tools = [tool]
elif isinstance(tool, Toolset):
new_tools = list(tool)
else:
raise TypeError(f"Expected Tool or Toolset, got {type(tool).__name__}")
# Check for duplicates before adding
combined_tools = self.tools + new_tools
_check_duplicate_tool_names(combined_tools)
self.tools.extend(new_tools)
def to_dict(self) -> Dict[str, Any]:
"""
Serialize the Toolset to a dictionary.
:returns: A dictionary representation of the Toolset
Note for subclass implementers:
The default implementation is ideal for scenarios where Tool resolution is static. However, if your subclass
of Toolset dynamically resolves Tool instances from external sourcessuch as an MCP server, OpenAPI URL, or
a local OpenAPI specificationyou should consider serializing the endpoint descriptor instead of the Tool
instances themselves. This strategy preserves the dynamic nature of your Toolset and minimizes the overhead
associated with serializing potentially large collections of Tool objects. Moreover, by serializing the
descriptor, you ensure that the deserialization process can accurately reconstruct the Tool instances, even
if they have been modified or removed since the last serialization. Failing to serialize the descriptor may
lead to issues where outdated or incorrect Tool configurations are loaded, potentially causing errors or
unexpected behavior.
"""
return {
"type": generate_qualified_class_name(type(self)),
"data": {"tools": [tool.to_dict() for tool in self.tools]},
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Toolset":
"""
Deserialize a Toolset from a dictionary.
:param data: Dictionary representation of the Toolset
:returns: A new Toolset instance
"""
inner_data = data["data"]
tools_data = inner_data.get("tools", [])
tools = []
for tool_data in tools_data:
tool_class = import_class_by_name(tool_data["type"])
if not issubclass(tool_class, Tool):
raise TypeError(f"Class '{tool_class}' is not a subclass of Tool")
tools.append(tool_class.from_dict(tool_data))
return cls(tools=tools)
def __add__(self, other: Union[Tool, "Toolset", List[Tool]]) -> "Toolset":
"""
Concatenate this Toolset with another Tool, Toolset, or list of Tools.
:param other: Another Tool, Toolset, or list of Tools to concatenate
:returns: A new Toolset containing all tools
:raises TypeError: If the other parameter is not a Tool, Toolset, or list of Tools
:raises ValueError: If the combination would result in duplicate tool names
"""
if isinstance(other, Tool):
combined_tools = self.tools + [other]
elif isinstance(other, Toolset):
combined_tools = self.tools + list(other)
elif isinstance(other, list) and all(isinstance(item, Tool) for item in other):
combined_tools = self.tools + other
else:
raise TypeError(f"Cannot add {type(other).__name__} to Toolset")
# Check for duplicates
_check_duplicate_tool_names(combined_tools)
return Toolset(tools=combined_tools)
def __len__(self) -> int:
"""
Return the number of Tools in this Toolset.
:returns: Number of Tools
"""
return len(self.tools)
def __getitem__(self, index):
"""
Get a Tool by index.
:param index: Index of the Tool to get
:returns: The Tool at the specified index
"""
return self.tools[index]

View File

@ -2,10 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0
from typing import List, Union
from typing import Any, Dict, List, Union
from numpy import exp
from haystack.tools import Tool, Toolset
def expand_page_range(page_range: List[Union[str, int]]) -> List[int]:
"""
@ -52,3 +54,19 @@ def expit(x) -> float:
:param x: input value. Can be a scalar or a numpy array.
"""
return 1 / (1 + exp(-x))
def serialize_tools_or_toolset(
tools: Union[Toolset, List[Tool], None],
) -> Union[Dict[str, Any], List[Dict[str, Any]], None]:
"""
Serialize a Toolset or a list of Tools to a dictionary or a list of tool dictionaries.
:param tools: A Toolset, a list of Tools, or None
:returns: A dictionary, a list of tool dictionaries, or None if tools is None
"""
if tools is None:
return None
if isinstance(tools, Toolset):
return tools.to_dict()
return [tool.to_dict() for tool in tools]

View File

@ -0,0 +1,3 @@
features:
- |
Introduced the Toolset class, allowing for the grouping and management of related tool functionalities. This new abstraction supports dynamic tool loading and registration.

View File

@ -22,6 +22,7 @@ from haystack.utils.auth import Secret
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import Tool
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.tools.toolset import Toolset
@pytest.fixture
@ -71,6 +72,10 @@ def mock_chat_completion_chunk_with_tools(openai_mock_stream):
yield mock_chat_completion_create
def mock_tool_function(x):
return x
@pytest.fixture
def tools():
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
@ -78,7 +83,7 @@ def tools():
name="weather",
description="useful to determine the weather in a given location",
parameters=tool_parameters,
function=lambda x: x,
function=mock_tool_function,
)
return [tool]
@ -269,6 +274,7 @@ class TestOpenAIChatGenerator:
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"tools": None,
},
}
with pytest.raises(ValueError):
@ -937,3 +943,45 @@ class TestOpenAIChatGenerator:
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"
def test_openai_chat_generator_with_toolset_initialization(self, tools, monkeypatch):
"""Test that the OpenAIChatGenerator can be initialized with a Toolset."""
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
toolset = Toolset(tools)
generator = OpenAIChatGenerator(tools=toolset)
assert generator.tools == toolset
def test_from_dict_with_toolset(self, tools, monkeypatch):
"""Test that the OpenAIChatGenerator can be deserialized from a dictionary with a Toolset."""
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
toolset = Toolset(tools)
component = OpenAIChatGenerator(tools=toolset)
data = component.to_dict()
deserialized_component = OpenAIChatGenerator.from_dict(data)
assert isinstance(deserialized_component.tools, Toolset)
assert len(deserialized_component.tools) == len(tools)
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_with_toolset(self, tools):
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
toolset = Toolset(tools)
component = OpenAIChatGenerator(tools=toolset)
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message = results["replies"][0]
assert not message.texts
assert not message.text
assert message.tool_calls
tool_call = message.tool_call
assert isinstance(tool_call, ToolCall)
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"

587
test/tools/test_toolset.py Normal file
View File

@ -0,0 +1,587 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import os
from haystack import Pipeline
from haystack.core.serialization import generate_qualified_class_name
from haystack.dataclasses import ChatMessage
from haystack.components.tools import ToolInvoker
from haystack.dataclasses.chat_message import ToolCall
from haystack.tools import Tool, Toolset
from haystack.tools.errors import ToolInvocationError
# Common functions for tests
def add_numbers(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
def multiply_numbers(a: int, b: int) -> int:
"""Multiply two numbers."""
return a * b
def subtract_numbers(a: int, b: int) -> int:
"""Subtract b from a."""
return a - b
class CustomToolset(Toolset):
def __init__(self, tools, custom_attr):
super().__init__(tools)
self.custom_attr = custom_attr
def to_dict(self):
data = super().to_dict()
data["custom_attr"] = self.custom_attr
return data
@classmethod
def from_dict(cls, data):
tools = [Tool.from_dict(tool_data) for tool_data in data["data"]["tools"]]
custom_attr = data["custom_attr"]
return cls(tools=tools, custom_attr=custom_attr)
class CalculatorToolset(Toolset):
"""A toolset for calculator operations."""
def __init__(self):
super().__init__([])
self._create_tools()
def _create_tools(self):
add_tool = Tool(
name="add",
description="Add two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=add_numbers,
)
multiply_tool = Tool(
name="multiply",
description="Multiply two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=multiply_numbers,
)
self.add(add_tool)
self.add(multiply_tool)
def to_dict(self):
return {
"type": generate_qualified_class_name(type(self)),
"data": {}, # no data to serialize as we define the tools dynamically
}
@classmethod
def from_dict(cls, data):
return cls()
def weather_function(location):
weather_info = {
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
}
return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
weather_parameters = {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}
@pytest.fixture
def weather_tool():
return Tool(
name="weather_tool",
description="Provides weather information for a given location.",
parameters=weather_parameters,
function=weather_function,
)
@pytest.fixture
def faulty_tool():
def faulty_tool_func(location):
raise Exception("This tool always fails.")
faulty_tool_parameters = {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
}
return Tool(
name="faulty_tool",
description="A tool that always fails when invoked.",
parameters=faulty_tool_parameters,
function=faulty_tool_func,
)
class TestToolset:
def test_toolset_with_multiple_tools(self):
"""Test that a Toolset with multiple tools works properly."""
add_tool = Tool(
name="add",
description="Add two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=add_numbers,
)
multiply_tool = Tool(
name="multiply",
description="Multiply two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=multiply_numbers,
)
toolset = Toolset([add_tool, multiply_tool])
assert len(toolset) == 2
assert toolset[0].name == "add"
assert toolset[1].name == "multiply"
invoker = ToolInvoker(tools=toolset)
add_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3})
add_message = ChatMessage.from_assistant(tool_calls=[add_call])
multiply_call = ToolCall(tool_name="multiply", arguments={"a": 4, "b": 5})
multiply_message = ChatMessage.from_assistant(tool_calls=[multiply_call])
result = invoker.run(messages=[add_message, multiply_message])
assert len(result["tool_messages"]) == 2
tool_results = [message.tool_call_result.result for message in result["tool_messages"]]
assert "5" in tool_results
assert "20" in tool_results
def test_toolset_adding(self):
"""Test that tools can be added to a Toolset."""
toolset = Toolset()
assert len(toolset) == 0
add_tool = Tool(
name="add",
description="Add two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=add_numbers,
)
toolset.add(add_tool)
assert len(toolset) == 1
assert toolset[0].name == "add"
invoker = ToolInvoker(tools=toolset)
tool_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3})
message = ChatMessage.from_assistant(tool_calls=[tool_call])
result = invoker.run(messages=[message])
assert len(result["tool_messages"]) == 1
assert result["tool_messages"][0].tool_call_result.result == "5"
def test_toolset_addition(self):
"""Test that toolsets can be combined."""
add_tool = Tool(
name="add",
description="Add two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=add_numbers,
)
multiply_tool = Tool(
name="multiply",
description="Multiply two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=multiply_numbers,
)
subtract_tool = Tool(
name="subtract",
description="Subtract two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=subtract_numbers,
)
toolset1 = Toolset([add_tool])
toolset2 = Toolset([multiply_tool])
combined_toolset = toolset1 + toolset2
assert len(combined_toolset) == 2
combined_toolset = combined_toolset + subtract_tool
assert len(combined_toolset) == 3
tool_names = [tool.name for tool in combined_toolset]
assert "add" in tool_names
assert "multiply" in tool_names
assert "subtract" in tool_names
invoker = ToolInvoker(tools=combined_toolset)
add_call = ToolCall(tool_name="add", arguments={"a": 10, "b": 5})
multiply_call = ToolCall(tool_name="multiply", arguments={"a": 10, "b": 5})
subtract_call = ToolCall(tool_name="subtract", arguments={"a": 10, "b": 5})
message = ChatMessage.from_assistant(tool_calls=[add_call, multiply_call, subtract_call])
result = invoker.run(messages=[message])
assert len(result["tool_messages"]) == 3
tool_results = [message.tool_call_result.result for message in result["tool_messages"]]
assert "15" in tool_results
assert "50" in tool_results
assert "5" in tool_results
def test_toolset_contains(self):
"""Test that the __contains__ method works correctly."""
add_tool = Tool(
name="add",
description="Add two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=add_numbers,
)
multiply_tool = Tool(
name="multiply",
description="Multiply two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=multiply_numbers,
)
toolset = Toolset([add_tool])
# Test with a tool instance
assert add_tool in toolset
assert multiply_tool not in toolset
# Test with a tool name
assert "add" in toolset
assert "multiply" not in toolset
assert "non_existent_tool" not in toolset
def test_toolset_add_various_types(self):
"""Test that the __add__ method works with various object types."""
add_tool = Tool(
name="add",
description="Add two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=add_numbers,
)
multiply_tool = Tool(
name="multiply",
description="Multiply two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=multiply_numbers,
)
subtract_tool = Tool(
name="subtract",
description="Subtract two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=subtract_numbers,
)
# Test adding a single tool
toolset1 = Toolset([add_tool])
result1 = toolset1 + multiply_tool
assert len(result1) == 2
assert add_tool in result1
assert multiply_tool in result1
# Test adding another toolset
toolset2 = Toolset([subtract_tool])
result2 = toolset1 + toolset2
assert len(result2) == 2
assert add_tool in result2
assert subtract_tool in result2
# Test adding a list of tools
result3 = toolset1 + [multiply_tool, subtract_tool]
assert len(result3) == 3
assert add_tool in result3
assert multiply_tool in result3
assert subtract_tool in result3
# Test adding types that aren't supported
with pytest.raises(TypeError):
toolset1 + "not_a_tool"
with pytest.raises(TypeError):
toolset1 + 123
def test_toolset_serialization(self):
"""Test that a Toolset can be serialized and deserialized."""
add_tool = Tool(
name="add",
description="Add two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=add_numbers,
)
toolset = Toolset([add_tool])
serialized = toolset.to_dict()
deserialized = Toolset.from_dict(serialized)
assert len(deserialized) == 1
assert deserialized[0].name == "add"
assert deserialized[0].description == "Add two numbers"
invoker = ToolInvoker(tools=deserialized)
tool_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3})
message = ChatMessage.from_assistant(tool_calls=[tool_call])
result = invoker.run(messages=[message])
assert len(result["tool_messages"]) == 1
assert result["tool_messages"][0].tool_call_result.result == "5"
def test_custom_toolset_serialization(self):
"""Test serialization and deserialization of a custom Toolset subclass."""
add_tool = Tool(
name="add",
description="Add two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=add_numbers,
)
custom_attr_value = "custom_value"
custom_toolset = CustomToolset(tools=[add_tool], custom_attr=custom_attr_value)
serialized = custom_toolset.to_dict()
assert serialized["type"].endswith("CustomToolset")
assert serialized["custom_attr"] == custom_attr_value
assert len(serialized["data"]["tools"]) == 1
assert serialized["data"]["tools"][0]["data"]["name"] == "add"
deserialized = CustomToolset.from_dict(serialized)
assert isinstance(deserialized, CustomToolset)
assert deserialized.custom_attr == custom_attr_value
assert len(deserialized) == 1
assert deserialized[0].name == "add"
invoker = ToolInvoker(tools=deserialized)
tool_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3})
message = ChatMessage.from_assistant(tool_calls=[tool_call])
result = invoker.run(messages=[message])
assert len(result["tool_messages"]) == 1
assert result["tool_messages"][0].tool_call_result.result == "5"
def test_toolset_duplicate_tool_names(self):
"""Test that a Toolset raises an error for duplicate tool names."""
add_tool1 = Tool(
name="add",
description="Add two numbers (first)",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=add_numbers,
)
add_tool2 = Tool(
name="add",
description="Add two numbers (second)",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=add_numbers,
)
with pytest.raises(ValueError, match="Duplicate tool names found"):
Toolset([add_tool1, add_tool2])
toolset = Toolset([add_tool1])
with pytest.raises(ValueError, match="Duplicate tool names found"):
toolset.add(add_tool2)
toolset2 = Toolset([add_tool2])
with pytest.raises(ValueError, match="Duplicate tool names found"):
combined = toolset + toolset2
class TestToolsetIntegration:
"""Integration tests for Toolset in complete pipelines."""
def test_custom_toolset_serde_in_pipeline(self):
"""Test serialization and deserialization of a custom toolset within a pipeline."""
pipeline = Pipeline()
pipeline.add_component("tool_invoker", ToolInvoker(tools=CalculatorToolset()))
pipeline_dict = pipeline.to_dict()
tool_invoker_dict = pipeline_dict["components"]["tool_invoker"]
assert tool_invoker_dict["type"] == "haystack.components.tools.tool_invoker.ToolInvoker"
assert len(tool_invoker_dict["init_parameters"]["tools"]["data"]) == 0
new_pipeline = Pipeline.from_dict(pipeline_dict)
assert new_pipeline == pipeline
def test_regular_toolset_serde_in_pipeline(self):
"""Test serialization and deserialization of regular Toolsets within a pipeline."""
add_tool = Tool(
name="add",
description="Add two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=add_numbers,
)
multiply_tool = Tool(
name="multiply",
description="Multiply two numbers",
parameters={
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
"required": ["a", "b"],
},
function=multiply_numbers,
)
toolset = Toolset([add_tool, multiply_tool])
pipeline = Pipeline()
pipeline.add_component("tool_invoker", ToolInvoker(tools=toolset))
pipeline_dict = pipeline.to_dict()
tool_invoker_dict = pipeline_dict["components"]["tool_invoker"]
assert tool_invoker_dict["type"] == "haystack.components.tools.tool_invoker.ToolInvoker"
# Verify the serialized toolset
tools_dict = tool_invoker_dict["init_parameters"]["tools"]
assert tools_dict["type"] == "haystack.tools.toolset.Toolset"
assert len(tools_dict["data"]["tools"]) == 2
tool_names = [tool["data"]["name"] for tool in tools_dict["data"]["tools"]]
assert "add" in tool_names
assert "multiply" in tool_names
# Deserialize and verify
new_pipeline = Pipeline.from_dict(pipeline_dict)
assert new_pipeline == pipeline
class TestToolsetWithToolInvoker:
def test_init_with_toolset(self, weather_tool):
"""Test initializing ToolInvoker with a Toolset."""
toolset = Toolset(tools=[weather_tool])
invoker = ToolInvoker(tools=toolset)
assert invoker.tools == toolset
assert invoker._tools_with_names == {tool.name: tool for tool in toolset}
def test_serde_with_toolset(self, weather_tool):
"""Test serialization and deserialization of ToolInvoker with a Toolset."""
toolset = Toolset(tools=[weather_tool])
invoker = ToolInvoker(tools=toolset)
data = invoker.to_dict()
deserialized_invoker = ToolInvoker.from_dict(data)
assert deserialized_invoker.tools == invoker.tools
assert deserialized_invoker._tools_with_names == invoker._tools_with_names
assert deserialized_invoker.raise_on_failure == invoker.raise_on_failure
assert deserialized_invoker.convert_result_to_json_string == invoker.convert_result_to_json_string
def test_tool_invocation_error_with_toolset(self, faulty_tool):
"""Test tool invocation errors with a Toolset."""
toolset = Toolset(tools=[faulty_tool])
invoker = ToolInvoker(tools=toolset)
tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"})
tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call])
with pytest.raises(ToolInvocationError):
invoker.run(messages=[tool_call_message])
def test_toolinvoker_deserialization_with_custom_toolset(self, weather_tool):
"""Test deserialization of ToolInvoker with a custom Toolset."""
custom_toolset = CustomToolset(tools=[weather_tool], custom_attr="custom_value")
invoker = ToolInvoker(tools=custom_toolset)
data = invoker.to_dict()
assert isinstance(data, dict)
assert "type" in data and "init_parameters" in data
tools_data = data["init_parameters"]["tools"]
assert isinstance(tools_data, dict)
assert len(tools_data["data"]["tools"]) == 1
assert tools_data["data"]["tools"][0]["type"] == "haystack.tools.tool.Tool"
assert tools_data.get("custom_attr") == "custom_value"
deserialized_invoker = ToolInvoker.from_dict(data)
assert deserialized_invoker.tools == invoker.tools
assert deserialized_invoker._tools_with_names == invoker._tools_with_names
assert deserialized_invoker.raise_on_failure == invoker.raise_on_failure
assert deserialized_invoker.convert_result_to_json_string == invoker.convert_result_to_json_string