mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-12 23:37:36 +00:00
feat: Update list joiner (#8851)
* Update ListJoiner to have default type List * Add reno * Add more tests * Remove unused import * Fix mypy * Update docstrings * Update haystack/components/joiners/list_joiner.py Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> --------- Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
This commit is contained in:
parent
e6c503dbb9
commit
2f383bce25
@ -3,7 +3,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any, Dict, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict
|
from haystack import component, default_from_dict, default_to_dict
|
||||||
from haystack.core.component.types import Variadic
|
from haystack.core.component.types import Variadic
|
||||||
@ -65,15 +65,19 @@ class ListJoiner:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, list_type_: Type):
|
def __init__(self, list_type_: Optional[Type] = None):
|
||||||
"""
|
"""
|
||||||
Creates a ListJoiner component.
|
Creates a ListJoiner component.
|
||||||
|
|
||||||
:param list_type_: The type of list that this joiner will handle (e.g., List[ChatMessage]).
|
:param list_type_: The expected type of the lists this component will join (e.g., List[ChatMessage]).
|
||||||
All input lists must be of this type.
|
If specified, all input lists must conform to this type. If None, the component defaults to handling
|
||||||
|
lists of any type including mixed types.
|
||||||
"""
|
"""
|
||||||
self.list_type_ = list_type_
|
self.list_type_ = list_type_
|
||||||
component.set_output_types(self, values=list_type_)
|
if list_type_ is not None:
|
||||||
|
component.set_output_types(self, values=list_type_)
|
||||||
|
else:
|
||||||
|
component.set_output_types(self, values=List)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -81,7 +85,9 @@ class ListJoiner:
|
|||||||
|
|
||||||
:returns: Dictionary with serialized data.
|
:returns: Dictionary with serialized data.
|
||||||
"""
|
"""
|
||||||
return default_to_dict(self, list_type_=serialize_type(self.list_type_))
|
return default_to_dict(
|
||||||
|
self, list_type_=serialize_type(self.list_type_) if self.list_type_ is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner":
|
def from_dict(cls, data: Dict[str, Any]) -> "ListJoiner":
|
||||||
@ -91,14 +97,16 @@ class ListJoiner:
|
|||||||
:param data: Dictionary to deserialize from.
|
:param data: Dictionary to deserialize from.
|
||||||
:returns: Deserialized component.
|
:returns: Deserialized component.
|
||||||
"""
|
"""
|
||||||
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"])
|
init_parameters = data.get("init_parameters")
|
||||||
|
if init_parameters is not None and init_parameters.get("list_type_") is not None:
|
||||||
|
data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"])
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
def run(self, values: Variadic[Any]) -> Dict[str, Any]:
|
def run(self, values: Variadic[List[Any]]) -> Dict[str, List[Any]]:
|
||||||
"""
|
"""
|
||||||
Joins multiple lists into a single flat list.
|
Joins multiple lists into a single flat list.
|
||||||
|
|
||||||
:param values:The list to be joined.
|
:param values: The list to be joined.
|
||||||
:returns: Dictionary with 'values' key containing the joined list.
|
:returns: Dictionary with 'values' key containing the joined list.
|
||||||
"""
|
"""
|
||||||
result = list(chain(*values))
|
result = list(chain(*values))
|
||||||
|
|||||||
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
Update ListJoiner to only optionally need list_type_ to be passed. By default it uses type List which acts like List[Any].
|
||||||
|
- This allows the ListJoiner to combine any incoming lists into a single flattened list.
|
||||||
|
- Users can still pass list_type_ if they would like to have stricter type validation in their pipelines.
|
||||||
@ -3,11 +3,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import pytest
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document, Pipeline
|
||||||
from haystack.dataclasses import ChatMessage
|
from haystack.dataclasses import ChatMessage
|
||||||
from haystack.dataclasses.answer import GeneratedAnswer
|
from haystack.dataclasses.answer import GeneratedAnswer
|
||||||
|
from haystack.components.builders import AnswerBuilder, ChatPromptBuilder
|
||||||
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||||
from haystack.components.joiners.list_joiner import ListJoiner
|
from haystack.components.joiners.list_joiner import ListJoiner
|
||||||
|
from haystack.components.embedders import SentenceTransformersTextEmbedder
|
||||||
|
from haystack.core.errors import PipelineConnectError
|
||||||
|
from haystack.utils.auth import Secret
|
||||||
|
|
||||||
|
|
||||||
class TestListJoiner:
|
class TestListJoiner:
|
||||||
@ -16,7 +22,15 @@ class TestListJoiner:
|
|||||||
assert isinstance(joiner, ListJoiner)
|
assert isinstance(joiner, ListJoiner)
|
||||||
assert joiner.list_type_ == List[ChatMessage]
|
assert joiner.list_type_ == List[ChatMessage]
|
||||||
|
|
||||||
def test_to_dict(self):
|
def test_to_dict_defaults(self):
|
||||||
|
joiner = ListJoiner()
|
||||||
|
data = joiner.to_dict()
|
||||||
|
assert data == {
|
||||||
|
"type": "haystack.components.joiners.list_joiner.ListJoiner",
|
||||||
|
"init_parameters": {"list_type_": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_to_dict_non_default(self):
|
||||||
joiner = ListJoiner(List[ChatMessage])
|
joiner = ListJoiner(List[ChatMessage])
|
||||||
data = joiner.to_dict()
|
data = joiner.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
@ -24,7 +38,13 @@ class TestListJoiner:
|
|||||||
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
|
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_from_dict(self):
|
def test_from_dict_default(self):
|
||||||
|
data = {"type": "haystack.components.joiners.list_joiner.ListJoiner", "init_parameters": {"list_type_": None}}
|
||||||
|
list_joiner = ListJoiner.from_dict(data)
|
||||||
|
assert isinstance(list_joiner, ListJoiner)
|
||||||
|
assert list_joiner.list_type_ is None
|
||||||
|
|
||||||
|
def test_from_dict_non_default(self):
|
||||||
data = {
|
data = {
|
||||||
"type": "haystack.components.joiners.list_joiner.ListJoiner",
|
"type": "haystack.components.joiners.list_joiner.ListJoiner",
|
||||||
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
|
"init_parameters": {"list_type_": "typing.List[haystack.dataclasses.chat_message.ChatMessage]"},
|
||||||
@ -64,8 +84,65 @@ class TestListJoiner:
|
|||||||
result = joiner.run([answers1, answers2])
|
result = joiner.run([answers1, answers2])
|
||||||
assert result == {"values": answers1 + answers2}
|
assert result == {"values": answers1 + answers2}
|
||||||
|
|
||||||
|
def test_list_two_different_types(self):
|
||||||
|
joiner = ListJoiner()
|
||||||
|
result = joiner.run([["a", "b"], [1, 2]])
|
||||||
|
assert result == {"values": ["a", "b", 1, 2]}
|
||||||
|
|
||||||
def test_mixed_empty_and_non_empty_lists(self):
|
def test_mixed_empty_and_non_empty_lists(self):
|
||||||
joiner = ListJoiner(List[ChatMessage])
|
joiner = ListJoiner(List[ChatMessage])
|
||||||
messages = [ChatMessage.from_user("Hello")]
|
messages = [ChatMessage.from_user("Hello")]
|
||||||
result = joiner.run([messages, [], messages])
|
result = joiner.run([messages, [], messages])
|
||||||
assert result == {"values": messages + messages}
|
assert result == {"values": messages + messages}
|
||||||
|
|
||||||
|
def test_pipeline_connection_validation(self):
|
||||||
|
joiner = ListJoiner()
|
||||||
|
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_component("joiner", joiner)
|
||||||
|
pipe.add_component("llm", llm)
|
||||||
|
pipe.connect("joiner.values", "llm.messages")
|
||||||
|
assert pipe is not None
|
||||||
|
|
||||||
|
def test_pipeline_connection_validation_list_chatmessage(self):
|
||||||
|
joiner = ListJoiner(List[ChatMessage])
|
||||||
|
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_component("joiner", joiner)
|
||||||
|
pipe.add_component("llm", llm)
|
||||||
|
pipe.connect("joiner", "llm.messages")
|
||||||
|
assert pipe is not None
|
||||||
|
|
||||||
|
def test_pipeline_bad_connection(self):
|
||||||
|
with pytest.raises(PipelineConnectError):
|
||||||
|
joiner = ListJoiner()
|
||||||
|
query_embedder = SentenceTransformersTextEmbedder()
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_component("joiner", joiner)
|
||||||
|
pipe.add_component("query_embedder", query_embedder)
|
||||||
|
pipe.connect("joiner.values", "query_embedder.text")
|
||||||
|
|
||||||
|
def test_pipeline_bad_connection_different_list_types(self):
|
||||||
|
with pytest.raises(PipelineConnectError):
|
||||||
|
joiner = ListJoiner(List[str])
|
||||||
|
llm = OpenAIChatGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key"))
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_component("joiner", joiner)
|
||||||
|
pipe.add_component("llm", llm)
|
||||||
|
pipe.connect("joiner.values", "llm.messages")
|
||||||
|
|
||||||
|
def test_result_two_different_types(self):
|
||||||
|
pipe = Pipeline()
|
||||||
|
pipe.add_component("answer_builder", AnswerBuilder())
|
||||||
|
pipe.add_component("chat_prompt_builder", ChatPromptBuilder())
|
||||||
|
pipe.add_component("joiner", ListJoiner())
|
||||||
|
pipe.connect("answer_builder", "joiner.values")
|
||||||
|
pipe.connect("chat_prompt_builder", "joiner.values")
|
||||||
|
result = pipe.run(
|
||||||
|
data={
|
||||||
|
"answer_builder": {"query": "What is nuclear physics?", "replies": ["This is an answer."]},
|
||||||
|
"chat_prompt_builder": {"template": [ChatMessage.from_user("Hello")]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert isinstance(result["joiner"]["values"][0], GeneratedAnswer)
|
||||||
|
assert isinstance(result["joiner"]["values"][1], ChatMessage)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user