Fix Pipeline.components (#2215)

* add components property, improve get_document_store()

* Update Documentation & Code Style

* use pipeline.get_document_store() instead of retriever.document_store

* add tests

* Update Documentation & Code Style

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
tstadel 2022-02-22 15:01:07 +01:00 committed by GitHub
parent d1b7761504
commit fe03ca70de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 153 additions and 6 deletions

View File

@ -32,6 +32,7 @@ except:
from haystack import __version__
from haystack.schema import EvaluationResult, MultiLabel, Document
from haystack.nodes.base import BaseComponent
from haystack.nodes.retriever.base import BaseRetriever
from haystack.document_stores.base import BaseDocumentStore
@ -421,7 +422,14 @@ class Pipeline(BasePipeline):
def __init__(self):
self.graph = DiGraph()
self.root_node = None
self.components: dict = {}
@property
def components(self):
return {
name: attributes["component"]
for name, attributes in self.graph.nodes.items()
if not isinstance(attributes["component"], RootNode)
}
def add_node(self, component, name: str, inputs: List[str]):
"""
@ -864,6 +872,11 @@ class Pipeline(BasePipeline):
:return: Instance of DocumentStore or None
"""
matches = self.get_nodes_by_class(class_type=BaseDocumentStore)
if len(matches) == 0:
matches = list(
set([retriever.document_store for retriever in self.get_nodes_by_class(class_type=BaseRetriever)])
)
if len(matches) > 1:
raise Exception(f"Multiple Document Stores found in Pipeline: {matches}")
if len(matches) == 0:

View File

@ -27,9 +27,7 @@ router = APIRouter()
PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH), pipeline_name=QUERY_PIPELINE_NAME)
# TODO make this generic for other pipelines with different naming
RETRIEVER = PIPELINE.get_node(name="Retriever")
DOCUMENT_STORE = RETRIEVER.document_store if RETRIEVER else None
DOCUMENT_STORE = PIPELINE.get_document_store()
logging.info(f"Loaded pipeline nodes: {PIPELINE.graph.nodes.keys()}")
concurrency_limiter = RequestLimiter(CONCURRENT_REQUEST_PER_WORKER)

View File

@ -7,8 +7,10 @@ import pytest
import responses
from haystack import __version__
from haystack.document_stores.base import BaseDocumentStore
from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
from haystack.nodes.retriever.base import BaseRetriever
from haystack.nodes.retriever.sparse import ElasticsearchRetriever
from haystack.pipelines import (
Pipeline,
@ -191,6 +193,7 @@ def test_load_from_deepset_cloud_query():
document_store = retriever.document_store
assert isinstance(retriever, ElasticsearchRetriever)
assert isinstance(document_store, DeepsetCloudDocumentStore)
assert document_store == query_pipeline.get_document_store()
prediction = query_pipeline.run(query="man on horse", params={})
@ -552,6 +555,140 @@ def test_parallel_paths_in_pipeline_graph_with_branching():
assert output["output"] == "ACABEABD"
def test_pipeline_components():
class Node(RootNode):
def run(self):
test = "test"
return {"test": test}, "output_1"
a = Node()
b = Node()
c = Node()
d = Node()
e = Node()
pipeline = Pipeline()
pipeline.add_node(name="A", component=a, inputs=["Query"])
pipeline.add_node(name="B", component=b, inputs=["A"])
pipeline.add_node(name="C", component=c, inputs=["B"])
pipeline.add_node(name="D", component=d, inputs=["C"])
pipeline.add_node(name="E", component=e, inputs=["D"])
assert len(pipeline.components) == 5
assert pipeline.components["A"] == a
assert pipeline.components["B"] == b
assert pipeline.components["C"] == c
assert pipeline.components["D"] == d
assert pipeline.components["E"] == e
def test_pipeline_get_document_store_from_components():
class DummyDocumentStore(BaseDocumentStore):
pass
doc_store = DummyDocumentStore()
pipeline = Pipeline()
pipeline.add_node(name="A", component=doc_store, inputs=["File"])
assert doc_store == pipeline.get_document_store()
def test_pipeline_get_document_store_from_components_multiple_doc_stores():
class DummyDocumentStore(BaseDocumentStore):
pass
doc_store_a = DummyDocumentStore()
doc_store_b = DummyDocumentStore()
pipeline = Pipeline()
pipeline.add_node(name="A", component=doc_store_a, inputs=["File"])
pipeline.add_node(name="B", component=doc_store_b, inputs=["File"])
with pytest.raises(Exception, match="Multiple Document Stores found in Pipeline"):
pipeline.get_document_store()
def test_pipeline_get_document_store_from_retriever():
class DummyRetriever(BaseRetriever):
def __init__(self, document_store):
self.document_store = document_store
def run(self):
test = "test"
return {"test": test}, "output_1"
class DummyDocumentStore(BaseDocumentStore):
pass
doc_store = DummyDocumentStore()
retriever = DummyRetriever(document_store=doc_store)
pipeline = Pipeline()
pipeline.add_node(name="A", component=retriever, inputs=["Query"])
assert doc_store == pipeline.get_document_store()
def test_pipeline_get_document_store_from_dual_retriever():
class DummyRetriever(BaseRetriever):
def __init__(self, document_store):
self.document_store = document_store
def run(self):
test = "test"
return {"test": test}, "output_1"
class DummyDocumentStore(BaseDocumentStore):
pass
class JoinNode(RootNode):
def run(self, output=None, inputs=None):
if inputs:
output = ""
for input_dict in inputs:
output += input_dict["output"]
return {"output": output}, "output_1"
doc_store = DummyDocumentStore()
retriever_a = DummyRetriever(document_store=doc_store)
retriever_b = DummyRetriever(document_store=doc_store)
pipeline = Pipeline()
pipeline.add_node(name="A", component=retriever_a, inputs=["Query"])
pipeline.add_node(name="B", component=retriever_b, inputs=["Query"])
pipeline.add_node(name="C", component=JoinNode(), inputs=["A", "B"])
assert doc_store == pipeline.get_document_store()
def test_pipeline_get_document_store_multiple_doc_stores_from_dual_retriever():
class DummyRetriever(BaseRetriever):
def __init__(self, document_store):
self.document_store = document_store
def run(self):
test = "test"
return {"test": test}, "output_1"
class DummyDocumentStore(BaseDocumentStore):
pass
class JoinNode(RootNode):
def run(self, output=None, inputs=None):
if inputs:
output = ""
for input_dict in inputs:
output += input_dict["output"]
return {"output": output}, "output_1"
doc_store_a = DummyDocumentStore()
doc_store_b = DummyDocumentStore()
retriever_a = DummyRetriever(document_store=doc_store_a)
retriever_b = DummyRetriever(document_store=doc_store_b)
pipeline = Pipeline()
pipeline.add_node(name="A", component=retriever_a, inputs=["Query"])
pipeline.add_node(name="B", component=retriever_b, inputs=["Query"])
pipeline.add_node(name="C", component=JoinNode(), inputs=["A", "B"])
with pytest.raises(Exception, match="Multiple Document Stores found in Pipeline"):
pipeline.get_document_store()
def test_existing_faiss_document_store():
clean_faiss_document_store()

View File

@ -405,8 +405,7 @@ def test_existing_faiss_document_store():
SAMPLES_PATH / "pipeline" / "test_pipeline_faiss_retrieval.yaml", pipeline_name="query_pipeline"
)
retriever = pipeline.get_node("DPRRetriever")
existing_document_store = retriever.document_store
existing_document_store = pipeline.get_document_store()
faiss_index = existing_document_store.faiss_indexes["document"]
assert faiss_index.ntotal == 2