mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-01 01:27:28 +00:00
247 lines
9.3 KiB
Python
247 lines
9.3 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import json
|
|
import warnings
|
|
from typing import Any, Dict, List
|
|
|
|
from haystack import component, default_from_dict, default_to_dict, logging
|
|
from haystack.dataclasses.chat_message import ChatMessage, ToolCall
|
|
from haystack.dataclasses.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_TOOL_INVOCATION_FAILURE = "Tool invocation failed with error: {error}."
|
|
_TOOL_NOT_FOUND = "Tool {tool_name} not found in the list of tools. Available tools are: {available_tools}."
|
|
_TOOL_RESULT_CONVERSION_FAILURE = (
|
|
"Failed to convert tool result to string using '{conversion_function}'. Error: {error}."
|
|
)
|
|
|
|
|
|
class ToolNotFoundException(Exception):
|
|
"""
|
|
Exception raised when a tool is not found in the list of available tools.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class StringConversionError(Exception):
|
|
"""
|
|
Exception raised when the conversion of a tool result to a string fails.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
@component
|
|
class ToolInvoker:
|
|
"""
|
|
Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.
|
|
|
|
At initialization, the ToolInvoker component is provided with a list of available tools.
|
|
At runtime, the component processes a list of ChatMessage object containing tool calls
|
|
and invokes the corresponding tools.
|
|
The results of the tool invocations are returned as a list of ChatMessage objects with tool role.
|
|
|
|
Usage example:
|
|
```python
|
|
from haystack.dataclasses import ChatMessage, ToolCall, Tool
|
|
from haystack.components.tools import ToolInvoker
|
|
|
|
# Tool definition
|
|
def dummy_weather_function(city: str):
|
|
return f"The weather in {city} is 20 degrees."
|
|
|
|
parameters = {"type": "object",
|
|
"properties": {"city": {"type": "string"}},
|
|
"required": ["city"]}
|
|
|
|
tool = Tool(name="weather_tool",
|
|
description="A tool to get the weather",
|
|
function=dummy_weather_function,
|
|
parameters=parameters)
|
|
|
|
# Usually, the ChatMessage with tool_calls is generated by a Language Model
|
|
# Here, we create it manually for demonstration purposes
|
|
tool_call = ToolCall(
|
|
tool_name="weather_tool",
|
|
arguments={"city": "Berlin"}
|
|
)
|
|
message = ChatMessage.from_assistant(tool_calls=[tool_call])
|
|
|
|
# ToolInvoker initialization and run
|
|
invoker = ToolInvoker(tools=[tool])
|
|
result = invoker.run(messages=[message])
|
|
|
|
print(result)
|
|
```
|
|
|
|
```
|
|
>> {
|
|
>> 'tool_messages': [
|
|
>> ChatMessage(
|
|
>> _role=<ChatRole.TOOL: 'tool'>,
|
|
>> _content=[
|
|
>> ToolCallResult(
|
|
>> result='"The weather in Berlin is 20 degrees."',
|
|
>> origin=ToolCall(
|
|
>> tool_name='weather_tool',
|
|
>> arguments={'city': 'Berlin'},
|
|
>> id=None
|
|
>> )
|
|
>> )
|
|
>> ],
|
|
>> _meta={}
|
|
>> )
|
|
>> ]
|
|
>> }
|
|
```
|
|
"""
|
|
|
|
def __init__(self, tools: List[Tool], raise_on_failure: bool = True, convert_result_to_json_string: bool = False):
|
|
"""
|
|
Initialize the ToolInvoker component.
|
|
|
|
:param tools:
|
|
A list of tools that can be invoked.
|
|
:param raise_on_failure:
|
|
If True, the component will raise an exception in case of errors
|
|
(tool not found, tool invocation errors, tool result conversion errors).
|
|
If False, the component will return a ChatMessage object with `error=True`
|
|
and a description of the error in `result`.
|
|
:param convert_result_to_json_string:
|
|
If True, the tool invocation result will be converted to a string using `json.dumps`.
|
|
If False, the tool invocation result will be converted to a string using `str`.
|
|
|
|
:raises ValueError:
|
|
If no tools are provided or if duplicate tool names are found.
|
|
"""
|
|
|
|
msg = "The `ToolInvoker` component is experimental and its API may change in the future."
|
|
warnings.warn(msg)
|
|
|
|
if not tools:
|
|
raise ValueError("ToolInvoker requires at least one tool to be provided.")
|
|
_check_duplicate_tool_names(tools)
|
|
|
|
self.tools = tools
|
|
self._tools_with_names = dict(zip([tool.name for tool in tools], tools))
|
|
self.raise_on_failure = raise_on_failure
|
|
self.convert_result_to_json_string = convert_result_to_json_string
|
|
|
|
def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall) -> ChatMessage:
|
|
"""
|
|
Prepares a ChatMessage with the result of a tool invocation.
|
|
|
|
:param result:
|
|
The tool result.
|
|
:returns:
|
|
A ChatMessage object containing the tool result as a string.
|
|
|
|
:raises
|
|
StringConversionError: If the conversion of the tool result to a string fails
|
|
and `raise_on_failure` is True.
|
|
"""
|
|
error = False
|
|
|
|
if self.convert_result_to_json_string:
|
|
try:
|
|
# We disable ensure_ascii so special chars like emojis are not converted
|
|
tool_result_str = json.dumps(result, ensure_ascii=False)
|
|
except Exception as e:
|
|
if self.raise_on_failure:
|
|
raise StringConversionError("Failed to convert tool result to string using `json.dumps`") from e
|
|
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="json.dumps")
|
|
error = True
|
|
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
|
|
|
|
try:
|
|
tool_result_str = str(result)
|
|
except Exception as e:
|
|
if self.raise_on_failure:
|
|
raise StringConversionError("Failed to convert tool result to string using `str`") from e
|
|
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="str")
|
|
error = True
|
|
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
|
|
|
|
@component.output_types(tool_messages=List[ChatMessage])
|
|
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
|
|
"""
|
|
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
|
|
|
|
:param messages:
|
|
A list of ChatMessage objects.
|
|
:returns:
|
|
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
|
|
Each ChatMessage objects wraps the result of a tool invocation.
|
|
|
|
:raises ToolNotFoundException:
|
|
If the tool is not found in the list of available tools and `raise_on_failure` is True.
|
|
:raises ToolInvocationError:
|
|
If the tool invocation fails and `raise_on_failure` is True.
|
|
:raises StringConversionError:
|
|
If the conversion of the tool result to a string fails and `raise_on_failure` is True.
|
|
"""
|
|
tool_messages = []
|
|
|
|
for message in messages:
|
|
tool_calls = message.tool_calls
|
|
if not tool_calls:
|
|
continue
|
|
|
|
for tool_call in tool_calls:
|
|
tool_name = tool_call.tool_name
|
|
tool_arguments = tool_call.arguments
|
|
|
|
if not tool_name in self._tools_with_names:
|
|
msg = _TOOL_NOT_FOUND.format(tool_name=tool_name, available_tools=self._tools_with_names.keys())
|
|
if self.raise_on_failure:
|
|
raise ToolNotFoundException(msg)
|
|
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
|
|
continue
|
|
|
|
tool_to_invoke = self._tools_with_names[tool_name]
|
|
try:
|
|
tool_result = tool_to_invoke.invoke(**tool_arguments)
|
|
except ToolInvocationError as e:
|
|
if self.raise_on_failure:
|
|
raise e
|
|
msg = _TOOL_INVOCATION_FAILURE.format(error=e)
|
|
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
|
|
continue
|
|
|
|
tool_message = self._prepare_tool_result_message(tool_result, tool_call)
|
|
tool_messages.append(tool_message)
|
|
|
|
return {"tool_messages": tool_messages}
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""
|
|
Serializes the component to a dictionary.
|
|
|
|
:returns:
|
|
Dictionary with serialized data.
|
|
"""
|
|
serialized_tools = [tool.to_dict() for tool in self.tools]
|
|
return default_to_dict(
|
|
self,
|
|
tools=serialized_tools,
|
|
raise_on_failure=self.raise_on_failure,
|
|
convert_result_to_json_string=self.convert_result_to_json_string,
|
|
)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker":
|
|
"""
|
|
Deserializes the component from a dictionary.
|
|
|
|
:param data:
|
|
The dictionary to deserialize from.
|
|
:returns:
|
|
The deserialized component.
|
|
"""
|
|
deserialize_tools_inplace(data["init_parameters"], key="tools")
|
|
return default_from_dict(cls, data)
|