diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 9a61b70dd..ca3c683b9 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -846,6 +846,7 @@ class Pipeline: ) as span: span.set_content_tag("haystack.component.input", last_inputs[name]) + logger.info("Running component {name}", name=name) res = comp.run(**last_inputs[name]) self.graph.nodes[name]["visits"] += 1 @@ -959,6 +960,15 @@ class Pipeline: # There was a lazy variadic or a component with only default waiting for input, we can run it waiting_for_input.remove((name, comp)) to_run.append((name, comp)) + + # Let's use the default value for the inputs that are still missing, or the component + # won't run and will be put back in the waiting list, causing an infinite loop. + for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore + if input_socket.is_mandatory: + continue + if input_socket.name not in last_inputs[name]: + last_inputs[name][input_socket.name] = input_socket.default_value + continue before_last_waiting_for_input = ( diff --git a/releasenotes/notes/fix-pipeline-run-loop-99f7ff9db16544d4.yaml b/releasenotes/notes/fix-pipeline-run-loop-99f7ff9db16544d4.yaml new file mode 100644 index 000000000..61ad81c78 --- /dev/null +++ b/releasenotes/notes/fix-pipeline-run-loop-99f7ff9db16544d4.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fix a bug when running a Pipeline that would cause it to get stuck in an infinite loop diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 95dae8151..1e34a9292 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 import logging -from typing import List, Optional +from typing import Any, Dict, List, Optional from unittest.mock import patch import pytest @@ -12,6 +12,7 @@ from haystack.components.builders import PromptBuilder from haystack.components.builders.answer_builder import AnswerBuilder from haystack.components.others import Multiplexer from haystack.components.retrievers.in_memory import InMemoryBM25Retriever +from haystack.components.routers import ConditionalRouter from haystack.core.component import component from haystack.core.component.types import InputSocket, OutputSocket from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineMaxLoops, PipelineRuntimeError @@ -916,3 +917,69 @@ def test_pipeline_is_not_stuck_with_components_with_only_defaults(): answers = res["answer_builder"]["answers"] assert len(answers) == 1 assert answers[0].data == "Paris" + + +def test_pipeline_is_not_stuck_with_components_with_only_defaults_as_first_components(): + """ + This tests verifies that a Pipeline doesn't get stuck running in a loop if + it has all the following characterics: + - The first Component has all defaults for its inputs + - The first Component receives one input from the user + - The first Component receives one input from a loop in the Pipeline + - The second Component has at least one default input + """ + + def fake_generator_run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + # Simple hack to simulate a model returning a different reply after the + # the first time it's called + if getattr(fake_generator_run, "called", False): + return {"replies": ["Rome"]} + fake_generator_run.called = True + return {"replies": ["Paris"]} + + FakeGenerator = component_class( + "FakeGenerator", + input_types={"prompt": str, "generation_kwargs": Optional[Dict[str, Any]]}, + output_types={"replies": List[str]}, + extra_fields={"run": fake_generator_run}, + ) + template = ( + "Answer the following question.\n" + "{% if previous_replies %}\n" + "Previously you replied incorrectly this:\n" + "{% for reply in previous_replies %}\n" + " - {{ reply }}\n" + "{% endfor %}\n" + "{% endif %}\n" + "Question: {{ query }}" + ) + router = ConditionalRouter( + routes=[ + { + "condition": "{{ replies == ['Rome'] }}", + "output": "{{ replies }}", + "output_name": "correct_replies", + "output_type": List[int], + }, + { + "condition": "{{ replies == ['Paris'] }}", + "output": "{{ replies }}", + "output_name": "incorrect_replies", + "output_type": List[int], + }, + ] + ) + + pipe = Pipeline() + + pipe.add_component("prompt_builder", PromptBuilder(template=template)) + pipe.add_component("generator", FakeGenerator()) + pipe.add_component("router", router) + + pipe.connect("prompt_builder.prompt", "generator.prompt") + pipe.connect("generator.replies", "router.replies") + pipe.connect("router.incorrect_replies", "prompt_builder.previous_replies") + + res = pipe.run({"prompt_builder": {"query": "What is the capital of Italy?"}}) + + assert res == {"router": {"correct_replies": ["Rome"]}}