mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 19:47:45 +00:00
feat: Add AnswerJoiner new component (#8122)
* Initial AnswerJoiner * Initial tests * Add release note * Resove mypy warning * Add custom join function * Serialize custom join function * Handle all Answer types, add integration test, improve pydoc * Make fixes * Add to API docs * Add more tests * Update haystack/components/joiners/answer_joiner.py Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> * Update docstrings and release notes * update docstrings --------- Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com> Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com> Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> Co-authored-by: Darja Fokina <daria.fokina@deepset.ai>
This commit is contained in:
parent
3d1ad10385
commit
25d3520f5a
@ -1,7 +1,7 @@
|
||||
loaders:
|
||||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
|
||||
search_path: [../../../haystack/components/joiners]
|
||||
modules: ["document_joiner", "branch"]
|
||||
modules: ["document_joiner", "branch", "answer_joiner"]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
- type: filter
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .answer_joiner import AnswerJoiner
|
||||
from .branch import BranchJoiner
|
||||
from .document_joiner import DocumentJoiner
|
||||
|
||||
__all__ = ["DocumentJoiner", "BranchJoiner"]
|
||||
__all__ = ["DocumentJoiner", "BranchJoiner", "AnswerJoiner"]
|
||||
|
||||
172
haystack/components/joiners/answer_joiner.py
Normal file
172
haystack/components/joiners/answer_joiner.py
Normal file
@ -0,0 +1,172 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from enum import Enum
|
||||
from math import inf
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.core.component.types import Variadic
|
||||
from haystack.dataclasses.answer import ExtractedAnswer, ExtractedTableAnswer, GeneratedAnswer
|
||||
|
||||
AnswerType = Union[GeneratedAnswer, ExtractedTableAnswer, ExtractedAnswer]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JoinMode(Enum):
|
||||
"""
|
||||
Enum for AnswerJoiner join modes.
|
||||
"""
|
||||
|
||||
CONCATENATE = "concatenate"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
@staticmethod
|
||||
def from_str(string: str) -> "JoinMode":
|
||||
"""
|
||||
Convert a string to a JoinMode enum.
|
||||
"""
|
||||
enum_map = {e.value: e for e in JoinMode}
|
||||
mode = enum_map.get(string)
|
||||
if mode is None:
|
||||
msg = f"Unknown join mode '{string}'. Supported modes in AnswerJoiner are: {list(enum_map.keys())}"
|
||||
raise ValueError(msg)
|
||||
return mode
|
||||
|
||||
|
||||
@component
|
||||
class AnswerJoiner:
|
||||
"""
|
||||
Merges multiple lists of `Answer` objects into a single list.
|
||||
|
||||
Use this component to combine answers from different Generators into a single list.
|
||||
Currently, the component supports only one join mode: `CONCATENATE`.
|
||||
This mode concatenates multiple lists of answers into a single list.
|
||||
|
||||
### Usage example
|
||||
|
||||
In this example, AnswerJoiner merges answers from two different Generators:
|
||||
|
||||
```python
|
||||
from haystack.components.builders import AnswerBuilder
|
||||
from haystack.components.joiners import AnswerJoiner
|
||||
|
||||
from haystack.core.pipeline import Pipeline
|
||||
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
|
||||
|
||||
query = "What's Natural Language Processing?"
|
||||
messages = [ChatMessage.from_system("You are a helpful, respectful and honest assistant. Be super concise."),
|
||||
ChatMessage.from_user(query)]
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("gpt-4o", OpenAIChatGenerator(model="gpt-4o"))
|
||||
pipe.add_component("llama", OpenAIChatGenerator(model="gpt-3.5-turbo"))
|
||||
pipe.add_component("aba", AnswerBuilder())
|
||||
pipe.add_component("abb", AnswerBuilder())
|
||||
pipe.add_component("joiner", AnswerJoiner())
|
||||
|
||||
pipe.connect("gpt-4o.replies", "aba")
|
||||
pipe.connect("llama.replies", "abb")
|
||||
pipe.connect("aba.answers", "joiner")
|
||||
pipe.connect("abb.answers", "joiner")
|
||||
|
||||
results = pipe.run(data={"gpt-4o": {"messages": messages},
|
||||
"llama": {"messages": messages},
|
||||
"aba": {"query": query},
|
||||
"abb": {"query": query}})
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
join_mode: Union[str, JoinMode] = JoinMode.CONCATENATE,
|
||||
top_k: Optional[int] = None,
|
||||
sort_by_score: bool = False,
|
||||
):
|
||||
"""
|
||||
Creates an AnswerJoiner component.
|
||||
|
||||
:param join_mode:
|
||||
Specifies the join mode to use. Available modes:
|
||||
- `concatenate`: Concatenates multiple lists of Answers into a single list.
|
||||
:param top_k:
|
||||
The maximum number of Answers to return.
|
||||
:param sort_by_score:
|
||||
If `True`, sorts the documents by score in descending order.
|
||||
If a document has no score, it is handled as if its score is -infinity.
|
||||
"""
|
||||
if isinstance(join_mode, str):
|
||||
join_mode = JoinMode.from_str(join_mode)
|
||||
join_mode_functions: Dict[JoinMode, Callable[[List[List[AnswerType]]], List[AnswerType]]] = {
|
||||
JoinMode.CONCATENATE: self._concatenate
|
||||
}
|
||||
self.join_mode_function: Callable[[List[List[AnswerType]]], List[AnswerType]] = join_mode_functions[join_mode]
|
||||
self.join_mode = join_mode
|
||||
self.top_k = top_k
|
||||
self.sort_by_score = sort_by_score
|
||||
|
||||
@component.output_types(answers=List[AnswerType])
|
||||
def run(self, answers: Variadic[List[AnswerType]], top_k: Optional[int] = None):
|
||||
"""
|
||||
Joins multiple lists of Answers into a single list depending on the `join_mode` parameter.
|
||||
|
||||
:param answers:
|
||||
Nested list of Answers to be merged.
|
||||
|
||||
:param top_k:
|
||||
The maximum number of Answers to return. Overrides the instance's `top_k` if provided.
|
||||
|
||||
:returns:
|
||||
A dictionary with the following keys:
|
||||
- `answers`: Merged list of Answers
|
||||
"""
|
||||
answers_list = list(answers)
|
||||
join_function = self.join_mode_function
|
||||
output_answers: List[AnswerType] = join_function(answers_list)
|
||||
|
||||
if self.sort_by_score:
|
||||
output_answers = sorted(
|
||||
output_answers, key=lambda answer: answer.score if hasattr(answer, "score") else -inf, reverse=True
|
||||
)
|
||||
|
||||
top_k = top_k or self.top_k
|
||||
if top_k:
|
||||
output_answers = output_answers[:top_k]
|
||||
return {"answers": output_answers}
|
||||
|
||||
def _concatenate(self, answer_lists: List[List[AnswerType]]) -> List[AnswerType]:
|
||||
"""
|
||||
Concatenate multiple lists of Answers, flattening them into a single list and sorting by score.
|
||||
|
||||
:param answer_lists: List of lists of Answers to be flattened.
|
||||
"""
|
||||
return list(itertools.chain.from_iterable(answer_lists))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(self, join_mode=str(self.join_mode), top_k=self.top_k, sort_by_score=self.sort_by_score)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AnswerJoiner":
|
||||
"""
|
||||
Deserializes the component from a dictionary.
|
||||
|
||||
:param data:
|
||||
The dictionary to deserialize from.
|
||||
:returns:
|
||||
The deserialized component.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Introduced a new AnswerJoiner component that allows joining multiple lists of Answers into a single list using
|
||||
the Concatenate join mode.
|
||||
141
test/components/joiners/test_answer_joiner.py
Normal file
141
test/components/joiners/test_answer_joiner.py
Normal file
@ -0,0 +1,141 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.components.builders import AnswerBuilder
|
||||
|
||||
from haystack import Document, Pipeline
|
||||
from haystack.dataclasses.answer import ExtractedAnswer, GeneratedAnswer, ExtractedTableAnswer
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.components.joiners.answer_joiner import AnswerJoiner, JoinMode
|
||||
from haystack.dataclasses import ChatMessage
|
||||
|
||||
|
||||
class TestAnswerJoiner:
|
||||
def test_init(self):
|
||||
joiner = AnswerJoiner()
|
||||
assert joiner.join_mode == JoinMode.CONCATENATE
|
||||
assert joiner.top_k is None
|
||||
assert joiner.sort_by_score is False
|
||||
|
||||
def test_init_with_custom_parameters(self):
|
||||
joiner = AnswerJoiner(join_mode="concatenate", top_k=5, sort_by_score=True)
|
||||
assert joiner.join_mode == JoinMode.CONCATENATE
|
||||
assert joiner.top_k == 5
|
||||
assert joiner.sort_by_score is True
|
||||
|
||||
def test_to_dict(self):
|
||||
joiner = AnswerJoiner()
|
||||
data = joiner.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.joiners.answer_joiner.AnswerJoiner",
|
||||
"init_parameters": {"join_mode": "concatenate", "top_k": None, "sort_by_score": False},
|
||||
}
|
||||
|
||||
def test_to_from_dict_custom_parameters(self):
|
||||
joiner = AnswerJoiner("concatenate", top_k=5, sort_by_score=True)
|
||||
data = joiner.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.joiners.answer_joiner.AnswerJoiner",
|
||||
"init_parameters": {"join_mode": "concatenate", "top_k": 5, "sort_by_score": True},
|
||||
}
|
||||
|
||||
deserialized_joiner = AnswerJoiner.from_dict(data)
|
||||
assert deserialized_joiner.join_mode == JoinMode.CONCATENATE
|
||||
assert deserialized_joiner.top_k == 5
|
||||
assert deserialized_joiner.sort_by_score is True
|
||||
|
||||
def test_from_dict(self):
|
||||
data = {"type": "haystack.components.joiners.answer_joiner.AnswerJoiner", "init_parameters": {}}
|
||||
answer_joiner = AnswerJoiner.from_dict(data)
|
||||
assert answer_joiner.join_mode == JoinMode.CONCATENATE
|
||||
assert answer_joiner.top_k is None
|
||||
assert answer_joiner.sort_by_score is False
|
||||
|
||||
def test_from_dict_customs_parameters(self):
|
||||
data = {
|
||||
"type": "haystack.components.joiners.answer_joiner.AnswerJoiner",
|
||||
"init_parameters": {"join_mode": "concatenate", "top_k": 5, "sort_by_score": True},
|
||||
}
|
||||
answer_joiner = AnswerJoiner.from_dict(data)
|
||||
assert answer_joiner.join_mode == JoinMode.CONCATENATE
|
||||
assert answer_joiner.top_k == 5
|
||||
assert answer_joiner.sort_by_score is True
|
||||
|
||||
def test_empty_list(self):
|
||||
joiner = AnswerJoiner()
|
||||
result = joiner.run([])
|
||||
assert result == {"answers": []}
|
||||
|
||||
def test_list_of_empty_lists(self):
|
||||
joiner = AnswerJoiner()
|
||||
result = joiner.run([[], []])
|
||||
assert result == {"answers": []}
|
||||
|
||||
def test_list_of_single_answer(self):
|
||||
joiner = AnswerJoiner()
|
||||
answers = [
|
||||
GeneratedAnswer(query="a", data="a", meta={}, documents=[Document(content="a")]),
|
||||
GeneratedAnswer(query="b", data="b", meta={}, documents=[Document(content="b")]),
|
||||
GeneratedAnswer(query="c", data="c", meta={}, documents=[Document(content="c")]),
|
||||
]
|
||||
result = joiner.run([answers])
|
||||
assert result == {"answers": answers}
|
||||
|
||||
def test_two_lists_of_generated_answers(self):
|
||||
joiner = AnswerJoiner()
|
||||
answers1 = [GeneratedAnswer(query="a", data="a", meta={}, documents=[Document(content="a")])]
|
||||
answers2 = [GeneratedAnswer(query="d", data="d", meta={}, documents=[Document(content="d")])]
|
||||
result = joiner.run([answers1, answers2])
|
||||
assert result == {"answers": answers1 + answers2}
|
||||
|
||||
def test_multiple_lists_of_mixed_answers(self):
|
||||
joiner = AnswerJoiner()
|
||||
answers1 = [GeneratedAnswer(query="a", data="a", meta={}, documents=[Document(content="a")])]
|
||||
answers2 = [ExtractedAnswer(query="d", score=0.9, meta={}, document=Document(content="d"))]
|
||||
answers3 = [ExtractedTableAnswer(query="e", score=0.7, meta={}, document=Document(content="e"))]
|
||||
answers4 = [GeneratedAnswer(query="f", data="f", meta={}, documents=[Document(content="f")])]
|
||||
all_answers = answers1 + answers2 + answers3 + answers4 # type: ignore
|
||||
result = joiner.run([answers1, answers2, answers3, answers4])
|
||||
assert result == {"answers": all_answers}
|
||||
|
||||
def test_unsupported_join_mode(self):
|
||||
unsupported_mode = "unsupported_mode"
|
||||
with pytest.raises(ValueError):
|
||||
AnswerJoiner(join_mode=unsupported_mode)
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY", ""), reason="Needs OPENAI_API_KEY to run this test.")
|
||||
@pytest.mark.integration
|
||||
def test_with_pipeline(self):
|
||||
query = "What's Natural Language Processing?"
|
||||
messages = [
|
||||
ChatMessage.from_system("You are a helpful, respectful and honest assistant. Be super concise."),
|
||||
ChatMessage.from_user(query),
|
||||
]
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("gpt-4o", OpenAIChatGenerator(model="gpt-4o"))
|
||||
pipe.add_component("llama", OpenAIChatGenerator(model="gpt-3.5-turbo"))
|
||||
pipe.add_component("aba", AnswerBuilder())
|
||||
pipe.add_component("abb", AnswerBuilder())
|
||||
pipe.add_component("joiner", AnswerJoiner())
|
||||
|
||||
pipe.connect("gpt-4o.replies", "aba")
|
||||
pipe.connect("llama.replies", "abb")
|
||||
pipe.connect("aba.answers", "joiner")
|
||||
pipe.connect("abb.answers", "joiner")
|
||||
|
||||
results = pipe.run(
|
||||
data={
|
||||
"gpt-4o": {"messages": messages},
|
||||
"llama": {"messages": messages},
|
||||
"aba": {"query": query},
|
||||
"abb": {"query": query},
|
||||
}
|
||||
)
|
||||
|
||||
assert "joiner" in results
|
||||
assert len(results["joiner"]["answers"]) == 2
|
||||
Loading…
x
Reference in New Issue
Block a user