mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 18:29:32 +00:00
fix: preserve root_node in JoinNode's output (#4820)
* preserve root_node and add tests * Added if statement to fix failing tests --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
This commit is contained in:
parent
f660f41c06
commit
6e982e9283
@ -21,7 +21,10 @@ class JoinNode(BaseComponent):
|
||||
top_k_join: Optional[int] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
if inputs:
|
||||
return self.run_accumulated(inputs, top_k_join=top_k_join)
|
||||
results = self.run_accumulated(inputs, top_k_join=top_k_join)
|
||||
if "root_node" in inputs[0]:
|
||||
results[0]["root_node"] = inputs[0]["root_node"]
|
||||
return results
|
||||
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
|
||||
return self.run_accumulated(
|
||||
inputs=[
|
||||
@ -55,7 +58,10 @@ class JoinNode(BaseComponent):
|
||||
top_k_join: Optional[int] = None,
|
||||
) -> Tuple[Dict, str]:
|
||||
if inputs:
|
||||
return self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join)
|
||||
results = self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join)
|
||||
if "root_node" in inputs[0]:
|
||||
results[0]["root_node"] = inputs[0]["root_node"]
|
||||
return results
|
||||
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
|
||||
return self.run_batch_accumulated(
|
||||
inputs=[
|
||||
|
||||
@ -17,3 +17,15 @@ def test_joinanswers(join_mode):
|
||||
result, _ = join_answers.run(inputs, top_k_join=1)
|
||||
assert len(result["answers"]) == 1
|
||||
assert result["answers"][0].answer == "answer 2"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_joinanswers_preserves_root_node():
|
||||
# https://github.com/deepset-ai/haystack-private/issues/51
|
||||
inputs = [
|
||||
{"answers": [Answer(answer="answer 1", score=0.7)], "root_node": "Query"},
|
||||
{"answers": [Answer(answer="answer 2", score=0.8)], "root_node": "Query"},
|
||||
]
|
||||
join_docs = JoinAnswers()
|
||||
result, _ = join_docs.run(inputs)
|
||||
assert result["root_node"] == "Query"
|
||||
|
||||
@ -42,3 +42,15 @@ def test_joindocuments_score_none(join_mode, sort_by_score):
|
||||
|
||||
result, _ = join_docs.run(inputs, top_k_join=1)
|
||||
assert len(result["documents"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_joindocuments_preserves_root_node():
|
||||
# https://github.com/deepset-ai/haystack-private/issues/51
|
||||
inputs = [
|
||||
{"documents": [Document(content="text document 1", content_type="text", score=0.2)], "root_node": "File"},
|
||||
{"documents": [Document(content="text document 2", content_type="text", score=None)], "root_node": "File"},
|
||||
]
|
||||
join_docs = JoinDocuments()
|
||||
result, _ = join_docs.run(inputs)
|
||||
assert result["root_node"] == "File"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user