haystack/test/tools/test_tool.py
Stefano Fiorucci f6fceb1b56
refactor: reorganize Tool serde utility functions (#9185)
* refactor: reorganize Tool serde utility functions

* license header

* rm unused import

* HF local update
2025-04-08 08:09:54 +02:00

138 lines
5.2 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import re
import pytest
from haystack.tools import Tool, _check_duplicate_tool_names
from haystack.tools.errors import ToolInvocationError
def get_weather_report(city: str) -> str:
return f"Weather report for {city}: 20°C, sunny"
parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
class TestTool:
def test_init(self):
tool = Tool(
name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
)
assert tool.name == "weather"
assert tool.description == "Get weather report"
assert tool.parameters == parameters
assert tool.function == get_weather_report
assert tool.inputs_from_state is None
assert tool.outputs_to_state is None
def test_init_invalid_parameters(self):
params = {"type": "invalid", "properties": {"city": {"type": "string"}}}
with pytest.raises(ValueError):
Tool(name="irrelevant", description="irrelevant", parameters=params, function=get_weather_report)
@pytest.mark.parametrize(
"outputs_to_state",
[
pytest.param({"documents": ["some_value"]}, id="config-not-a-dict"),
pytest.param({"documents": {"source": get_weather_report}}, id="source-not-a-string"),
pytest.param({"documents": {"handler": "some_string", "source": "docs"}}, id="handler-not-callable"),
],
)
def test_init_invalid_output_structure(self, outputs_to_state):
with pytest.raises(ValueError):
Tool(
name="irrelevant",
description="irrelevant",
parameters={"type": "object", "properties": {"city": {"type": "string"}}},
function=get_weather_report,
outputs_to_state=outputs_to_state,
)
def test_tool_spec(self):
tool = Tool(
name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
)
assert tool.tool_spec == {"name": "weather", "description": "Get weather report", "parameters": parameters}
def test_invoke(self):
tool = Tool(
name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
)
assert tool.invoke(city="Berlin") == "Weather report for Berlin: 20°C, sunny"
def test_invoke_fail(self):
tool = Tool(
name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
)
with pytest.raises(
ToolInvocationError,
match=re.escape(
"Failed to invoke Tool `weather` with parameters {}. Error: get_weather_report() missing 1 required positional argument: 'city'"
),
):
tool.invoke()
def test_to_dict(self):
tool = Tool(
name="weather",
description="Get weather report",
parameters=parameters,
function=get_weather_report,
outputs_to_state={"documents": {"handler": get_weather_report, "source": "docs"}},
)
assert tool.to_dict() == {
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather",
"description": "Get weather report",
"parameters": parameters,
"function": "test_tool.get_weather_report",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
},
}
def test_from_dict(self):
tool_dict = {
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather",
"description": "Get weather report",
"parameters": parameters,
"function": "test_tool.get_weather_report",
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
},
}
tool = Tool.from_dict(tool_dict)
assert tool.name == "weather"
assert tool.description == "Get weather report"
assert tool.parameters == parameters
assert tool.function == get_weather_report
assert tool.outputs_to_state["documents"]["source"] == "docs"
assert tool.outputs_to_state["documents"]["handler"] == get_weather_report
def test_check_duplicate_tool_names():
tools = [
Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report),
Tool(name="weather", description="A different description", parameters=parameters, function=get_weather_report),
]
with pytest.raises(ValueError):
_check_duplicate_tool_names(tools)
tools = [
Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report),
Tool(name="weather2", description="Get weather report", parameters=parameters, function=get_weather_report),
]
_check_duplicate_tool_names(tools)