mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-06 11:57:14 +00:00
Add BranchJoiner and deprecate Multiplexer (#7765)
This commit is contained in:
parent
5c468feecf
commit
8d80ff86d9
@ -1,7 +1,7 @@
|
|||||||
loaders:
|
loaders:
|
||||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||||
search_path: [../../../haystack/components/joiners]
|
search_path: [../../../haystack/components/joiners]
|
||||||
modules: ["document_joiner"]
|
modules: ["document_joiner", "branch"]
|
||||||
ignore_when_discovered: ["__init__"]
|
ignore_when_discovered: ["__init__"]
|
||||||
processors:
|
processors:
|
||||||
- type: filter
|
- type: filter
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
#
|
#
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from haystack.components.joiners.document_joiner import DocumentJoiner
|
from .branch import BranchJoiner
|
||||||
|
from .document_joiner import DocumentJoiner
|
||||||
|
|
||||||
__all__ = ["DocumentJoiner"]
|
__all__ = ["DocumentJoiner", "BranchJoiner"]
|
||||||
|
|||||||
141
haystack/components/joiners/branch.py
Normal file
141
haystack/components/joiners/branch.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Any, Dict, Type
|
||||||
|
|
||||||
|
from haystack import component, default_from_dict, default_to_dict, logging
|
||||||
|
from haystack.core.component.types import Variadic
|
||||||
|
from haystack.utils import deserialize_type, serialize_type
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@component(is_greedy=True)
|
||||||
|
class BranchJoiner:
|
||||||
|
"""
|
||||||
|
A component to join different branches of a pipeline into one single output.
|
||||||
|
|
||||||
|
`BranchJoiner` receives multiple data connections of the same type from other components and passes the first
|
||||||
|
value coming to its single output, possibly distributing it to various other components.
|
||||||
|
|
||||||
|
`BranchJoiner` is fundamental to close loops in a pipeline, where the two branches it joins are the ones
|
||||||
|
coming from the previous component and one coming back from a loop. For example, `BranchJoiner` could be used
|
||||||
|
to send data to a component evaluating errors. `BranchJoiner` would receive two connections, one to get the
|
||||||
|
original data and another one to get modified data in case there was an error. In both cases, `BranchJoiner`
|
||||||
|
would send (or re-send in case of a loop) data to the component evaluating errors. See "Usage example" below.
|
||||||
|
|
||||||
|
Another use case with a need for `BranchJoiner` is to reconcile multiple branches coming out of a decision
|
||||||
|
or Classifier component. For example, in a RAG pipeline, there might be a "query language classifier" component
|
||||||
|
sending the query to different retrievers, selecting one specifically according to the detected language. After the
|
||||||
|
retrieval step the pipeline would ideally continue with a `PromptBuilder`, and since we don't know in advance the
|
||||||
|
language of the query, all the retrievers should be ideally connected to the single `PromptBuilder`. Since the
|
||||||
|
`PromptBuilder` won't accept more than one connection in input, we would connect all the retrievers to a
|
||||||
|
`BranchJoiner` component and reconcile them in a single output that can be connected to the `PromptBuilder`
|
||||||
|
downstream.
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from haystack import Pipeline
|
||||||
|
from haystack.components.converters import OutputAdapter
|
||||||
|
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||||
|
from haystack.components.joiners import BranchJoiner
|
||||||
|
from haystack.components.validators import JsonSchemaValidator
|
||||||
|
from haystack.dataclasses import ChatMessage
|
||||||
|
|
||||||
|
person_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"first_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"},
|
||||||
|
"last_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"},
|
||||||
|
"nationality": {"type": "string", "enum": ["Italian", "Portuguese", "American"]},
|
||||||
|
},
|
||||||
|
"required": ["first_name", "last_name", "nationality"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize a pipeline
|
||||||
|
pipe = Pipeline()
|
||||||
|
|
||||||
|
# Add components to the pipeline
|
||||||
|
pipe.add_component('joiner', BranchJoiner(List[ChatMessage]))
|
||||||
|
pipe.add_component('fc_llm', OpenAIChatGenerator(model="gpt-3.5-turbo-0125"))
|
||||||
|
pipe.add_component('validator', JsonSchemaValidator(json_schema=person_schema))
|
||||||
|
pipe.add_component('adapter', OutputAdapter("{{chat_message}}", List[ChatMessage])),
|
||||||
|
# And connect them
|
||||||
|
pipe.connect("adapter", "joiner")
|
||||||
|
pipe.connect("joiner", "fc_llm")
|
||||||
|
pipe.connect("fc_llm.replies", "validator.messages")
|
||||||
|
pipe.connect("validator.validation_error", "joiner")
|
||||||
|
|
||||||
|
result = pipe.run(data={"fc_llm": {"generation_kwargs": {"response_format": {"type": "json_object"}}},
|
||||||
|
"adapter": {"chat_message": [ChatMessage.from_user("Create json object from Peter Parker")]}})
|
||||||
|
|
||||||
|
print(json.loads(result["validator"]["validated"][0].content))
|
||||||
|
|
||||||
|
|
||||||
|
>> {'first_name': 'Peter', 'last_name': 'Parker', 'nationality': 'American', 'name': 'Spider-Man', 'occupation':
|
||||||
|
>> 'Superhero', 'age': 23, 'location': 'New York City'}
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that `BranchJoiner` can manage only one data type at a time. In this case, `BranchJoiner` is created for passing
|
||||||
|
`List[ChatMessage]`. This determines the type of data that `BranchJoiner` will receive from the upstream connected
|
||||||
|
components and also the type of data that `BranchJoiner` will send through its output.
|
||||||
|
|
||||||
|
In the code example, `BranchJoiner` receives a looped back `List[ChatMessage]` from the `JsonSchemaValidator` and
|
||||||
|
sends it down to the `OpenAIChatGenerator` for re-generation. We can have multiple loopback connections in the
|
||||||
|
pipeline. In this instance, the downstream component is only one (the `OpenAIChatGenerator`), but the pipeline might
|
||||||
|
have more than one downstream component.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, type_: Type):
|
||||||
|
"""
|
||||||
|
Create a `BranchJoiner` component.
|
||||||
|
|
||||||
|
:param type_: The type of data that the `BranchJoiner` will receive from the upstream connected components and
|
||||||
|
distribute to the downstream connected components.
|
||||||
|
"""
|
||||||
|
self.type_ = type_
|
||||||
|
# type_'s type can't be determined statically
|
||||||
|
component.set_input_types(self, value=Variadic[type_]) # type: ignore
|
||||||
|
component.set_output_types(self, value=type_)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
Serializes the component to a dictionary.
|
||||||
|
|
||||||
|
:returns:
|
||||||
|
Dictionary with serialized data.
|
||||||
|
"""
|
||||||
|
return default_to_dict(self, type_=serialize_type(self.type_))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "BranchJoiner":
|
||||||
|
"""
|
||||||
|
Deserializes the component from a dictionary.
|
||||||
|
|
||||||
|
:param data:
|
||||||
|
Dictionary to deserialize from.
|
||||||
|
:returns:
|
||||||
|
Deserialized component.
|
||||||
|
"""
|
||||||
|
data["init_parameters"]["type_"] = deserialize_type(data["init_parameters"]["type_"])
|
||||||
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
def run(self, **kwargs):
|
||||||
|
"""
|
||||||
|
The run method of the `BranchJoiner` component.
|
||||||
|
|
||||||
|
Multiplexes the input data from the upstream connected components and distributes it to the downstream connected
|
||||||
|
components.
|
||||||
|
|
||||||
|
:param **kwargs: The input data. Must be of the type declared in `__init__`.
|
||||||
|
:return: A dictionary with the following keys:
|
||||||
|
- `value`: The input data.
|
||||||
|
"""
|
||||||
|
if (inputs_count := len(kwargs["value"])) != 1:
|
||||||
|
raise ValueError(f"BranchJoiner expects only one input, but {inputs_count} were received.")
|
||||||
|
return {"value": kwargs["value"][0]}
|
||||||
@ -3,6 +3,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from haystack import component, default_from_dict, default_to_dict, logging
|
from haystack import component, default_from_dict, default_to_dict, logging
|
||||||
@ -103,6 +104,10 @@ class Multiplexer:
|
|||||||
:param type_: The type of data that the `Multiplexer` will receive from the upstream connected components and
|
:param type_: The type of data that the `Multiplexer` will receive from the upstream connected components and
|
||||||
distribute to the downstream connected components.
|
distribute to the downstream connected components.
|
||||||
"""
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"`Multiplexer` is deprecated and will be removed in Haystack 2.4.0. Use `joiners.BranchJoiner` instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
self.type_ = type_
|
self.type_ = type_
|
||||||
component.set_input_types(self, value=Variadic[type_])
|
component.set_input_types(self, value=Variadic[type_])
|
||||||
component.set_output_types(self, value=type_)
|
component.set_output_types(self, value=type_)
|
||||||
|
|||||||
14
releasenotes/notes/add-branch-joiner-037298459ca74077.yaml
Normal file
14
releasenotes/notes/add-branch-joiner-037298459ca74077.yaml
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
---
|
||||||
|
highlights: >
|
||||||
|
The `Multiplexer` component proved to be hard to explain and to understand. After reviewing its use cases, the documentation
|
||||||
|
was rewritten and the component was renamed to `BranchJoiner` to better explain its functionalities.
|
||||||
|
upgrade:
|
||||||
|
- |
|
||||||
|
`BranchJoiner` has the very same interface as `Multiplexer`. To upgrade your code, just rename any occurrence
|
||||||
|
of `Multiplexer` to `BranchJoiner` and ajdust the imports accordingly.
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Add `BranchJoiner` to eventually replace `Multiplexer`
|
||||||
|
deprecations:
|
||||||
|
- |
|
||||||
|
`Mulitplexer` is now deprecated.
|
||||||
35
test/components/joiners/test_branch_joiner.py
Normal file
35
test/components/joiners/test_branch_joiner.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.components.joiners import BranchJoiner
|
||||||
|
|
||||||
|
|
||||||
|
class TestBranchJoiner:
|
||||||
|
def test_one_value(self):
|
||||||
|
joiner = BranchJoiner(int)
|
||||||
|
output = joiner.run(value=[2])
|
||||||
|
assert output == {"value": 2}
|
||||||
|
|
||||||
|
def test_one_value_of_wrong_type(self):
|
||||||
|
# BranchJoiner does not type check the input
|
||||||
|
joiner = BranchJoiner(int)
|
||||||
|
output = joiner.run(value=["hello"])
|
||||||
|
assert output == {"value": "hello"}
|
||||||
|
|
||||||
|
def test_one_value_of_none_type(self):
|
||||||
|
# BranchJoiner does not type check the input
|
||||||
|
joiner = BranchJoiner(int)
|
||||||
|
output = joiner.run(value=[None])
|
||||||
|
assert output == {"value": None}
|
||||||
|
|
||||||
|
def test_more_values_of_expected_type(self):
|
||||||
|
joiner = BranchJoiner(int)
|
||||||
|
with pytest.raises(ValueError, match="BranchJoiner expects only one input, but 3 were received."):
|
||||||
|
joiner.run(value=[2, 3, 4])
|
||||||
|
|
||||||
|
def test_no_values(self):
|
||||||
|
joiner = BranchJoiner(int)
|
||||||
|
with pytest.raises(ValueError, match="BranchJoiner expects only one input, but 0 were received."):
|
||||||
|
joiner.run(value=[])
|
||||||
Loading…
x
Reference in New Issue
Block a user