fix: preserve root_node in JoinNode's output (#4820)

* preserve root_node and add tests

* Added if statement to fix failing tests

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
This commit is contained in:
ZanSara 2023-05-08 10:17:36 +02:00 committed by GitHub
parent f660f41c06
commit 6e982e9283
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 2 deletions

View File

@ -21,7 +21,10 @@ class JoinNode(BaseComponent):
top_k_join: Optional[int] = None,
) -> Tuple[Dict, str]:
if inputs:
return self.run_accumulated(inputs, top_k_join=top_k_join)
results = self.run_accumulated(inputs, top_k_join=top_k_join)
if "root_node" in inputs[0]:
results[0]["root_node"] = inputs[0]["root_node"]
return results
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
return self.run_accumulated(
inputs=[
@ -55,7 +58,10 @@ class JoinNode(BaseComponent):
top_k_join: Optional[int] = None,
) -> Tuple[Dict, str]:
if inputs:
return self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join)
results = self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join)
if "root_node" in inputs[0]:
results[0]["root_node"] = inputs[0]["root_node"]
return results
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
return self.run_batch_accumulated(
inputs=[

View File

@ -17,3 +17,15 @@ def test_joinanswers(join_mode):
result, _ = join_answers.run(inputs, top_k_join=1)
assert len(result["answers"]) == 1
assert result["answers"][0].answer == "answer 2"
@pytest.mark.unit
def test_joinanswers_preserves_root_node():
# https://github.com/deepset-ai/haystack-private/issues/51
inputs = [
{"answers": [Answer(answer="answer 1", score=0.7)], "root_node": "Query"},
{"answers": [Answer(answer="answer 2", score=0.8)], "root_node": "Query"},
]
join_docs = JoinAnswers()
result, _ = join_docs.run(inputs)
assert result["root_node"] == "Query"

View File

@ -42,3 +42,15 @@ def test_joindocuments_score_none(join_mode, sort_by_score):
result, _ = join_docs.run(inputs, top_k_join=1)
assert len(result["documents"]) == 1
@pytest.mark.unit
def test_joindocuments_preserves_root_node():
# https://github.com/deepset-ai/haystack-private/issues/51
inputs = [
{"documents": [Document(content="text document 1", content_type="text", score=0.2)], "root_node": "File"},
{"documents": [Document(content="text document 2", content_type="text", score=None)], "root_node": "File"},
]
join_docs = JoinDocuments()
result, _ = join_docs.run(inputs)
assert result["root_node"] == "File"