From 8d80ff86d98de70a2f6bf63c5aece7416c8534ee Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 30 May 2024 15:34:52 +0200 Subject: [PATCH] Add BranchJoiner and deprecate Multiplexer (#7765) --- docs/pydoc/config/joiners_api.yml | 2 +- haystack/components/joiners/__init__.py | 5 +- haystack/components/joiners/branch.py | 141 ++++++++++++++++++ haystack/components/others/multiplexer.py | 5 + .../add-branch-joiner-037298459ca74077.yaml | 14 ++ test/components/joiners/test_branch_joiner.py | 35 +++++ 6 files changed, 199 insertions(+), 3 deletions(-) create mode 100644 haystack/components/joiners/branch.py create mode 100644 releasenotes/notes/add-branch-joiner-037298459ca74077.yaml create mode 100644 test/components/joiners/test_branch_joiner.py diff --git a/docs/pydoc/config/joiners_api.yml b/docs/pydoc/config/joiners_api.yml index 9cbf2b161..ad6e89a52 100644 --- a/docs/pydoc/config/joiners_api.yml +++ b/docs/pydoc/config/joiners_api.yml @@ -1,7 +1,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/components/joiners] - modules: ["document_joiner"] + modules: ["document_joiner", "branch"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/components/joiners/__init__.py b/haystack/components/joiners/__init__.py index 23f815050..a72f73c13 100644 --- a/haystack/components/joiners/__init__.py +++ b/haystack/components/joiners/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from haystack.components.joiners.document_joiner import DocumentJoiner +from .branch import BranchJoiner +from .document_joiner import DocumentJoiner -__all__ = ["DocumentJoiner"] +__all__ = ["DocumentJoiner", "BranchJoiner"] diff --git a/haystack/components/joiners/branch.py b/haystack/components/joiners/branch.py new file mode 100644 index 000000000..45673f1a4 --- /dev/null +++ b/haystack/components/joiners/branch.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Type + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.core.component.types import Variadic +from haystack.utils import deserialize_type, serialize_type + +logger = logging.getLogger(__name__) + + +@component(is_greedy=True) +class BranchJoiner: + """ + A component to join different branches of a pipeline into one single output. + + `BranchJoiner` receives multiple data connections of the same type from other components and passes the first + value coming to its single output, possibly distributing it to various other components. + + `BranchJoiner` is fundamental to close loops in a pipeline, where the two branches it joins are the ones + coming from the previous component and one coming back from a loop. For example, `BranchJoiner` could be used + to send data to a component evaluating errors. `BranchJoiner` would receive two connections, one to get the + original data and another one to get modified data in case there was an error. In both cases, `BranchJoiner` + would send (or re-send in case of a loop) data to the component evaluating errors. See "Usage example" below. + + Another use case with a need for `BranchJoiner` is to reconcile multiple branches coming out of a decision + or Classifier component. For example, in a RAG pipeline, there might be a "query language classifier" component + sending the query to different retrievers, selecting one specifically according to the detected language. After the + retrieval step the pipeline would ideally continue with a `PromptBuilder`, and since we don't know in advance the + language of the query, all the retrievers should be ideally connected to the single `PromptBuilder`. Since the + `PromptBuilder` won't accept more than one connection in input, we would connect all the retrievers to a + `BranchJoiner` component and reconcile them in a single output that can be connected to the `PromptBuilder` + downstream. + + Usage example: + + ```python + import json + from typing import List + + from haystack import Pipeline + from haystack.components.converters import OutputAdapter + from haystack.components.generators.chat import OpenAIChatGenerator + from haystack.components.joiners import BranchJoiner + from haystack.components.validators import JsonSchemaValidator + from haystack.dataclasses import ChatMessage + + person_schema = { + "type": "object", + "properties": { + "first_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"}, + "last_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"}, + "nationality": {"type": "string", "enum": ["Italian", "Portuguese", "American"]}, + }, + "required": ["first_name", "last_name", "nationality"] + } + + # Initialize a pipeline + pipe = Pipeline() + + # Add components to the pipeline + pipe.add_component('joiner', BranchJoiner(List[ChatMessage])) + pipe.add_component('fc_llm', OpenAIChatGenerator(model="gpt-3.5-turbo-0125")) + pipe.add_component('validator', JsonSchemaValidator(json_schema=person_schema)) + pipe.add_component('adapter', OutputAdapter("{{chat_message}}", List[ChatMessage])), + # And connect them + pipe.connect("adapter", "joiner") + pipe.connect("joiner", "fc_llm") + pipe.connect("fc_llm.replies", "validator.messages") + pipe.connect("validator.validation_error", "joiner") + + result = pipe.run(data={"fc_llm": {"generation_kwargs": {"response_format": {"type": "json_object"}}}, + "adapter": {"chat_message": [ChatMessage.from_user("Create json object from Peter Parker")]}}) + + print(json.loads(result["validator"]["validated"][0].content)) + + + >> {'first_name': 'Peter', 'last_name': 'Parker', 'nationality': 'American', 'name': 'Spider-Man', 'occupation': + >> 'Superhero', 'age': 23, 'location': 'New York City'} + ``` + + Note that `BranchJoiner` can manage only one data type at a time. In this case, `BranchJoiner` is created for passing + `List[ChatMessage]`. This determines the type of data that `BranchJoiner` will receive from the upstream connected + components and also the type of data that `BranchJoiner` will send through its output. + + In the code example, `BranchJoiner` receives a looped back `List[ChatMessage]` from the `JsonSchemaValidator` and + sends it down to the `OpenAIChatGenerator` for re-generation. We can have multiple loopback connections in the + pipeline. In this instance, the downstream component is only one (the `OpenAIChatGenerator`), but the pipeline might + have more than one downstream component. + """ + + def __init__(self, type_: Type): + """ + Create a `BranchJoiner` component. + + :param type_: The type of data that the `BranchJoiner` will receive from the upstream connected components and + distribute to the downstream connected components. + """ + self.type_ = type_ + # type_'s type can't be determined statically + component.set_input_types(self, value=Variadic[type_]) # type: ignore + component.set_output_types(self, value=type_) + + def to_dict(self): + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict(self, type_=serialize_type(self.type_)) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "BranchJoiner": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + data["init_parameters"]["type_"] = deserialize_type(data["init_parameters"]["type_"]) + return default_from_dict(cls, data) + + def run(self, **kwargs): + """ + The run method of the `BranchJoiner` component. + + Multiplexes the input data from the upstream connected components and distributes it to the downstream connected + components. + + :param **kwargs: The input data. Must be of the type declared in `__init__`. + :return: A dictionary with the following keys: + - `value`: The input data. + """ + if (inputs_count := len(kwargs["value"])) != 1: + raise ValueError(f"BranchJoiner expects only one input, but {inputs_count} were received.") + return {"value": kwargs["value"][0]} diff --git a/haystack/components/others/multiplexer.py b/haystack/components/others/multiplexer.py index 0f23debe8..015fcd732 100644 --- a/haystack/components/others/multiplexer.py +++ b/haystack/components/others/multiplexer.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import sys +import warnings from typing import Any, Dict from haystack import component, default_from_dict, default_to_dict, logging @@ -103,6 +104,10 @@ class Multiplexer: :param type_: The type of data that the `Multiplexer` will receive from the upstream connected components and distribute to the downstream connected components. """ + warnings.warn( + "`Multiplexer` is deprecated and will be removed in Haystack 2.4.0. Use `joiners.BranchJoiner` instead.", + DeprecationWarning, + ) self.type_ = type_ component.set_input_types(self, value=Variadic[type_]) component.set_output_types(self, value=type_) diff --git a/releasenotes/notes/add-branch-joiner-037298459ca74077.yaml b/releasenotes/notes/add-branch-joiner-037298459ca74077.yaml new file mode 100644 index 000000000..75893aa53 --- /dev/null +++ b/releasenotes/notes/add-branch-joiner-037298459ca74077.yaml @@ -0,0 +1,14 @@ +--- +highlights: > + The `Multiplexer` component proved to be hard to explain and to understand. After reviewing its use cases, the documentation + was rewritten and the component was renamed to `BranchJoiner` to better explain its functionalities. +upgrade: + - | + `BranchJoiner` has the very same interface as `Multiplexer`. To upgrade your code, just rename any occurrence + of `Multiplexer` to `BranchJoiner` and ajdust the imports accordingly. +features: + - | + Add `BranchJoiner` to eventually replace `Multiplexer` +deprecations: + - | + `Mulitplexer` is now deprecated. diff --git a/test/components/joiners/test_branch_joiner.py b/test/components/joiners/test_branch_joiner.py new file mode 100644 index 000000000..05d30ad6d --- /dev/null +++ b/test/components/joiners/test_branch_joiner.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from haystack.components.joiners import BranchJoiner + + +class TestBranchJoiner: + def test_one_value(self): + joiner = BranchJoiner(int) + output = joiner.run(value=[2]) + assert output == {"value": 2} + + def test_one_value_of_wrong_type(self): + # BranchJoiner does not type check the input + joiner = BranchJoiner(int) + output = joiner.run(value=["hello"]) + assert output == {"value": "hello"} + + def test_one_value_of_none_type(self): + # BranchJoiner does not type check the input + joiner = BranchJoiner(int) + output = joiner.run(value=[None]) + assert output == {"value": None} + + def test_more_values_of_expected_type(self): + joiner = BranchJoiner(int) + with pytest.raises(ValueError, match="BranchJoiner expects only one input, but 3 were received."): + joiner.run(value=[2, 3, 4]) + + def test_no_values(self): + joiner = BranchJoiner(int) + with pytest.raises(ValueError, match="BranchJoiner expects only one input, but 0 were received."): + joiner.run(value=[])