haystack/haystack/components/tools/tool_invoker.py
Stefano Fiorucci 7dcbf25bd7
feat: add Tool Invoker component (#8664)
* port toolinvoker

* release note
2024-12-20 14:02:42 +01:00

247 lines
9.3 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import json
import warnings
from typing import Any, Dict, List
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses.chat_message import ChatMessage, ToolCall
from haystack.dataclasses.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace
logger = logging.getLogger(__name__)
_TOOL_INVOCATION_FAILURE = "Tool invocation failed with error: {error}."
_TOOL_NOT_FOUND = "Tool {tool_name} not found in the list of tools. Available tools are: {available_tools}."
_TOOL_RESULT_CONVERSION_FAILURE = (
"Failed to convert tool result to string using '{conversion_function}'. Error: {error}."
)
class ToolNotFoundException(Exception):
"""
Exception raised when a tool is not found in the list of available tools.
"""
pass
class StringConversionError(Exception):
"""
Exception raised when the conversion of a tool result to a string fails.
"""
pass
@component
class ToolInvoker:
"""
Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.
At initialization, the ToolInvoker component is provided with a list of available tools.
At runtime, the component processes a list of ChatMessage object containing tool calls
and invokes the corresponding tools.
The results of the tool invocations are returned as a list of ChatMessage objects with tool role.
Usage example:
```python
from haystack.dataclasses import ChatMessage, ToolCall, Tool
from haystack.components.tools import ToolInvoker
# Tool definition
def dummy_weather_function(city: str):
return f"The weather in {city} is 20 degrees."
parameters = {"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"]}
tool = Tool(name="weather_tool",
description="A tool to get the weather",
function=dummy_weather_function,
parameters=parameters)
# Usually, the ChatMessage with tool_calls is generated by a Language Model
# Here, we create it manually for demonstration purposes
tool_call = ToolCall(
tool_name="weather_tool",
arguments={"city": "Berlin"}
)
message = ChatMessage.from_assistant(tool_calls=[tool_call])
# ToolInvoker initialization and run
invoker = ToolInvoker(tools=[tool])
result = invoker.run(messages=[message])
print(result)
```
```
>> {
>> 'tool_messages': [
>> ChatMessage(
>> _role=<ChatRole.TOOL: 'tool'>,
>> _content=[
>> ToolCallResult(
>> result='"The weather in Berlin is 20 degrees."',
>> origin=ToolCall(
>> tool_name='weather_tool',
>> arguments={'city': 'Berlin'},
>> id=None
>> )
>> )
>> ],
>> _meta={}
>> )
>> ]
>> }
```
"""
def __init__(self, tools: List[Tool], raise_on_failure: bool = True, convert_result_to_json_string: bool = False):
"""
Initialize the ToolInvoker component.
:param tools:
A list of tools that can be invoked.
:param raise_on_failure:
If True, the component will raise an exception in case of errors
(tool not found, tool invocation errors, tool result conversion errors).
If False, the component will return a ChatMessage object with `error=True`
and a description of the error in `result`.
:param convert_result_to_json_string:
If True, the tool invocation result will be converted to a string using `json.dumps`.
If False, the tool invocation result will be converted to a string using `str`.
:raises ValueError:
If no tools are provided or if duplicate tool names are found.
"""
msg = "The `ToolInvoker` component is experimental and its API may change in the future."
warnings.warn(msg)
if not tools:
raise ValueError("ToolInvoker requires at least one tool to be provided.")
_check_duplicate_tool_names(tools)
self.tools = tools
self._tools_with_names = dict(zip([tool.name for tool in tools], tools))
self.raise_on_failure = raise_on_failure
self.convert_result_to_json_string = convert_result_to_json_string
def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall) -> ChatMessage:
"""
Prepares a ChatMessage with the result of a tool invocation.
:param result:
The tool result.
:returns:
A ChatMessage object containing the tool result as a string.
:raises
StringConversionError: If the conversion of the tool result to a string fails
and `raise_on_failure` is True.
"""
error = False
if self.convert_result_to_json_string:
try:
# We disable ensure_ascii so special chars like emojis are not converted
tool_result_str = json.dumps(result, ensure_ascii=False)
except Exception as e:
if self.raise_on_failure:
raise StringConversionError("Failed to convert tool result to string using `json.dumps`") from e
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="json.dumps")
error = True
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
try:
tool_result_str = str(result)
except Exception as e:
if self.raise_on_failure:
raise StringConversionError("Failed to convert tool result to string using `str`") from e
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="str")
error = True
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
@component.output_types(tool_messages=List[ChatMessage])
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
"""
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
:param messages:
A list of ChatMessage objects.
:returns:
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
Each ChatMessage objects wraps the result of a tool invocation.
:raises ToolNotFoundException:
If the tool is not found in the list of available tools and `raise_on_failure` is True.
:raises ToolInvocationError:
If the tool invocation fails and `raise_on_failure` is True.
:raises StringConversionError:
If the conversion of the tool result to a string fails and `raise_on_failure` is True.
"""
tool_messages = []
for message in messages:
tool_calls = message.tool_calls
if not tool_calls:
continue
for tool_call in tool_calls:
tool_name = tool_call.tool_name
tool_arguments = tool_call.arguments
if not tool_name in self._tools_with_names:
msg = _TOOL_NOT_FOUND.format(tool_name=tool_name, available_tools=self._tools_with_names.keys())
if self.raise_on_failure:
raise ToolNotFoundException(msg)
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
continue
tool_to_invoke = self._tools_with_names[tool_name]
try:
tool_result = tool_to_invoke.invoke(**tool_arguments)
except ToolInvocationError as e:
if self.raise_on_failure:
raise e
msg = _TOOL_INVOCATION_FAILURE.format(error=e)
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
continue
tool_message = self._prepare_tool_result_message(tool_result, tool_call)
tool_messages.append(tool_message)
return {"tool_messages": tool_messages}
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
serialized_tools = [tool.to_dict() for tool in self.tools]
return default_to_dict(
self,
tools=serialized_tools,
raise_on_failure=self.raise_on_failure,
convert_result_to_json_string=self.convert_result_to_json_string,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker":
"""
Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
deserialize_tools_inplace(data["init_parameters"], key="tools")
return default_from_dict(cls, data)