Fix execution of Pipelines with parallel nodes (#901)

This commit is contained in:
oryx1729 2021-03-18 12:41:30 +01:00 committed by GitHub
parent 24d0c4d42d
commit e9f0076dbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 134 additions and 40 deletions

View File

@ -346,7 +346,7 @@ Note that pipelines with split or merge nodes are currently not supported.
## JoinDocuments Objects
```python
class JoinDocuments()
class JoinDocuments(BaseComponent)
```
A node to join documents outputted by multiple retriever nodes.

View File

@ -1,5 +1,7 @@
from abc import ABC
import logging
import os
import traceback
from abc import ABC
from copy import deepcopy
from pathlib import Path
from typing import List, Optional, Dict
@ -17,6 +19,9 @@ from haystack.summarizer.base import BaseSummarizer
from haystack.translator.base import BaseTranslator
logger = logging.getLogger(__name__)
class Pipeline(ABC):
"""
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
@ -100,33 +105,37 @@ class Pipeline(ABC):
def run(self, **kwargs):
node_output = None
stack = {
queue = {
self.root_node_id: {"pipeline_type": self.pipeline_type, **kwargs}
} # ordered dict with "node_id" -> "input" mapping that acts as a FIFO stack
nodes_executed = set()
i = -1 # the last item is popped off the stack unless it is a join node with unprocessed predecessors
while stack:
node_id = list(stack.keys())[i]
node_input = stack[node_id]
predecessors = set(self.graph.predecessors(node_id))
if predecessors.issubset(nodes_executed): # only execute if predecessor nodes are executed
nodes_executed.add(node_id)
node_output, stream_id = self.graph.nodes[node_id]["component"].run(**node_input)
stack.pop(node_id)
} # ordered dict with "node_id" -> "input" mapping that acts as a FIFO queue
i = 0 # the first item is popped off the queue unless it is a "join" node with unprocessed predecessors
while queue:
node_id = list(queue.keys())[i]
node_input = queue[node_id]
predecessors = set(nx.ancestors(self.graph, node_id))
if predecessors.isdisjoint(set(queue.keys())): # only execute if predecessor nodes are executed
try:
logger.debug(f"Running node `{node_id}` with input `{node_input}`")
node_output, stream_id = self.graph.nodes[node_id]["component"].run(**node_input)
except Exception as e:
tb = traceback.format_exc()
raise Exception(f"Exception while running node `{node_id}` with input `{node_input}`: {e}, full stack trace: {tb}")
queue.pop(node_id)
next_nodes = self.get_next_nodes(node_id, stream_id)
for n in next_nodes: # add successor nodes with corresponding inputs to the stack
if stack.get(n): # concatenate inputs if it's a join node
existing_input = stack[n]
for n in next_nodes: # add successor nodes with corresponding inputs to the queue
if queue.get(n): # concatenate inputs if it's a join node
existing_input = queue[n]
if "inputs" not in existing_input.keys():
updated_input = {"inputs": [existing_input, node_output]}
else:
updated_input = existing_input["inputs"].append(node_output)
stack[n] = updated_input
existing_input["inputs"].append(node_output)
updated_input = existing_input
queue[n] = updated_input
else:
stack[n] = node_output
i = -1
else: # attempt executing lower nodes in the stack as `node_id` has unprocessed predecessors
i -= 1
queue[n] = node_output
i = 0
else:
i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors
return node_output
def get_next_nodes(self, node_id: str, stream_id: str):
@ -134,7 +143,7 @@ class Pipeline(ABC):
next_nodes = [
next_node
for _, next_node, data in current_node_edges
if not stream_id or data["label"] == stream_id
if not stream_id or data["label"] == stream_id or stream_id == "output_all"
]
return next_nodes
@ -239,22 +248,26 @@ class Pipeline(ABC):
:param definitions: dict containing definitions of all components retrieved from the YAML.
:param components: dict containing component objects.
"""
if name in components.keys(): # check if component is already loaded.
return components[name]
try:
if name in components.keys(): # check if component is already loaded.
return components[name]
component_params = definitions[name]["params"]
component_type = definitions[name]["type"]
component_params = definitions[name]["params"]
component_type = definitions[name]["type"]
logger.debug(f"Loading component `{name}` of type `{definitions[name]['type']}`")
for key, value in component_params.items():
# Component params can reference to other components. For instance, a Retriever can reference a
# DocumentStore defined in the YAML. All references should be recursively resolved.
if value in definitions.keys(): # check if the param value is a reference to another component.
if value not in components.keys(): # check if the referenced component is already loaded.
cls._load_or_get_component(name=value, definitions=definitions, components=components)
component_params[key] = components[value] # substitute reference (string) with the component object.
for key, value in component_params.items():
# Component params can reference to other components. For instance, a Retriever can reference a
# DocumentStore defined in the YAML. All references should be recursively resolved.
if value in definitions.keys(): # check if the param value is a reference to another component.
if value not in components.keys(): # check if the referenced component is already loaded.
cls._load_or_get_component(name=value, definitions=definitions, components=components)
component_params[key] = components[value] # substitute reference (string) with the component object.
instance = BaseComponent.load_from_args(component_type=component_type, **component_params)
components[name] = instance
instance = BaseComponent.load_from_args(component_type=component_type, **component_params)
components[name] = instance
except Exception as e:
raise Exception(f"Failed loading pipeline component '{name}': {e}")
return instance
@classmethod
@ -524,7 +537,7 @@ class RootNode:
return kwargs, "output_1"
class JoinDocuments:
class JoinDocuments(BaseComponent):
"""
A node to join documents outputted by multiple retriever nodes.
@ -583,5 +596,5 @@ class JoinDocuments:
documents = sorted(document_map.values(), key=lambda d: d.score, reverse=True)
if self.top_k:
documents = documents[: self.top_k]
output = {"query": inputs[0]["query"], "documents": documents}
output = {"query": inputs[0]["query"], "documents": documents, "labels": inputs[0].get("labels", None)}
return output, "output_1"

View File

@ -264,4 +264,85 @@ def test_parallel_paths_in_pipeline_graph():
pipeline.add_node(name="D", component=D(), inputs=["B"])
pipeline.add_node(name="E", component=JoinNode(), inputs=["C", "D"])
output = pipeline.run(query="test")
assert output["output"] == "ABDABC"
assert output["output"] == "ABCABD"
def test_parallel_paths_in_pipeline_graph_with_branching():
class AWithOutput1(RootNode):
outgoing_edges = 2
def run(self, **kwargs):
kwargs["output"] = "A"
return kwargs, "output_1"
class AWithOutput2(RootNode):
outgoing_edges = 2
def run(self, **kwargs):
kwargs["output"] = "A"
return kwargs, "output_2"
class AWithOutputAll(RootNode):
outgoing_edges = 2
def run(self, **kwargs):
kwargs["output"] = "A"
return kwargs, "output_all"
class B(RootNode):
def run(self, **kwargs):
kwargs["output"] += "B"
return kwargs, "output_1"
class C(RootNode):
def run(self, **kwargs):
kwargs["output"] += "C"
return kwargs, "output_1"
class D(RootNode):
def run(self, **kwargs):
kwargs["output"] += "D"
return kwargs, "output_1"
class E(RootNode):
def run(self, **kwargs):
kwargs["output"] += "E"
return kwargs, "output_1"
class JoinNode(RootNode):
def run(self, **kwargs):
if kwargs.get("inputs"):
kwargs["output"] = ""
for input_dict in kwargs["inputs"]:
kwargs["output"] += (input_dict["output"])
return kwargs, "output_1"
pipeline = Pipeline()
pipeline.add_node(name="A", component=AWithOutput1(), inputs=["Query"])
pipeline.add_node(name="B", component=B(), inputs=["A.output_1"])
pipeline.add_node(name="C", component=C(), inputs=["A.output_2"])
pipeline.add_node(name="D", component=E(), inputs=["B"])
pipeline.add_node(name="E", component=D(), inputs=["B"])
pipeline.add_node(name="F", component=JoinNode(), inputs=["D", "E", "C"])
output = pipeline.run(query="test")
assert output["output"] == "ABEABD"
pipeline = Pipeline()
pipeline.add_node(name="A", component=AWithOutput2(), inputs=["Query"])
pipeline.add_node(name="B", component=B(), inputs=["A.output_1"])
pipeline.add_node(name="C", component=C(), inputs=["A.output_2"])
pipeline.add_node(name="D", component=E(), inputs=["B"])
pipeline.add_node(name="E", component=D(), inputs=["B"])
pipeline.add_node(name="F", component=JoinNode(), inputs=["D", "E", "C"])
output = pipeline.run(query="test")
assert output["output"] == "AC"
pipeline = Pipeline()
pipeline.add_node(name="A", component=AWithOutputAll(), inputs=["Query"])
pipeline.add_node(name="B", component=B(), inputs=["A.output_1"])
pipeline.add_node(name="C", component=C(), inputs=["A.output_2"])
pipeline.add_node(name="D", component=E(), inputs=["B"])
pipeline.add_node(name="E", component=D(), inputs=["B"])
pipeline.add_node(name="F", component=JoinNode(), inputs=["D", "E", "C"])
output = pipeline.run(query="test")
assert output["output"] == "ACABEABD"