diff --git a/haystack/components/joiners/list_joiner.py b/haystack/components/joiners/list_joiner.py index af5b109e1..7d9d26579 100644 --- a/haystack/components/joiners/list_joiner.py +++ b/haystack/components/joiners/list_joiner.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from itertools import chain -from typing import Any, Dict, Type +from typing import Any, Dict, List, Optional, Type from haystack import component, default_from_dict, default_to_dict from haystack.core.component.types import Variadic @@ -65,15 +65,19 @@ class ListJoiner: ``` """ - def __init__(self, list_type_: Type): + def __init__(self, list_type_: Optional[Type] = None): """ Creates a ListJoiner component. - :param list_type_: The type of list that this joiner will handle (e.g., List[ChatMessage]). - All input lists must be of this type. + :param list_type_: The expected type of the lists this component will join (e.g., List[ChatMessage]). + If specified, all input lists must conform to this type. If None, the component defaults to handling + lists of any type including mixed types. """ self.list_type_ = list_type_ - component.set_output_types(self, values=list_type_) + if list_type_ is not None: + component.set_output_types(self, values=list_type_) + else: + component.set_output_types(self, values=List) def to_dict(self) -> Dict[str, Any]: """ @@ -81,7 +85,9 @@ class ListJoiner: :returns: Dictionary with serialized data. """ - return default_to_dict(self, list_type_=serialize_type(self.list_type_)) + return default_to_dict( + self, list_type_=serialize_type(self.list_type_) if self.list_type_ is not None else None + ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner": @@ -91,14 +97,16 @@ class ListJoiner: :param data: Dictionary to deserialize from. :returns: Deserialized component. """ - data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"]) + init_parameters = data.get("init_parameters") + if init_parameters is not None and init_parameters.get("list_type_") is not None: + data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"]) return default_from_dict(cls, data) - def run(self, values: Variadic[Any]) -> Dict[str, Any]: + def run(self, values: Variadic[List[Any]]) -> Dict[str, List[Any]]: """ Joins multiple lists into a single flat list. - :param values:The list to be joined. + :param values: The list to be joined. :returns: Dictionary with 'values' key containing the joined list. """ result = list(chain(*values)) diff --git a/releasenotes/notes/update-list-joiner-0a068cfb058f3c35.yaml b/releasenotes/notes/update-list-joiner-0a068cfb058f3c35.yaml new file mode 100644 index 000000000..75a1f3f5f --- /dev/null +++ b/releasenotes/notes/update-list-joiner-0a068cfb058f3c35.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Update ListJoiner to only optionally need list_type_ to be passed. By default it uses type List which acts like List[Any]. + - This allows the ListJoiner to combine any incoming lists into a single flattened list. + - Users can still pass list_type_ if they would like to have stricter type validation in their pipelines. diff --git a/test/components/joiners/test_list_joiner.py b/test/components/joiners/test_list_joiner.py index 9ea9fa57e..5d4f8849a 100644 --- a/test/components/joiners/test_list_joiner.py +++ b/test/components/joiners/test_list_joiner.py @@ -3,11 +3,17 @@ # SPDX-License-Identifier: Apache-2.0 from typing import List +import pytest -from haystack import Document +from haystack import Document, Pipeline from haystack.dataclasses import ChatMessage from haystack.dataclasses.answer import GeneratedAnswer +from haystack.components.builders import AnswerBuilder, ChatPromptBuilder +from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.components.joiners.list_joiner import ListJoiner +from haystack.components.embedders import SentenceTransformersTextEmbedder +from haystack.core.errors import PipelineConnectError +from haystack.utils.auth import Secret class TestListJoiner: @@ -16,7 +22,15 @@ class TestListJoiner: assert isinstance(joiner, ListJoiner) assert joiner.list_type_ == List[ChatMessage] - def test_to_dict(self): + def test_to_dict_defaults(self): + joiner = ListJoiner() + data = joiner.to_dict() + assert data == { + "type": "haystack.components.joiners.list_joiner.ListJoiner", + "init_parameters": {"list_type_": None}, + } + + def test_to_dict_non_default(self): joiner = ListJoiner(List[ChatMessage]) data = joiner.to_dict() assert data == { @@ -24,7 +38,13 @@ class TestListJoiner: "init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"}, } - def test_from_dict(self): + def test_from_dict_default(self): + data = {"type": "haystack.components.joiners.list_joiner.ListJoiner", "init_parameters": {"list_type_": None}} + list_joiner = ListJoiner.from_dict(data) + assert isinstance(list_joiner, ListJoiner) + assert list_joiner.list_type_ is None + + def test_from_dict_non_default(self): data = { "type": "haystack.components.joiners.list_joiner.ListJoiner", "init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"}, @@ -64,8 +84,65 @@ class TestListJoiner: result = joiner.run([answers1, answers2]) assert result == {"values": answers1 + answers2} + def test_list_two_different_types(self): + joiner = ListJoiner() + result = joiner.run([["a", "b"], [1, 2]]) + assert result == {"values": ["a", "b", 1, 2]} + def test_mixed_empty_and_non_empty_lists(self): joiner = ListJoiner(List[ChatMessage]) messages = [ChatMessage.from_user("Hello")] result = joiner.run([messages, [], messages]) assert result == {"values": messages + messages} + + def test_pipeline_connection_validation(self): + joiner = ListJoiner() + llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key")) + pipe = Pipeline() + pipe.add_component("joiner", joiner) + pipe.add_component("llm", llm) + pipe.connect("joiner.values", "llm.messages") + assert pipe is not None + + def test_pipeline_connection_validation_list_chatmessage(self): + joiner = ListJoiner(List[ChatMessage]) + llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key")) + pipe = Pipeline() + pipe.add_component("joiner", joiner) + pipe.add_component("llm", llm) + pipe.connect("joiner", "llm.messages") + assert pipe is not None + + def test_pipeline_bad_connection(self): + with pytest.raises(PipelineConnectError): + joiner = ListJoiner() + query_embedder = SentenceTransformersTextEmbedder() + pipe = Pipeline() + pipe.add_component("joiner", joiner) + pipe.add_component("query_embedder", query_embedder) + pipe.connect("joiner.values", "query_embedder.text") + + def test_pipeline_bad_connection_different_list_types(self): + with pytest.raises(PipelineConnectError): + joiner = ListJoiner(List[str]) + llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key")) + pipe = Pipeline() + pipe.add_component("joiner", joiner) + pipe.add_component("llm", llm) + pipe.connect("joiner.values", "llm.messages") + + def test_result_two_different_types(self): + pipe = Pipeline() + pipe.add_component("answer_builder", AnswerBuilder()) + pipe.add_component("chat_prompt_builder", ChatPromptBuilder()) + pipe.add_component("joiner", ListJoiner()) + pipe.connect("answer_builder", "joiner.values") + pipe.connect("chat_prompt_builder", "joiner.values") + result = pipe.run( + data={ + "answer_builder": {"query": "What is nuclear physics?", "replies": ["This is an answer."]}, + "chat_prompt_builder": {"template": [ChatMessage.from_user("Hello")]}, + } + ) + assert isinstance(result["joiner"]["values"][0], GeneratedAnswer) + assert isinstance(result["joiner"]["values"][1], ChatMessage)