Add validation for root node in Pipeline (#987)

This commit is contained in:
oryx1729 2021-04-21 12:18:33 +02:00 committed by GitHub
parent 8c1e411380
commit 7269530e45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 2 deletions

View File

@ -24,7 +24,7 @@ from haystack.graph_retriever.base import BaseGraphRetriever
logger = logging.getLogger(__name__)
class Pipeline(ABC):
class Pipeline:
"""
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
@ -65,6 +65,9 @@ class Pipeline(ABC):
self.graph.add_node(name, component=component, inputs=inputs)
if len(self.graph.nodes) == 2: # first node added; connect with Root
assert len(inputs) == 1 and inputs[0].split(".")[0] == self.root_node_id, \
f"The '{name}' node can only input from {self.root_node_id}. " \
f"Set the 'inputs' parameter to ['{self.root_node_id}']"
self.graph.add_edge(self.root_node_id, name, label="output_1")
return

View File

@ -29,7 +29,9 @@ def test_load_yaml(document_store_with_docs):
@pytest.mark.slow
@pytest.mark.elasticsearch
@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize(
"retriever_with_docs, document_store_with_docs", [("elasticsearch", "elasticsearch")], indirect=True
)
def test_graph_creation(reader, retriever_with_docs, document_store_with_docs):
pipeline = Pipeline()
pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["Query"])
@ -43,6 +45,10 @@ def test_graph_creation(reader, retriever_with_docs, document_store_with_docs):
with pytest.raises(Exception):
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["InvalidNode"])
with pytest.raises(Exception):
pipeline = Pipeline()
pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["InvalidNode"])
@pytest.mark.slow
@pytest.mark.elasticsearch