From fe03ca70de0d47ce6a9be9dc68e66a201ae1f1d4 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Tue, 22 Feb 2022 15:01:07 +0100 Subject: [PATCH] Fix Pipeline.components (#2215) * add components property, improve get_document_store() * Update Documentation & Code Style * use pipeline.get_document_store() instead of retriever.document_store * add tests * Update Documentation & Code Style * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- haystack/pipelines/base.py | 15 +++- rest_api/controller/search.py | 4 +- test/test_pipeline.py | 137 ++++++++++++++++++++++++++++++++ test/test_standard_pipelines.py | 3 +- 4 files changed, 153 insertions(+), 6 deletions(-) diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py index b733d1949..7e404e285 100644 --- a/haystack/pipelines/base.py +++ b/haystack/pipelines/base.py @@ -32,6 +32,7 @@ except: from haystack import __version__ from haystack.schema import EvaluationResult, MultiLabel, Document from haystack.nodes.base import BaseComponent +from haystack.nodes.retriever.base import BaseRetriever from haystack.document_stores.base import BaseDocumentStore @@ -421,7 +422,14 @@ class Pipeline(BasePipeline): def __init__(self): self.graph = DiGraph() self.root_node = None - self.components: dict = {} + + @property + def components(self): + return { + name: attributes["component"] + for name, attributes in self.graph.nodes.items() + if not isinstance(attributes["component"], RootNode) + } def add_node(self, component, name: str, inputs: List[str]): """ @@ -864,6 +872,11 @@ class Pipeline(BasePipeline): :return: Instance of DocumentStore or None """ matches = self.get_nodes_by_class(class_type=BaseDocumentStore) + if len(matches) == 0: + matches = list( + set([retriever.document_store for retriever in self.get_nodes_by_class(class_type=BaseRetriever)]) + ) + if len(matches) > 1: raise Exception(f"Multiple Document Stores found in Pipeline: {matches}") if len(matches) == 0: diff --git a/rest_api/controller/search.py b/rest_api/controller/search.py index e624e7110..0c134b4a6 100644 --- a/rest_api/controller/search.py +++ b/rest_api/controller/search.py @@ -27,9 +27,7 @@ router = APIRouter() PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH), pipeline_name=QUERY_PIPELINE_NAME) -# TODO make this generic for other pipelines with different naming -RETRIEVER = PIPELINE.get_node(name="Retriever") -DOCUMENT_STORE = RETRIEVER.document_store if RETRIEVER else None +DOCUMENT_STORE = PIPELINE.get_document_store() logging.info(f"Loaded pipeline nodes: {PIPELINE.graph.nodes.keys()}") concurrency_limiter = RequestLimiter(CONCURRENT_REQUEST_PER_WORKER) diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 695631977..7f65b07eb 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -7,8 +7,10 @@ import pytest import responses from haystack import __version__ +from haystack.document_stores.base import BaseDocumentStore from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore +from haystack.nodes.retriever.base import BaseRetriever from haystack.nodes.retriever.sparse import ElasticsearchRetriever from haystack.pipelines import ( Pipeline, @@ -191,6 +193,7 @@ def test_load_from_deepset_cloud_query(): document_store = retriever.document_store assert isinstance(retriever, ElasticsearchRetriever) assert isinstance(document_store, DeepsetCloudDocumentStore) + assert document_store == query_pipeline.get_document_store() prediction = query_pipeline.run(query="man on horse", params={}) @@ -552,6 +555,140 @@ def test_parallel_paths_in_pipeline_graph_with_branching(): assert output["output"] == "ACABEABD" +def test_pipeline_components(): + class Node(RootNode): + def run(self): + test = "test" + return {"test": test}, "output_1" + + a = Node() + b = Node() + c = Node() + d = Node() + e = Node() + pipeline = Pipeline() + pipeline.add_node(name="A", component=a, inputs=["Query"]) + pipeline.add_node(name="B", component=b, inputs=["A"]) + pipeline.add_node(name="C", component=c, inputs=["B"]) + pipeline.add_node(name="D", component=d, inputs=["C"]) + pipeline.add_node(name="E", component=e, inputs=["D"]) + assert len(pipeline.components) == 5 + assert pipeline.components["A"] == a + assert pipeline.components["B"] == b + assert pipeline.components["C"] == c + assert pipeline.components["D"] == d + assert pipeline.components["E"] == e + + +def test_pipeline_get_document_store_from_components(): + class DummyDocumentStore(BaseDocumentStore): + pass + + doc_store = DummyDocumentStore() + pipeline = Pipeline() + pipeline.add_node(name="A", component=doc_store, inputs=["File"]) + + assert doc_store == pipeline.get_document_store() + + +def test_pipeline_get_document_store_from_components_multiple_doc_stores(): + class DummyDocumentStore(BaseDocumentStore): + pass + + doc_store_a = DummyDocumentStore() + doc_store_b = DummyDocumentStore() + pipeline = Pipeline() + pipeline.add_node(name="A", component=doc_store_a, inputs=["File"]) + pipeline.add_node(name="B", component=doc_store_b, inputs=["File"]) + + with pytest.raises(Exception, match="Multiple Document Stores found in Pipeline"): + pipeline.get_document_store() + + +def test_pipeline_get_document_store_from_retriever(): + class DummyRetriever(BaseRetriever): + def __init__(self, document_store): + self.document_store = document_store + + def run(self): + test = "test" + return {"test": test}, "output_1" + + class DummyDocumentStore(BaseDocumentStore): + pass + + doc_store = DummyDocumentStore() + retriever = DummyRetriever(document_store=doc_store) + pipeline = Pipeline() + pipeline.add_node(name="A", component=retriever, inputs=["Query"]) + + assert doc_store == pipeline.get_document_store() + + +def test_pipeline_get_document_store_from_dual_retriever(): + class DummyRetriever(BaseRetriever): + def __init__(self, document_store): + self.document_store = document_store + + def run(self): + test = "test" + return {"test": test}, "output_1" + + class DummyDocumentStore(BaseDocumentStore): + pass + + class JoinNode(RootNode): + def run(self, output=None, inputs=None): + if inputs: + output = "" + for input_dict in inputs: + output += input_dict["output"] + return {"output": output}, "output_1" + + doc_store = DummyDocumentStore() + retriever_a = DummyRetriever(document_store=doc_store) + retriever_b = DummyRetriever(document_store=doc_store) + pipeline = Pipeline() + pipeline.add_node(name="A", component=retriever_a, inputs=["Query"]) + pipeline.add_node(name="B", component=retriever_b, inputs=["Query"]) + pipeline.add_node(name="C", component=JoinNode(), inputs=["A", "B"]) + + assert doc_store == pipeline.get_document_store() + + +def test_pipeline_get_document_store_multiple_doc_stores_from_dual_retriever(): + class DummyRetriever(BaseRetriever): + def __init__(self, document_store): + self.document_store = document_store + + def run(self): + test = "test" + return {"test": test}, "output_1" + + class DummyDocumentStore(BaseDocumentStore): + pass + + class JoinNode(RootNode): + def run(self, output=None, inputs=None): + if inputs: + output = "" + for input_dict in inputs: + output += input_dict["output"] + return {"output": output}, "output_1" + + doc_store_a = DummyDocumentStore() + doc_store_b = DummyDocumentStore() + retriever_a = DummyRetriever(document_store=doc_store_a) + retriever_b = DummyRetriever(document_store=doc_store_b) + pipeline = Pipeline() + pipeline.add_node(name="A", component=retriever_a, inputs=["Query"]) + pipeline.add_node(name="B", component=retriever_b, inputs=["Query"]) + pipeline.add_node(name="C", component=JoinNode(), inputs=["A", "B"]) + + with pytest.raises(Exception, match="Multiple Document Stores found in Pipeline"): + pipeline.get_document_store() + + def test_existing_faiss_document_store(): clean_faiss_document_store() diff --git a/test/test_standard_pipelines.py b/test/test_standard_pipelines.py index 7109942b6..830df6dfe 100644 --- a/test/test_standard_pipelines.py +++ b/test/test_standard_pipelines.py @@ -405,8 +405,7 @@ def test_existing_faiss_document_store(): SAMPLES_PATH / "pipeline" / "test_pipeline_faiss_retrieval.yaml", pipeline_name="query_pipeline" ) - retriever = pipeline.get_node("DPRRetriever") - existing_document_store = retriever.document_store + existing_document_store = pipeline.get_document_store() faiss_index = existing_document_store.faiss_indexes["document"] assert faiss_index.ntotal == 2