diff --git a/haystack/nodes/base.py b/haystack/nodes/base.py index e21fc8a8e..d43d92430 100644 --- a/haystack/nodes/base.py +++ b/haystack/nodes/base.py @@ -265,9 +265,9 @@ class BaseComponent(ABC): if all_debug: output["_debug"] = all_debug - # add "extra" args that were not used by the node + # add "extra" args that were not used by the node, but not the 'inputs' value for k, v in arguments.items(): - if k not in output.keys(): + if k not in output.keys() and k != "inputs": output[k] = v output["params"] = params diff --git a/test/pipelines/test_pipeline.py b/test/pipelines/test_pipeline.py index 08a930b1b..b64da8680 100644 --- a/test/pipelines/test_pipeline.py +++ b/test/pipelines/test_pipeline.py @@ -18,9 +18,11 @@ import pandas as pd from haystack import __version__ from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore +from haystack.document_stores.memory import InMemoryDocumentStore from haystack.nodes.other.join_docs import JoinDocuments from haystack.nodes.base import BaseComponent from haystack.nodes.retriever.sparse import BM25Retriever +from haystack.nodes.retriever.sparse import FilterRetriever from haystack.pipelines import ( Pipeline, RootNode, @@ -1999,3 +2001,66 @@ def test_batch_querying_multiple_queries(document_store_with_docs): assert isinstance(result["answers"][0][0], Answer) assert len(result["answers"]) == 2 # Predictions for 2 collections of documents assert len(result["answers"][0]) == 5 # top-k of 5 for collection of docs + + +def test_fix_to_pipeline_execution_when_join_follows_join(): + # wire up 4 retrievers, each with one document + document_store_1 = InMemoryDocumentStore() + retriever_1 = FilterRetriever(document_store_1, scale_score=True) + dicts_1 = [{"content": "Alpha", "score": 0.552}] + document_store_1.write_documents(dicts_1) + + document_store_2 = InMemoryDocumentStore() + retriever_2 = FilterRetriever(document_store_2, scale_score=True) + dicts_2 = [{"content": "Beta", "score": 0.542}] + document_store_2.write_documents(dicts_2) + + document_store_3 = InMemoryDocumentStore() + retriever_3 = FilterRetriever(document_store_3, scale_score=True) + dicts_3 = [{"content": "Gamma", "score": 0.532}] + document_store_3.write_documents(dicts_3) + + document_store_4 = InMemoryDocumentStore() + retriever_4 = FilterRetriever(document_store_4, scale_score=True) + dicts_4 = [{"content": "Delta", "score": 0.512}] + document_store_4.write_documents(dicts_4) + + # wire up a pipeline of the retrievers, with 4-way join + pipeline = Pipeline() + pipeline.add_node(component=retriever_1, name="Retriever1", inputs=["Query"]) + pipeline.add_node(component=retriever_2, name="Retriever2", inputs=["Query"]) + pipeline.add_node(component=retriever_3, name="Retriever3", inputs=["Query"]) + pipeline.add_node(component=retriever_4, name="Retriever4", inputs=["Query"]) + pipeline.add_node( + component=JoinDocuments(weights=[0.25, 0.25, 0.25, 0.25], join_mode="merge"), + name="Join", + inputs=["Retriever1", "Retriever2", "Retriever3", "Retriever4"], + ) + + res = pipeline.run(query="Alpha Beta Gamma Delta") + documents = res["documents"] + assert len(documents) == 4 # all four documents should be found + + # wire up a pipeline of the retrievers, with join following join + pipeline = Pipeline() + pipeline.add_node(component=retriever_1, name="Retriever1", inputs=["Query"]) + pipeline.add_node(component=retriever_2, name="Retriever2", inputs=["Query"]) + pipeline.add_node(component=retriever_3, name="Retriever3", inputs=["Query"]) + pipeline.add_node(component=retriever_4, name="Retriever4", inputs=["Query"]) + pipeline.add_node( + component=JoinDocuments(weights=[0.5, 0.5], join_mode="merge"), + name="Join12", + inputs=["Retriever1", "Retriever2"], + ) + pipeline.add_node( + component=JoinDocuments(weights=[0.5, 0.5], join_mode="merge"), + name="Join34", + inputs=["Retriever3", "Retriever4"], + ) + pipeline.add_node( + component=JoinDocuments(weights=[0.5, 0.5], join_mode="merge"), name="JoinFinal", inputs=["Join12", "Join34"] + ) + + res = pipeline.run(query="Alpha Beta Gamma Delta") + documents = res["documents"] + assert len(documents) == 4 # all four documents should be found