diff --git a/haystack/tools/component_tool.py b/haystack/tools/component_tool.py index 5198ff0da..01a0851f1 100644 --- a/haystack/tools/component_tool.py +++ b/haystack/tools/component_tool.py @@ -206,14 +206,14 @@ class ComponentTool(Tool): """ serialized_component = component_to_dict(obj=self._component, name=self.name) - serialized = { + serialized: Dict[str, Any] = { "component": serialized_component, "name": self.name, "description": self.description, "parameters": self._unresolved_parameters, - "outputs_to_string": self.outputs_to_string, "inputs_from_state": self.inputs_from_state, - "outputs_to_state": self.outputs_to_state, + # This is soft-copied as to not modify the attributes in place + "outputs_to_state": self.outputs_to_state.copy() if self.outputs_to_state else None, } if self.outputs_to_state is not None: @@ -226,7 +226,11 @@ class ComponentTool(Tool): serialized["outputs_to_state"] = serialized_outputs if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None: - serialized["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"]) + # This is soft-copied as to not modify the attributes in place + serialized["outputs_to_string"] = self.outputs_to_string.copy() + serialized["outputs_to_string"]["handler"] = serialize_callable(self.outputs_to_string["handler"]) + else: + serialized["outputs_to_string"] = None return {"type": generate_qualified_class_name(type(self)), "data": serialized} diff --git a/haystack/tools/tool.py b/haystack/tools/tool.py index c1d767044..39e3cc143 100644 --- a/haystack/tools/tool.py +++ b/haystack/tools/tool.py @@ -122,7 +122,7 @@ class Tool: data["outputs_to_state"] = serialized_outputs if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None: - data["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"]) + data["outputs_to_string"]["handler"] = serialize_callable(self.outputs_to_string["handler"]) return {"type": generate_qualified_class_name(type(self)), "data": data} diff --git a/releasenotes/notes/fix-serialization-tool-and-comp-tool-017caac6bb56e744.yaml b/releasenotes/notes/fix-serialization-tool-and-comp-tool-017caac6bb56e744.yaml new file mode 100644 index 000000000..55f86ec31 --- /dev/null +++ b/releasenotes/notes/fix-serialization-tool-and-comp-tool-017caac6bb56e744.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fix the serialization of ComponentTool and Tool when specifying outputs_to_string. Previously an error occurred on deserialization right after serializing if outputs_to_string is not None. diff --git a/test/tools/test_component_tool.py b/test/tools/test_component_tool.py index 470891349..2db721c44 100644 --- a/test/tools/test_component_tool.py +++ b/test/tools/test_component_tool.py @@ -60,6 +60,10 @@ class SimpleComponent: return {"reply": f"Hello, {text}!"} +def reply_formatter(input_text: str) -> str: + return f"Formatted reply: {input_text}" + + @dataclass class User: """A simple user dataclass.""" @@ -593,24 +597,33 @@ class TestToolComponentInPipelineWithOpenAI: component=SimpleComponent(), name="simple_tool", description="A simple tool", + outputs_to_string={"source": "reply", "handler": reply_formatter}, inputs_from_state={"test": "input"}, outputs_to_state={"output": {"source": "out", "handler": output_handler}}, ) # Test serialization + expected_tool_dict = { + "type": "haystack.tools.component_tool.ComponentTool", + "data": { + "component": {"type": "test_component_tool.SimpleComponent", "init_parameters": {}}, + "name": "simple_tool", + "description": "A simple tool", + "parameters": None, + "outputs_to_string": {"source": "reply", "handler": "test_component_tool.reply_formatter"}, + "inputs_from_state": {"test": "input"}, + "outputs_to_state": {"output": {"source": "out", "handler": "test_component_tool.output_handler"}}, + }, + } tool_dict = tool.to_dict() - assert tool_dict["type"] == "haystack.tools.component_tool.ComponentTool" - assert tool_dict["data"]["name"] == "simple_tool" - assert tool_dict["data"]["description"] == "A simple tool" - assert "component" in tool_dict["data"] - assert tool_dict["data"]["inputs_from_state"] == {"test": "input"} - assert tool_dict["data"]["outputs_to_state"]["output"]["handler"] == "test_component_tool.output_handler" + assert tool_dict == expected_tool_dict # Test deserialization - new_tool = ComponentTool.from_dict(tool_dict) + new_tool = ComponentTool.from_dict(expected_tool_dict) assert new_tool.name == tool.name assert new_tool.description == tool.description assert new_tool.parameters == tool.parameters + assert new_tool.outputs_to_string == tool.outputs_to_string assert new_tool.inputs_from_state == tool.inputs_from_state assert new_tool.outputs_to_state == tool.outputs_to_state assert isinstance(new_tool._component, SimpleComponent) diff --git a/test/tools/test_tool.py b/test/tools/test_tool.py index 49d01aa0e..352cd5871 100644 --- a/test/tools/test_tool.py +++ b/test/tools/test_tool.py @@ -13,6 +13,10 @@ def get_weather_report(city: str) -> str: return f"Weather report for {city}: 20°C, sunny" +def format_string(text: str) -> str: + return f"Formatted: {text}" + + parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} @@ -84,6 +88,8 @@ class TestTool: description="Get weather report", parameters=parameters, function=get_weather_report, + outputs_to_string={"handler": format_string}, + inputs_from_state={"state_key": "tool_input_key"}, outputs_to_state={"documents": {"handler": get_weather_report, "source": "docs"}}, ) @@ -94,8 +100,8 @@ class TestTool: "description": "Get weather report", "parameters": parameters, "function": "test_tool.get_weather_report", - "outputs_to_string": None, - "inputs_from_state": None, + "outputs_to_string": {"handler": "test_tool.format_string"}, + "inputs_from_state": {"state_key": "tool_input_key"}, "outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}}, }, } @@ -108,6 +114,8 @@ class TestTool: "description": "Get weather report", "parameters": parameters, "function": "test_tool.get_weather_report", + "outputs_to_string": {"handler": "test_tool.format_string"}, + "inputs_from_state": {"state_key": "tool_input_key"}, "outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}}, }, } @@ -118,8 +126,9 @@ class TestTool: assert tool.description == "Get weather report" assert tool.parameters == parameters assert tool.function == get_weather_report - assert tool.outputs_to_state["documents"]["source"] == "docs" - assert tool.outputs_to_state["documents"]["handler"] == get_weather_report + assert tool.outputs_to_string == {"handler": format_string} + assert tool.inputs_from_state == {"state_key": "tool_input_key"} + assert tool.outputs_to_state == {"documents": {"source": "docs", "handler": get_weather_report}} def test_check_duplicate_tool_names():