mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 18:59:28 +00:00
Add sort arg to JoinAnswers (#2436)
* Add sort arg to JoinAnswers * Update Documentation & Code Style * Change naming and docstring * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
15a9ff6f67
commit
43bfea6f3d
@ -70,7 +70,7 @@ A node to join `Answer`s produced by multiple `Reader` nodes.
|
||||
#### JoinAnswers.\_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None)
|
||||
def __init__(join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None, sort_by_score: bool = True)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -81,6 +81,9 @@ of individual `Answer`s.
|
||||
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each
|
||||
`Reader` score. This parameter is not compatible with the `"concatenate"` join_mode.
|
||||
- `top_k_join`: Limit `Answer`s to top_k based on the resulting scored of the join.
|
||||
- `sort_by_score`: Whether to sort the incoming answers by their score. Set this to True if your Answers
|
||||
are coming from a Reader or TableReader. Set to False if any Answers come from a Generator since this assigns
|
||||
None as a score to each.
|
||||
|
||||
<a id="route_documents"></a>
|
||||
|
||||
|
||||
@ -2742,6 +2742,11 @@
|
||||
"top_k_join": {
|
||||
"title": "Top K Join",
|
||||
"type": "integer"
|
||||
},
|
||||
"sort_by_score": {
|
||||
"title": "Sort By Score",
|
||||
"default": true,
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
||||
@ -10,7 +10,11 @@ class JoinAnswers(BaseComponent):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None
|
||||
self,
|
||||
join_mode: str = "concatenate",
|
||||
weights: Optional[List[float]] = None,
|
||||
top_k_join: Optional[int] = None,
|
||||
sort_by_score: bool = True,
|
||||
):
|
||||
"""
|
||||
:param join_mode: `"concatenate"` to combine documents from multiple `Reader`s. `"merge"` to aggregate scores
|
||||
@ -19,6 +23,9 @@ class JoinAnswers(BaseComponent):
|
||||
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each
|
||||
`Reader` score. This parameter is not compatible with the `"concatenate"` join_mode.
|
||||
:param top_k_join: Limit `Answer`s to top_k based on the resulting scored of the join.
|
||||
:param sort_by_score: Whether to sort the incoming answers by their score. Set this to True if your Answers
|
||||
are coming from a Reader or TableReader. Set to False if any Answers come from a Generator since this assigns
|
||||
None as a score to each.
|
||||
"""
|
||||
|
||||
assert join_mode in ["concatenate", "merge"], f"JoinAnswers node does not support '{join_mode}' join_mode."
|
||||
@ -31,6 +38,7 @@ class JoinAnswers(BaseComponent):
|
||||
self.join_mode = join_mode
|
||||
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
|
||||
self.top_k_join = top_k_join
|
||||
self.sort_by_score = sort_by_score
|
||||
|
||||
def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
||||
reader_results = [inp["answers"] for inp in inputs]
|
||||
@ -40,12 +48,13 @@ class JoinAnswers(BaseComponent):
|
||||
|
||||
if self.join_mode == "concatenate":
|
||||
concatenated_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result]
|
||||
concatenated_answers = sorted(concatenated_answers, reverse=True)[:top_k_join]
|
||||
if self.sort_by_score:
|
||||
concatenated_answers = sorted(concatenated_answers, reverse=True)
|
||||
concatenated_answers = concatenated_answers[:top_k_join]
|
||||
return {"answers": concatenated_answers, "labels": inputs[0].get("labels", None)}, "output_1"
|
||||
|
||||
elif self.join_mode == "merge":
|
||||
merged_answers = self._merge_answers(reader_results)
|
||||
|
||||
merged_answers = merged_answers[:top_k_join]
|
||||
return {"answers": merged_answers, "labels": inputs[0].get("labels", None)}, "output_1"
|
||||
|
||||
@ -59,5 +68,8 @@ class JoinAnswers(BaseComponent):
|
||||
for answer in result:
|
||||
if isinstance(answer.score, float):
|
||||
answer.score *= weight
|
||||
merged_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result]
|
||||
if self.sort_by_score:
|
||||
merged_answers = sorted(merged_answers, reverse=True)
|
||||
|
||||
return sorted([answer for cur_reader_result in reader_results for answer in cur_reader_result], reverse=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user