diff --git a/docs/_src/api/api/pipelines.md b/docs/_src/api/api/pipelines.md index e73f3e053..c672750e8 100644 --- a/docs/_src/api/api/pipelines.md +++ b/docs/_src/api/api/pipelines.md @@ -346,7 +346,7 @@ Note that pipelines with split or merge nodes are currently not supported. ## JoinDocuments Objects ```python -class JoinDocuments() +class JoinDocuments(BaseComponent) ``` A node to join documents outputted by multiple retriever nodes. diff --git a/haystack/pipeline.py b/haystack/pipeline.py index 9edddd240..e01d4ff4d 100644 --- a/haystack/pipeline.py +++ b/haystack/pipeline.py @@ -1,5 +1,7 @@ -from abc import ABC +import logging import os +import traceback +from abc import ABC from copy import deepcopy from pathlib import Path from typing import List, Optional, Dict @@ -17,6 +19,9 @@ from haystack.summarizer.base import BaseSummarizer from haystack.translator.base import BaseTranslator +logger = logging.getLogger(__name__) + + class Pipeline(ABC): """ Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components. @@ -100,33 +105,37 @@ class Pipeline(ABC): def run(self, **kwargs): node_output = None - stack = { + queue = { self.root_node_id: {"pipeline_type": self.pipeline_type, **kwargs} - } # ordered dict with "node_id" -> "input" mapping that acts as a FIFO stack - nodes_executed = set() - i = -1 # the last item is popped off the stack unless it is a join node with unprocessed predecessors - while stack: - node_id = list(stack.keys())[i] - node_input = stack[node_id] - predecessors = set(self.graph.predecessors(node_id)) - if predecessors.issubset(nodes_executed): # only execute if predecessor nodes are executed - nodes_executed.add(node_id) - node_output, stream_id = self.graph.nodes[node_id]["component"].run(**node_input) - stack.pop(node_id) + } # ordered dict with "node_id" -> "input" mapping that acts as a FIFO queue + i = 0 # the first item is popped off the queue unless it is a "join" node with unprocessed predecessors + while queue: + node_id = list(queue.keys())[i] + node_input = queue[node_id] + predecessors = set(nx.ancestors(self.graph, node_id)) + if predecessors.isdisjoint(set(queue.keys())): # only execute if predecessor nodes are executed + try: + logger.debug(f"Running node `{node_id}` with input `{node_input}`") + node_output, stream_id = self.graph.nodes[node_id]["component"].run(**node_input) + except Exception as e: + tb = traceback.format_exc() + raise Exception(f"Exception while running node `{node_id}` with input `{node_input}`: {e}, full stack trace: {tb}") + queue.pop(node_id) next_nodes = self.get_next_nodes(node_id, stream_id) - for n in next_nodes: # add successor nodes with corresponding inputs to the stack - if stack.get(n): # concatenate inputs if it's a join node - existing_input = stack[n] + for n in next_nodes: # add successor nodes with corresponding inputs to the queue + if queue.get(n): # concatenate inputs if it's a join node + existing_input = queue[n] if "inputs" not in existing_input.keys(): updated_input = {"inputs": [existing_input, node_output]} else: - updated_input = existing_input["inputs"].append(node_output) - stack[n] = updated_input + existing_input["inputs"].append(node_output) + updated_input = existing_input + queue[n] = updated_input else: - stack[n] = node_output - i = -1 - else: # attempt executing lower nodes in the stack as `node_id` has unprocessed predecessors - i -= 1 + queue[n] = node_output + i = 0 + else: + i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors return node_output def get_next_nodes(self, node_id: str, stream_id: str): @@ -134,7 +143,7 @@ class Pipeline(ABC): next_nodes = [ next_node for _, next_node, data in current_node_edges - if not stream_id or data["label"] == stream_id + if not stream_id or data["label"] == stream_id or stream_id == "output_all" ] return next_nodes @@ -239,22 +248,26 @@ class Pipeline(ABC): :param definitions: dict containing definitions of all components retrieved from the YAML. :param components: dict containing component objects. """ - if name in components.keys(): # check if component is already loaded. - return components[name] + try: + if name in components.keys(): # check if component is already loaded. + return components[name] - component_params = definitions[name]["params"] - component_type = definitions[name]["type"] + component_params = definitions[name]["params"] + component_type = definitions[name]["type"] + logger.debug(f"Loading component `{name}` of type `{definitions[name]['type']}`") - for key, value in component_params.items(): - # Component params can reference to other components. For instance, a Retriever can reference a - # DocumentStore defined in the YAML. All references should be recursively resolved. - if value in definitions.keys(): # check if the param value is a reference to another component. - if value not in components.keys(): # check if the referenced component is already loaded. - cls._load_or_get_component(name=value, definitions=definitions, components=components) - component_params[key] = components[value] # substitute reference (string) with the component object. + for key, value in component_params.items(): + # Component params can reference to other components. For instance, a Retriever can reference a + # DocumentStore defined in the YAML. All references should be recursively resolved. + if value in definitions.keys(): # check if the param value is a reference to another component. + if value not in components.keys(): # check if the referenced component is already loaded. + cls._load_or_get_component(name=value, definitions=definitions, components=components) + component_params[key] = components[value] # substitute reference (string) with the component object. - instance = BaseComponent.load_from_args(component_type=component_type, **component_params) - components[name] = instance + instance = BaseComponent.load_from_args(component_type=component_type, **component_params) + components[name] = instance + except Exception as e: + raise Exception(f"Failed loading pipeline component '{name}': {e}") return instance @classmethod @@ -524,7 +537,7 @@ class RootNode: return kwargs, "output_1" -class JoinDocuments: +class JoinDocuments(BaseComponent): """ A node to join documents outputted by multiple retriever nodes. @@ -583,5 +596,5 @@ class JoinDocuments: documents = sorted(document_map.values(), key=lambda d: d.score, reverse=True) if self.top_k: documents = documents[: self.top_k] - output = {"query": inputs[0]["query"], "documents": documents} + output = {"query": inputs[0]["query"], "documents": documents, "labels": inputs[0].get("labels", None)} return output, "output_1" diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 98c666af3..a6093e268 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -264,4 +264,85 @@ def test_parallel_paths_in_pipeline_graph(): pipeline.add_node(name="D", component=D(), inputs=["B"]) pipeline.add_node(name="E", component=JoinNode(), inputs=["C", "D"]) output = pipeline.run(query="test") - assert output["output"] == "ABDABC" + assert output["output"] == "ABCABD" + + +def test_parallel_paths_in_pipeline_graph_with_branching(): + class AWithOutput1(RootNode): + outgoing_edges = 2 + def run(self, **kwargs): + kwargs["output"] = "A" + return kwargs, "output_1" + + class AWithOutput2(RootNode): + outgoing_edges = 2 + def run(self, **kwargs): + kwargs["output"] = "A" + return kwargs, "output_2" + + class AWithOutputAll(RootNode): + outgoing_edges = 2 + def run(self, **kwargs): + kwargs["output"] = "A" + return kwargs, "output_all" + + class B(RootNode): + def run(self, **kwargs): + kwargs["output"] += "B" + return kwargs, "output_1" + + class C(RootNode): + def run(self, **kwargs): + kwargs["output"] += "C" + return kwargs, "output_1" + + class D(RootNode): + def run(self, **kwargs): + kwargs["output"] += "D" + return kwargs, "output_1" + + class E(RootNode): + def run(self, **kwargs): + kwargs["output"] += "E" + return kwargs, "output_1" + + class JoinNode(RootNode): + def run(self, **kwargs): + if kwargs.get("inputs"): + kwargs["output"] = "" + for input_dict in kwargs["inputs"]: + kwargs["output"] += (input_dict["output"]) + return kwargs, "output_1" + + pipeline = Pipeline() + pipeline.add_node(name="A", component=AWithOutput1(), inputs=["Query"]) + pipeline.add_node(name="B", component=B(), inputs=["A.output_1"]) + pipeline.add_node(name="C", component=C(), inputs=["A.output_2"]) + pipeline.add_node(name="D", component=E(), inputs=["B"]) + pipeline.add_node(name="E", component=D(), inputs=["B"]) + pipeline.add_node(name="F", component=JoinNode(), inputs=["D", "E", "C"]) + output = pipeline.run(query="test") + assert output["output"] == "ABEABD" + + pipeline = Pipeline() + pipeline.add_node(name="A", component=AWithOutput2(), inputs=["Query"]) + pipeline.add_node(name="B", component=B(), inputs=["A.output_1"]) + pipeline.add_node(name="C", component=C(), inputs=["A.output_2"]) + pipeline.add_node(name="D", component=E(), inputs=["B"]) + pipeline.add_node(name="E", component=D(), inputs=["B"]) + pipeline.add_node(name="F", component=JoinNode(), inputs=["D", "E", "C"]) + output = pipeline.run(query="test") + assert output["output"] == "AC" + + pipeline = Pipeline() + pipeline.add_node(name="A", component=AWithOutputAll(), inputs=["Query"]) + pipeline.add_node(name="B", component=B(), inputs=["A.output_1"]) + pipeline.add_node(name="C", component=C(), inputs=["A.output_2"]) + pipeline.add_node(name="D", component=E(), inputs=["B"]) + pipeline.add_node(name="E", component=D(), inputs=["B"]) + pipeline.add_node(name="F", component=JoinNode(), inputs=["D", "E", "C"]) + output = pipeline.run(query="test") + assert output["output"] == "ACABEABD" + + +