From c81d68402c209a4f84d032d9794aec80be70a2b6 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 4 Apr 2025 10:09:46 -0400 Subject: [PATCH] feat: Add Toolset to tooling architecture (#9161) * Add Toolset abstraction * Add reno note * More pydoc improvements * Update test * Simplify, Toolset is a dataclass * Wrap toolset instance with list * Add example * Toolset pydoc serde enhancement * Toolset as init param * Fix types * Linting * Minor updates * PR feedback * Add to pydoc config, minor import fixes * Improve pydoc example * Improve coverage for test_toolset.py * Improve test_toolset.py, test custom toolset serde properly * Update haystack/utils/misc.py Co-authored-by: Stefano Fiorucci * Rework Toolset pydoc * Another minor pydoc improvement * Prevent single Tool instantiating Toolset * Reduce number of integration tests * Remove some toolset tests from openai * Rework tests --------- Co-authored-by: Stefano Fiorucci --- docs/pydoc/config/tools_api.yml | 2 +- haystack/components/generators/chat/openai.py | 32 +- haystack/components/tools/tool_invoker.py | 64 +- haystack/tools/__init__.py | 2 + haystack/tools/tool.py | 18 +- haystack/tools/toolset.py | 291 +++++++++ haystack/utils/misc.py | 20 +- ...set-class-to-tooling-23376c72e31c5a9a.yaml | 3 + .../components/generators/chat/test_openai.py | 50 +- test/tools/test_toolset.py | 587 ++++++++++++++++++ 10 files changed, 1046 insertions(+), 23 deletions(-) create mode 100644 haystack/tools/toolset.py create mode 100644 releasenotes/notes/add-toolset-class-to-tooling-23376c72e31c5a9a.yaml create mode 100644 test/tools/test_toolset.py diff --git a/docs/pydoc/config/tools_api.yml b/docs/pydoc/config/tools_api.yml index 3050e6c58..e376b54ef 100644 --- a/docs/pydoc/config/tools_api.yml +++ b/docs/pydoc/config/tools_api.yml @@ -2,7 +2,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/tools] modules: - ["tool", "from_function", "component_tool"] + ["tool", "from_function", "component_tool", "toolset"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index b0febe4bd..fc70cc6c4 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -23,7 +23,9 @@ from haystack.dataclasses import ( select_streaming_callback, ) from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace +from haystack.tools.toolset import Toolset from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable +from haystack.utils.misc import serialize_tools_or_toolset logger = logging.getLogger(__name__) @@ -81,7 +83,7 @@ class OpenAIChatGenerator: generation_kwargs: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, max_retries: Optional[int] = None, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, tools_strict: bool = False, ): """ @@ -127,7 +129,8 @@ class OpenAIChatGenerator: Maximum number of retries to contact OpenAI after an internal error. If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5. :param tools: - A list of tools for which the model can prepare calls. + A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a + list of `Tool` objects or a `Toolset` instance. :param tools_strict: Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly the schema provided in the `parameters` field of the tool definition, but this may increase latency. @@ -140,10 +143,11 @@ class OpenAIChatGenerator: self.organization = organization self.timeout = timeout self.max_retries = max_retries - self.tools = tools + self.tools = tools # Store tools as-is, whether it's a list or a Toolset self.tools_strict = tools_strict - _check_duplicate_tool_names(tools) + # Check for duplicate tool names + _check_duplicate_tool_names(list(self.tools or [])) if timeout is None: timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0")) @@ -185,7 +189,7 @@ class OpenAIChatGenerator: api_key=self.api_key.to_dict(), timeout=self.timeout, max_retries=self.max_retries, - tools=[tool.to_dict() for tool in self.tools] if self.tools else None, + tools=serialize_tools_or_toolset(self.tools), tools_strict=self.tools_strict, ) @@ -213,7 +217,7 @@ class OpenAIChatGenerator: streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[Dict[str, Any]] = None, *, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, tools_strict: Optional[bool] = None, ): """ @@ -228,8 +232,9 @@ class OpenAIChatGenerator: override the parameters passed during component initialization. For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create). :param tools: - A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set - during component initialization. + A list of tools or a Toolset for which the model can prepare calls. If set, it will override the + `tools` parameter set during component initialization. This parameter can accept either a list of + `Tool` objects or a `Toolset` instance. :param tools_strict: Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly the schema provided in the `parameters` field of the tool definition, but this may increase latency. @@ -285,7 +290,7 @@ class OpenAIChatGenerator: streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[Dict[str, Any]] = None, *, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, tools_strict: Optional[bool] = None, ): """ @@ -304,8 +309,9 @@ class OpenAIChatGenerator: override the parameters passed during component initialization. For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create). :param tools: - A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set - during component initialization. + A list of tools or a Toolset for which the model can prepare calls. If set, it will override the + `tools` parameter set during component initialization. This parameter can accept either a list of + `Tool` objects or a `Toolset` instance. :param tools_strict: Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly the schema provided in the `parameters` field of the tool definition, but this may increase latency. @@ -362,7 +368,7 @@ class OpenAIChatGenerator: messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[Dict[str, Any]] = None, - tools: Optional[List[Tool]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, tools_strict: Optional[bool] = None, ) -> Dict[str, Any]: # update generation kwargs by merging with the generation kwargs passed to the run method @@ -372,6 +378,8 @@ class OpenAIChatGenerator: openai_formatted_messages = [message.to_openai_dict_format() for message in messages] tools = tools or self.tools + if isinstance(tools, Toolset): + tools = list(tools) tools_strict = tools_strict if tools_strict is not None else self.tools_strict _check_duplicate_tool_names(tools) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 3bbcfb408..dece392b1 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -4,13 +4,15 @@ import inspect import json -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging from haystack.core.component.sockets import Sockets from haystack.dataclasses import ChatMessage, State, ToolCall from haystack.tools.component_tool import ComponentTool from haystack.tools.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace +from haystack.tools.toolset import Toolset +from haystack.utils.misc import serialize_tools_or_toolset logger = logging.getLogger(__name__) @@ -57,7 +59,8 @@ class ToolInvoker: Usage example: ```python - from haystack.dataclasses import ChatMessage, ToolCall, Tool + from haystack.dataclasses import ChatMessage, ToolCall + from haystack.tools import Tool from haystack.components.tools import ToolInvoker # Tool definition @@ -108,14 +111,55 @@ class ToolInvoker: >> ] >> } ``` + + Usage example with a Toolset: + ```python + from haystack.dataclasses import ChatMessage, ToolCall + from haystack.tools import Tool, Toolset + from haystack.components.tools import ToolInvoker + + # Tool definition + def dummy_weather_function(city: str): + return f"The weather in {city} is 20 degrees." + + parameters = {"type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"]} + + tool = Tool(name="weather_tool", + description="A tool to get the weather", + function=dummy_weather_function, + parameters=parameters) + + # Create a Toolset + toolset = Toolset([tool]) + + # Usually, the ChatMessage with tool_calls is generated by a Language Model + # Here, we create it manually for demonstration purposes + tool_call = ToolCall( + tool_name="weather_tool", + arguments={"city": "Berlin"} + ) + message = ChatMessage.from_assistant(tool_calls=[tool_call]) + + # ToolInvoker initialization and run with Toolset + invoker = ToolInvoker(tools=toolset) + result = invoker.run(messages=[message]) + + print(result) """ - def __init__(self, tools: List[Tool], raise_on_failure: bool = True, convert_result_to_json_string: bool = False): + def __init__( + self, + tools: Union[List[Tool], Toolset], + raise_on_failure: bool = True, + convert_result_to_json_string: bool = False, + ): """ Initialize the ToolInvoker component. :param tools: - A list of tools that can be invoked. + A list of tools that can be invoked or a Toolset instance that can resolve tools. :param raise_on_failure: If True, the component will raise an exception in case of errors (tool not found, tool invocation errors, tool result conversion errors). @@ -129,13 +173,20 @@ class ToolInvoker: """ if not tools: raise ValueError("ToolInvoker requires at least one tool.") + + # could be a Toolset instance or a list of Tools + self.tools = tools + + # Convert Toolset to list for internal use + if isinstance(tools, Toolset): + tools = list(tools) + _check_duplicate_tool_names(tools) tool_names = [tool.name for tool in tools] duplicates = {name for name in tool_names if tool_names.count(name) > 1} if duplicates: raise ValueError(f"Duplicate tool names found: {duplicates}") - self.tools = tools self._tools_with_names = dict(zip(tool_names, tools)) self.raise_on_failure = raise_on_failure self.convert_result_to_json_string = convert_result_to_json_string @@ -385,10 +436,9 @@ class ToolInvoker: :returns: Dictionary with serialized data. """ - serialized_tools = [tool.to_dict() for tool in self.tools] return default_to_dict( self, - tools=serialized_tools, + tools=serialize_tools_or_toolset(self.tools), raise_on_failure=self.raise_on_failure, convert_result_to_json_string=self.convert_result_to_json_string, ) diff --git a/haystack/tools/__init__.py b/haystack/tools/__init__.py index 314ae63cb..c1eeed3cf 100644 --- a/haystack/tools/__init__.py +++ b/haystack/tools/__init__.py @@ -8,6 +8,7 @@ from .from_function import create_tool_from_function, tool from .tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from .component_tool import ComponentTool +from .toolset import Toolset __all__ = [ @@ -17,4 +18,5 @@ __all__ = [ "create_tool_from_function", "tool", "ComponentTool", + "Toolset", ] diff --git a/haystack/tools/tool.py b/haystack/tools/tool.py index 6d50aa8a9..9cde27f9a 100644 --- a/haystack/tools/tool.py +++ b/haystack/tools/tool.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional from jsonschema import Draft202012Validator from jsonschema.exceptions import SchemaError +from haystack.core.errors import DeserializationError from haystack.core.serialization import generate_qualified_class_name, import_class_by_name from haystack.tools.errors import ToolInvocationError from haystack.utils import deserialize_callable, serialize_callable @@ -190,8 +191,23 @@ def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"): if serialized_tools is None: return + # Check if it's a serialized Toolset (a dict with "type" and "data" keys) + if isinstance(serialized_tools, dict) and all(k in serialized_tools for k in ["type", "data"]): + toolset_class_name = serialized_tools.get("type") + if not toolset_class_name: + raise DeserializationError("The 'type' key is missing or None in the serialized toolset data") + + toolset_class = import_class_by_name(toolset_class_name) + from haystack.tools.toolset import Toolset # avoid circular import + + if not issubclass(toolset_class, Toolset): + raise TypeError(f"Class '{toolset_class}' is not a subclass of Toolset") + + data[key] = toolset_class.from_dict(serialized_tools) + return + if not isinstance(serialized_tools, list): - raise TypeError(f"The value of '{key}' is not a list") + raise TypeError(f"The value of '{key}' is not a list or a dictionary") deserialized_tools = [] for tool in serialized_tools: diff --git a/haystack/tools/toolset.py b/haystack/tools/toolset.py new file mode 100644 index 000000000..11fda7b14 --- /dev/null +++ b/haystack/tools/toolset.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field +from typing import Any, Dict, Iterator, List, Union + +from haystack.core.serialization import generate_qualified_class_name, import_class_by_name +from haystack.tools.tool import Tool, _check_duplicate_tool_names + + +@dataclass +class Toolset: + """ + A collection of related Tools that can be used and managed as a cohesive unit. + + Toolset serves two main purposes: + + 1. Group related tools together: + Toolset allows you to organize related tools into a single collection, making it easier + to manage and use them as a unit in Haystack pipelines. + + Example: + ```python + from haystack.tools import Tool, Toolset + from haystack.components.tools import ToolInvoker + + # Define math functions + def add_numbers(a: int, b: int) -> int: + return a + b + + def subtract_numbers(a: int, b: int) -> int: + return a - b + + # Create tools with proper schemas + add_tool = Tool( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"} + }, + "required": ["a", "b"] + }, + function=add_numbers + ) + + subtract_tool = Tool( + name="subtract", + description="Subtract b from a", + parameters={ + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"} + }, + "required": ["a", "b"] + }, + function=subtract_numbers + ) + + # Create a toolset with the math tools + math_toolset = Toolset([add_tool, subtract_tool]) + + # Use the toolset with a ToolInvoker or ChatGenerator component + invoker = ToolInvoker(tools=math_toolset) + ``` + + 2. Base class for dynamic tool loading: + By subclassing Toolset, you can create implementations that dynamically load tools + from external sources like OpenAPI URLs, MCP servers, or other resources. + + Example: + ```python + from haystack.core.serialization import generate_qualified_class_name + from haystack.tools import Tool, Toolset + from haystack.components.tools import ToolInvoker + + class CalculatorToolset(Toolset): + '''A toolset for calculator operations.''' + + def __init__(self): + tools = self._create_tools() + super().__init__(tools) + + def _create_tools(self): + # These Tool instances are obviously defined statically and for illustration purposes only. + # In a real-world scenario, you would dynamically load tools from an external source here. + tools = [] + add_tool = Tool( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=lambda a, b: a + b, + ) + + multiply_tool = Tool( + name="multiply", + description="Multiply two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=lambda a, b: a * b, + ) + + tools.append(add_tool) + tools.append(multiply_tool) + + return tools + + def to_dict(self): + return { + "type": generate_qualified_class_name(type(self)), + "data": {}, # no data to serialize as we define the tools dynamically + } + + @classmethod + def from_dict(cls, data): + return cls() # Recreate the tools dynamically during deserialization + + # Create the dynamic toolset and use it with ToolInvoker + calculator_toolset = CalculatorToolset() + invoker = ToolInvoker(tools=calculator_toolset) + ``` + + Toolset implements the collection interface (__iter__, __contains__, __len__, __getitem__), + making it behave like a list of Tools. This makes it compatible with components that expect + iterable tools, such as ToolInvoker or Haystack chat generators. + + When implementing a custom Toolset subclass for dynamic tool loading: + - Perform the dynamic loading in the __init__ method + - Override to_dict() and from_dict() methods if your tools are defined dynamically + - Serialize endpoint descriptors rather than tool instances if your tools + are loaded from external sources + """ + + # Use field() with default_factory to initialize the list + tools: List[Tool] = field(default_factory=list) + + def __post_init__(self): + """ + Validate and set up the toolset after initialization. + + This handles the case when tools are provided during initialization. + """ + # If initialization was done a single Tool, raise an error + if isinstance(self.tools, Tool): + raise TypeError("A single Tool cannot be directly passed to Toolset. Please use a list: Toolset([tool])") + + # Check for duplicate tool names in the initial set + _check_duplicate_tool_names(self.tools) + + def __iter__(self) -> Iterator[Tool]: + """ + Return an iterator over the Tools in this Toolset. + + This allows the Toolset to be used wherever a list of Tools is expected. + + :returns: An iterator yielding Tool instances + """ + return iter(self.tools) + + def __contains__(self, item: Any) -> bool: + """ + Check if a tool is in this Toolset. + + Supports checking by: + - Tool instance: tool in toolset + - Tool name: "tool_name" in toolset + + :param item: Tool instance or tool name string + :returns: True if contained, False otherwise + """ + if isinstance(item, str): + return any(tool.name == item for tool in self.tools) + if isinstance(item, Tool): + return item in self.tools + return False + + def add(self, tool: Union[Tool, "Toolset"]) -> None: + """ + Add a new Tool or merge another Toolset. + + :param tool: A Tool instance or another Toolset to add + :raises ValueError: If adding the tool would result in duplicate tool names + :raises TypeError: If the provided object is not a Tool or Toolset + """ + new_tools = [] + + if isinstance(tool, Tool): + new_tools = [tool] + elif isinstance(tool, Toolset): + new_tools = list(tool) + else: + raise TypeError(f"Expected Tool or Toolset, got {type(tool).__name__}") + + # Check for duplicates before adding + combined_tools = self.tools + new_tools + _check_duplicate_tool_names(combined_tools) + + self.tools.extend(new_tools) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize the Toolset to a dictionary. + + :returns: A dictionary representation of the Toolset + + Note for subclass implementers: + The default implementation is ideal for scenarios where Tool resolution is static. However, if your subclass + of Toolset dynamically resolves Tool instances from external sources—such as an MCP server, OpenAPI URL, or + a local OpenAPI specification—you should consider serializing the endpoint descriptor instead of the Tool + instances themselves. This strategy preserves the dynamic nature of your Toolset and minimizes the overhead + associated with serializing potentially large collections of Tool objects. Moreover, by serializing the + descriptor, you ensure that the deserialization process can accurately reconstruct the Tool instances, even + if they have been modified or removed since the last serialization. Failing to serialize the descriptor may + lead to issues where outdated or incorrect Tool configurations are loaded, potentially causing errors or + unexpected behavior. + """ + return { + "type": generate_qualified_class_name(type(self)), + "data": {"tools": [tool.to_dict() for tool in self.tools]}, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Toolset": + """ + Deserialize a Toolset from a dictionary. + + :param data: Dictionary representation of the Toolset + :returns: A new Toolset instance + """ + inner_data = data["data"] + tools_data = inner_data.get("tools", []) + + tools = [] + for tool_data in tools_data: + tool_class = import_class_by_name(tool_data["type"]) + if not issubclass(tool_class, Tool): + raise TypeError(f"Class '{tool_class}' is not a subclass of Tool") + tools.append(tool_class.from_dict(tool_data)) + + return cls(tools=tools) + + def __add__(self, other: Union[Tool, "Toolset", List[Tool]]) -> "Toolset": + """ + Concatenate this Toolset with another Tool, Toolset, or list of Tools. + + :param other: Another Tool, Toolset, or list of Tools to concatenate + :returns: A new Toolset containing all tools + :raises TypeError: If the other parameter is not a Tool, Toolset, or list of Tools + :raises ValueError: If the combination would result in duplicate tool names + """ + if isinstance(other, Tool): + combined_tools = self.tools + [other] + elif isinstance(other, Toolset): + combined_tools = self.tools + list(other) + elif isinstance(other, list) and all(isinstance(item, Tool) for item in other): + combined_tools = self.tools + other + else: + raise TypeError(f"Cannot add {type(other).__name__} to Toolset") + + # Check for duplicates + _check_duplicate_tool_names(combined_tools) + + return Toolset(tools=combined_tools) + + def __len__(self) -> int: + """ + Return the number of Tools in this Toolset. + + :returns: Number of Tools + """ + return len(self.tools) + + def __getitem__(self, index): + """ + Get a Tool by index. + + :param index: Index of the Tool to get + :returns: The Tool at the specified index + """ + return self.tools[index] diff --git a/haystack/utils/misc.py b/haystack/utils/misc.py index 8320f7aaf..812de8ecf 100644 --- a/haystack/utils/misc.py +++ b/haystack/utils/misc.py @@ -2,10 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import List, Union +from typing import Any, Dict, List, Union from numpy import exp +from haystack.tools import Tool, Toolset + def expand_page_range(page_range: List[Union[str, int]]) -> List[int]: """ @@ -52,3 +54,19 @@ def expit(x) -> float: :param x: input value. Can be a scalar or a numpy array. """ return 1 / (1 + exp(-x)) + + +def serialize_tools_or_toolset( + tools: Union[Toolset, List[Tool], None], +) -> Union[Dict[str, Any], List[Dict[str, Any]], None]: + """ + Serialize a Toolset or a list of Tools to a dictionary or a list of tool dictionaries. + + :param tools: A Toolset, a list of Tools, or None + :returns: A dictionary, a list of tool dictionaries, or None if tools is None + """ + if tools is None: + return None + if isinstance(tools, Toolset): + return tools.to_dict() + return [tool.to_dict() for tool in tools] diff --git a/releasenotes/notes/add-toolset-class-to-tooling-23376c72e31c5a9a.yaml b/releasenotes/notes/add-toolset-class-to-tooling-23376c72e31c5a9a.yaml new file mode 100644 index 000000000..43d47d7dd --- /dev/null +++ b/releasenotes/notes/add-toolset-class-to-tooling-23376c72e31c5a9a.yaml @@ -0,0 +1,3 @@ +features: + - | + Introduced the Toolset class, allowing for the grouping and management of related tool functionalities. This new abstraction supports dynamic tool loading and registration. diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 5a4da1162..d717f64fd 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -22,6 +22,7 @@ from haystack.utils.auth import Secret from haystack.dataclasses import ChatMessage, ToolCall from haystack.tools import Tool from haystack.components.generators.chat.openai import OpenAIChatGenerator +from haystack.tools.toolset import Toolset @pytest.fixture @@ -71,6 +72,10 @@ def mock_chat_completion_chunk_with_tools(openai_mock_stream): yield mock_chat_completion_create +def mock_tool_function(x): + return x + + @pytest.fixture def tools(): tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} @@ -78,7 +83,7 @@ def tools(): name="weather", description="useful to determine the weather in a given location", parameters=tool_parameters, - function=lambda x: x, + function=mock_tool_function, ) return [tool] @@ -269,6 +274,7 @@ class TestOpenAIChatGenerator: "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "tools": None, }, } with pytest.raises(ValueError): @@ -937,3 +943,45 @@ class TestOpenAIChatGenerator: assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} assert message.meta["finish_reason"] == "tool_calls" + + def test_openai_chat_generator_with_toolset_initialization(self, tools, monkeypatch): + """Test that the OpenAIChatGenerator can be initialized with a Toolset.""" + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + toolset = Toolset(tools) + generator = OpenAIChatGenerator(tools=toolset) + assert generator.tools == toolset + + def test_from_dict_with_toolset(self, tools, monkeypatch): + """Test that the OpenAIChatGenerator can be deserialized from a dictionary with a Toolset.""" + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + toolset = Toolset(tools) + component = OpenAIChatGenerator(tools=toolset) + data = component.to_dict() + + deserialized_component = OpenAIChatGenerator.from_dict(data) + + assert isinstance(deserialized_component.tools, Toolset) + assert len(deserialized_component.tools) == len(tools) + assert all(isinstance(tool, Tool) for tool in deserialized_component.tools) + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_toolset(self, tools): + chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + toolset = Toolset(tools) + component = OpenAIChatGenerator(tools=toolset) + results = component.run(chat_messages) + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert not message.texts + assert not message.text + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls" diff --git a/test/tools/test_toolset.py b/test/tools/test_toolset.py new file mode 100644 index 000000000..f8c59ed96 --- /dev/null +++ b/test/tools/test_toolset.py @@ -0,0 +1,587 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import os +from haystack import Pipeline +from haystack.core.serialization import generate_qualified_class_name +from haystack.dataclasses import ChatMessage +from haystack.components.tools import ToolInvoker +from haystack.dataclasses.chat_message import ToolCall +from haystack.tools import Tool, Toolset +from haystack.tools.errors import ToolInvocationError + + +# Common functions for tests +def add_numbers(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + +def multiply_numbers(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + +def subtract_numbers(a: int, b: int) -> int: + """Subtract b from a.""" + return a - b + + +class CustomToolset(Toolset): + def __init__(self, tools, custom_attr): + super().__init__(tools) + self.custom_attr = custom_attr + + def to_dict(self): + data = super().to_dict() + data["custom_attr"] = self.custom_attr + return data + + @classmethod + def from_dict(cls, data): + tools = [Tool.from_dict(tool_data) for tool_data in data["data"]["tools"]] + custom_attr = data["custom_attr"] + return cls(tools=tools, custom_attr=custom_attr) + + +class CalculatorToolset(Toolset): + """A toolset for calculator operations.""" + + def __init__(self): + super().__init__([]) + self._create_tools() + + def _create_tools(self): + add_tool = Tool( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=add_numbers, + ) + + multiply_tool = Tool( + name="multiply", + description="Multiply two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=multiply_numbers, + ) + + self.add(add_tool) + self.add(multiply_tool) + + def to_dict(self): + return { + "type": generate_qualified_class_name(type(self)), + "data": {}, # no data to serialize as we define the tools dynamically + } + + @classmethod + def from_dict(cls, data): + return cls() + + +def weather_function(location): + weather_info = { + "Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}, + "Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"}, + "Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"}, + } + return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"}) + + +weather_parameters = {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]} + + +@pytest.fixture +def weather_tool(): + return Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters=weather_parameters, + function=weather_function, + ) + + +@pytest.fixture +def faulty_tool(): + def faulty_tool_func(location): + raise Exception("This tool always fails.") + + faulty_tool_parameters = { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + } + + return Tool( + name="faulty_tool", + description="A tool that always fails when invoked.", + parameters=faulty_tool_parameters, + function=faulty_tool_func, + ) + + +class TestToolset: + def test_toolset_with_multiple_tools(self): + """Test that a Toolset with multiple tools works properly.""" + add_tool = Tool( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=add_numbers, + ) + + multiply_tool = Tool( + name="multiply", + description="Multiply two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=multiply_numbers, + ) + + toolset = Toolset([add_tool, multiply_tool]) + + assert len(toolset) == 2 + assert toolset[0].name == "add" + assert toolset[1].name == "multiply" + + invoker = ToolInvoker(tools=toolset) + + add_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3}) + add_message = ChatMessage.from_assistant(tool_calls=[add_call]) + + multiply_call = ToolCall(tool_name="multiply", arguments={"a": 4, "b": 5}) + multiply_message = ChatMessage.from_assistant(tool_calls=[multiply_call]) + + result = invoker.run(messages=[add_message, multiply_message]) + + assert len(result["tool_messages"]) == 2 + tool_results = [message.tool_call_result.result for message in result["tool_messages"]] + assert "5" in tool_results + assert "20" in tool_results + + def test_toolset_adding(self): + """Test that tools can be added to a Toolset.""" + toolset = Toolset() + assert len(toolset) == 0 + + add_tool = Tool( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=add_numbers, + ) + + toolset.add(add_tool) + assert len(toolset) == 1 + assert toolset[0].name == "add" + + invoker = ToolInvoker(tools=toolset) + tool_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3}) + message = ChatMessage.from_assistant(tool_calls=[tool_call]) + result = invoker.run(messages=[message]) + + assert len(result["tool_messages"]) == 1 + assert result["tool_messages"][0].tool_call_result.result == "5" + + def test_toolset_addition(self): + """Test that toolsets can be combined.""" + add_tool = Tool( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=add_numbers, + ) + + multiply_tool = Tool( + name="multiply", + description="Multiply two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=multiply_numbers, + ) + + subtract_tool = Tool( + name="subtract", + description="Subtract two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=subtract_numbers, + ) + + toolset1 = Toolset([add_tool]) + toolset2 = Toolset([multiply_tool]) + + combined_toolset = toolset1 + toolset2 + assert len(combined_toolset) == 2 + + combined_toolset = combined_toolset + subtract_tool + assert len(combined_toolset) == 3 + + tool_names = [tool.name for tool in combined_toolset] + assert "add" in tool_names + assert "multiply" in tool_names + assert "subtract" in tool_names + + invoker = ToolInvoker(tools=combined_toolset) + + add_call = ToolCall(tool_name="add", arguments={"a": 10, "b": 5}) + multiply_call = ToolCall(tool_name="multiply", arguments={"a": 10, "b": 5}) + subtract_call = ToolCall(tool_name="subtract", arguments={"a": 10, "b": 5}) + + message = ChatMessage.from_assistant(tool_calls=[add_call, multiply_call, subtract_call]) + + result = invoker.run(messages=[message]) + + assert len(result["tool_messages"]) == 3 + tool_results = [message.tool_call_result.result for message in result["tool_messages"]] + assert "15" in tool_results + assert "50" in tool_results + assert "5" in tool_results + + def test_toolset_contains(self): + """Test that the __contains__ method works correctly.""" + add_tool = Tool( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=add_numbers, + ) + + multiply_tool = Tool( + name="multiply", + description="Multiply two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=multiply_numbers, + ) + + toolset = Toolset([add_tool]) + + # Test with a tool instance + assert add_tool in toolset + assert multiply_tool not in toolset + + # Test with a tool name + assert "add" in toolset + assert "multiply" not in toolset + assert "non_existent_tool" not in toolset + + def test_toolset_add_various_types(self): + """Test that the __add__ method works with various object types.""" + add_tool = Tool( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=add_numbers, + ) + + multiply_tool = Tool( + name="multiply", + description="Multiply two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=multiply_numbers, + ) + + subtract_tool = Tool( + name="subtract", + description="Subtract two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=subtract_numbers, + ) + + # Test adding a single tool + toolset1 = Toolset([add_tool]) + result1 = toolset1 + multiply_tool + assert len(result1) == 2 + assert add_tool in result1 + assert multiply_tool in result1 + + # Test adding another toolset + toolset2 = Toolset([subtract_tool]) + result2 = toolset1 + toolset2 + assert len(result2) == 2 + assert add_tool in result2 + assert subtract_tool in result2 + + # Test adding a list of tools + result3 = toolset1 + [multiply_tool, subtract_tool] + assert len(result3) == 3 + assert add_tool in result3 + assert multiply_tool in result3 + assert subtract_tool in result3 + + # Test adding types that aren't supported + with pytest.raises(TypeError): + toolset1 + "not_a_tool" + + with pytest.raises(TypeError): + toolset1 + 123 + + def test_toolset_serialization(self): + """Test that a Toolset can be serialized and deserialized.""" + add_tool = Tool( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=add_numbers, + ) + + toolset = Toolset([add_tool]) + + serialized = toolset.to_dict() + + deserialized = Toolset.from_dict(serialized) + + assert len(deserialized) == 1 + assert deserialized[0].name == "add" + assert deserialized[0].description == "Add two numbers" + + invoker = ToolInvoker(tools=deserialized) + tool_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3}) + message = ChatMessage.from_assistant(tool_calls=[tool_call]) + result = invoker.run(messages=[message]) + + assert len(result["tool_messages"]) == 1 + assert result["tool_messages"][0].tool_call_result.result == "5" + + def test_custom_toolset_serialization(self): + """Test serialization and deserialization of a custom Toolset subclass.""" + add_tool = Tool( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=add_numbers, + ) + + custom_attr_value = "custom_value" + custom_toolset = CustomToolset(tools=[add_tool], custom_attr=custom_attr_value) + + serialized = custom_toolset.to_dict() + assert serialized["type"].endswith("CustomToolset") + assert serialized["custom_attr"] == custom_attr_value + assert len(serialized["data"]["tools"]) == 1 + assert serialized["data"]["tools"][0]["data"]["name"] == "add" + + deserialized = CustomToolset.from_dict(serialized) + assert isinstance(deserialized, CustomToolset) + assert deserialized.custom_attr == custom_attr_value + assert len(deserialized) == 1 + assert deserialized[0].name == "add" + + invoker = ToolInvoker(tools=deserialized) + tool_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3}) + message = ChatMessage.from_assistant(tool_calls=[tool_call]) + result = invoker.run(messages=[message]) + + assert len(result["tool_messages"]) == 1 + assert result["tool_messages"][0].tool_call_result.result == "5" + + def test_toolset_duplicate_tool_names(self): + """Test that a Toolset raises an error for duplicate tool names.""" + add_tool1 = Tool( + name="add", + description="Add two numbers (first)", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=add_numbers, + ) + + add_tool2 = Tool( + name="add", + description="Add two numbers (second)", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=add_numbers, + ) + + with pytest.raises(ValueError, match="Duplicate tool names found"): + Toolset([add_tool1, add_tool2]) + + toolset = Toolset([add_tool1]) + + with pytest.raises(ValueError, match="Duplicate tool names found"): + toolset.add(add_tool2) + + toolset2 = Toolset([add_tool2]) + with pytest.raises(ValueError, match="Duplicate tool names found"): + combined = toolset + toolset2 + + +class TestToolsetIntegration: + """Integration tests for Toolset in complete pipelines.""" + + def test_custom_toolset_serde_in_pipeline(self): + """Test serialization and deserialization of a custom toolset within a pipeline.""" + + pipeline = Pipeline() + pipeline.add_component("tool_invoker", ToolInvoker(tools=CalculatorToolset())) + + pipeline_dict = pipeline.to_dict() + + tool_invoker_dict = pipeline_dict["components"]["tool_invoker"] + assert tool_invoker_dict["type"] == "haystack.components.tools.tool_invoker.ToolInvoker" + assert len(tool_invoker_dict["init_parameters"]["tools"]["data"]) == 0 + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline + + def test_regular_toolset_serde_in_pipeline(self): + """Test serialization and deserialization of regular Toolsets within a pipeline.""" + + add_tool = Tool( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=add_numbers, + ) + + multiply_tool = Tool( + name="multiply", + description="Multiply two numbers", + parameters={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + function=multiply_numbers, + ) + + toolset = Toolset([add_tool, multiply_tool]) + + pipeline = Pipeline() + pipeline.add_component("tool_invoker", ToolInvoker(tools=toolset)) + + pipeline_dict = pipeline.to_dict() + + tool_invoker_dict = pipeline_dict["components"]["tool_invoker"] + assert tool_invoker_dict["type"] == "haystack.components.tools.tool_invoker.ToolInvoker" + + # Verify the serialized toolset + tools_dict = tool_invoker_dict["init_parameters"]["tools"] + assert tools_dict["type"] == "haystack.tools.toolset.Toolset" + assert len(tools_dict["data"]["tools"]) == 2 + tool_names = [tool["data"]["name"] for tool in tools_dict["data"]["tools"]] + assert "add" in tool_names + assert "multiply" in tool_names + + # Deserialize and verify + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline + + +class TestToolsetWithToolInvoker: + def test_init_with_toolset(self, weather_tool): + """Test initializing ToolInvoker with a Toolset.""" + toolset = Toolset(tools=[weather_tool]) + invoker = ToolInvoker(tools=toolset) + assert invoker.tools == toolset + assert invoker._tools_with_names == {tool.name: tool for tool in toolset} + + def test_serde_with_toolset(self, weather_tool): + """Test serialization and deserialization of ToolInvoker with a Toolset.""" + toolset = Toolset(tools=[weather_tool]) + invoker = ToolInvoker(tools=toolset) + data = invoker.to_dict() + deserialized_invoker = ToolInvoker.from_dict(data) + assert deserialized_invoker.tools == invoker.tools + assert deserialized_invoker._tools_with_names == invoker._tools_with_names + assert deserialized_invoker.raise_on_failure == invoker.raise_on_failure + assert deserialized_invoker.convert_result_to_json_string == invoker.convert_result_to_json_string + + def test_tool_invocation_error_with_toolset(self, faulty_tool): + """Test tool invocation errors with a Toolset.""" + toolset = Toolset(tools=[faulty_tool]) + invoker = ToolInvoker(tools=toolset) + tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"}) + tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) + with pytest.raises(ToolInvocationError): + invoker.run(messages=[tool_call_message]) + + def test_toolinvoker_deserialization_with_custom_toolset(self, weather_tool): + """Test deserialization of ToolInvoker with a custom Toolset.""" + custom_toolset = CustomToolset(tools=[weather_tool], custom_attr="custom_value") + invoker = ToolInvoker(tools=custom_toolset) + data = invoker.to_dict() + + assert isinstance(data, dict) + assert "type" in data and "init_parameters" in data + tools_data = data["init_parameters"]["tools"] + assert isinstance(tools_data, dict) + assert len(tools_data["data"]["tools"]) == 1 + assert tools_data["data"]["tools"][0]["type"] == "haystack.tools.tool.Tool" + assert tools_data.get("custom_attr") == "custom_value" + + deserialized_invoker = ToolInvoker.from_dict(data) + assert deserialized_invoker.tools == invoker.tools + assert deserialized_invoker._tools_with_names == invoker._tools_with_names + assert deserialized_invoker.raise_on_failure == invoker.raise_on_failure + assert deserialized_invoker.convert_result_to_json_string == invoker.convert_result_to_json_string