mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-16 10:03:44 +00:00
Add sort arg to JoinAnswers (#2436)
* Add sort arg to JoinAnswers * Update Documentation & Code Style * Change naming and docstring * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
15a9ff6f67
commit
43bfea6f3d
@ -70,7 +70,7 @@ A node to join `Answer`s produced by multiple `Reader` nodes.
|
|||||||
#### JoinAnswers.\_\_init\_\_
|
#### JoinAnswers.\_\_init\_\_
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def __init__(join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None)
|
def __init__(join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None, sort_by_score: bool = True)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Arguments**:
|
**Arguments**:
|
||||||
@ -81,6 +81,9 @@ of individual `Answer`s.
|
|||||||
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each
|
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each
|
||||||
`Reader` score. This parameter is not compatible with the `"concatenate"` join_mode.
|
`Reader` score. This parameter is not compatible with the `"concatenate"` join_mode.
|
||||||
- `top_k_join`: Limit `Answer`s to top_k based on the resulting scored of the join.
|
- `top_k_join`: Limit `Answer`s to top_k based on the resulting scored of the join.
|
||||||
|
- `sort_by_score`: Whether to sort the incoming answers by their score. Set this to True if your Answers
|
||||||
|
are coming from a Reader or TableReader. Set to False if any Answers come from a Generator since this assigns
|
||||||
|
None as a score to each.
|
||||||
|
|
||||||
<a id="route_documents"></a>
|
<a id="route_documents"></a>
|
||||||
|
|
||||||
|
|||||||
@ -2742,6 +2742,11 @@
|
|||||||
"top_k_join": {
|
"top_k_join": {
|
||||||
"title": "Top K Join",
|
"title": "Top K Join",
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"sort_by_score": {
|
||||||
|
"title": "Sort By Score",
|
||||||
|
"default": true,
|
||||||
|
"type": "boolean"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|||||||
@ -10,7 +10,11 @@ class JoinAnswers(BaseComponent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None
|
self,
|
||||||
|
join_mode: str = "concatenate",
|
||||||
|
weights: Optional[List[float]] = None,
|
||||||
|
top_k_join: Optional[int] = None,
|
||||||
|
sort_by_score: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param join_mode: `"concatenate"` to combine documents from multiple `Reader`s. `"merge"` to aggregate scores
|
:param join_mode: `"concatenate"` to combine documents from multiple `Reader`s. `"merge"` to aggregate scores
|
||||||
@ -19,6 +23,9 @@ class JoinAnswers(BaseComponent):
|
|||||||
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each
|
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each
|
||||||
`Reader` score. This parameter is not compatible with the `"concatenate"` join_mode.
|
`Reader` score. This parameter is not compatible with the `"concatenate"` join_mode.
|
||||||
:param top_k_join: Limit `Answer`s to top_k based on the resulting scored of the join.
|
:param top_k_join: Limit `Answer`s to top_k based on the resulting scored of the join.
|
||||||
|
:param sort_by_score: Whether to sort the incoming answers by their score. Set this to True if your Answers
|
||||||
|
are coming from a Reader or TableReader. Set to False if any Answers come from a Generator since this assigns
|
||||||
|
None as a score to each.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert join_mode in ["concatenate", "merge"], f"JoinAnswers node does not support '{join_mode}' join_mode."
|
assert join_mode in ["concatenate", "merge"], f"JoinAnswers node does not support '{join_mode}' join_mode."
|
||||||
@ -31,6 +38,7 @@ class JoinAnswers(BaseComponent):
|
|||||||
self.join_mode = join_mode
|
self.join_mode = join_mode
|
||||||
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
|
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
|
||||||
self.top_k_join = top_k_join
|
self.top_k_join = top_k_join
|
||||||
|
self.sort_by_score = sort_by_score
|
||||||
|
|
||||||
def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
||||||
reader_results = [inp["answers"] for inp in inputs]
|
reader_results = [inp["answers"] for inp in inputs]
|
||||||
@ -40,12 +48,13 @@ class JoinAnswers(BaseComponent):
|
|||||||
|
|
||||||
if self.join_mode == "concatenate":
|
if self.join_mode == "concatenate":
|
||||||
concatenated_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result]
|
concatenated_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result]
|
||||||
concatenated_answers = sorted(concatenated_answers, reverse=True)[:top_k_join]
|
if self.sort_by_score:
|
||||||
|
concatenated_answers = sorted(concatenated_answers, reverse=True)
|
||||||
|
concatenated_answers = concatenated_answers[:top_k_join]
|
||||||
return {"answers": concatenated_answers, "labels": inputs[0].get("labels", None)}, "output_1"
|
return {"answers": concatenated_answers, "labels": inputs[0].get("labels", None)}, "output_1"
|
||||||
|
|
||||||
elif self.join_mode == "merge":
|
elif self.join_mode == "merge":
|
||||||
merged_answers = self._merge_answers(reader_results)
|
merged_answers = self._merge_answers(reader_results)
|
||||||
|
|
||||||
merged_answers = merged_answers[:top_k_join]
|
merged_answers = merged_answers[:top_k_join]
|
||||||
return {"answers": merged_answers, "labels": inputs[0].get("labels", None)}, "output_1"
|
return {"answers": merged_answers, "labels": inputs[0].get("labels", None)}, "output_1"
|
||||||
|
|
||||||
@ -59,5 +68,8 @@ class JoinAnswers(BaseComponent):
|
|||||||
for answer in result:
|
for answer in result:
|
||||||
if isinstance(answer.score, float):
|
if isinstance(answer.score, float):
|
||||||
answer.score *= weight
|
answer.score *= weight
|
||||||
|
merged_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result]
|
||||||
|
if self.sort_by_score:
|
||||||
|
merged_answers = sorted(merged_answers, reverse=True)
|
||||||
|
|
||||||
return sorted([answer for cur_reader_result in reader_results for answer in cur_reader_result], reverse=True)
|
return sorted([answer for cur_reader_result in reader_results for answer in cur_reader_result], reverse=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user