feat: Update list joiner (#8851)

* Update ListJoiner to have default type List

* Add reno

* Add more tests

* Remove unused import

* Fix mypy

* Update docstrings

* Update haystack/components/joiners/list_joiner.py

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>

---------

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
This commit is contained in:
Sebastian Husch Lee 2025-02-14 09:47:19 +01:00 committed by GitHub
parent e6c503dbb9
commit 2f383bce25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 103 additions and 12 deletions

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from itertools import chain
from typing import Any, Dict, Type
from typing import Any, Dict, List, Optional, Type
from haystack import component, default_from_dict, default_to_dict
from haystack.core.component.types import Variadic
@ -65,15 +65,19 @@ class ListJoiner:
```
"""
def __init__(self, list_type_: Type):
def __init__(self, list_type_: Optional[Type] = None):
"""
Creates a ListJoiner component.
:param list_type_: The type of list that this joiner will handle (e.g., List[ChatMessage]).
All input lists must be of this type.
:param list_type_: The expected type of the lists this component will join (e.g., List[ChatMessage]).
If specified, all input lists must conform to this type. If None, the component defaults to handling
lists of any type including mixed types.
"""
self.list_type_ = list_type_
component.set_output_types(self, values=list_type_)
if list_type_ is not None:
component.set_output_types(self, values=list_type_)
else:
component.set_output_types(self, values=List)
def to_dict(self) -> Dict[str, Any]:
"""
@ -81,7 +85,9 @@ class ListJoiner:
:returns: Dictionary with serialized data.
"""
return default_to_dict(self, list_type_=serialize_type(self.list_type_))
return default_to_dict(
self, list_type_=serialize_type(self.list_type_) if self.list_type_ is not None else None
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner":
@ -91,14 +97,16 @@ class ListJoiner:
:param data: Dictionary to deserialize from.
:returns: Deserialized component.
"""
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"])
init_parameters = data.get("init_parameters")
if init_parameters is not None and init_parameters.get("list_type_") is not None:
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"])
return default_from_dict(cls, data)
def run(self, values: Variadic[Any]) -> Dict[str, Any]:
def run(self, values: Variadic[List[Any]]) -> Dict[str, List[Any]]:
"""
Joins multiple lists into a single flat list.
:param values:The list to be joined.
:param values: The list to be joined.
:returns: Dictionary with 'values' key containing the joined list.
"""
result = list(chain(*values))

View File

@ -0,0 +1,6 @@
---
enhancements:
- |
Update ListJoiner to only optionally need list_type_ to be passed. By default it uses type List which acts like List[Any].
- This allows the ListJoiner to combine any incoming lists into a single flattened list.
- Users can still pass list_type_ if they would like to have stricter type validation in their pipelines.

View File

@ -3,11 +3,17 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest
from haystack import Document
from haystack import Document, Pipeline
from haystack.dataclasses import ChatMessage
from haystack.dataclasses.answer import GeneratedAnswer
from haystack.components.builders import AnswerBuilder, ChatPromptBuilder
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.joiners.list_joiner import ListJoiner
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.core.errors import PipelineConnectError
from haystack.utils.auth import Secret
class TestListJoiner:
@ -16,7 +22,15 @@ class TestListJoiner:
assert isinstance(joiner, ListJoiner)
assert joiner.list_type_ == List[ChatMessage]
def test_to_dict(self):
def test_to_dict_defaults(self):
joiner = ListJoiner()
data = joiner.to_dict()
assert data == {
"type": "haystack.components.joiners.list_joiner.ListJoiner",
"init_parameters": {"list_type_": None},
}
def test_to_dict_non_default(self):
joiner = ListJoiner(List[ChatMessage])
data = joiner.to_dict()
assert data == {
@ -24,7 +38,13 @@ class TestListJoiner:
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
}
def test_from_dict(self):
def test_from_dict_default(self):
data = {"type": "haystack.components.joiners.list_joiner.ListJoiner", "init_parameters": {"list_type_": None}}
list_joiner = ListJoiner.from_dict(data)
assert isinstance(list_joiner, ListJoiner)
assert list_joiner.list_type_ is None
def test_from_dict_non_default(self):
data = {
"type": "haystack.components.joiners.list_joiner.ListJoiner",
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
@ -64,8 +84,65 @@ class TestListJoiner:
result = joiner.run([answers1, answers2])
assert result == {"values": answers1 + answers2}
def test_list_two_different_types(self):
joiner = ListJoiner()
result = joiner.run([["a", "b"], [1, 2]])
assert result == {"values": ["a", "b", 1, 2]}
def test_mixed_empty_and_non_empty_lists(self):
joiner = ListJoiner(List[ChatMessage])
messages = [ChatMessage.from_user("Hello")]
result = joiner.run([messages, [], messages])
assert result == {"values": messages + messages}
def test_pipeline_connection_validation(self):
joiner = ListJoiner()
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
pipe = Pipeline()
pipe.add_component("joiner", joiner)
pipe.add_component("llm", llm)
pipe.connect("joiner.values", "llm.messages")
assert pipe is not None
def test_pipeline_connection_validation_list_chatmessage(self):
joiner = ListJoiner(List[ChatMessage])
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
pipe = Pipeline()
pipe.add_component("joiner", joiner)
pipe.add_component("llm", llm)
pipe.connect("joiner", "llm.messages")
assert pipe is not None
def test_pipeline_bad_connection(self):
with pytest.raises(PipelineConnectError):
joiner = ListJoiner()
query_embedder = SentenceTransformersTextEmbedder()
pipe = Pipeline()
pipe.add_component("joiner", joiner)
pipe.add_component("query_embedder", query_embedder)
pipe.connect("joiner.values", "query_embedder.text")
def test_pipeline_bad_connection_different_list_types(self):
with pytest.raises(PipelineConnectError):
joiner = ListJoiner(List[str])
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
pipe = Pipeline()
pipe.add_component("joiner", joiner)
pipe.add_component("llm", llm)
pipe.connect("joiner.values", "llm.messages")
def test_result_two_different_types(self):
pipe = Pipeline()
pipe.add_component("answer_builder", AnswerBuilder())
pipe.add_component("chat_prompt_builder", ChatPromptBuilder())
pipe.add_component("joiner", ListJoiner())
pipe.connect("answer_builder", "joiner.values")
pipe.connect("chat_prompt_builder", "joiner.values")
result = pipe.run(
data={
"answer_builder": {"query": "What is nuclear physics?", "replies": ["This is an answer."]},
"chat_prompt_builder": {"template": [ChatMessage.from_user("Hello")]},
}
)
assert isinstance(result["joiner"]["values"][0], GeneratedAnswer)
assert isinstance(result["joiner"]["values"][1], ChatMessage)