feat: Reintroduce max_loops_allowed check in Pipeline.run() (#7010)

* Reintroduce max_loops_allowed check in Pipeline.run()

* Add release notes
This commit is contained in:
Silvano Cerza 2024-02-19 10:05:35 +01:00 committed by GitHub
parent 3cc8e54f41
commit 5f97e08feb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 1 deletions

View File

@ -17,6 +17,7 @@ from haystack.core.errors import (
PipelineConnectError,
PipelineDrawingError,
PipelineError,
PipelineMaxLoops,
PipelineRuntimeError,
PipelineValidationError,
)
@ -661,6 +662,10 @@ class Pipeline:
# "input1": 1, "input2": 2,
# }
# Reset the visits count for each component
for node in self.graph.nodes:
self.graph.nodes[node]["visits"] = 0
# TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not
# As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run()
self.warm_up()
@ -759,8 +764,12 @@ class Pipeline:
continue
if name in last_inputs and len(comp.__haystack_input__._sockets_dict) == len(last_inputs[name]): # type: ignore
if self.graph.nodes[name]["visits"] > self.max_loops_allowed:
msg = f"Maximum loops count ({self.max_loops_allowed}) exceeded for component '{name}'"
raise PipelineMaxLoops(msg)
# This component has all the inputs it needs to run
res = comp.run(**last_inputs[name])
self.graph.nodes[name]["visits"] += 1
if not isinstance(res, Mapping):
raise PipelineRuntimeError(

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
Change `Pipeline.run()` to check if `max_loops_allowed` has been reached.
If we attempt to run a Component that already ran the number of `max_loops_allowed` a `PipelineMaxLoops` will be raised.

View File

@ -9,7 +9,7 @@ import pytest
from haystack.core.component import component
from haystack.core.component.types import InputSocket, OutputSocket
from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineRuntimeError
from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineMaxLoops, PipelineRuntimeError
from haystack.core.pipeline import Pipeline
from haystack.testing.factory import component_class
from haystack.testing.sample_components import AddFixedValue, Double
@ -280,6 +280,22 @@ def test_repr_in_notebook(mock_is_in_jupyter):
mock_show.assert_called_once_with()
def test_run_raises_if_max_visits_reached():
def custom_init(self):
component.set_input_type(self, "x", int)
component.set_input_type(self, "y", int, 1)
component.set_output_types(self, a=int, b=int)
FakeComponent = component_class("FakeComponent", output={"a": 1, "b": 1}, extra_fields={"__init__": custom_init})
pipe = Pipeline(max_loops_allowed=1)
pipe.add_component("first", FakeComponent())
pipe.add_component("second", FakeComponent())
pipe.connect("first.a", "second.x")
pipe.connect("second.b", "first.y")
with pytest.raises(PipelineMaxLoops):
pipe.run({"first": {"x": 1}})
def test_run_with_component_that_does_not_return_dict():
BrokenComponent = component_class(
"BrokenComponent", input_types={"a": int}, output_types={"b": int}, output=1 # type:ignore