feat: Update list joiner (#8851)

* Update ListJoiner to have default type List

* Add reno

* Add more tests

* Remove unused import

* Fix mypy

* Update docstrings

* Update haystack/components/joiners/list_joiner.py

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>

---------

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
This commit is contained in:
Sebastian Husch Lee 2025-02-14 09:47:19 +01:00 committed by GitHub
parent e6c503dbb9
commit 2f383bce25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 103 additions and 12 deletions

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from itertools import chain from itertools import chain
from typing import Any, Dict, Type from typing import Any, Dict, List, Optional, Type
from haystack import component, default_from_dict, default_to_dict from haystack import component, default_from_dict, default_to_dict
from haystack.core.component.types import Variadic from haystack.core.component.types import Variadic
@ -65,15 +65,19 @@ class ListJoiner:
``` ```
""" """
def __init__(self, list_type_: Type): def __init__(self, list_type_: Optional[Type] = None):
""" """
Creates a ListJoiner component. Creates a ListJoiner component.
:param list_type_: The type of list that this joiner will handle (e.g., List[ChatMessage]). :param list_type_: The expected type of the lists this component will join (e.g., List[ChatMessage]).
All input lists must be of this type. If specified, all input lists must conform to this type. If None, the component defaults to handling
lists of any type including mixed types.
""" """
self.list_type_ = list_type_ self.list_type_ = list_type_
component.set_output_types(self, values=list_type_) if list_type_ is not None:
component.set_output_types(self, values=list_type_)
else:
component.set_output_types(self, values=List)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
@ -81,7 +85,9 @@ class ListJoiner:
:returns: Dictionary with serialized data. :returns: Dictionary with serialized data.
""" """
return default_to_dict(self, list_type_=serialize_type(self.list_type_)) return default_to_dict(
self, list_type_=serialize_type(self.list_type_) if self.list_type_ is not None else None
)
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner": def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner":
@ -91,14 +97,16 @@ class ListJoiner:
:param data: Dictionary to deserialize from. :param data: Dictionary to deserialize from.
:returns: Deserialized component. :returns: Deserialized component.
""" """
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"]) init_parameters = data.get("init_parameters")
if init_parameters is not None and init_parameters.get("list_type_") is not None:
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"])
return default_from_dict(cls, data) return default_from_dict(cls, data)
def run(self, values: Variadic[Any]) -> Dict[str, Any]: def run(self, values: Variadic[List[Any]]) -> Dict[str, List[Any]]:
""" """
Joins multiple lists into a single flat list. Joins multiple lists into a single flat list.
:param values:The list to be joined. :param values: The list to be joined.
:returns: Dictionary with 'values' key containing the joined list. :returns: Dictionary with 'values' key containing the joined list.
""" """
result = list(chain(*values)) result = list(chain(*values))

View File

@ -0,0 +1,6 @@
---
enhancements:
- |
Update ListJoiner to only optionally need list_type_ to be passed. By default it uses type List which acts like List[Any].
- This allows the ListJoiner to combine any incoming lists into a single flattened list.
- Users can still pass list_type_ if they would like to have stricter type validation in their pipelines.

View File

@ -3,11 +3,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List from typing import List
import pytest
from haystack import Document from haystack import Document, Pipeline
from haystack.dataclasses import ChatMessage from haystack.dataclasses import ChatMessage
from haystack.dataclasses.answer import GeneratedAnswer from haystack.dataclasses.answer import GeneratedAnswer
from haystack.components.builders import AnswerBuilder, ChatPromptBuilder
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.joiners.list_joiner import ListJoiner from haystack.components.joiners.list_joiner import ListJoiner
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.core.errors import PipelineConnectError
from haystack.utils.auth import Secret
class TestListJoiner: class TestListJoiner:
@ -16,7 +22,15 @@ class TestListJoiner:
assert isinstance(joiner, ListJoiner) assert isinstance(joiner, ListJoiner)
assert joiner.list_type_ == List[ChatMessage] assert joiner.list_type_ == List[ChatMessage]
def test_to_dict(self): def test_to_dict_defaults(self):
joiner = ListJoiner()
data = joiner.to_dict()
assert data == {
"type": "haystack.components.joiners.list_joiner.ListJoiner",
"init_parameters": {"list_type_": None},
}
def test_to_dict_non_default(self):
joiner = ListJoiner(List[ChatMessage]) joiner = ListJoiner(List[ChatMessage])
data = joiner.to_dict() data = joiner.to_dict()
assert data == { assert data == {
@ -24,7 +38,13 @@ class TestListJoiner:
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"}, "init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
} }
def test_from_dict(self): def test_from_dict_default(self):
data = {"type": "haystack.components.joiners.list_joiner.ListJoiner", "init_parameters": {"list_type_": None}}
list_joiner = ListJoiner.from_dict(data)
assert isinstance(list_joiner, ListJoiner)
assert list_joiner.list_type_ is None
def test_from_dict_non_default(self):
data = { data = {
"type": "haystack.components.joiners.list_joiner.ListJoiner", "type": "haystack.components.joiners.list_joiner.ListJoiner",
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"}, "init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
@ -64,8 +84,65 @@ class TestListJoiner:
result = joiner.run([answers1, answers2]) result = joiner.run([answers1, answers2])
assert result == {"values": answers1 + answers2} assert result == {"values": answers1 + answers2}
def test_list_two_different_types(self):
joiner = ListJoiner()
result = joiner.run([["a", "b"], [1, 2]])
assert result == {"values": ["a", "b", 1, 2]}
def test_mixed_empty_and_non_empty_lists(self): def test_mixed_empty_and_non_empty_lists(self):
joiner = ListJoiner(List[ChatMessage]) joiner = ListJoiner(List[ChatMessage])
messages = [ChatMessage.from_user("Hello")] messages = [ChatMessage.from_user("Hello")]
result = joiner.run([messages, [], messages]) result = joiner.run([messages, [], messages])
assert result == {"values": messages + messages} assert result == {"values": messages + messages}
def test_pipeline_connection_validation(self):
joiner = ListJoiner()
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
pipe = Pipeline()
pipe.add_component("joiner", joiner)
pipe.add_component("llm", llm)
pipe.connect("joiner.values", "llm.messages")
assert pipe is not None
def test_pipeline_connection_validation_list_chatmessage(self):
joiner = ListJoiner(List[ChatMessage])
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
pipe = Pipeline()
pipe.add_component("joiner", joiner)
pipe.add_component("llm", llm)
pipe.connect("joiner", "llm.messages")
assert pipe is not None
def test_pipeline_bad_connection(self):
with pytest.raises(PipelineConnectError):
joiner = ListJoiner()
query_embedder = SentenceTransformersTextEmbedder()
pipe = Pipeline()
pipe.add_component("joiner", joiner)
pipe.add_component("query_embedder", query_embedder)
pipe.connect("joiner.values", "query_embedder.text")
def test_pipeline_bad_connection_different_list_types(self):
with pytest.raises(PipelineConnectError):
joiner = ListJoiner(List[str])
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
pipe = Pipeline()
pipe.add_component("joiner", joiner)
pipe.add_component("llm", llm)
pipe.connect("joiner.values", "llm.messages")
def test_result_two_different_types(self):
pipe = Pipeline()
pipe.add_component("answer_builder", AnswerBuilder())
pipe.add_component("chat_prompt_builder", ChatPromptBuilder())
pipe.add_component("joiner", ListJoiner())
pipe.connect("answer_builder", "joiner.values")
pipe.connect("chat_prompt_builder", "joiner.values")
result = pipe.run(
data={
"answer_builder": {"query": "What is nuclear physics?", "replies": ["This is an answer."]},
"chat_prompt_builder": {"template": [ChatMessage.from_user("Hello")]},
}
)
assert isinstance(result["joiner"]["values"][0], GeneratedAnswer)
assert isinstance(result["joiner"]["values"][1], ChatMessage)