mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 20:46:31 +00:00
enhancement: support multiple Tool string outputs
This commit is contained in:
parent
8091034fb5
commit
6e45e9cc3e
@ -313,24 +313,35 @@ class ToolInvoker:
|
||||
and `raise_on_failure` is True.
|
||||
"""
|
||||
outputs_config = tool_to_invoke.outputs_to_string or {}
|
||||
source_key = outputs_config.get("source")
|
||||
|
||||
# If no handler is provided, we use the default handler
|
||||
output_to_string_handler = outputs_config.get("handler", self._default_output_to_string_handler)
|
||||
|
||||
# If a source key is provided, we extract the result from the source key
|
||||
result_to_convert = result.get(source_key) if source_key is not None else result
|
||||
|
||||
try:
|
||||
tool_result_str = output_to_string_handler(result_to_convert)
|
||||
chat_message = ChatMessage.from_tool(tool_result=tool_result_str, origin=tool_call)
|
||||
# Root level single output configuration
|
||||
if not outputs_config or "source" in outputs_config or "handler" in outputs_config:
|
||||
source_key = outputs_config.get("source", None)
|
||||
# If a source key is provided, we extract the result from the source key
|
||||
value = result.get(source_key) if source_key is not None else result
|
||||
# If no handler is provided, we use the default handler
|
||||
output_to_string_handler = outputs_config.get("handler", self._default_output_to_string_handler)
|
||||
tool_result_str = output_to_string_handler(value)
|
||||
return ChatMessage.from_tool(tool_result=tool_result_str, origin=tool_call)
|
||||
|
||||
# Multiple outputs configuration
|
||||
tool_result = {}
|
||||
for output_key, config in outputs_config.items():
|
||||
# If no source key is provided, we use the output key itself
|
||||
source_key = config.get("source", output_key)
|
||||
value = result[source_key]
|
||||
# If no handler is provided, we use the default handler
|
||||
output_to_string_handler = config.get("handler", self._default_output_to_string_handler)
|
||||
key_result_str = output_to_string_handler(value)
|
||||
tool_result[output_key] = key_result_str
|
||||
tool_result_str = self._default_output_to_string_handler(tool_result)
|
||||
return ChatMessage.from_tool(tool_result=tool_result_str, origin=tool_call)
|
||||
except Exception as e:
|
||||
error = StringConversionError(tool_call.tool_name, output_to_string_handler.__name__, e)
|
||||
if self.raise_on_failure:
|
||||
raise error from e
|
||||
logger.error("{error_exception}", error_exception=error)
|
||||
chat_message = ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
|
||||
return chat_message
|
||||
return ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_func_params(tool: Tool) -> set:
|
||||
|
||||
@ -105,6 +105,22 @@ class Tool:
|
||||
raise ValueError("outputs_to_string source must be a string.")
|
||||
if "handler" in self.outputs_to_string and not callable(self.outputs_to_string["handler"]):
|
||||
raise ValueError("outputs_to_string handler must be callable")
|
||||
if "source" in self.outputs_to_string or "handler" in self.outputs_to_string:
|
||||
for key in self.outputs_to_string:
|
||||
if key not in {"source", "handler"}:
|
||||
raise ValueError(
|
||||
"Invalid outputs_to_string config. "
|
||||
"When using 'source' or 'handler' at the root level, no other keys are allowed. "
|
||||
"Use individual output configs instead."
|
||||
)
|
||||
else:
|
||||
for key, config in self.outputs_to_string.items():
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError(f"outputs_to_string configuration for key '{key}' must be a dictionary")
|
||||
if "source" in config and not isinstance(config["source"], str):
|
||||
raise ValueError(f"outputs_to_string source for key '{key}' must be a string.")
|
||||
if "handler" in config and not callable(config["handler"]):
|
||||
raise ValueError(f"outputs_to_string handler for key '{key}' must be callable")
|
||||
|
||||
@property
|
||||
def tool_spec(self) -> dict[str, Any]:
|
||||
|
||||
@ -0,0 +1,12 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Support Multiple Tool String Outputs
|
||||
|
||||
Added support for tools to define multiple string outputs using the `outputs_to_string` configuration.
|
||||
This allows users to specify how different parts of a tool's output should be converted to strings,
|
||||
enhancing flexibility in handling tool results.
|
||||
|
||||
- Updated `ToolInvoker` to handle multiple output configurations.
|
||||
- Updated `Tool` to validate and store multiple output configurations.
|
||||
- Added tests to verify the functionality of multiple string outputs.
|
||||
@ -784,6 +784,25 @@ class TestToolInvokerErrorHandling:
|
||||
assert tool_message.tool_call_results[0].error
|
||||
assert "Failed to invoke" in tool_message.tool_call_results[0].result
|
||||
|
||||
def test_outputs_to_string_with_multiple_outputs(self):
|
||||
weather_tool = Tool(
|
||||
name="weather_tool",
|
||||
description="Provides weather information for a given location.",
|
||||
parameters=weather_parameters,
|
||||
function=weather_function,
|
||||
# Pass custom handler that will throw an error when trying to convert tool_result
|
||||
outputs_to_string={"weather": {"source": "weather"}, "temp": {"source": "temperature"}},
|
||||
)
|
||||
invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True)
|
||||
|
||||
tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
|
||||
|
||||
tool_result = {"weather": "sunny", "temperature": 25, "unit": "celsius"}
|
||||
chat_message = invoker._prepare_tool_result_message(
|
||||
result=tool_result, tool_call=tool_call, tool_to_invoke=weather_tool
|
||||
)
|
||||
assert chat_message.tool_call_results[0].result == "{'weather': 'sunny', 'temp': '25'}"
|
||||
|
||||
def test_string_conversion_error(self):
|
||||
weather_tool = Tool(
|
||||
name="weather_tool",
|
||||
|
||||
@ -65,6 +65,26 @@ class TestTool:
|
||||
outputs_to_state=outputs_to_state,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"outputs_to_string",
|
||||
[
|
||||
pytest.param({"source": get_weather_report}, id="source-not-a-string"),
|
||||
pytest.param({"handler": "some_string"}, id="handler-not-callable"),
|
||||
pytest.param({"documents": ["some_value"]}, id="multi-value-config-not-a-dict"),
|
||||
pytest.param({"documents": {"source": get_weather_report}}, id="multi-value-source-not-a-string"),
|
||||
pytest.param({"documents": {"handler": "some_string"}}, id="multi-value-handler-not-callable"),
|
||||
],
|
||||
)
|
||||
def test_init_invalid_output_to_string_structure(self, outputs_to_string):
|
||||
with pytest.raises(ValueError):
|
||||
Tool(
|
||||
name="irrelevant",
|
||||
description="irrelevant",
|
||||
parameters={"type": "object", "properties": {"city": {"type": "string"}}},
|
||||
function=get_weather_report,
|
||||
outputs_to_string=outputs_to_string,
|
||||
)
|
||||
|
||||
def test_tool_spec(self):
|
||||
tool = Tool(
|
||||
name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user