mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-09 21:56:46 +00:00
Fix execution of Pipelines with parallel nodes (#901)
This commit is contained in:
parent
24d0c4d42d
commit
e9f0076dbd
@ -346,7 +346,7 @@ Note that pipelines with split or merge nodes are currently not supported.
|
||||
## JoinDocuments Objects
|
||||
|
||||
```python
|
||||
class JoinDocuments()
|
||||
class JoinDocuments(BaseComponent)
|
||||
```
|
||||
|
||||
A node to join documents outputted by multiple retriever nodes.
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from abc import ABC
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
@ -17,6 +19,9 @@ from haystack.summarizer.base import BaseSummarizer
|
||||
from haystack.translator.base import BaseTranslator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Pipeline(ABC):
|
||||
"""
|
||||
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
|
||||
@ -100,33 +105,37 @@ class Pipeline(ABC):
|
||||
|
||||
def run(self, **kwargs):
|
||||
node_output = None
|
||||
stack = {
|
||||
queue = {
|
||||
self.root_node_id: {"pipeline_type": self.pipeline_type, **kwargs}
|
||||
} # ordered dict with "node_id" -> "input" mapping that acts as a FIFO stack
|
||||
nodes_executed = set()
|
||||
i = -1 # the last item is popped off the stack unless it is a join node with unprocessed predecessors
|
||||
while stack:
|
||||
node_id = list(stack.keys())[i]
|
||||
node_input = stack[node_id]
|
||||
predecessors = set(self.graph.predecessors(node_id))
|
||||
if predecessors.issubset(nodes_executed): # only execute if predecessor nodes are executed
|
||||
nodes_executed.add(node_id)
|
||||
node_output, stream_id = self.graph.nodes[node_id]["component"].run(**node_input)
|
||||
stack.pop(node_id)
|
||||
} # ordered dict with "node_id" -> "input" mapping that acts as a FIFO queue
|
||||
i = 0 # the first item is popped off the queue unless it is a "join" node with unprocessed predecessors
|
||||
while queue:
|
||||
node_id = list(queue.keys())[i]
|
||||
node_input = queue[node_id]
|
||||
predecessors = set(nx.ancestors(self.graph, node_id))
|
||||
if predecessors.isdisjoint(set(queue.keys())): # only execute if predecessor nodes are executed
|
||||
try:
|
||||
logger.debug(f"Running node `{node_id}` with input `{node_input}`")
|
||||
node_output, stream_id = self.graph.nodes[node_id]["component"].run(**node_input)
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
raise Exception(f"Exception while running node `{node_id}` with input `{node_input}`: {e}, full stack trace: {tb}")
|
||||
queue.pop(node_id)
|
||||
next_nodes = self.get_next_nodes(node_id, stream_id)
|
||||
for n in next_nodes: # add successor nodes with corresponding inputs to the stack
|
||||
if stack.get(n): # concatenate inputs if it's a join node
|
||||
existing_input = stack[n]
|
||||
for n in next_nodes: # add successor nodes with corresponding inputs to the queue
|
||||
if queue.get(n): # concatenate inputs if it's a join node
|
||||
existing_input = queue[n]
|
||||
if "inputs" not in existing_input.keys():
|
||||
updated_input = {"inputs": [existing_input, node_output]}
|
||||
else:
|
||||
updated_input = existing_input["inputs"].append(node_output)
|
||||
stack[n] = updated_input
|
||||
existing_input["inputs"].append(node_output)
|
||||
updated_input = existing_input
|
||||
queue[n] = updated_input
|
||||
else:
|
||||
stack[n] = node_output
|
||||
i = -1
|
||||
else: # attempt executing lower nodes in the stack as `node_id` has unprocessed predecessors
|
||||
i -= 1
|
||||
queue[n] = node_output
|
||||
i = 0
|
||||
else:
|
||||
i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors
|
||||
return node_output
|
||||
|
||||
def get_next_nodes(self, node_id: str, stream_id: str):
|
||||
@ -134,7 +143,7 @@ class Pipeline(ABC):
|
||||
next_nodes = [
|
||||
next_node
|
||||
for _, next_node, data in current_node_edges
|
||||
if not stream_id or data["label"] == stream_id
|
||||
if not stream_id or data["label"] == stream_id or stream_id == "output_all"
|
||||
]
|
||||
return next_nodes
|
||||
|
||||
@ -239,22 +248,26 @@ class Pipeline(ABC):
|
||||
:param definitions: dict containing definitions of all components retrieved from the YAML.
|
||||
:param components: dict containing component objects.
|
||||
"""
|
||||
if name in components.keys(): # check if component is already loaded.
|
||||
return components[name]
|
||||
try:
|
||||
if name in components.keys(): # check if component is already loaded.
|
||||
return components[name]
|
||||
|
||||
component_params = definitions[name]["params"]
|
||||
component_type = definitions[name]["type"]
|
||||
component_params = definitions[name]["params"]
|
||||
component_type = definitions[name]["type"]
|
||||
logger.debug(f"Loading component `{name}` of type `{definitions[name]['type']}`")
|
||||
|
||||
for key, value in component_params.items():
|
||||
# Component params can reference to other components. For instance, a Retriever can reference a
|
||||
# DocumentStore defined in the YAML. All references should be recursively resolved.
|
||||
if value in definitions.keys(): # check if the param value is a reference to another component.
|
||||
if value not in components.keys(): # check if the referenced component is already loaded.
|
||||
cls._load_or_get_component(name=value, definitions=definitions, components=components)
|
||||
component_params[key] = components[value] # substitute reference (string) with the component object.
|
||||
for key, value in component_params.items():
|
||||
# Component params can reference to other components. For instance, a Retriever can reference a
|
||||
# DocumentStore defined in the YAML. All references should be recursively resolved.
|
||||
if value in definitions.keys(): # check if the param value is a reference to another component.
|
||||
if value not in components.keys(): # check if the referenced component is already loaded.
|
||||
cls._load_or_get_component(name=value, definitions=definitions, components=components)
|
||||
component_params[key] = components[value] # substitute reference (string) with the component object.
|
||||
|
||||
instance = BaseComponent.load_from_args(component_type=component_type, **component_params)
|
||||
components[name] = instance
|
||||
instance = BaseComponent.load_from_args(component_type=component_type, **component_params)
|
||||
components[name] = instance
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed loading pipeline component '{name}': {e}")
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
@ -524,7 +537,7 @@ class RootNode:
|
||||
return kwargs, "output_1"
|
||||
|
||||
|
||||
class JoinDocuments:
|
||||
class JoinDocuments(BaseComponent):
|
||||
"""
|
||||
A node to join documents outputted by multiple retriever nodes.
|
||||
|
||||
@ -583,5 +596,5 @@ class JoinDocuments:
|
||||
documents = sorted(document_map.values(), key=lambda d: d.score, reverse=True)
|
||||
if self.top_k:
|
||||
documents = documents[: self.top_k]
|
||||
output = {"query": inputs[0]["query"], "documents": documents}
|
||||
output = {"query": inputs[0]["query"], "documents": documents, "labels": inputs[0].get("labels", None)}
|
||||
return output, "output_1"
|
||||
|
||||
@ -264,4 +264,85 @@ def test_parallel_paths_in_pipeline_graph():
|
||||
pipeline.add_node(name="D", component=D(), inputs=["B"])
|
||||
pipeline.add_node(name="E", component=JoinNode(), inputs=["C", "D"])
|
||||
output = pipeline.run(query="test")
|
||||
assert output["output"] == "ABDABC"
|
||||
assert output["output"] == "ABCABD"
|
||||
|
||||
|
||||
def test_parallel_paths_in_pipeline_graph_with_branching():
|
||||
class AWithOutput1(RootNode):
|
||||
outgoing_edges = 2
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] = "A"
|
||||
return kwargs, "output_1"
|
||||
|
||||
class AWithOutput2(RootNode):
|
||||
outgoing_edges = 2
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] = "A"
|
||||
return kwargs, "output_2"
|
||||
|
||||
class AWithOutputAll(RootNode):
|
||||
outgoing_edges = 2
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] = "A"
|
||||
return kwargs, "output_all"
|
||||
|
||||
class B(RootNode):
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] += "B"
|
||||
return kwargs, "output_1"
|
||||
|
||||
class C(RootNode):
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] += "C"
|
||||
return kwargs, "output_1"
|
||||
|
||||
class D(RootNode):
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] += "D"
|
||||
return kwargs, "output_1"
|
||||
|
||||
class E(RootNode):
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] += "E"
|
||||
return kwargs, "output_1"
|
||||
|
||||
class JoinNode(RootNode):
|
||||
def run(self, **kwargs):
|
||||
if kwargs.get("inputs"):
|
||||
kwargs["output"] = ""
|
||||
for input_dict in kwargs["inputs"]:
|
||||
kwargs["output"] += (input_dict["output"])
|
||||
return kwargs, "output_1"
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=AWithOutput1(), inputs=["Query"])
|
||||
pipeline.add_node(name="B", component=B(), inputs=["A.output_1"])
|
||||
pipeline.add_node(name="C", component=C(), inputs=["A.output_2"])
|
||||
pipeline.add_node(name="D", component=E(), inputs=["B"])
|
||||
pipeline.add_node(name="E", component=D(), inputs=["B"])
|
||||
pipeline.add_node(name="F", component=JoinNode(), inputs=["D", "E", "C"])
|
||||
output = pipeline.run(query="test")
|
||||
assert output["output"] == "ABEABD"
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=AWithOutput2(), inputs=["Query"])
|
||||
pipeline.add_node(name="B", component=B(), inputs=["A.output_1"])
|
||||
pipeline.add_node(name="C", component=C(), inputs=["A.output_2"])
|
||||
pipeline.add_node(name="D", component=E(), inputs=["B"])
|
||||
pipeline.add_node(name="E", component=D(), inputs=["B"])
|
||||
pipeline.add_node(name="F", component=JoinNode(), inputs=["D", "E", "C"])
|
||||
output = pipeline.run(query="test")
|
||||
assert output["output"] == "AC"
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="A", component=AWithOutputAll(), inputs=["Query"])
|
||||
pipeline.add_node(name="B", component=B(), inputs=["A.output_1"])
|
||||
pipeline.add_node(name="C", component=C(), inputs=["A.output_2"])
|
||||
pipeline.add_node(name="D", component=E(), inputs=["B"])
|
||||
pipeline.add_node(name="E", component=D(), inputs=["B"])
|
||||
pipeline.add_node(name="F", component=JoinNode(), inputs=["D", "E", "C"])
|
||||
output = pipeline.run(query="test")
|
||||
assert output["output"] == "ACABEABD"
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user