diff --git a/haystack/pipeline.py b/haystack/pipeline.py index bb9abde49..c1560395b 100644 --- a/haystack/pipeline.py +++ b/haystack/pipeline.py @@ -24,7 +24,7 @@ from haystack.graph_retriever.base import BaseGraphRetriever logger = logging.getLogger(__name__) -class Pipeline(ABC): +class Pipeline: """ Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components. @@ -65,6 +65,9 @@ class Pipeline(ABC): self.graph.add_node(name, component=component, inputs=inputs) if len(self.graph.nodes) == 2: # first node added; connect with Root + assert len(inputs) == 1 and inputs[0].split(".")[0] == self.root_node_id, \ + f"The '{name}' node can only input from {self.root_node_id}. " \ + f"Set the 'inputs' parameter to ['{self.root_node_id}']" self.graph.add_edge(self.root_node_id, name, label="output_1") return diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 41a6bb951..8a8d44634 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -29,7 +29,9 @@ def test_load_yaml(document_store_with_docs): @pytest.mark.slow @pytest.mark.elasticsearch -@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch"], indirect=True) +@pytest.mark.parametrize( + "retriever_with_docs, document_store_with_docs", [("elasticsearch", "elasticsearch")], indirect=True +) def test_graph_creation(reader, retriever_with_docs, document_store_with_docs): pipeline = Pipeline() pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["Query"]) @@ -43,6 +45,10 @@ def test_graph_creation(reader, retriever_with_docs, document_store_with_docs): with pytest.raises(Exception): pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["InvalidNode"]) + with pytest.raises(Exception): + pipeline = Pipeline() + pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["InvalidNode"]) + @pytest.mark.slow @pytest.mark.elasticsearch