From 32e87d37c153cc8e8c5de5c72e5f939b2efd1022 Mon Sep 17 00:00:00 2001 From: Nicola Procopio Date: Mon, 16 Oct 2023 09:31:52 +0200 Subject: [PATCH] fixed join_docs.py concatenate (#5970) * added hybrid search example Added an example about hybrid search for faq pipeline on covid dataset * formatted with back formatter * renamed document * fixed * fixed typos * added test added test for hybrid search * fixed withespaces * removed test for hybrid search * fixed pylint * commented logging * fixed bug in join_docs.py _concatenate_results * Update join_docs.py updated comment * format with black * added releasenote on PR * updated release notes * updated test_join_documents * updated test * updated test * Update test_join_documents.py * formatted with black * fixed test * fixed --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> --- haystack/nodes/other/join_docs.py | 24 +++++++++++++------ ...ocuments-concatenate-56a7cdba00a7248e.yaml | 4 ++++ test/nodes/test_join_documents.py | 24 +++++++++++++++++++ 3 files changed, 45 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/fix-joinDocuments-concatenate-56a7cdba00a7248e.yaml diff --git a/haystack/nodes/other/join_docs.py b/haystack/nodes/other/join_docs.py index 4185873a7..7ce0de819 100644 --- a/haystack/nodes/other/join_docs.py +++ b/haystack/nodes/other/join_docs.py @@ -1,11 +1,10 @@ -from collections import defaultdict import logging +from collections import defaultdict from math import inf +from typing import List, Optional -from typing import Optional, List - -from haystack.schema import Document from haystack.nodes.other.join import JoinNode +from haystack.schema import Document logger = logging.getLogger(__name__) @@ -64,7 +63,7 @@ class JoinDocuments(JoinNode): document_map = {doc.id: doc for result in results for doc in result} if self.join_mode == "concatenate": - scores_map = self._concatenate_results(results) + scores_map = self._concatenate_results(results, document_map) elif self.join_mode == "merge": scores_map = self._calculate_comb_sum(results) elif self.join_mode == "reciprocal_rank_fusion": @@ -118,11 +117,22 @@ class JoinDocuments(JoinNode): return output, "output_1" - def _concatenate_results(self, results): + def _concatenate_results(self, results, document_map): """ Concatenates multiple document result lists. + Return the documents with the higher score. """ - return {doc.id: doc.score for result in results for doc in result} + list_id = list(document_map.keys()) + scores_map = {} + for idx in list_id: + tmp = [] + for result in results: + for doc in result: + if doc.id == idx: + tmp.append(doc) + item_best_score = max(tmp, key=lambda x: x.score) + scores_map.update({idx: item_best_score.score}) + return scores_map def _calculate_comb_sum(self, results): """ diff --git a/releasenotes/notes/fix-joinDocuments-concatenate-56a7cdba00a7248e.yaml b/releasenotes/notes/fix-joinDocuments-concatenate-56a7cdba00a7248e.yaml new file mode 100644 index 000000000..e282f7b9f --- /dev/null +++ b/releasenotes/notes/fix-joinDocuments-concatenate-56a7cdba00a7248e.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Make JoinDocuments return only the document with the highest score if there are duplicate documents in the list. diff --git a/test/nodes/test_join_documents.py b/test/nodes/test_join_documents.py index 92cf5c882..aa303e26b 100644 --- a/test/nodes/test_join_documents.py +++ b/test/nodes/test_join_documents.py @@ -54,3 +54,27 @@ def test_joindocuments_preserves_root_node(): join_docs = JoinDocuments() result, _ = join_docs.run(inputs) assert result["root_node"] == "File" + + +@pytest.mark.unit +def test_joindocuments_concatenate_keep_only_highest_ranking_duplicate(): + inputs = [ + { + "documents": [ + Document(content="text document 1", content_type="text", score=0.2), + Document(content="text document 2", content_type="text", score=0.3), + ] + }, + {"documents": [Document(content="text document 2", content_type="text", score=0.7)]}, + ] + expected_outputs = { + "documents": [ + Document(content="text document 2", content_type="text", score=0.7), + Document(content="text document 1", content_type="text", score=0.2), + ] + } + + join_docs = JoinDocuments(join_mode="concatenate") + result, _ = join_docs.run(inputs) + assert len(result["documents"]) == 2 + assert result["documents"] == expected_outputs["documents"]