mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-24 22:16:34 +00:00
feat: Add Toolset to tooling architecture (#9161)
* Add Toolset abstraction * Add reno note * More pydoc improvements * Update test * Simplify, Toolset is a dataclass * Wrap toolset instance with list * Add example * Toolset pydoc serde enhancement * Toolset as init param * Fix types * Linting * Minor updates * PR feedback * Add to pydoc config, minor import fixes * Improve pydoc example * Improve coverage for test_toolset.py * Improve test_toolset.py, test custom toolset serde properly * Update haystack/utils/misc.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * Rework Toolset pydoc * Another minor pydoc improvement * Prevent single Tool instantiating Toolset * Reduce number of integration tests * Remove some toolset tests from openai * Rework tests --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
parent
a2f73d134d
commit
c81d68402c
@ -2,7 +2,7 @@ loaders:
|
||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||
search_path: [../../../haystack/tools]
|
||||
modules:
|
||||
["tool", "from_function", "component_tool"]
|
||||
["tool", "from_function", "component_tool", "toolset"]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
- type: filter
|
||||
|
||||
@ -23,7 +23,9 @@ from haystack.dataclasses import (
|
||||
select_streaming_callback,
|
||||
)
|
||||
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||
from haystack.tools.toolset import Toolset
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.misc import serialize_tools_or_toolset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -81,7 +83,7 @@ class OpenAIChatGenerator:
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
tools_strict: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -127,7 +129,8 @@ class OpenAIChatGenerator:
|
||||
Maximum number of retries to contact OpenAI after an internal error.
|
||||
If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5.
|
||||
:param tools:
|
||||
A list of tools for which the model can prepare calls.
|
||||
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
|
||||
list of `Tool` objects or a `Toolset` instance.
|
||||
:param tools_strict:
|
||||
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
|
||||
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
|
||||
@ -140,10 +143,11 @@ class OpenAIChatGenerator:
|
||||
self.organization = organization
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.tools = tools
|
||||
self.tools = tools # Store tools as-is, whether it's a list or a Toolset
|
||||
self.tools_strict = tools_strict
|
||||
|
||||
_check_duplicate_tool_names(tools)
|
||||
# Check for duplicate tool names
|
||||
_check_duplicate_tool_names(list(self.tools or []))
|
||||
|
||||
if timeout is None:
|
||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||
@ -185,7 +189,7 @@ class OpenAIChatGenerator:
|
||||
api_key=self.api_key.to_dict(),
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
|
||||
tools=serialize_tools_or_toolset(self.tools),
|
||||
tools_strict=self.tools_strict,
|
||||
)
|
||||
|
||||
@ -213,7 +217,7 @@ class OpenAIChatGenerator:
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
tools_strict: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
@ -228,8 +232,9 @@ class OpenAIChatGenerator:
|
||||
override the parameters passed during component initialization.
|
||||
For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create).
|
||||
:param tools:
|
||||
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
|
||||
during component initialization.
|
||||
A list of tools or a Toolset for which the model can prepare calls. If set, it will override the
|
||||
`tools` parameter set during component initialization. This parameter can accept either a list of
|
||||
`Tool` objects or a `Toolset` instance.
|
||||
:param tools_strict:
|
||||
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
|
||||
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
|
||||
@ -285,7 +290,7 @@ class OpenAIChatGenerator:
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
tools_strict: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
@ -304,8 +309,9 @@ class OpenAIChatGenerator:
|
||||
override the parameters passed during component initialization.
|
||||
For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create).
|
||||
:param tools:
|
||||
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
|
||||
during component initialization.
|
||||
A list of tools or a Toolset for which the model can prepare calls. If set, it will override the
|
||||
`tools` parameter set during component initialization. This parameter can accept either a list of
|
||||
`Tool` objects or a `Toolset` instance.
|
||||
:param tools_strict:
|
||||
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
|
||||
the schema provided in the `parameters` field of the tool definition, but this may increase latency.
|
||||
@ -362,7 +368,7 @@ class OpenAIChatGenerator:
|
||||
messages: List[ChatMessage],
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[Union[List[Tool], Toolset]] = None,
|
||||
tools_strict: Optional[bool] = None,
|
||||
) -> Dict[str, Any]:
|
||||
# update generation kwargs by merging with the generation kwargs passed to the run method
|
||||
@ -372,6 +378,8 @@ class OpenAIChatGenerator:
|
||||
openai_formatted_messages = [message.to_openai_dict_format() for message in messages]
|
||||
|
||||
tools = tools or self.tools
|
||||
if isinstance(tools, Toolset):
|
||||
tools = list(tools)
|
||||
tools_strict = tools_strict if tools_strict is not None else self.tools_strict
|
||||
_check_duplicate_tool_names(tools)
|
||||
|
||||
|
||||
@ -4,13 +4,15 @@
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.core.component.sockets import Sockets
|
||||
from haystack.dataclasses import ChatMessage, State, ToolCall
|
||||
from haystack.tools.component_tool import ComponentTool
|
||||
from haystack.tools.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||
from haystack.tools.toolset import Toolset
|
||||
from haystack.utils.misc import serialize_tools_or_toolset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -57,7 +59,8 @@ class ToolInvoker:
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack.dataclasses import ChatMessage, ToolCall, Tool
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.tools import Tool
|
||||
from haystack.components.tools import ToolInvoker
|
||||
|
||||
# Tool definition
|
||||
@ -108,14 +111,55 @@ class ToolInvoker:
|
||||
>> ]
|
||||
>> }
|
||||
```
|
||||
|
||||
Usage example with a Toolset:
|
||||
```python
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.tools import Tool, Toolset
|
||||
from haystack.components.tools import ToolInvoker
|
||||
|
||||
# Tool definition
|
||||
def dummy_weather_function(city: str):
|
||||
return f"The weather in {city} is 20 degrees."
|
||||
|
||||
parameters = {"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"]}
|
||||
|
||||
tool = Tool(name="weather_tool",
|
||||
description="A tool to get the weather",
|
||||
function=dummy_weather_function,
|
||||
parameters=parameters)
|
||||
|
||||
# Create a Toolset
|
||||
toolset = Toolset([tool])
|
||||
|
||||
# Usually, the ChatMessage with tool_calls is generated by a Language Model
|
||||
# Here, we create it manually for demonstration purposes
|
||||
tool_call = ToolCall(
|
||||
tool_name="weather_tool",
|
||||
arguments={"city": "Berlin"}
|
||||
)
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
|
||||
# ToolInvoker initialization and run with Toolset
|
||||
invoker = ToolInvoker(tools=toolset)
|
||||
result = invoker.run(messages=[message])
|
||||
|
||||
print(result)
|
||||
"""
|
||||
|
||||
def __init__(self, tools: List[Tool], raise_on_failure: bool = True, convert_result_to_json_string: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
tools: Union[List[Tool], Toolset],
|
||||
raise_on_failure: bool = True,
|
||||
convert_result_to_json_string: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the ToolInvoker component.
|
||||
|
||||
:param tools:
|
||||
A list of tools that can be invoked.
|
||||
A list of tools that can be invoked or a Toolset instance that can resolve tools.
|
||||
:param raise_on_failure:
|
||||
If True, the component will raise an exception in case of errors
|
||||
(tool not found, tool invocation errors, tool result conversion errors).
|
||||
@ -129,13 +173,20 @@ class ToolInvoker:
|
||||
"""
|
||||
if not tools:
|
||||
raise ValueError("ToolInvoker requires at least one tool.")
|
||||
|
||||
# could be a Toolset instance or a list of Tools
|
||||
self.tools = tools
|
||||
|
||||
# Convert Toolset to list for internal use
|
||||
if isinstance(tools, Toolset):
|
||||
tools = list(tools)
|
||||
|
||||
_check_duplicate_tool_names(tools)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
duplicates = {name for name in tool_names if tool_names.count(name) > 1}
|
||||
if duplicates:
|
||||
raise ValueError(f"Duplicate tool names found: {duplicates}")
|
||||
|
||||
self.tools = tools
|
||||
self._tools_with_names = dict(zip(tool_names, tools))
|
||||
self.raise_on_failure = raise_on_failure
|
||||
self.convert_result_to_json_string = convert_result_to_json_string
|
||||
@ -385,10 +436,9 @@ class ToolInvoker:
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
serialized_tools = [tool.to_dict() for tool in self.tools]
|
||||
return default_to_dict(
|
||||
self,
|
||||
tools=serialized_tools,
|
||||
tools=serialize_tools_or_toolset(self.tools),
|
||||
raise_on_failure=self.raise_on_failure,
|
||||
convert_result_to_json_string=self.convert_result_to_json_string,
|
||||
)
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
from .from_function import create_tool_from_function, tool
|
||||
from .tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
|
||||
from .component_tool import ComponentTool
|
||||
from .toolset import Toolset
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -17,4 +18,5 @@ __all__ = [
|
||||
"create_tool_from_function",
|
||||
"tool",
|
||||
"ComponentTool",
|
||||
"Toolset",
|
||||
]
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
from jsonschema import Draft202012Validator
|
||||
from jsonschema.exceptions import SchemaError
|
||||
|
||||
from haystack.core.errors import DeserializationError
|
||||
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
|
||||
from haystack.tools.errors import ToolInvocationError
|
||||
from haystack.utils import deserialize_callable, serialize_callable
|
||||
@ -190,8 +191,23 @@ def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"):
|
||||
if serialized_tools is None:
|
||||
return
|
||||
|
||||
# Check if it's a serialized Toolset (a dict with "type" and "data" keys)
|
||||
if isinstance(serialized_tools, dict) and all(k in serialized_tools for k in ["type", "data"]):
|
||||
toolset_class_name = serialized_tools.get("type")
|
||||
if not toolset_class_name:
|
||||
raise DeserializationError("The 'type' key is missing or None in the serialized toolset data")
|
||||
|
||||
toolset_class = import_class_by_name(toolset_class_name)
|
||||
from haystack.tools.toolset import Toolset # avoid circular import
|
||||
|
||||
if not issubclass(toolset_class, Toolset):
|
||||
raise TypeError(f"Class '{toolset_class}' is not a subclass of Toolset")
|
||||
|
||||
data[key] = toolset_class.from_dict(serialized_tools)
|
||||
return
|
||||
|
||||
if not isinstance(serialized_tools, list):
|
||||
raise TypeError(f"The value of '{key}' is not a list")
|
||||
raise TypeError(f"The value of '{key}' is not a list or a dictionary")
|
||||
|
||||
deserialized_tools = []
|
||||
for tool in serialized_tools:
|
||||
|
||||
291
haystack/tools/toolset.py
Normal file
291
haystack/tools/toolset.py
Normal file
@ -0,0 +1,291 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterator, List, Union
|
||||
|
||||
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
|
||||
from haystack.tools.tool import Tool, _check_duplicate_tool_names
|
||||
|
||||
|
||||
@dataclass
|
||||
class Toolset:
|
||||
"""
|
||||
A collection of related Tools that can be used and managed as a cohesive unit.
|
||||
|
||||
Toolset serves two main purposes:
|
||||
|
||||
1. Group related tools together:
|
||||
Toolset allows you to organize related tools into a single collection, making it easier
|
||||
to manage and use them as a unit in Haystack pipelines.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from haystack.tools import Tool, Toolset
|
||||
from haystack.components.tools import ToolInvoker
|
||||
|
||||
# Define math functions
|
||||
def add_numbers(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
def subtract_numbers(a: int, b: int) -> int:
|
||||
return a - b
|
||||
|
||||
# Create tools with proper schemas
|
||||
add_tool = Tool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "integer"},
|
||||
"b": {"type": "integer"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
},
|
||||
function=add_numbers
|
||||
)
|
||||
|
||||
subtract_tool = Tool(
|
||||
name="subtract",
|
||||
description="Subtract b from a",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "integer"},
|
||||
"b": {"type": "integer"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
},
|
||||
function=subtract_numbers
|
||||
)
|
||||
|
||||
# Create a toolset with the math tools
|
||||
math_toolset = Toolset([add_tool, subtract_tool])
|
||||
|
||||
# Use the toolset with a ToolInvoker or ChatGenerator component
|
||||
invoker = ToolInvoker(tools=math_toolset)
|
||||
```
|
||||
|
||||
2. Base class for dynamic tool loading:
|
||||
By subclassing Toolset, you can create implementations that dynamically load tools
|
||||
from external sources like OpenAPI URLs, MCP servers, or other resources.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from haystack.core.serialization import generate_qualified_class_name
|
||||
from haystack.tools import Tool, Toolset
|
||||
from haystack.components.tools import ToolInvoker
|
||||
|
||||
class CalculatorToolset(Toolset):
|
||||
'''A toolset for calculator operations.'''
|
||||
|
||||
def __init__(self):
|
||||
tools = self._create_tools()
|
||||
super().__init__(tools)
|
||||
|
||||
def _create_tools(self):
|
||||
# These Tool instances are obviously defined statically and for illustration purposes only.
|
||||
# In a real-world scenario, you would dynamically load tools from an external source here.
|
||||
tools = []
|
||||
add_tool = Tool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=lambda a, b: a + b,
|
||||
)
|
||||
|
||||
multiply_tool = Tool(
|
||||
name="multiply",
|
||||
description="Multiply two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=lambda a, b: a * b,
|
||||
)
|
||||
|
||||
tools.append(add_tool)
|
||||
tools.append(multiply_tool)
|
||||
|
||||
return tools
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"type": generate_qualified_class_name(type(self)),
|
||||
"data": {}, # no data to serialize as we define the tools dynamically
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
return cls() # Recreate the tools dynamically during deserialization
|
||||
|
||||
# Create the dynamic toolset and use it with ToolInvoker
|
||||
calculator_toolset = CalculatorToolset()
|
||||
invoker = ToolInvoker(tools=calculator_toolset)
|
||||
```
|
||||
|
||||
Toolset implements the collection interface (__iter__, __contains__, __len__, __getitem__),
|
||||
making it behave like a list of Tools. This makes it compatible with components that expect
|
||||
iterable tools, such as ToolInvoker or Haystack chat generators.
|
||||
|
||||
When implementing a custom Toolset subclass for dynamic tool loading:
|
||||
- Perform the dynamic loading in the __init__ method
|
||||
- Override to_dict() and from_dict() methods if your tools are defined dynamically
|
||||
- Serialize endpoint descriptors rather than tool instances if your tools
|
||||
are loaded from external sources
|
||||
"""
|
||||
|
||||
# Use field() with default_factory to initialize the list
|
||||
tools: List[Tool] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Validate and set up the toolset after initialization.
|
||||
|
||||
This handles the case when tools are provided during initialization.
|
||||
"""
|
||||
# If initialization was done a single Tool, raise an error
|
||||
if isinstance(self.tools, Tool):
|
||||
raise TypeError("A single Tool cannot be directly passed to Toolset. Please use a list: Toolset([tool])")
|
||||
|
||||
# Check for duplicate tool names in the initial set
|
||||
_check_duplicate_tool_names(self.tools)
|
||||
|
||||
def __iter__(self) -> Iterator[Tool]:
|
||||
"""
|
||||
Return an iterator over the Tools in this Toolset.
|
||||
|
||||
This allows the Toolset to be used wherever a list of Tools is expected.
|
||||
|
||||
:returns: An iterator yielding Tool instances
|
||||
"""
|
||||
return iter(self.tools)
|
||||
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
"""
|
||||
Check if a tool is in this Toolset.
|
||||
|
||||
Supports checking by:
|
||||
- Tool instance: tool in toolset
|
||||
- Tool name: "tool_name" in toolset
|
||||
|
||||
:param item: Tool instance or tool name string
|
||||
:returns: True if contained, False otherwise
|
||||
"""
|
||||
if isinstance(item, str):
|
||||
return any(tool.name == item for tool in self.tools)
|
||||
if isinstance(item, Tool):
|
||||
return item in self.tools
|
||||
return False
|
||||
|
||||
def add(self, tool: Union[Tool, "Toolset"]) -> None:
|
||||
"""
|
||||
Add a new Tool or merge another Toolset.
|
||||
|
||||
:param tool: A Tool instance or another Toolset to add
|
||||
:raises ValueError: If adding the tool would result in duplicate tool names
|
||||
:raises TypeError: If the provided object is not a Tool or Toolset
|
||||
"""
|
||||
new_tools = []
|
||||
|
||||
if isinstance(tool, Tool):
|
||||
new_tools = [tool]
|
||||
elif isinstance(tool, Toolset):
|
||||
new_tools = list(tool)
|
||||
else:
|
||||
raise TypeError(f"Expected Tool or Toolset, got {type(tool).__name__}")
|
||||
|
||||
# Check for duplicates before adding
|
||||
combined_tools = self.tools + new_tools
|
||||
_check_duplicate_tool_names(combined_tools)
|
||||
|
||||
self.tools.extend(new_tools)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize the Toolset to a dictionary.
|
||||
|
||||
:returns: A dictionary representation of the Toolset
|
||||
|
||||
Note for subclass implementers:
|
||||
The default implementation is ideal for scenarios where Tool resolution is static. However, if your subclass
|
||||
of Toolset dynamically resolves Tool instances from external sources—such as an MCP server, OpenAPI URL, or
|
||||
a local OpenAPI specification—you should consider serializing the endpoint descriptor instead of the Tool
|
||||
instances themselves. This strategy preserves the dynamic nature of your Toolset and minimizes the overhead
|
||||
associated with serializing potentially large collections of Tool objects. Moreover, by serializing the
|
||||
descriptor, you ensure that the deserialization process can accurately reconstruct the Tool instances, even
|
||||
if they have been modified or removed since the last serialization. Failing to serialize the descriptor may
|
||||
lead to issues where outdated or incorrect Tool configurations are loaded, potentially causing errors or
|
||||
unexpected behavior.
|
||||
"""
|
||||
return {
|
||||
"type": generate_qualified_class_name(type(self)),
|
||||
"data": {"tools": [tool.to_dict() for tool in self.tools]},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Toolset":
|
||||
"""
|
||||
Deserialize a Toolset from a dictionary.
|
||||
|
||||
:param data: Dictionary representation of the Toolset
|
||||
:returns: A new Toolset instance
|
||||
"""
|
||||
inner_data = data["data"]
|
||||
tools_data = inner_data.get("tools", [])
|
||||
|
||||
tools = []
|
||||
for tool_data in tools_data:
|
||||
tool_class = import_class_by_name(tool_data["type"])
|
||||
if not issubclass(tool_class, Tool):
|
||||
raise TypeError(f"Class '{tool_class}' is not a subclass of Tool")
|
||||
tools.append(tool_class.from_dict(tool_data))
|
||||
|
||||
return cls(tools=tools)
|
||||
|
||||
def __add__(self, other: Union[Tool, "Toolset", List[Tool]]) -> "Toolset":
|
||||
"""
|
||||
Concatenate this Toolset with another Tool, Toolset, or list of Tools.
|
||||
|
||||
:param other: Another Tool, Toolset, or list of Tools to concatenate
|
||||
:returns: A new Toolset containing all tools
|
||||
:raises TypeError: If the other parameter is not a Tool, Toolset, or list of Tools
|
||||
:raises ValueError: If the combination would result in duplicate tool names
|
||||
"""
|
||||
if isinstance(other, Tool):
|
||||
combined_tools = self.tools + [other]
|
||||
elif isinstance(other, Toolset):
|
||||
combined_tools = self.tools + list(other)
|
||||
elif isinstance(other, list) and all(isinstance(item, Tool) for item in other):
|
||||
combined_tools = self.tools + other
|
||||
else:
|
||||
raise TypeError(f"Cannot add {type(other).__name__} to Toolset")
|
||||
|
||||
# Check for duplicates
|
||||
_check_duplicate_tool_names(combined_tools)
|
||||
|
||||
return Toolset(tools=combined_tools)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Return the number of Tools in this Toolset.
|
||||
|
||||
:returns: Number of Tools
|
||||
"""
|
||||
return len(self.tools)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Get a Tool by index.
|
||||
|
||||
:param index: Index of the Tool to get
|
||||
:returns: The Tool at the specified index
|
||||
"""
|
||||
return self.tools[index]
|
||||
@ -2,10 +2,12 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from numpy import exp
|
||||
|
||||
from haystack.tools import Tool, Toolset
|
||||
|
||||
|
||||
def expand_page_range(page_range: List[Union[str, int]]) -> List[int]:
|
||||
"""
|
||||
@ -52,3 +54,19 @@ def expit(x) -> float:
|
||||
:param x: input value. Can be a scalar or a numpy array.
|
||||
"""
|
||||
return 1 / (1 + exp(-x))
|
||||
|
||||
|
||||
def serialize_tools_or_toolset(
|
||||
tools: Union[Toolset, List[Tool], None],
|
||||
) -> Union[Dict[str, Any], List[Dict[str, Any]], None]:
|
||||
"""
|
||||
Serialize a Toolset or a list of Tools to a dictionary or a list of tool dictionaries.
|
||||
|
||||
:param tools: A Toolset, a list of Tools, or None
|
||||
:returns: A dictionary, a list of tool dictionaries, or None if tools is None
|
||||
"""
|
||||
if tools is None:
|
||||
return None
|
||||
if isinstance(tools, Toolset):
|
||||
return tools.to_dict()
|
||||
return [tool.to_dict() for tool in tools]
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
features:
|
||||
- |
|
||||
Introduced the Toolset class, allowing for the grouping and management of related tool functionalities. This new abstraction supports dynamic tool loading and registration.
|
||||
@ -22,6 +22,7 @@ from haystack.utils.auth import Secret
|
||||
from haystack.dataclasses import ChatMessage, ToolCall
|
||||
from haystack.tools import Tool
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||
from haystack.tools.toolset import Toolset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -71,6 +72,10 @@ def mock_chat_completion_chunk_with_tools(openai_mock_stream):
|
||||
yield mock_chat_completion_create
|
||||
|
||||
|
||||
def mock_tool_function(x):
|
||||
return x
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tools():
|
||||
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
@ -78,7 +83,7 @@ def tools():
|
||||
name="weather",
|
||||
description="useful to determine the weather in a given location",
|
||||
parameters=tool_parameters,
|
||||
function=lambda x: x,
|
||||
function=mock_tool_function,
|
||||
)
|
||||
|
||||
return [tool]
|
||||
@ -269,6 +274,7 @@ class TestOpenAIChatGenerator:
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
|
||||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||
"tools": None,
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
@ -937,3 +943,45 @@ class TestOpenAIChatGenerator:
|
||||
assert tool_call.tool_name == "weather"
|
||||
assert tool_call.arguments == {"city": "Paris"}
|
||||
assert message.meta["finish_reason"] == "tool_calls"
|
||||
|
||||
def test_openai_chat_generator_with_toolset_initialization(self, tools, monkeypatch):
|
||||
"""Test that the OpenAIChatGenerator can be initialized with a Toolset."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
toolset = Toolset(tools)
|
||||
generator = OpenAIChatGenerator(tools=toolset)
|
||||
assert generator.tools == toolset
|
||||
|
||||
def test_from_dict_with_toolset(self, tools, monkeypatch):
|
||||
"""Test that the OpenAIChatGenerator can be deserialized from a dictionary with a Toolset."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
toolset = Toolset(tools)
|
||||
component = OpenAIChatGenerator(tools=toolset)
|
||||
data = component.to_dict()
|
||||
|
||||
deserialized_component = OpenAIChatGenerator.from_dict(data)
|
||||
|
||||
assert isinstance(deserialized_component.tools, Toolset)
|
||||
assert len(deserialized_component.tools) == len(tools)
|
||||
assert all(isinstance(tool, Tool) for tool in deserialized_component.tools)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_with_toolset(self, tools):
|
||||
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
|
||||
toolset = Toolset(tools)
|
||||
component = OpenAIChatGenerator(tools=toolset)
|
||||
results = component.run(chat_messages)
|
||||
assert len(results["replies"]) == 1
|
||||
message = results["replies"][0]
|
||||
|
||||
assert not message.texts
|
||||
assert not message.text
|
||||
assert message.tool_calls
|
||||
tool_call = message.tool_call
|
||||
assert isinstance(tool_call, ToolCall)
|
||||
assert tool_call.tool_name == "weather"
|
||||
assert tool_call.arguments == {"city": "Paris"}
|
||||
assert message.meta["finish_reason"] == "tool_calls"
|
||||
|
||||
587
test/tools/test_toolset.py
Normal file
587
test/tools/test_toolset.py
Normal file
@ -0,0 +1,587 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import os
|
||||
from haystack import Pipeline
|
||||
from haystack.core.serialization import generate_qualified_class_name
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.components.tools import ToolInvoker
|
||||
from haystack.dataclasses.chat_message import ToolCall
|
||||
from haystack.tools import Tool, Toolset
|
||||
from haystack.tools.errors import ToolInvocationError
|
||||
|
||||
|
||||
# Common functions for tests
|
||||
def add_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
|
||||
def multiply_numbers(a: int, b: int) -> int:
|
||||
"""Multiply two numbers."""
|
||||
return a * b
|
||||
|
||||
|
||||
def subtract_numbers(a: int, b: int) -> int:
|
||||
"""Subtract b from a."""
|
||||
return a - b
|
||||
|
||||
|
||||
class CustomToolset(Toolset):
|
||||
def __init__(self, tools, custom_attr):
|
||||
super().__init__(tools)
|
||||
self.custom_attr = custom_attr
|
||||
|
||||
def to_dict(self):
|
||||
data = super().to_dict()
|
||||
data["custom_attr"] = self.custom_attr
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
tools = [Tool.from_dict(tool_data) for tool_data in data["data"]["tools"]]
|
||||
custom_attr = data["custom_attr"]
|
||||
return cls(tools=tools, custom_attr=custom_attr)
|
||||
|
||||
|
||||
class CalculatorToolset(Toolset):
|
||||
"""A toolset for calculator operations."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__([])
|
||||
self._create_tools()
|
||||
|
||||
def _create_tools(self):
|
||||
add_tool = Tool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=add_numbers,
|
||||
)
|
||||
|
||||
multiply_tool = Tool(
|
||||
name="multiply",
|
||||
description="Multiply two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=multiply_numbers,
|
||||
)
|
||||
|
||||
self.add(add_tool)
|
||||
self.add(multiply_tool)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"type": generate_qualified_class_name(type(self)),
|
||||
"data": {}, # no data to serialize as we define the tools dynamically
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
return cls()
|
||||
|
||||
|
||||
def weather_function(location):
|
||||
weather_info = {
|
||||
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
|
||||
"Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
|
||||
"Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
|
||||
}
|
||||
return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
|
||||
|
||||
|
||||
weather_parameters = {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tool():
|
||||
return Tool(
|
||||
name="weather_tool",
|
||||
description="Provides weather information for a given location.",
|
||||
parameters=weather_parameters,
|
||||
function=weather_function,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def faulty_tool():
|
||||
def faulty_tool_func(location):
|
||||
raise Exception("This tool always fails.")
|
||||
|
||||
faulty_tool_parameters = {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
"required": ["location"],
|
||||
}
|
||||
|
||||
return Tool(
|
||||
name="faulty_tool",
|
||||
description="A tool that always fails when invoked.",
|
||||
parameters=faulty_tool_parameters,
|
||||
function=faulty_tool_func,
|
||||
)
|
||||
|
||||
|
||||
class TestToolset:
|
||||
def test_toolset_with_multiple_tools(self):
|
||||
"""Test that a Toolset with multiple tools works properly."""
|
||||
add_tool = Tool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=add_numbers,
|
||||
)
|
||||
|
||||
multiply_tool = Tool(
|
||||
name="multiply",
|
||||
description="Multiply two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=multiply_numbers,
|
||||
)
|
||||
|
||||
toolset = Toolset([add_tool, multiply_tool])
|
||||
|
||||
assert len(toolset) == 2
|
||||
assert toolset[0].name == "add"
|
||||
assert toolset[1].name == "multiply"
|
||||
|
||||
invoker = ToolInvoker(tools=toolset)
|
||||
|
||||
add_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3})
|
||||
add_message = ChatMessage.from_assistant(tool_calls=[add_call])
|
||||
|
||||
multiply_call = ToolCall(tool_name="multiply", arguments={"a": 4, "b": 5})
|
||||
multiply_message = ChatMessage.from_assistant(tool_calls=[multiply_call])
|
||||
|
||||
result = invoker.run(messages=[add_message, multiply_message])
|
||||
|
||||
assert len(result["tool_messages"]) == 2
|
||||
tool_results = [message.tool_call_result.result for message in result["tool_messages"]]
|
||||
assert "5" in tool_results
|
||||
assert "20" in tool_results
|
||||
|
||||
def test_toolset_adding(self):
|
||||
"""Test that tools can be added to a Toolset."""
|
||||
toolset = Toolset()
|
||||
assert len(toolset) == 0
|
||||
|
||||
add_tool = Tool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=add_numbers,
|
||||
)
|
||||
|
||||
toolset.add(add_tool)
|
||||
assert len(toolset) == 1
|
||||
assert toolset[0].name == "add"
|
||||
|
||||
invoker = ToolInvoker(tools=toolset)
|
||||
tool_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3})
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
result = invoker.run(messages=[message])
|
||||
|
||||
assert len(result["tool_messages"]) == 1
|
||||
assert result["tool_messages"][0].tool_call_result.result == "5"
|
||||
|
||||
def test_toolset_addition(self):
|
||||
"""Test that toolsets can be combined."""
|
||||
add_tool = Tool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=add_numbers,
|
||||
)
|
||||
|
||||
multiply_tool = Tool(
|
||||
name="multiply",
|
||||
description="Multiply two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=multiply_numbers,
|
||||
)
|
||||
|
||||
subtract_tool = Tool(
|
||||
name="subtract",
|
||||
description="Subtract two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=subtract_numbers,
|
||||
)
|
||||
|
||||
toolset1 = Toolset([add_tool])
|
||||
toolset2 = Toolset([multiply_tool])
|
||||
|
||||
combined_toolset = toolset1 + toolset2
|
||||
assert len(combined_toolset) == 2
|
||||
|
||||
combined_toolset = combined_toolset + subtract_tool
|
||||
assert len(combined_toolset) == 3
|
||||
|
||||
tool_names = [tool.name for tool in combined_toolset]
|
||||
assert "add" in tool_names
|
||||
assert "multiply" in tool_names
|
||||
assert "subtract" in tool_names
|
||||
|
||||
invoker = ToolInvoker(tools=combined_toolset)
|
||||
|
||||
add_call = ToolCall(tool_name="add", arguments={"a": 10, "b": 5})
|
||||
multiply_call = ToolCall(tool_name="multiply", arguments={"a": 10, "b": 5})
|
||||
subtract_call = ToolCall(tool_name="subtract", arguments={"a": 10, "b": 5})
|
||||
|
||||
message = ChatMessage.from_assistant(tool_calls=[add_call, multiply_call, subtract_call])
|
||||
|
||||
result = invoker.run(messages=[message])
|
||||
|
||||
assert len(result["tool_messages"]) == 3
|
||||
tool_results = [message.tool_call_result.result for message in result["tool_messages"]]
|
||||
assert "15" in tool_results
|
||||
assert "50" in tool_results
|
||||
assert "5" in tool_results
|
||||
|
||||
def test_toolset_contains(self):
|
||||
"""Test that the __contains__ method works correctly."""
|
||||
add_tool = Tool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=add_numbers,
|
||||
)
|
||||
|
||||
multiply_tool = Tool(
|
||||
name="multiply",
|
||||
description="Multiply two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=multiply_numbers,
|
||||
)
|
||||
|
||||
toolset = Toolset([add_tool])
|
||||
|
||||
# Test with a tool instance
|
||||
assert add_tool in toolset
|
||||
assert multiply_tool not in toolset
|
||||
|
||||
# Test with a tool name
|
||||
assert "add" in toolset
|
||||
assert "multiply" not in toolset
|
||||
assert "non_existent_tool" not in toolset
|
||||
|
||||
def test_toolset_add_various_types(self):
|
||||
"""Test that the __add__ method works with various object types."""
|
||||
add_tool = Tool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=add_numbers,
|
||||
)
|
||||
|
||||
multiply_tool = Tool(
|
||||
name="multiply",
|
||||
description="Multiply two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=multiply_numbers,
|
||||
)
|
||||
|
||||
subtract_tool = Tool(
|
||||
name="subtract",
|
||||
description="Subtract two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=subtract_numbers,
|
||||
)
|
||||
|
||||
# Test adding a single tool
|
||||
toolset1 = Toolset([add_tool])
|
||||
result1 = toolset1 + multiply_tool
|
||||
assert len(result1) == 2
|
||||
assert add_tool in result1
|
||||
assert multiply_tool in result1
|
||||
|
||||
# Test adding another toolset
|
||||
toolset2 = Toolset([subtract_tool])
|
||||
result2 = toolset1 + toolset2
|
||||
assert len(result2) == 2
|
||||
assert add_tool in result2
|
||||
assert subtract_tool in result2
|
||||
|
||||
# Test adding a list of tools
|
||||
result3 = toolset1 + [multiply_tool, subtract_tool]
|
||||
assert len(result3) == 3
|
||||
assert add_tool in result3
|
||||
assert multiply_tool in result3
|
||||
assert subtract_tool in result3
|
||||
|
||||
# Test adding types that aren't supported
|
||||
with pytest.raises(TypeError):
|
||||
toolset1 + "not_a_tool"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
toolset1 + 123
|
||||
|
||||
def test_toolset_serialization(self):
|
||||
"""Test that a Toolset can be serialized and deserialized."""
|
||||
add_tool = Tool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=add_numbers,
|
||||
)
|
||||
|
||||
toolset = Toolset([add_tool])
|
||||
|
||||
serialized = toolset.to_dict()
|
||||
|
||||
deserialized = Toolset.from_dict(serialized)
|
||||
|
||||
assert len(deserialized) == 1
|
||||
assert deserialized[0].name == "add"
|
||||
assert deserialized[0].description == "Add two numbers"
|
||||
|
||||
invoker = ToolInvoker(tools=deserialized)
|
||||
tool_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3})
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
result = invoker.run(messages=[message])
|
||||
|
||||
assert len(result["tool_messages"]) == 1
|
||||
assert result["tool_messages"][0].tool_call_result.result == "5"
|
||||
|
||||
def test_custom_toolset_serialization(self):
|
||||
"""Test serialization and deserialization of a custom Toolset subclass."""
|
||||
add_tool = Tool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=add_numbers,
|
||||
)
|
||||
|
||||
custom_attr_value = "custom_value"
|
||||
custom_toolset = CustomToolset(tools=[add_tool], custom_attr=custom_attr_value)
|
||||
|
||||
serialized = custom_toolset.to_dict()
|
||||
assert serialized["type"].endswith("CustomToolset")
|
||||
assert serialized["custom_attr"] == custom_attr_value
|
||||
assert len(serialized["data"]["tools"]) == 1
|
||||
assert serialized["data"]["tools"][0]["data"]["name"] == "add"
|
||||
|
||||
deserialized = CustomToolset.from_dict(serialized)
|
||||
assert isinstance(deserialized, CustomToolset)
|
||||
assert deserialized.custom_attr == custom_attr_value
|
||||
assert len(deserialized) == 1
|
||||
assert deserialized[0].name == "add"
|
||||
|
||||
invoker = ToolInvoker(tools=deserialized)
|
||||
tool_call = ToolCall(tool_name="add", arguments={"a": 2, "b": 3})
|
||||
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
result = invoker.run(messages=[message])
|
||||
|
||||
assert len(result["tool_messages"]) == 1
|
||||
assert result["tool_messages"][0].tool_call_result.result == "5"
|
||||
|
||||
def test_toolset_duplicate_tool_names(self):
|
||||
"""Test that a Toolset raises an error for duplicate tool names."""
|
||||
add_tool1 = Tool(
|
||||
name="add",
|
||||
description="Add two numbers (first)",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=add_numbers,
|
||||
)
|
||||
|
||||
add_tool2 = Tool(
|
||||
name="add",
|
||||
description="Add two numbers (second)",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=add_numbers,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Duplicate tool names found"):
|
||||
Toolset([add_tool1, add_tool2])
|
||||
|
||||
toolset = Toolset([add_tool1])
|
||||
|
||||
with pytest.raises(ValueError, match="Duplicate tool names found"):
|
||||
toolset.add(add_tool2)
|
||||
|
||||
toolset2 = Toolset([add_tool2])
|
||||
with pytest.raises(ValueError, match="Duplicate tool names found"):
|
||||
combined = toolset + toolset2
|
||||
|
||||
|
||||
class TestToolsetIntegration:
|
||||
"""Integration tests for Toolset in complete pipelines."""
|
||||
|
||||
def test_custom_toolset_serde_in_pipeline(self):
|
||||
"""Test serialization and deserialization of a custom toolset within a pipeline."""
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("tool_invoker", ToolInvoker(tools=CalculatorToolset()))
|
||||
|
||||
pipeline_dict = pipeline.to_dict()
|
||||
|
||||
tool_invoker_dict = pipeline_dict["components"]["tool_invoker"]
|
||||
assert tool_invoker_dict["type"] == "haystack.components.tools.tool_invoker.ToolInvoker"
|
||||
assert len(tool_invoker_dict["init_parameters"]["tools"]["data"]) == 0
|
||||
|
||||
new_pipeline = Pipeline.from_dict(pipeline_dict)
|
||||
assert new_pipeline == pipeline
|
||||
|
||||
def test_regular_toolset_serde_in_pipeline(self):
|
||||
"""Test serialization and deserialization of regular Toolsets within a pipeline."""
|
||||
|
||||
add_tool = Tool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=add_numbers,
|
||||
)
|
||||
|
||||
multiply_tool = Tool(
|
||||
name="multiply",
|
||||
description="Multiply two numbers",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
function=multiply_numbers,
|
||||
)
|
||||
|
||||
toolset = Toolset([add_tool, multiply_tool])
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("tool_invoker", ToolInvoker(tools=toolset))
|
||||
|
||||
pipeline_dict = pipeline.to_dict()
|
||||
|
||||
tool_invoker_dict = pipeline_dict["components"]["tool_invoker"]
|
||||
assert tool_invoker_dict["type"] == "haystack.components.tools.tool_invoker.ToolInvoker"
|
||||
|
||||
# Verify the serialized toolset
|
||||
tools_dict = tool_invoker_dict["init_parameters"]["tools"]
|
||||
assert tools_dict["type"] == "haystack.tools.toolset.Toolset"
|
||||
assert len(tools_dict["data"]["tools"]) == 2
|
||||
tool_names = [tool["data"]["name"] for tool in tools_dict["data"]["tools"]]
|
||||
assert "add" in tool_names
|
||||
assert "multiply" in tool_names
|
||||
|
||||
# Deserialize and verify
|
||||
new_pipeline = Pipeline.from_dict(pipeline_dict)
|
||||
assert new_pipeline == pipeline
|
||||
|
||||
|
||||
class TestToolsetWithToolInvoker:
|
||||
def test_init_with_toolset(self, weather_tool):
|
||||
"""Test initializing ToolInvoker with a Toolset."""
|
||||
toolset = Toolset(tools=[weather_tool])
|
||||
invoker = ToolInvoker(tools=toolset)
|
||||
assert invoker.tools == toolset
|
||||
assert invoker._tools_with_names == {tool.name: tool for tool in toolset}
|
||||
|
||||
def test_serde_with_toolset(self, weather_tool):
|
||||
"""Test serialization and deserialization of ToolInvoker with a Toolset."""
|
||||
toolset = Toolset(tools=[weather_tool])
|
||||
invoker = ToolInvoker(tools=toolset)
|
||||
data = invoker.to_dict()
|
||||
deserialized_invoker = ToolInvoker.from_dict(data)
|
||||
assert deserialized_invoker.tools == invoker.tools
|
||||
assert deserialized_invoker._tools_with_names == invoker._tools_with_names
|
||||
assert deserialized_invoker.raise_on_failure == invoker.raise_on_failure
|
||||
assert deserialized_invoker.convert_result_to_json_string == invoker.convert_result_to_json_string
|
||||
|
||||
def test_tool_invocation_error_with_toolset(self, faulty_tool):
|
||||
"""Test tool invocation errors with a Toolset."""
|
||||
toolset = Toolset(tools=[faulty_tool])
|
||||
invoker = ToolInvoker(tools=toolset)
|
||||
tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"})
|
||||
tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
||||
with pytest.raises(ToolInvocationError):
|
||||
invoker.run(messages=[tool_call_message])
|
||||
|
||||
def test_toolinvoker_deserialization_with_custom_toolset(self, weather_tool):
|
||||
"""Test deserialization of ToolInvoker with a custom Toolset."""
|
||||
custom_toolset = CustomToolset(tools=[weather_tool], custom_attr="custom_value")
|
||||
invoker = ToolInvoker(tools=custom_toolset)
|
||||
data = invoker.to_dict()
|
||||
|
||||
assert isinstance(data, dict)
|
||||
assert "type" in data and "init_parameters" in data
|
||||
tools_data = data["init_parameters"]["tools"]
|
||||
assert isinstance(tools_data, dict)
|
||||
assert len(tools_data["data"]["tools"]) == 1
|
||||
assert tools_data["data"]["tools"][0]["type"] == "haystack.tools.tool.Tool"
|
||||
assert tools_data.get("custom_attr") == "custom_value"
|
||||
|
||||
deserialized_invoker = ToolInvoker.from_dict(data)
|
||||
assert deserialized_invoker.tools == invoker.tools
|
||||
assert deserialized_invoker._tools_with_names == invoker._tools_with_names
|
||||
assert deserialized_invoker.raise_on_failure == invoker.raise_on_failure
|
||||
assert deserialized_invoker.convert_result_to_json_string == invoker.convert_result_to_json_string
|
||||
Loading…
x
Reference in New Issue
Block a user