Add validation for root node in Pipeline (#987)

This commit is contained in:
oryx1729 2021-04-21 12:18:33 +02:00 committed by GitHub
parent 8c1e411380
commit 7269530e45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 2 deletions

View File

@ -24,7 +24,7 @@ from haystack.graph_retriever.base import BaseGraphRetriever
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Pipeline(ABC): class Pipeline:
""" """
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components. Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
@ -65,6 +65,9 @@ class Pipeline(ABC):
self.graph.add_node(name, component=component, inputs=inputs) self.graph.add_node(name, component=component, inputs=inputs)
if len(self.graph.nodes) == 2: # first node added; connect with Root if len(self.graph.nodes) == 2: # first node added; connect with Root
assert len(inputs) == 1 and inputs[0].split(".")[0] == self.root_node_id, \
f"The '{name}' node can only input from {self.root_node_id}. " \
f"Set the 'inputs' parameter to ['{self.root_node_id}']"
self.graph.add_edge(self.root_node_id, name, label="output_1") self.graph.add_edge(self.root_node_id, name, label="output_1")
return return

View File

@ -29,7 +29,9 @@ def test_load_yaml(document_store_with_docs):
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.elasticsearch @pytest.mark.elasticsearch
@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch"], indirect=True) @pytest.mark.parametrize(
"retriever_with_docs, document_store_with_docs", [("elasticsearch", "elasticsearch")], indirect=True
)
def test_graph_creation(reader, retriever_with_docs, document_store_with_docs): def test_graph_creation(reader, retriever_with_docs, document_store_with_docs):
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["Query"]) pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["Query"])
@ -43,6 +45,10 @@ def test_graph_creation(reader, retriever_with_docs, document_store_with_docs):
with pytest.raises(Exception): with pytest.raises(Exception):
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["InvalidNode"]) pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["InvalidNode"])
with pytest.raises(Exception):
pipeline = Pipeline()
pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["InvalidNode"])
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.elasticsearch @pytest.mark.elasticsearch