diff --git a/docs/_src/api/api/other.md b/docs/_src/api/api/other.md index 17522cc88..0c4324d15 100644 --- a/docs/_src/api/api/other.md +++ b/docs/_src/api/api/other.md @@ -70,7 +70,7 @@ A node to join `Answer`s produced by multiple `Reader` nodes. #### JoinAnswers.\_\_init\_\_ ```python -def __init__(join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None) +def __init__(join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None, sort_by_score: bool = True) ``` **Arguments**: @@ -81,6 +81,9 @@ of individual `Answer`s. adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each `Reader` score. This parameter is not compatible with the `"concatenate"` join_mode. - `top_k_join`: Limit `Answer`s to top_k based on the resulting scored of the join. +- `sort_by_score`: Whether to sort the incoming answers by their score. Set this to True if your Answers +are coming from a Reader or TableReader. Set to False if any Answers come from a Generator since this assigns +None as a score to each. diff --git a/haystack/json-schemas/haystack-pipeline-master.schema.json b/haystack/json-schemas/haystack-pipeline-master.schema.json index af54e3db0..0b918aada 100644 --- a/haystack/json-schemas/haystack-pipeline-master.schema.json +++ b/haystack/json-schemas/haystack-pipeline-master.schema.json @@ -2742,6 +2742,11 @@ "top_k_join": { "title": "Top K Join", "type": "integer" + }, + "sort_by_score": { + "title": "Sort By Score", + "default": true, + "type": "boolean" } }, "additionalProperties": false, diff --git a/haystack/nodes/other/join_answers.py b/haystack/nodes/other/join_answers.py index 64d0adb29..4f02b22f4 100644 --- a/haystack/nodes/other/join_answers.py +++ b/haystack/nodes/other/join_answers.py @@ -10,7 +10,11 @@ class JoinAnswers(BaseComponent): """ def __init__( - self, join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None + self, + join_mode: str = "concatenate", + weights: Optional[List[float]] = None, + top_k_join: Optional[int] = None, + sort_by_score: bool = True, ): """ :param join_mode: `"concatenate"` to combine documents from multiple `Reader`s. `"merge"` to aggregate scores @@ -19,6 +23,9 @@ class JoinAnswers(BaseComponent): adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each `Reader` score. This parameter is not compatible with the `"concatenate"` join_mode. :param top_k_join: Limit `Answer`s to top_k based on the resulting scored of the join. + :param sort_by_score: Whether to sort the incoming answers by their score. Set this to True if your Answers + are coming from a Reader or TableReader. Set to False if any Answers come from a Generator since this assigns + None as a score to each. """ assert join_mode in ["concatenate", "merge"], f"JoinAnswers node does not support '{join_mode}' join_mode." @@ -31,6 +38,7 @@ class JoinAnswers(BaseComponent): self.join_mode = join_mode self.weights = [float(i) / sum(weights) for i in weights] if weights else None self.top_k_join = top_k_join + self.sort_by_score = sort_by_score def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore reader_results = [inp["answers"] for inp in inputs] @@ -40,12 +48,13 @@ class JoinAnswers(BaseComponent): if self.join_mode == "concatenate": concatenated_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result] - concatenated_answers = sorted(concatenated_answers, reverse=True)[:top_k_join] + if self.sort_by_score: + concatenated_answers = sorted(concatenated_answers, reverse=True) + concatenated_answers = concatenated_answers[:top_k_join] return {"answers": concatenated_answers, "labels": inputs[0].get("labels", None)}, "output_1" elif self.join_mode == "merge": merged_answers = self._merge_answers(reader_results) - merged_answers = merged_answers[:top_k_join] return {"answers": merged_answers, "labels": inputs[0].get("labels", None)}, "output_1" @@ -59,5 +68,8 @@ class JoinAnswers(BaseComponent): for answer in result: if isinstance(answer.score, float): answer.score *= weight + merged_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result] + if self.sort_by_score: + merged_answers = sorted(merged_answers, reverse=True) return sorted([answer for cur_reader_result in reader_results for answer in cur_reader_result], reverse=True)