mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-10 14:54:10 +00:00
feat: Add support for returning intermediate outputs of pipeline components (#7504)
* feat: Add support for returning intermediate outputs of pipeline components The `pipeline.run` method has been extended to accept a set of component names whose inputs are returned in addition to the outputs of leaf components. * Add reno * Lint --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
parent
9a9c8aa1c8
commit
fd84cd5f9a
@ -613,7 +613,7 @@ class Pipeline:
|
|||||||
|
|
||||||
# TODO: We're ignoring these linting rules for the time being, after we properly optimize this function we'll remove the noqa
|
# TODO: We're ignoring these linting rules for the time being, after we properly optimize this function we'll remove the noqa
|
||||||
def run( # noqa: C901, PLR0912, PLR0915 pylint: disable=too-many-branches
|
def run( # noqa: C901, PLR0912, PLR0915 pylint: disable=too-many-branches
|
||||||
self, data: Dict[str, Any], debug: bool = False
|
self, data: Dict[str, Any], debug: bool = False, include_outputs_from: Optional[Set[str]] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Runs the pipeline with given input data.
|
Runs the pipeline with given input data.
|
||||||
@ -623,8 +623,16 @@ class Pipeline:
|
|||||||
and its value is a dictionary of that component's input parameters.
|
and its value is a dictionary of that component's input parameters.
|
||||||
:param debug:
|
:param debug:
|
||||||
Set to True to collect and return debug information.
|
Set to True to collect and return debug information.
|
||||||
|
:param include_outputs_from:
|
||||||
|
Set of component names whose individual outputs are to be
|
||||||
|
included in the pipeline's output. For components that are
|
||||||
|
invoked multiple times (in a loop), only the last-produced
|
||||||
|
output is included.
|
||||||
:returns:
|
:returns:
|
||||||
A dictionary containing the pipeline's output.
|
A dictionary where each entry corresponds to a component name
|
||||||
|
and its output. If `include_outputs_from` is `None`, this dictionary
|
||||||
|
will only contain the outputs of leaf components, i.e., components
|
||||||
|
without outgoing connections.
|
||||||
|
|
||||||
:raises PipelineRuntimeError:
|
:raises PipelineRuntimeError:
|
||||||
If a component fails or returns unexpected output.
|
If a component fails or returns unexpected output.
|
||||||
@ -756,6 +764,8 @@ class Pipeline:
|
|||||||
# The waiting_for_input list is used to keep track of components that are waiting for input.
|
# The waiting_for_input list is used to keep track of components that are waiting for input.
|
||||||
waiting_for_input: List[Tuple[str, Component]] = []
|
waiting_for_input: List[Tuple[str, Component]] = []
|
||||||
|
|
||||||
|
include_outputs_from = set() if include_outputs_from is None else include_outputs_from
|
||||||
|
|
||||||
with tracing.tracer.trace(
|
with tracing.tracer.trace(
|
||||||
"haystack.pipeline.run",
|
"haystack.pipeline.run",
|
||||||
tags={
|
tags={
|
||||||
@ -765,7 +775,11 @@ class Pipeline:
|
|||||||
},
|
},
|
||||||
):
|
):
|
||||||
# This is what we'll return at the end
|
# This is what we'll return at the end
|
||||||
final_outputs = {}
|
final_outputs: Dict[Any, Any] = {}
|
||||||
|
|
||||||
|
# Cache for extra outputs, if enabled.
|
||||||
|
extra_outputs: Dict[Any, Any] = {}
|
||||||
|
|
||||||
while len(to_run) > 0:
|
while len(to_run) > 0:
|
||||||
name, comp = to_run.pop(0)
|
name, comp = to_run.pop(0)
|
||||||
|
|
||||||
@ -826,6 +840,11 @@ class Pipeline:
|
|||||||
span.set_tags(tags={"haystack.component.visits": self.graph.nodes[name]["visits"]})
|
span.set_tags(tags={"haystack.component.visits": self.graph.nodes[name]["visits"]})
|
||||||
span.set_content_tag("haystack.component.output", res)
|
span.set_content_tag("haystack.component.output", res)
|
||||||
|
|
||||||
|
if name in include_outputs_from:
|
||||||
|
# Deepcopy the outputs to prevent downstream nodes from modifying them
|
||||||
|
# We don't care about loops - Always store the last output.
|
||||||
|
extra_outputs[name] = deepcopy(res)
|
||||||
|
|
||||||
# Reset the waiting for input previous states, we managed to run a component
|
# Reset the waiting for input previous states, we managed to run a component
|
||||||
before_last_waiting_for_input = None
|
before_last_waiting_for_input = None
|
||||||
last_waiting_for_input = None
|
last_waiting_for_input = None
|
||||||
@ -988,6 +1007,11 @@ class Pipeline:
|
|||||||
waiting_for_input.remove((name, comp))
|
waiting_for_input.remove((name, comp))
|
||||||
to_run.append((name, comp))
|
to_run.append((name, comp))
|
||||||
|
|
||||||
|
if len(include_outputs_from) > 0:
|
||||||
|
for name, output in extra_outputs.items():
|
||||||
|
if name not in final_outputs:
|
||||||
|
final_outputs[name] = output
|
||||||
|
|
||||||
return final_outputs
|
return final_outputs
|
||||||
|
|
||||||
def _prepare_component_input_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]:
|
def _prepare_component_input_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Any]]:
|
||||||
|
|||||||
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
`pipeline.run` accepts a set of component names whose intermediate outputs are returned in the final
|
||||||
|
pipeline output dictionary.
|
||||||
61
test/core/pipeline/test_intermediate_outputs.py
Normal file
61
test/core/pipeline/test_intermediate_outputs.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from haystack.components.others import Multiplexer
|
||||||
|
from haystack.core.pipeline import Pipeline
|
||||||
|
from haystack.testing.sample_components import Accumulate, AddFixedValue, Double, Threshold
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_intermediate_outputs():
|
||||||
|
pipeline = Pipeline()
|
||||||
|
pipeline.add_component("first_addition", AddFixedValue(add=2))
|
||||||
|
pipeline.add_component("second_addition", AddFixedValue())
|
||||||
|
pipeline.add_component("double", Double())
|
||||||
|
pipeline.connect("first_addition", "double")
|
||||||
|
pipeline.connect("double", "second_addition")
|
||||||
|
|
||||||
|
results = pipeline.run(
|
||||||
|
{"first_addition": {"value": 1}}, include_outputs_from={"first_addition", "second_addition", "double"}
|
||||||
|
)
|
||||||
|
assert results == {"second_addition": {"result": 7}, "first_addition": {"result": 3}, "double": {"value": 6}}
|
||||||
|
|
||||||
|
results = pipeline.run({"first_addition": {"value": 1}}, include_outputs_from={"double"})
|
||||||
|
assert results == {"second_addition": {"result": 7}, "double": {"value": 6}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_with_loops_intermediate_outputs():
|
||||||
|
accumulator = Accumulate()
|
||||||
|
|
||||||
|
pipeline = Pipeline(max_loops_allowed=10)
|
||||||
|
pipeline.add_component("add_one", AddFixedValue(add=1))
|
||||||
|
pipeline.add_component("multiplexer", Multiplexer(type_=int))
|
||||||
|
pipeline.add_component("below_10", Threshold(threshold=10))
|
||||||
|
pipeline.add_component("below_5", Threshold(threshold=5))
|
||||||
|
pipeline.add_component("add_three", AddFixedValue(add=3))
|
||||||
|
pipeline.add_component("accumulator", accumulator)
|
||||||
|
pipeline.add_component("add_two", AddFixedValue(add=2))
|
||||||
|
|
||||||
|
pipeline.connect("add_one.result", "multiplexer")
|
||||||
|
pipeline.connect("multiplexer.value", "below_10.value")
|
||||||
|
pipeline.connect("below_10.below", "accumulator.value")
|
||||||
|
pipeline.connect("accumulator.value", "below_5.value")
|
||||||
|
pipeline.connect("below_5.above", "add_three.value")
|
||||||
|
pipeline.connect("below_5.below", "multiplexer")
|
||||||
|
pipeline.connect("add_three.result", "multiplexer")
|
||||||
|
pipeline.connect("below_10.above", "add_two.value")
|
||||||
|
|
||||||
|
results = pipeline.run(
|
||||||
|
{"add_one": {"value": 3}},
|
||||||
|
include_outputs_from={"add_two", "add_one", "multiplexer", "below_10", "accumulator", "below_5", "add_three"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert results == {
|
||||||
|
"add_two": {"result": 13},
|
||||||
|
"add_one": {"result": 4},
|
||||||
|
"multiplexer": {"value": 11},
|
||||||
|
"below_10": {"above": 11},
|
||||||
|
"accumulator": {"value": 8},
|
||||||
|
"below_5": {"above": 8},
|
||||||
|
"add_three": {"result": 11},
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user