fix: Fix corner case when running Pipeline that causes it to get stuck in a loop (#7531)

* Fix corner case when running Pipeline that causes it to get stuck in a loop

* Update haystack/core/pipeline/pipeline.py

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>

---------

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
Silvano Cerza 2024-04-11 16:39:38 +02:00 committed by GitHub
parent b90a005b85
commit 6a8834e43e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 82 additions and 1 deletions

View File

@ -846,6 +846,7 @@ class Pipeline:
) as span:
span.set_content_tag("haystack.component.input", last_inputs[name])
logger.info("Running component {name}", name=name)
res = comp.run(**last_inputs[name])
self.graph.nodes[name]["visits"] += 1
@ -959,6 +960,15 @@ class Pipeline:
# There was a lazy variadic or a component with only default waiting for input, we can run it
waiting_for_input.remove((name, comp))
to_run.append((name, comp))
# Let's use the default value for the inputs that are still missing, or the component
# won't run and will be put back in the waiting list, causing an infinite loop.
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.is_mandatory:
continue
if input_socket.name not in last_inputs[name]:
last_inputs[name][input_socket.name] = input_socket.default_value
continue
before_last_waiting_for_input = (

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fix a bug when running a Pipeline that would cause it to get stuck in an infinite loop

View File

@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import List, Optional
from typing import Any, Dict, List, Optional
from unittest.mock import patch
import pytest
@ -12,6 +12,7 @@ from haystack.components.builders import PromptBuilder
from haystack.components.builders.answer_builder import AnswerBuilder
from haystack.components.others import Multiplexer
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.components.routers import ConditionalRouter
from haystack.core.component import component
from haystack.core.component.types import InputSocket, OutputSocket
from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineMaxLoops, PipelineRuntimeError
@ -916,3 +917,69 @@ def test_pipeline_is_not_stuck_with_components_with_only_defaults():
answers = res["answer_builder"]["answers"]
assert len(answers) == 1
assert answers[0].data == "Paris"
def test_pipeline_is_not_stuck_with_components_with_only_defaults_as_first_components():
"""
This tests verifies that a Pipeline doesn't get stuck running in a loop if
it has all the following characterics:
- The first Component has all defaults for its inputs
- The first Component receives one input from the user
- The first Component receives one input from a loop in the Pipeline
- The second Component has at least one default input
"""
def fake_generator_run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
# Simple hack to simulate a model returning a different reply after the
# the first time it's called
if getattr(fake_generator_run, "called", False):
return {"replies": ["Rome"]}
fake_generator_run.called = True
return {"replies": ["Paris"]}
FakeGenerator = component_class(
"FakeGenerator",
input_types={"prompt": str, "generation_kwargs": Optional[Dict[str, Any]]},
output_types={"replies": List[str]},
extra_fields={"run": fake_generator_run},
)
template = (
"Answer the following question.\n"
"{% if previous_replies %}\n"
"Previously you replied incorrectly this:\n"
"{% for reply in previous_replies %}\n"
" - {{ reply }}\n"
"{% endfor %}\n"
"{% endif %}\n"
"Question: {{ query }}"
)
router = ConditionalRouter(
routes=[
{
"condition": "{{ replies == ['Rome'] }}",
"output": "{{ replies }}",
"output_name": "correct_replies",
"output_type": List[int],
},
{
"condition": "{{ replies == ['Paris'] }}",
"output": "{{ replies }}",
"output_name": "incorrect_replies",
"output_type": List[int],
},
]
)
pipe = Pipeline()
pipe.add_component("prompt_builder", PromptBuilder(template=template))
pipe.add_component("generator", FakeGenerator())
pipe.add_component("router", router)
pipe.connect("prompt_builder.prompt", "generator.prompt")
pipe.connect("generator.replies", "router.replies")
pipe.connect("router.incorrect_replies", "prompt_builder.previous_replies")
res = pipe.run({"prompt_builder": {"query": "What is the capital of Italy?"}})
assert res == {"router": {"correct_replies": ["Rome"]}}