mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
fix: Fix Tool and ComponentTool serialization when specifying outputs_to_string
(#9524)
* Fix serialization of outputs_to_string in Tool and ComponentTool * Add reno * Fix mypy, simplify logic * fix pylint * Fix test --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
parent
a16ee96003
commit
3784889e5d
@ -206,14 +206,14 @@ class ComponentTool(Tool):
|
||||
"""
|
||||
serialized_component = component_to_dict(obj=self._component, name=self.name)
|
||||
|
||||
serialized = {
|
||||
serialized: Dict[str, Any] = {
|
||||
"component": serialized_component,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self._unresolved_parameters,
|
||||
"outputs_to_string": self.outputs_to_string,
|
||||
"inputs_from_state": self.inputs_from_state,
|
||||
"outputs_to_state": self.outputs_to_state,
|
||||
# This is soft-copied as to not modify the attributes in place
|
||||
"outputs_to_state": self.outputs_to_state.copy() if self.outputs_to_state else None,
|
||||
}
|
||||
|
||||
if self.outputs_to_state is not None:
|
||||
@ -226,7 +226,11 @@ class ComponentTool(Tool):
|
||||
serialized["outputs_to_state"] = serialized_outputs
|
||||
|
||||
if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
|
||||
serialized["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])
|
||||
# This is soft-copied as to not modify the attributes in place
|
||||
serialized["outputs_to_string"] = self.outputs_to_string.copy()
|
||||
serialized["outputs_to_string"]["handler"] = serialize_callable(self.outputs_to_string["handler"])
|
||||
else:
|
||||
serialized["outputs_to_string"] = None
|
||||
|
||||
return {"type": generate_qualified_class_name(type(self)), "data": serialized}
|
||||
|
||||
|
@ -122,7 +122,7 @@ class Tool:
|
||||
data["outputs_to_state"] = serialized_outputs
|
||||
|
||||
if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
|
||||
data["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])
|
||||
data["outputs_to_string"]["handler"] = serialize_callable(self.outputs_to_string["handler"])
|
||||
|
||||
return {"type": generate_qualified_class_name(type(self)), "data": data}
|
||||
|
||||
|
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Fix the serialization of ComponentTool and Tool when specifying outputs_to_string. Previously an error occurred on deserialization right after serializing if outputs_to_string is not None.
|
@ -60,6 +60,10 @@ class SimpleComponent:
|
||||
return {"reply": f"Hello, {text}!"}
|
||||
|
||||
|
||||
def reply_formatter(input_text: str) -> str:
|
||||
return f"Formatted reply: {input_text}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class User:
|
||||
"""A simple user dataclass."""
|
||||
@ -593,24 +597,33 @@ class TestToolComponentInPipelineWithOpenAI:
|
||||
component=SimpleComponent(),
|
||||
name="simple_tool",
|
||||
description="A simple tool",
|
||||
outputs_to_string={"source": "reply", "handler": reply_formatter},
|
||||
inputs_from_state={"test": "input"},
|
||||
outputs_to_state={"output": {"source": "out", "handler": output_handler}},
|
||||
)
|
||||
|
||||
# Test serialization
|
||||
expected_tool_dict = {
|
||||
"type": "haystack.tools.component_tool.ComponentTool",
|
||||
"data": {
|
||||
"component": {"type": "test_component_tool.SimpleComponent", "init_parameters": {}},
|
||||
"name": "simple_tool",
|
||||
"description": "A simple tool",
|
||||
"parameters": None,
|
||||
"outputs_to_string": {"source": "reply", "handler": "test_component_tool.reply_formatter"},
|
||||
"inputs_from_state": {"test": "input"},
|
||||
"outputs_to_state": {"output": {"source": "out", "handler": "test_component_tool.output_handler"}},
|
||||
},
|
||||
}
|
||||
tool_dict = tool.to_dict()
|
||||
assert tool_dict["type"] == "haystack.tools.component_tool.ComponentTool"
|
||||
assert tool_dict["data"]["name"] == "simple_tool"
|
||||
assert tool_dict["data"]["description"] == "A simple tool"
|
||||
assert "component" in tool_dict["data"]
|
||||
assert tool_dict["data"]["inputs_from_state"] == {"test": "input"}
|
||||
assert tool_dict["data"]["outputs_to_state"]["output"]["handler"] == "test_component_tool.output_handler"
|
||||
assert tool_dict == expected_tool_dict
|
||||
|
||||
# Test deserialization
|
||||
new_tool = ComponentTool.from_dict(tool_dict)
|
||||
new_tool = ComponentTool.from_dict(expected_tool_dict)
|
||||
assert new_tool.name == tool.name
|
||||
assert new_tool.description == tool.description
|
||||
assert new_tool.parameters == tool.parameters
|
||||
assert new_tool.outputs_to_string == tool.outputs_to_string
|
||||
assert new_tool.inputs_from_state == tool.inputs_from_state
|
||||
assert new_tool.outputs_to_state == tool.outputs_to_state
|
||||
assert isinstance(new_tool._component, SimpleComponent)
|
||||
|
@ -13,6 +13,10 @@ def get_weather_report(city: str) -> str:
|
||||
return f"Weather report for {city}: 20°C, sunny"
|
||||
|
||||
|
||||
def format_string(text: str) -> str:
|
||||
return f"Formatted: {text}"
|
||||
|
||||
|
||||
parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
|
||||
|
||||
|
||||
@ -84,6 +88,8 @@ class TestTool:
|
||||
description="Get weather report",
|
||||
parameters=parameters,
|
||||
function=get_weather_report,
|
||||
outputs_to_string={"handler": format_string},
|
||||
inputs_from_state={"state_key": "tool_input_key"},
|
||||
outputs_to_state={"documents": {"handler": get_weather_report, "source": "docs"}},
|
||||
)
|
||||
|
||||
@ -94,8 +100,8 @@ class TestTool:
|
||||
"description": "Get weather report",
|
||||
"parameters": parameters,
|
||||
"function": "test_tool.get_weather_report",
|
||||
"outputs_to_string": None,
|
||||
"inputs_from_state": None,
|
||||
"outputs_to_string": {"handler": "test_tool.format_string"},
|
||||
"inputs_from_state": {"state_key": "tool_input_key"},
|
||||
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
|
||||
},
|
||||
}
|
||||
@ -108,6 +114,8 @@ class TestTool:
|
||||
"description": "Get weather report",
|
||||
"parameters": parameters,
|
||||
"function": "test_tool.get_weather_report",
|
||||
"outputs_to_string": {"handler": "test_tool.format_string"},
|
||||
"inputs_from_state": {"state_key": "tool_input_key"},
|
||||
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
|
||||
},
|
||||
}
|
||||
@ -118,8 +126,9 @@ class TestTool:
|
||||
assert tool.description == "Get weather report"
|
||||
assert tool.parameters == parameters
|
||||
assert tool.function == get_weather_report
|
||||
assert tool.outputs_to_state["documents"]["source"] == "docs"
|
||||
assert tool.outputs_to_state["documents"]["handler"] == get_weather_report
|
||||
assert tool.outputs_to_string == {"handler": format_string}
|
||||
assert tool.inputs_from_state == {"state_key": "tool_input_key"}
|
||||
assert tool.outputs_to_state == {"documents": {"source": "docs", "handler": get_weather_report}}
|
||||
|
||||
|
||||
def test_check_duplicate_tool_names():
|
||||
|
Loading…
x
Reference in New Issue
Block a user