mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-31 03:46:30 +00:00
Fix Pipeline.components (#2215)
* add components property, improve get_document_store() * Update Documentation & Code Style * use pipeline.get_document_store() instead of retriever.document_store * add tests * Update Documentation & Code Style * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
d1b7761504
commit
fe03ca70de
@ -32,6 +32,7 @@ except:
|
||||
from haystack import __version__
|
||||
from haystack.schema import EvaluationResult, MultiLabel, Document
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
|
||||
|
||||
@ -421,7 +422,14 @@ class Pipeline(BasePipeline):
|
||||
def __init__(self):
|
||||
self.graph = DiGraph()
|
||||
self.root_node = None
|
||||
self.components: dict = {}
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
return {
|
||||
name: attributes["component"]
|
||||
for name, attributes in self.graph.nodes.items()
|
||||
if not isinstance(attributes["component"], RootNode)
|
||||
}
|
||||
|
||||
def add_node(self, component, name: str, inputs: List[str]):
|
||||
"""
|
||||
@ -864,6 +872,11 @@ class Pipeline(BasePipeline):
|
||||
:return: Instance of DocumentStore or None
|
||||
"""
|
||||
matches = self.get_nodes_by_class(class_type=BaseDocumentStore)
|
||||
if len(matches) == 0:
|
||||
matches = list(
|
||||
set([retriever.document_store for retriever in self.get_nodes_by_class(class_type=BaseRetriever)])
|
||||
)
|
||||
|
||||
if len(matches) > 1:
|
||||
raise Exception(f"Multiple Document Stores found in Pipeline: {matches}")
|
||||
if len(matches) == 0:
|
||||
|
@ -27,9 +27,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH), pipeline_name=QUERY_PIPELINE_NAME)
|
||||
# TODO make this generic for other pipelines with different naming
|
||||
RETRIEVER = PIPELINE.get_node(name="Retriever")
|
||||
DOCUMENT_STORE = RETRIEVER.document_store if RETRIEVER else None
|
||||
DOCUMENT_STORE = PIPELINE.get_document_store()
|
||||
logging.info(f"Loaded pipeline nodes: {PIPELINE.graph.nodes.keys()}")
|
||||
|
||||
concurrency_limiter = RequestLimiter(CONCURRENT_REQUEST_PER_WORKER)
|
||||
|
@ -7,8 +7,10 @@ import pytest
|
||||
import responses
|
||||
|
||||
from haystack import __version__
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore
|
||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.nodes.retriever.sparse import ElasticsearchRetriever
|
||||
from haystack.pipelines import (
|
||||
Pipeline,
|
||||
@ -191,6 +193,7 @@ def test_load_from_deepset_cloud_query():
|
||||
document_store = retriever.document_store
|
||||
assert isinstance(retriever, ElasticsearchRetriever)
|
||||
assert isinstance(document_store, DeepsetCloudDocumentStore)
|
||||
assert document_store == query_pipeline.get_document_store()
|
||||
|
||||
prediction = query_pipeline.run(query="man on horse", params={})
|
||||
|
||||
@ -552,6 +555,140 @@ def test_parallel_paths_in_pipeline_graph_with_branching():
|
||||
assert output["output"] == "ACABEABD"
|
||||
|
||||
|
||||
def test_pipeline_components():
|
||||
class Node(RootNode):
|
||||
def run(self):
|
||||
test = "test"
|
||||
return {"test": test}, "output_1"
|
||||
|
||||
a = Node()
|
||||
b = Node()
|
||||
c = Node()
|
||||
d = Node()
|
||||
e = Node()
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=a, inputs=["Query"])
|
||||
pipeline.add_node(name="B", component=b, inputs=["A"])
|
||||
pipeline.add_node(name="C", component=c, inputs=["B"])
|
||||
pipeline.add_node(name="D", component=d, inputs=["C"])
|
||||
pipeline.add_node(name="E", component=e, inputs=["D"])
|
||||
assert len(pipeline.components) == 5
|
||||
assert pipeline.components["A"] == a
|
||||
assert pipeline.components["B"] == b
|
||||
assert pipeline.components["C"] == c
|
||||
assert pipeline.components["D"] == d
|
||||
assert pipeline.components["E"] == e
|
||||
|
||||
|
||||
def test_pipeline_get_document_store_from_components():
|
||||
class DummyDocumentStore(BaseDocumentStore):
|
||||
pass
|
||||
|
||||
doc_store = DummyDocumentStore()
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=doc_store, inputs=["File"])
|
||||
|
||||
assert doc_store == pipeline.get_document_store()
|
||||
|
||||
|
||||
def test_pipeline_get_document_store_from_components_multiple_doc_stores():
|
||||
class DummyDocumentStore(BaseDocumentStore):
|
||||
pass
|
||||
|
||||
doc_store_a = DummyDocumentStore()
|
||||
doc_store_b = DummyDocumentStore()
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=doc_store_a, inputs=["File"])
|
||||
pipeline.add_node(name="B", component=doc_store_b, inputs=["File"])
|
||||
|
||||
with pytest.raises(Exception, match="Multiple Document Stores found in Pipeline"):
|
||||
pipeline.get_document_store()
|
||||
|
||||
|
||||
def test_pipeline_get_document_store_from_retriever():
|
||||
class DummyRetriever(BaseRetriever):
|
||||
def __init__(self, document_store):
|
||||
self.document_store = document_store
|
||||
|
||||
def run(self):
|
||||
test = "test"
|
||||
return {"test": test}, "output_1"
|
||||
|
||||
class DummyDocumentStore(BaseDocumentStore):
|
||||
pass
|
||||
|
||||
doc_store = DummyDocumentStore()
|
||||
retriever = DummyRetriever(document_store=doc_store)
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=retriever, inputs=["Query"])
|
||||
|
||||
assert doc_store == pipeline.get_document_store()
|
||||
|
||||
|
||||
def test_pipeline_get_document_store_from_dual_retriever():
|
||||
class DummyRetriever(BaseRetriever):
|
||||
def __init__(self, document_store):
|
||||
self.document_store = document_store
|
||||
|
||||
def run(self):
|
||||
test = "test"
|
||||
return {"test": test}, "output_1"
|
||||
|
||||
class DummyDocumentStore(BaseDocumentStore):
|
||||
pass
|
||||
|
||||
class JoinNode(RootNode):
|
||||
def run(self, output=None, inputs=None):
|
||||
if inputs:
|
||||
output = ""
|
||||
for input_dict in inputs:
|
||||
output += input_dict["output"]
|
||||
return {"output": output}, "output_1"
|
||||
|
||||
doc_store = DummyDocumentStore()
|
||||
retriever_a = DummyRetriever(document_store=doc_store)
|
||||
retriever_b = DummyRetriever(document_store=doc_store)
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=retriever_a, inputs=["Query"])
|
||||
pipeline.add_node(name="B", component=retriever_b, inputs=["Query"])
|
||||
pipeline.add_node(name="C", component=JoinNode(), inputs=["A", "B"])
|
||||
|
||||
assert doc_store == pipeline.get_document_store()
|
||||
|
||||
|
||||
def test_pipeline_get_document_store_multiple_doc_stores_from_dual_retriever():
|
||||
class DummyRetriever(BaseRetriever):
|
||||
def __init__(self, document_store):
|
||||
self.document_store = document_store
|
||||
|
||||
def run(self):
|
||||
test = "test"
|
||||
return {"test": test}, "output_1"
|
||||
|
||||
class DummyDocumentStore(BaseDocumentStore):
|
||||
pass
|
||||
|
||||
class JoinNode(RootNode):
|
||||
def run(self, output=None, inputs=None):
|
||||
if inputs:
|
||||
output = ""
|
||||
for input_dict in inputs:
|
||||
output += input_dict["output"]
|
||||
return {"output": output}, "output_1"
|
||||
|
||||
doc_store_a = DummyDocumentStore()
|
||||
doc_store_b = DummyDocumentStore()
|
||||
retriever_a = DummyRetriever(document_store=doc_store_a)
|
||||
retriever_b = DummyRetriever(document_store=doc_store_b)
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=retriever_a, inputs=["Query"])
|
||||
pipeline.add_node(name="B", component=retriever_b, inputs=["Query"])
|
||||
pipeline.add_node(name="C", component=JoinNode(), inputs=["A", "B"])
|
||||
|
||||
with pytest.raises(Exception, match="Multiple Document Stores found in Pipeline"):
|
||||
pipeline.get_document_store()
|
||||
|
||||
|
||||
def test_existing_faiss_document_store():
|
||||
clean_faiss_document_store()
|
||||
|
||||
|
@ -405,8 +405,7 @@ def test_existing_faiss_document_store():
|
||||
SAMPLES_PATH / "pipeline" / "test_pipeline_faiss_retrieval.yaml", pipeline_name="query_pipeline"
|
||||
)
|
||||
|
||||
retriever = pipeline.get_node("DPRRetriever")
|
||||
existing_document_store = retriever.document_store
|
||||
existing_document_store = pipeline.get_document_store()
|
||||
faiss_index = existing_document_store.faiss_indexes["document"]
|
||||
assert faiss_index.ntotal == 2
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user