mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-12 07:17:41 +00:00
feat: Update list joiner (#8851)
* Update ListJoiner to have default type List * Add reno * Add more tests * Remove unused import * Fix mypy * Update docstrings * Update haystack/components/joiners/list_joiner.py Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> --------- Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
This commit is contained in:
parent
e6c503dbb9
commit
2f383bce25
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.core.component.types import Variadic
|
||||
@ -65,15 +65,19 @@ class ListJoiner:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, list_type_: Type):
|
||||
def __init__(self, list_type_: Optional[Type] = None):
|
||||
"""
|
||||
Creates a ListJoiner component.
|
||||
|
||||
:param list_type_: The type of list that this joiner will handle (e.g., List[ChatMessage]).
|
||||
All input lists must be of this type.
|
||||
:param list_type_: The expected type of the lists this component will join (e.g., List[ChatMessage]).
|
||||
If specified, all input lists must conform to this type. If None, the component defaults to handling
|
||||
lists of any type including mixed types.
|
||||
"""
|
||||
self.list_type_ = list_type_
|
||||
component.set_output_types(self, values=list_type_)
|
||||
if list_type_ is not None:
|
||||
component.set_output_types(self, values=list_type_)
|
||||
else:
|
||||
component.set_output_types(self, values=List)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -81,7 +85,9 @@ class ListJoiner:
|
||||
|
||||
:returns: Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(self, list_type_=serialize_type(self.list_type_))
|
||||
return default_to_dict(
|
||||
self, list_type_=serialize_type(self.list_type_) if self.list_type_ is not None else None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner":
|
||||
@ -91,14 +97,16 @@ class ListJoiner:
|
||||
:param data: Dictionary to deserialize from.
|
||||
:returns: Deserialized component.
|
||||
"""
|
||||
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"])
|
||||
init_parameters = data.get("init_parameters")
|
||||
if init_parameters is not None and init_parameters.get("list_type_") is not None:
|
||||
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def run(self, values: Variadic[Any]) -> Dict[str, Any]:
|
||||
def run(self, values: Variadic[List[Any]]) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
Joins multiple lists into a single flat list.
|
||||
|
||||
:param values:The list to be joined.
|
||||
:param values: The list to be joined.
|
||||
:returns: Dictionary with 'values' key containing the joined list.
|
||||
"""
|
||||
result = list(chain(*values))
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Update ListJoiner to only optionally need list_type_ to be passed. By default it uses type List which acts like List[Any].
|
||||
- This allows the ListJoiner to combine any incoming lists into a single flattened list.
|
||||
- Users can still pass list_type_ if they would like to have stricter type validation in their pipelines.
|
||||
@ -3,11 +3,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List
|
||||
import pytest
|
||||
|
||||
from haystack import Document
|
||||
from haystack import Document, Pipeline
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.dataclasses.answer import GeneratedAnswer
|
||||
from haystack.components.builders import AnswerBuilder, ChatPromptBuilder
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||
from haystack.components.joiners.list_joiner import ListJoiner
|
||||
from haystack.components.embedders import SentenceTransformersTextEmbedder
|
||||
from haystack.core.errors import PipelineConnectError
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
|
||||
class TestListJoiner:
|
||||
@ -16,7 +22,15 @@ class TestListJoiner:
|
||||
assert isinstance(joiner, ListJoiner)
|
||||
assert joiner.list_type_ == List[ChatMessage]
|
||||
|
||||
def test_to_dict(self):
|
||||
def test_to_dict_defaults(self):
|
||||
joiner = ListJoiner()
|
||||
data = joiner.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.joiners.list_joiner.ListJoiner",
|
||||
"init_parameters": {"list_type_": None},
|
||||
}
|
||||
|
||||
def test_to_dict_non_default(self):
|
||||
joiner = ListJoiner(List[ChatMessage])
|
||||
data = joiner.to_dict()
|
||||
assert data == {
|
||||
@ -24,7 +38,13 @@ class TestListJoiner:
|
||||
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
|
||||
}
|
||||
|
||||
def test_from_dict(self):
|
||||
def test_from_dict_default(self):
|
||||
data = {"type": "haystack.components.joiners.list_joiner.ListJoiner", "init_parameters": {"list_type_": None}}
|
||||
list_joiner = ListJoiner.from_dict(data)
|
||||
assert isinstance(list_joiner, ListJoiner)
|
||||
assert list_joiner.list_type_ is None
|
||||
|
||||
def test_from_dict_non_default(self):
|
||||
data = {
|
||||
"type": "haystack.components.joiners.list_joiner.ListJoiner",
|
||||
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
|
||||
@ -64,8 +84,65 @@ class TestListJoiner:
|
||||
result = joiner.run([answers1, answers2])
|
||||
assert result == {"values": answers1 + answers2}
|
||||
|
||||
def test_list_two_different_types(self):
|
||||
joiner = ListJoiner()
|
||||
result = joiner.run([["a", "b"], [1, 2]])
|
||||
assert result == {"values": ["a", "b", 1, 2]}
|
||||
|
||||
def test_mixed_empty_and_non_empty_lists(self):
|
||||
joiner = ListJoiner(List[ChatMessage])
|
||||
messages = [ChatMessage.from_user("Hello")]
|
||||
result = joiner.run([messages, [], messages])
|
||||
assert result == {"values": messages + messages}
|
||||
|
||||
def test_pipeline_connection_validation(self):
|
||||
joiner = ListJoiner()
|
||||
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("joiner", joiner)
|
||||
pipe.add_component("llm", llm)
|
||||
pipe.connect("joiner.values", "llm.messages")
|
||||
assert pipe is not None
|
||||
|
||||
def test_pipeline_connection_validation_list_chatmessage(self):
|
||||
joiner = ListJoiner(List[ChatMessage])
|
||||
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("joiner", joiner)
|
||||
pipe.add_component("llm", llm)
|
||||
pipe.connect("joiner", "llm.messages")
|
||||
assert pipe is not None
|
||||
|
||||
def test_pipeline_bad_connection(self):
|
||||
with pytest.raises(PipelineConnectError):
|
||||
joiner = ListJoiner()
|
||||
query_embedder = SentenceTransformersTextEmbedder()
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("joiner", joiner)
|
||||
pipe.add_component("query_embedder", query_embedder)
|
||||
pipe.connect("joiner.values", "query_embedder.text")
|
||||
|
||||
def test_pipeline_bad_connection_different_list_types(self):
|
||||
with pytest.raises(PipelineConnectError):
|
||||
joiner = ListJoiner(List[str])
|
||||
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("joiner", joiner)
|
||||
pipe.add_component("llm", llm)
|
||||
pipe.connect("joiner.values", "llm.messages")
|
||||
|
||||
def test_result_two_different_types(self):
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("answer_builder", AnswerBuilder())
|
||||
pipe.add_component("chat_prompt_builder", ChatPromptBuilder())
|
||||
pipe.add_component("joiner", ListJoiner())
|
||||
pipe.connect("answer_builder", "joiner.values")
|
||||
pipe.connect("chat_prompt_builder", "joiner.values")
|
||||
result = pipe.run(
|
||||
data={
|
||||
"answer_builder": {"query": "What is nuclear physics?", "replies": ["This is an answer."]},
|
||||
"chat_prompt_builder": {"template": [ChatMessage.from_user("Hello")]},
|
||||
}
|
||||
)
|
||||
assert isinstance(result["joiner"]["values"][0], GeneratedAnswer)
|
||||
assert isinstance(result["joiner"]["values"][1], ChatMessage)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user