mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-09 14:23:43 +00:00
fix: preserve root_node in JoinNode's output (#4820)
* preserve root_node and add tests * Added if statement to fix failing tests --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
This commit is contained in:
parent
f660f41c06
commit
6e982e9283
@ -21,7 +21,10 @@ class JoinNode(BaseComponent):
|
|||||||
top_k_join: Optional[int] = None,
|
top_k_join: Optional[int] = None,
|
||||||
) -> Tuple[Dict, str]:
|
) -> Tuple[Dict, str]:
|
||||||
if inputs:
|
if inputs:
|
||||||
return self.run_accumulated(inputs, top_k_join=top_k_join)
|
results = self.run_accumulated(inputs, top_k_join=top_k_join)
|
||||||
|
if "root_node" in inputs[0]:
|
||||||
|
results[0]["root_node"] = inputs[0]["root_node"]
|
||||||
|
return results
|
||||||
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
|
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
|
||||||
return self.run_accumulated(
|
return self.run_accumulated(
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -55,7 +58,10 @@ class JoinNode(BaseComponent):
|
|||||||
top_k_join: Optional[int] = None,
|
top_k_join: Optional[int] = None,
|
||||||
) -> Tuple[Dict, str]:
|
) -> Tuple[Dict, str]:
|
||||||
if inputs:
|
if inputs:
|
||||||
return self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join)
|
results = self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join)
|
||||||
|
if "root_node" in inputs[0]:
|
||||||
|
results[0]["root_node"] = inputs[0]["root_node"]
|
||||||
|
return results
|
||||||
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
|
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
|
||||||
return self.run_batch_accumulated(
|
return self.run_batch_accumulated(
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@ -17,3 +17,15 @@ def test_joinanswers(join_mode):
|
|||||||
result, _ = join_answers.run(inputs, top_k_join=1)
|
result, _ = join_answers.run(inputs, top_k_join=1)
|
||||||
assert len(result["answers"]) == 1
|
assert len(result["answers"]) == 1
|
||||||
assert result["answers"][0].answer == "answer 2"
|
assert result["answers"][0].answer == "answer 2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_joinanswers_preserves_root_node():
|
||||||
|
# https://github.com/deepset-ai/haystack-private/issues/51
|
||||||
|
inputs = [
|
||||||
|
{"answers": [Answer(answer="answer 1", score=0.7)], "root_node": "Query"},
|
||||||
|
{"answers": [Answer(answer="answer 2", score=0.8)], "root_node": "Query"},
|
||||||
|
]
|
||||||
|
join_docs = JoinAnswers()
|
||||||
|
result, _ = join_docs.run(inputs)
|
||||||
|
assert result["root_node"] == "Query"
|
||||||
|
|||||||
@ -42,3 +42,15 @@ def test_joindocuments_score_none(join_mode, sort_by_score):
|
|||||||
|
|
||||||
result, _ = join_docs.run(inputs, top_k_join=1)
|
result, _ = join_docs.run(inputs, top_k_join=1)
|
||||||
assert len(result["documents"]) == 1
|
assert len(result["documents"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_joindocuments_preserves_root_node():
|
||||||
|
# https://github.com/deepset-ai/haystack-private/issues/51
|
||||||
|
inputs = [
|
||||||
|
{"documents": [Document(content="text document 1", content_type="text", score=0.2)], "root_node": "File"},
|
||||||
|
{"documents": [Document(content="text document 2", content_type="text", score=None)], "root_node": "File"},
|
||||||
|
]
|
||||||
|
join_docs = JoinDocuments()
|
||||||
|
result, _ = join_docs.run(inputs)
|
||||||
|
assert result["root_node"] == "File"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user