diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 73dbe59dc..2b8708b49 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -17,6 +17,7 @@ from haystack.core.errors import ( PipelineConnectError, PipelineDrawingError, PipelineError, + PipelineMaxLoops, PipelineRuntimeError, PipelineValidationError, ) @@ -661,6 +662,10 @@ class Pipeline: # "input1": 1, "input2": 2, # } + # Reset the visits count for each component + for node in self.graph.nodes: + self.graph.nodes[node]["visits"] = 0 + # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() self.warm_up() @@ -759,8 +764,12 @@ class Pipeline: continue if name in last_inputs and len(comp.__haystack_input__._sockets_dict) == len(last_inputs[name]): # type: ignore + if self.graph.nodes[name]["visits"] > self.max_loops_allowed: + msg = f"Maximum loops count ({self.max_loops_allowed}) exceeded for component '{name}'" + raise PipelineMaxLoops(msg) # This component has all the inputs it needs to run res = comp.run(**last_inputs[name]) + self.graph.nodes[name]["visits"] += 1 if not isinstance(res, Mapping): raise PipelineRuntimeError( diff --git a/releasenotes/notes/max-loops-in-run-df9f5c068a723f71.yaml b/releasenotes/notes/max-loops-in-run-df9f5c068a723f71.yaml new file mode 100644 index 000000000..00c6d3a5d --- /dev/null +++ b/releasenotes/notes/max-loops-in-run-df9f5c068a723f71.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Change `Pipeline.run()` to check if `max_loops_allowed` has been reached. + If we attempt to run a Component that already ran the number of `max_loops_allowed` a `PipelineMaxLoops` will be raised. diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 1876ee28b..d3690855a 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -9,7 +9,7 @@ import pytest from haystack.core.component import component from haystack.core.component.types import InputSocket, OutputSocket -from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineRuntimeError +from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineMaxLoops, PipelineRuntimeError from haystack.core.pipeline import Pipeline from haystack.testing.factory import component_class from haystack.testing.sample_components import AddFixedValue, Double @@ -280,6 +280,22 @@ def test_repr_in_notebook(mock_is_in_jupyter): mock_show.assert_called_once_with() +def test_run_raises_if_max_visits_reached(): + def custom_init(self): + component.set_input_type(self, "x", int) + component.set_input_type(self, "y", int, 1) + component.set_output_types(self, a=int, b=int) + + FakeComponent = component_class("FakeComponent", output={"a": 1, "b": 1}, extra_fields={"__init__": custom_init}) + pipe = Pipeline(max_loops_allowed=1) + pipe.add_component("first", FakeComponent()) + pipe.add_component("second", FakeComponent()) + pipe.connect("first.a", "second.x") + pipe.connect("second.b", "first.y") + with pytest.raises(PipelineMaxLoops): + pipe.run({"first": {"x": 1}}) + + def test_run_with_component_that_does_not_return_dict(): BrokenComponent = component_class( "BrokenComponent", input_types={"a": int}, output_types={"b": int}, output=1 # type:ignore