diff --git a/docs/pydoc/config/joiners_api.yml b/docs/pydoc/config/joiners_api.yml index ad6e89a52..6b7d42216 100644 --- a/docs/pydoc/config/joiners_api.yml +++ b/docs/pydoc/config/joiners_api.yml @@ -1,7 +1,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/components/joiners] - modules: ["document_joiner", "branch"] + modules: ["document_joiner", "branch", "answer_joiner"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/components/joiners/__init__.py b/haystack/components/joiners/__init__.py index a72f73c13..57878c209 100644 --- a/haystack/components/joiners/__init__.py +++ b/haystack/components/joiners/__init__.py @@ -2,7 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 +from .answer_joiner import AnswerJoiner from .branch import BranchJoiner from .document_joiner import DocumentJoiner -__all__ = ["DocumentJoiner", "BranchJoiner"] +__all__ = ["DocumentJoiner", "BranchJoiner", "AnswerJoiner"] diff --git a/haystack/components/joiners/answer_joiner.py b/haystack/components/joiners/answer_joiner.py new file mode 100644 index 000000000..3a6a8b824 --- /dev/null +++ b/haystack/components/joiners/answer_joiner.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from enum import Enum +from math import inf +from typing import Any, Callable, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.core.component.types import Variadic +from haystack.dataclasses.answer import ExtractedAnswer, ExtractedTableAnswer, GeneratedAnswer + +AnswerType = Union[GeneratedAnswer, ExtractedTableAnswer, ExtractedAnswer] + +logger = logging.getLogger(__name__) + + +class JoinMode(Enum): + """ + Enum for AnswerJoiner join modes. + """ + + CONCATENATE = "concatenate" + + def __str__(self): + return self.value + + @staticmethod + def from_str(string: str) -> "JoinMode": + """ + Convert a string to a JoinMode enum. + """ + enum_map = {e.value: e for e in JoinMode} + mode = enum_map.get(string) + if mode is None: + msg = f"Unknown join mode '{string}'. Supported modes in AnswerJoiner are: {list(enum_map.keys())}" + raise ValueError(msg) + return mode + + +@component +class AnswerJoiner: + """ + Merges multiple lists of `Answer` objects into a single list. + + Use this component to combine answers from different Generators into a single list. + Currently, the component supports only one join mode: `CONCATENATE`. + This mode concatenates multiple lists of answers into a single list. + + ### Usage example + + In this example, AnswerJoiner merges answers from two different Generators: + + ```python + from haystack.components.builders import AnswerBuilder + from haystack.components.joiners import AnswerJoiner + + from haystack.core.pipeline import Pipeline + + from haystack.components.generators.chat import OpenAIChatGenerator + from haystack.dataclasses import ChatMessage + + + query = "What's Natural Language Processing?" + messages = [ChatMessage.from_system("You are a helpful, respectful and honest assistant. Be super concise."), + ChatMessage.from_user(query)] + + pipe = Pipeline() + pipe.add_component("gpt-4o", OpenAIChatGenerator(model="gpt-4o")) + pipe.add_component("llama", OpenAIChatGenerator(model="gpt-3.5-turbo")) + pipe.add_component("aba", AnswerBuilder()) + pipe.add_component("abb", AnswerBuilder()) + pipe.add_component("joiner", AnswerJoiner()) + + pipe.connect("gpt-4o.replies", "aba") + pipe.connect("llama.replies", "abb") + pipe.connect("aba.answers", "joiner") + pipe.connect("abb.answers", "joiner") + + results = pipe.run(data={"gpt-4o": {"messages": messages}, + "llama": {"messages": messages}, + "aba": {"query": query}, + "abb": {"query": query}}) + ``` + """ + + def __init__( + self, + join_mode: Union[str, JoinMode] = JoinMode.CONCATENATE, + top_k: Optional[int] = None, + sort_by_score: bool = False, + ): + """ + Creates an AnswerJoiner component. + + :param join_mode: + Specifies the join mode to use. Available modes: + - `concatenate`: Concatenates multiple lists of Answers into a single list. + :param top_k: + The maximum number of Answers to return. + :param sort_by_score: + If `True`, sorts the documents by score in descending order. + If a document has no score, it is handled as if its score is -infinity. + """ + if isinstance(join_mode, str): + join_mode = JoinMode.from_str(join_mode) + join_mode_functions: Dict[JoinMode, Callable[[List[List[AnswerType]]], List[AnswerType]]] = { + JoinMode.CONCATENATE: self._concatenate + } + self.join_mode_function: Callable[[List[List[AnswerType]]], List[AnswerType]] = join_mode_functions[join_mode] + self.join_mode = join_mode + self.top_k = top_k + self.sort_by_score = sort_by_score + + @component.output_types(answers=List[AnswerType]) + def run(self, answers: Variadic[List[AnswerType]], top_k: Optional[int] = None): + """ + Joins multiple lists of Answers into a single list depending on the `join_mode` parameter. + + :param answers: + Nested list of Answers to be merged. + + :param top_k: + The maximum number of Answers to return. Overrides the instance's `top_k` if provided. + + :returns: + A dictionary with the following keys: + - `answers`: Merged list of Answers + """ + answers_list = list(answers) + join_function = self.join_mode_function + output_answers: List[AnswerType] = join_function(answers_list) + + if self.sort_by_score: + output_answers = sorted( + output_answers, key=lambda answer: answer.score if hasattr(answer, "score") else -inf, reverse=True + ) + + top_k = top_k or self.top_k + if top_k: + output_answers = output_answers[:top_k] + return {"answers": output_answers} + + def _concatenate(self, answer_lists: List[List[AnswerType]]) -> List[AnswerType]: + """ + Concatenate multiple lists of Answers, flattening them into a single list and sorting by score. + + :param answer_lists: List of lists of Answers to be flattened. + """ + return list(itertools.chain.from_iterable(answer_lists)) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict(self, join_mode=str(self.join_mode), top_k=self.top_k, sort_by_score=self.sort_by_score) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AnswerJoiner": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + return default_from_dict(cls, data) diff --git a/releasenotes/notes/introduce-answer-joiner-component-885dd7846776f5cb.yaml b/releasenotes/notes/introduce-answer-joiner-component-885dd7846776f5cb.yaml new file mode 100644 index 000000000..e1773183d --- /dev/null +++ b/releasenotes/notes/introduce-answer-joiner-component-885dd7846776f5cb.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Introduced a new AnswerJoiner component that allows joining multiple lists of Answers into a single list using + the Concatenate join mode. diff --git a/test/components/joiners/test_answer_joiner.py b/test/components/joiners/test_answer_joiner.py new file mode 100644 index 000000000..b7ba764ec --- /dev/null +++ b/test/components/joiners/test_answer_joiner.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest + +from haystack.components.builders import AnswerBuilder + +from haystack import Document, Pipeline +from haystack.dataclasses.answer import ExtractedAnswer, GeneratedAnswer, ExtractedTableAnswer +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.joiners.answer_joiner import AnswerJoiner, JoinMode +from haystack.dataclasses import ChatMessage + + +class TestAnswerJoiner: + def test_init(self): + joiner = AnswerJoiner() + assert joiner.join_mode == JoinMode.CONCATENATE + assert joiner.top_k is None + assert joiner.sort_by_score is False + + def test_init_with_custom_parameters(self): + joiner = AnswerJoiner(join_mode="concatenate", top_k=5, sort_by_score=True) + assert joiner.join_mode == JoinMode.CONCATENATE + assert joiner.top_k == 5 + assert joiner.sort_by_score is True + + def test_to_dict(self): + joiner = AnswerJoiner() + data = joiner.to_dict() + assert data == { + "type": "haystack.components.joiners.answer_joiner.AnswerJoiner", + "init_parameters": {"join_mode": "concatenate", "top_k": None, "sort_by_score": False}, + } + + def test_to_from_dict_custom_parameters(self): + joiner = AnswerJoiner("concatenate", top_k=5, sort_by_score=True) + data = joiner.to_dict() + assert data == { + "type": "haystack.components.joiners.answer_joiner.AnswerJoiner", + "init_parameters": {"join_mode": "concatenate", "top_k": 5, "sort_by_score": True}, + } + + deserialized_joiner = AnswerJoiner.from_dict(data) + assert deserialized_joiner.join_mode == JoinMode.CONCATENATE + assert deserialized_joiner.top_k == 5 + assert deserialized_joiner.sort_by_score is True + + def test_from_dict(self): + data = {"type": "haystack.components.joiners.answer_joiner.AnswerJoiner", "init_parameters": {}} + answer_joiner = AnswerJoiner.from_dict(data) + assert answer_joiner.join_mode == JoinMode.CONCATENATE + assert answer_joiner.top_k is None + assert answer_joiner.sort_by_score is False + + def test_from_dict_customs_parameters(self): + data = { + "type": "haystack.components.joiners.answer_joiner.AnswerJoiner", + "init_parameters": {"join_mode": "concatenate", "top_k": 5, "sort_by_score": True}, + } + answer_joiner = AnswerJoiner.from_dict(data) + assert answer_joiner.join_mode == JoinMode.CONCATENATE + assert answer_joiner.top_k == 5 + assert answer_joiner.sort_by_score is True + + def test_empty_list(self): + joiner = AnswerJoiner() + result = joiner.run([]) + assert result == {"answers": []} + + def test_list_of_empty_lists(self): + joiner = AnswerJoiner() + result = joiner.run([[], []]) + assert result == {"answers": []} + + def test_list_of_single_answer(self): + joiner = AnswerJoiner() + answers = [ + GeneratedAnswer(query="a", data="a", meta={}, documents=[Document(content="a")]), + GeneratedAnswer(query="b", data="b", meta={}, documents=[Document(content="b")]), + GeneratedAnswer(query="c", data="c", meta={}, documents=[Document(content="c")]), + ] + result = joiner.run([answers]) + assert result == {"answers": answers} + + def test_two_lists_of_generated_answers(self): + joiner = AnswerJoiner() + answers1 = [GeneratedAnswer(query="a", data="a", meta={}, documents=[Document(content="a")])] + answers2 = [GeneratedAnswer(query="d", data="d", meta={}, documents=[Document(content="d")])] + result = joiner.run([answers1, answers2]) + assert result == {"answers": answers1 + answers2} + + def test_multiple_lists_of_mixed_answers(self): + joiner = AnswerJoiner() + answers1 = [GeneratedAnswer(query="a", data="a", meta={}, documents=[Document(content="a")])] + answers2 = [ExtractedAnswer(query="d", score=0.9, meta={}, document=Document(content="d"))] + answers3 = [ExtractedTableAnswer(query="e", score=0.7, meta={}, document=Document(content="e"))] + answers4 = [GeneratedAnswer(query="f", data="f", meta={}, documents=[Document(content="f")])] + all_answers = answers1 + answers2 + answers3 + answers4 # type: ignore + result = joiner.run([answers1, answers2, answers3, answers4]) + assert result == {"answers": all_answers} + + def test_unsupported_join_mode(self): + unsupported_mode = "unsupported_mode" + with pytest.raises(ValueError): + AnswerJoiner(join_mode=unsupported_mode) + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY", ""), reason="Needs OPENAI_API_KEY to run this test.") + @pytest.mark.integration + def test_with_pipeline(self): + query = "What's Natural Language Processing?" + messages = [ + ChatMessage.from_system("You are a helpful, respectful and honest assistant. Be super concise."), + ChatMessage.from_user(query), + ] + + pipe = Pipeline() + pipe.add_component("gpt-4o", OpenAIChatGenerator(model="gpt-4o")) + pipe.add_component("llama", OpenAIChatGenerator(model="gpt-3.5-turbo")) + pipe.add_component("aba", AnswerBuilder()) + pipe.add_component("abb", AnswerBuilder()) + pipe.add_component("joiner", AnswerJoiner()) + + pipe.connect("gpt-4o.replies", "aba") + pipe.connect("llama.replies", "abb") + pipe.connect("aba.answers", "joiner") + pipe.connect("abb.answers", "joiner") + + results = pipe.run( + data={ + "gpt-4o": {"messages": messages}, + "llama": {"messages": messages}, + "aba": {"query": query}, + "abb": {"query": query}, + } + ) + + assert "joiner" in results + assert len(results["joiner"]["answers"]) == 2