Add sort arg to JoinAnswers (#2436)

* Add sort arg to JoinAnswers

* Update Documentation & Code Style

* Change naming and docstring

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Branden Chan 2022-05-10 11:47:00 +02:00 committed by GitHub
parent 15a9ff6f67
commit 43bfea6f3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 4 deletions

View File

@ -70,7 +70,7 @@ A node to join `Answer`s produced by multiple `Reader` nodes.
#### JoinAnswers.\_\_init\_\_
```python
def __init__(join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None)
def __init__(join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None, sort_by_score: bool = True)
```
**Arguments**:
@ -81,6 +81,9 @@ of individual `Answer`s.
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each
`Reader` score. This parameter is not compatible with the `"concatenate"` join_mode.
- `top_k_join`: Limit `Answer`s to top_k based on the resulting scored of the join.
- `sort_by_score`: Whether to sort the incoming answers by their score. Set this to True if your Answers
are coming from a Reader or TableReader. Set to False if any Answers come from a Generator since this assigns
None as a score to each.
<a id="route_documents"></a>

View File

@ -2742,6 +2742,11 @@
"top_k_join": {
"title": "Top K Join",
"type": "integer"
},
"sort_by_score": {
"title": "Sort By Score",
"default": true,
"type": "boolean"
}
},
"additionalProperties": false,

View File

@ -10,7 +10,11 @@ class JoinAnswers(BaseComponent):
"""
def __init__(
self, join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None
self,
join_mode: str = "concatenate",
weights: Optional[List[float]] = None,
top_k_join: Optional[int] = None,
sort_by_score: bool = True,
):
"""
:param join_mode: `"concatenate"` to combine documents from multiple `Reader`s. `"merge"` to aggregate scores
@ -19,6 +23,9 @@ class JoinAnswers(BaseComponent):
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each
`Reader` score. This parameter is not compatible with the `"concatenate"` join_mode.
:param top_k_join: Limit `Answer`s to top_k based on the resulting scored of the join.
:param sort_by_score: Whether to sort the incoming answers by their score. Set this to True if your Answers
are coming from a Reader or TableReader. Set to False if any Answers come from a Generator since this assigns
None as a score to each.
"""
assert join_mode in ["concatenate", "merge"], f"JoinAnswers node does not support '{join_mode}' join_mode."
@ -31,6 +38,7 @@ class JoinAnswers(BaseComponent):
self.join_mode = join_mode
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
self.top_k_join = top_k_join
self.sort_by_score = sort_by_score
def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
reader_results = [inp["answers"] for inp in inputs]
@ -40,12 +48,13 @@ class JoinAnswers(BaseComponent):
if self.join_mode == "concatenate":
concatenated_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result]
concatenated_answers = sorted(concatenated_answers, reverse=True)[:top_k_join]
if self.sort_by_score:
concatenated_answers = sorted(concatenated_answers, reverse=True)
concatenated_answers = concatenated_answers[:top_k_join]
return {"answers": concatenated_answers, "labels": inputs[0].get("labels", None)}, "output_1"
elif self.join_mode == "merge":
merged_answers = self._merge_answers(reader_results)
merged_answers = merged_answers[:top_k_join]
return {"answers": merged_answers, "labels": inputs[0].get("labels", None)}, "output_1"
@ -59,5 +68,8 @@ class JoinAnswers(BaseComponent):
for answer in result:
if isinstance(answer.score, float):
answer.score *= weight
merged_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result]
if self.sort_by_score:
merged_answers = sorted(merged_answers, reverse=True)
return sorted([answer for cur_reader_result in reader_results for answer in cur_reader_result], reverse=True)