From f6867ebaeeaaa0e2eb62b4c745d344f094153aeb Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> Date: Mon, 26 May 2025 09:40:09 +0200 Subject: [PATCH] =?UTF-8?q?fix:=20Fix=20invoker=20to=20work=20when=20using?= =?UTF-8?q?=20dataclass=20with=20from=5Fdict=20but=20dataclass=E2=80=A6=20?= =?UTF-8?q?(#9434)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix invoker to work when using dataclass with from_dict but dataclass is already given * add reno * Add unit test * Remove line --- haystack/tools/component_tool.py | 12 ++++++---- ...nt-invoker-dataclass-2efe773f03df8a93.yaml | 5 +++++ test/tools/test_component_tool.py | 22 +++++++++++++++++++ 3 files changed, 35 insertions(+), 4 deletions(-) create mode 100644 releasenotes/notes/fix-component-invoker-dataclass-2efe773f03df8a93.yaml diff --git a/haystack/tools/component_tool.py b/haystack/tools/component_tool.py index 7cac568d4..b0f9a3356 100644 --- a/haystack/tools/component_tool.py +++ b/haystack/tools/component_tool.py @@ -159,15 +159,19 @@ class ComponentTool(Tool): target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type if hasattr(target_type, "from_dict"): if isinstance(param_value, list): - param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)] + resolved_param_value = [ + target_type.from_dict(item) if isinstance(item, dict) else item for item in param_value + ] elif isinstance(param_value, dict): - param_value = target_type.from_dict(param_value) + resolved_param_value = target_type.from_dict(param_value) + else: + resolved_param_value = param_value else: # Let TypeAdapter handle both single values and lists type_adapter = TypeAdapter(param_type) - param_value = type_adapter.validate_python(param_value) + resolved_param_value = type_adapter.validate_python(param_value) - converted_kwargs[param_name] = param_value + converted_kwargs[param_name] = resolved_param_value logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}") return component.run(**converted_kwargs) diff --git a/releasenotes/notes/fix-component-invoker-dataclass-2efe773f03df8a93.yaml b/releasenotes/notes/fix-component-invoker-dataclass-2efe773f03df8a93.yaml new file mode 100644 index 000000000..d8ad94f5d --- /dev/null +++ b/releasenotes/notes/fix-component-invoker-dataclass-2efe773f03df8a93.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Fix component_invoker used by ComponentTool to work when a dataclass like ChatMessage is directly passed to `component_tool.invoke(...)`. + Previously this would either cause an error or silently skip your input. diff --git a/test/tools/test_component_tool.py b/test/tools/test_component_tool.py index f6975992a..470891349 100644 --- a/test/tools/test_component_tool.py +++ b/test/tools/test_component_tool.py @@ -30,6 +30,21 @@ from test.tools.test_parameters_schema_utils import BYTE_STREAM_SCHEMA, DOCUMENT # Component and Model Definitions +@component +class SimpleComponentUsingChatMessages: + """A simple component that generates text.""" + + @component.output_types(reply=str) + def run(self, messages: List[ChatMessage]) -> Dict[str, str]: + """ + A simple component that generates text. + + :param messages: Users messages + :return: A dictionary with the generated text. + """ + return {"reply": f"Hello, {messages[0].text}!"} + + @component class SimpleComponent: """A simple component that generates text.""" @@ -306,6 +321,13 @@ class TestComponentTool: with pytest.raises(ValueError): ComponentTool(component=not_a_component, name="invalid_tool", description="This should fail") + def test_component_invoker_with_chat_message_input(self): + tool = ComponentTool( + component=SimpleComponentUsingChatMessages(), name="simple_tool", description="A simple tool" + ) + result = tool.invoke(messages=[ChatMessage.from_user(text="world")]) + assert result == {"reply": "Hello, world!"} + # Integration tests class TestToolComponentInPipelineWithOpenAI: