Add BranchJoiner and deprecate Multiplexer (#7765)

This commit is contained in:
Massimiliano Pippi 2024-05-30 15:34:52 +02:00 committed by GitHub
parent 5c468feecf
commit 8d80ff86d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 199 additions and 3 deletions

View File

@ -1,7 +1,7 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/joiners]
modules: ["document_joiner"]
modules: ["document_joiner", "branch"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter

View File

@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
from haystack.components.joiners.document_joiner import DocumentJoiner
from .branch import BranchJoiner
from .document_joiner import DocumentJoiner
__all__ = ["DocumentJoiner"]
__all__ = ["DocumentJoiner", "BranchJoiner"]

View File

@ -0,0 +1,141 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Type
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.core.component.types import Variadic
from haystack.utils import deserialize_type, serialize_type
logger = logging.getLogger(__name__)
@component(is_greedy=True)
class BranchJoiner:
"""
A component to join different branches of a pipeline into one single output.
`BranchJoiner` receives multiple data connections of the same type from other components and passes the first
value coming to its single output, possibly distributing it to various other components.
`BranchJoiner` is fundamental to close loops in a pipeline, where the two branches it joins are the ones
coming from the previous component and one coming back from a loop. For example, `BranchJoiner` could be used
to send data to a component evaluating errors. `BranchJoiner` would receive two connections, one to get the
original data and another one to get modified data in case there was an error. In both cases, `BranchJoiner`
would send (or re-send in case of a loop) data to the component evaluating errors. See "Usage example" below.
Another use case with a need for `BranchJoiner` is to reconcile multiple branches coming out of a decision
or Classifier component. For example, in a RAG pipeline, there might be a "query language classifier" component
sending the query to different retrievers, selecting one specifically according to the detected language. After the
retrieval step the pipeline would ideally continue with a `PromptBuilder`, and since we don't know in advance the
language of the query, all the retrievers should be ideally connected to the single `PromptBuilder`. Since the
`PromptBuilder` won't accept more than one connection in input, we would connect all the retrievers to a
`BranchJoiner` component and reconcile them in a single output that can be connected to the `PromptBuilder`
downstream.
Usage example:
```python
import json
from typing import List
from haystack import Pipeline
from haystack.components.converters import OutputAdapter
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.joiners import BranchJoiner
from haystack.components.validators import JsonSchemaValidator
from haystack.dataclasses import ChatMessage
person_schema = {
"type": "object",
"properties": {
"first_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"},
"last_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"},
"nationality": {"type": "string", "enum": ["Italian", "Portuguese", "American"]},
},
"required": ["first_name", "last_name", "nationality"]
}
# Initialize a pipeline
pipe = Pipeline()
# Add components to the pipeline
pipe.add_component('joiner', BranchJoiner(List[ChatMessage]))
pipe.add_component('fc_llm', OpenAIChatGenerator(model="gpt-3.5-turbo-0125"))
pipe.add_component('validator', JsonSchemaValidator(json_schema=person_schema))
pipe.add_component('adapter', OutputAdapter("{{chat_message}}", List[ChatMessage])),
# And connect them
pipe.connect("adapter", "joiner")
pipe.connect("joiner", "fc_llm")
pipe.connect("fc_llm.replies", "validator.messages")
pipe.connect("validator.validation_error", "joiner")
result = pipe.run(data={"fc_llm": {"generation_kwargs": {"response_format": {"type": "json_object"}}},
"adapter": {"chat_message": [ChatMessage.from_user("Create json object from Peter Parker")]}})
print(json.loads(result["validator"]["validated"][0].content))
>> {'first_name': 'Peter', 'last_name': 'Parker', 'nationality': 'American', 'name': 'Spider-Man', 'occupation':
>> 'Superhero', 'age': 23, 'location': 'New York City'}
```
Note that `BranchJoiner` can manage only one data type at a time. In this case, `BranchJoiner` is created for passing
`List[ChatMessage]`. This determines the type of data that `BranchJoiner` will receive from the upstream connected
components and also the type of data that `BranchJoiner` will send through its output.
In the code example, `BranchJoiner` receives a looped back `List[ChatMessage]` from the `JsonSchemaValidator` and
sends it down to the `OpenAIChatGenerator` for re-generation. We can have multiple loopback connections in the
pipeline. In this instance, the downstream component is only one (the `OpenAIChatGenerator`), but the pipeline might
have more than one downstream component.
"""
def __init__(self, type_: Type):
"""
Create a `BranchJoiner` component.
:param type_: The type of data that the `BranchJoiner` will receive from the upstream connected components and
distribute to the downstream connected components.
"""
self.type_ = type_
# type_'s type can't be determined statically
component.set_input_types(self, value=Variadic[type_]) # type: ignore
component.set_output_types(self, value=type_)
def to_dict(self):
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(self, type_=serialize_type(self.type_))
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BranchJoiner":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
data["init_parameters"]["type_"] = deserialize_type(data["init_parameters"]["type_"])
return default_from_dict(cls, data)
def run(self, **kwargs):
"""
The run method of the `BranchJoiner` component.
Multiplexes the input data from the upstream connected components and distributes it to the downstream connected
components.
:param **kwargs: The input data. Must be of the type declared in `__init__`.
:return: A dictionary with the following keys:
- `value`: The input data.
"""
if (inputs_count := len(kwargs["value"])) != 1:
raise ValueError(f"BranchJoiner expects only one input, but {inputs_count} were received.")
return {"value": kwargs["value"][0]}

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import sys
import warnings
from typing import Any, Dict
from haystack import component, default_from_dict, default_to_dict, logging
@ -103,6 +104,10 @@ class Multiplexer:
:param type_: The type of data that the `Multiplexer` will receive from the upstream connected components and
distribute to the downstream connected components.
"""
warnings.warn(
"`Multiplexer` is deprecated and will be removed in Haystack 2.4.0. Use `joiners.BranchJoiner` instead.",
DeprecationWarning,
)
self.type_ = type_
component.set_input_types(self, value=Variadic[type_])
component.set_output_types(self, value=type_)

View File

@ -0,0 +1,14 @@
---
highlights: >
The `Multiplexer` component proved to be hard to explain and to understand. After reviewing its use cases, the documentation
was rewritten and the component was renamed to `BranchJoiner` to better explain its functionalities.
upgrade:
- |
`BranchJoiner` has the very same interface as `Multiplexer`. To upgrade your code, just rename any occurrence
of `Multiplexer` to `BranchJoiner` and ajdust the imports accordingly.
features:
- |
Add `BranchJoiner` to eventually replace `Multiplexer`
deprecations:
- |
`Mulitplexer` is now deprecated.

View File

@ -0,0 +1,35 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from haystack.components.joiners import BranchJoiner
class TestBranchJoiner:
def test_one_value(self):
joiner = BranchJoiner(int)
output = joiner.run(value=[2])
assert output == {"value": 2}
def test_one_value_of_wrong_type(self):
# BranchJoiner does not type check the input
joiner = BranchJoiner(int)
output = joiner.run(value=["hello"])
assert output == {"value": "hello"}
def test_one_value_of_none_type(self):
# BranchJoiner does not type check the input
joiner = BranchJoiner(int)
output = joiner.run(value=[None])
assert output == {"value": None}
def test_more_values_of_expected_type(self):
joiner = BranchJoiner(int)
with pytest.raises(ValueError, match="BranchJoiner expects only one input, but 3 were received."):
joiner.run(value=[2, 3, 4])
def test_no_values(self):
joiner = BranchJoiner(int)
with pytest.raises(ValueError, match="BranchJoiner expects only one input, but 0 were received."):
joiner.run(value=[])