diff --git a/haystack/nodes/other/join.py b/haystack/nodes/other/join.py index f96ef03a2..d2809648e 100644 --- a/haystack/nodes/other/join.py +++ b/haystack/nodes/other/join.py @@ -21,7 +21,10 @@ class JoinNode(BaseComponent): top_k_join: Optional[int] = None, ) -> Tuple[Dict, str]: if inputs: - return self.run_accumulated(inputs, top_k_join=top_k_join) + results = self.run_accumulated(inputs, top_k_join=top_k_join) + if "root_node" in inputs[0]: + results[0]["root_node"] = inputs[0]["root_node"] + return results warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.") return self.run_accumulated( inputs=[ @@ -55,7 +58,10 @@ class JoinNode(BaseComponent): top_k_join: Optional[int] = None, ) -> Tuple[Dict, str]: if inputs: - return self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join) + results = self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join) + if "root_node" in inputs[0]: + results[0]["root_node"] = inputs[0]["root_node"] + return results warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.") return self.run_batch_accumulated( inputs=[ diff --git a/test/nodes/test_join_answers.py b/test/nodes/test_join_answers.py index 9ba802383..67567ceab 100644 --- a/test/nodes/test_join_answers.py +++ b/test/nodes/test_join_answers.py @@ -17,3 +17,15 @@ def test_joinanswers(join_mode): result, _ = join_answers.run(inputs, top_k_join=1) assert len(result["answers"]) == 1 assert result["answers"][0].answer == "answer 2" + + +@pytest.mark.unit +def test_joinanswers_preserves_root_node(): + # https://github.com/deepset-ai/haystack-private/issues/51 + inputs = [ + {"answers": [Answer(answer="answer 1", score=0.7)], "root_node": "Query"}, + {"answers": [Answer(answer="answer 2", score=0.8)], "root_node": "Query"}, + ] + join_docs = JoinAnswers() + result, _ = join_docs.run(inputs) + assert result["root_node"] == "Query" diff --git a/test/nodes/test_join_documents.py b/test/nodes/test_join_documents.py index 0a8b88a7d..92cf5c882 100644 --- a/test/nodes/test_join_documents.py +++ b/test/nodes/test_join_documents.py @@ -42,3 +42,15 @@ def test_joindocuments_score_none(join_mode, sort_by_score): result, _ = join_docs.run(inputs, top_k_join=1) assert len(result["documents"]) == 1 + + +@pytest.mark.unit +def test_joindocuments_preserves_root_node(): + # https://github.com/deepset-ai/haystack-private/issues/51 + inputs = [ + {"documents": [Document(content="text document 1", content_type="text", score=0.2)], "root_node": "File"}, + {"documents": [Document(content="text document 2", content_type="text", score=None)], "root_node": "File"}, + ] + join_docs = JoinDocuments() + result, _ = join_docs.run(inputs) + assert result["root_node"] == "File"