mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-06 14:53:19 +00:00
fix: Fix corner case when running Pipeline that causes it to get stuck in a loop (#7531)
* Fix corner case when running Pipeline that causes it to get stuck in a loop * Update haystack/core/pipeline/pipeline.py Co-authored-by: Massimiliano Pippi <mpippi@gmail.com> --------- Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
parent
b90a005b85
commit
6a8834e43e
@ -846,6 +846,7 @@ class Pipeline:
|
||||
) as span:
|
||||
span.set_content_tag("haystack.component.input", last_inputs[name])
|
||||
|
||||
logger.info("Running component {name}", name=name)
|
||||
res = comp.run(**last_inputs[name])
|
||||
self.graph.nodes[name]["visits"] += 1
|
||||
|
||||
@ -959,6 +960,15 @@ class Pipeline:
|
||||
# There was a lazy variadic or a component with only default waiting for input, we can run it
|
||||
waiting_for_input.remove((name, comp))
|
||||
to_run.append((name, comp))
|
||||
|
||||
# Let's use the default value for the inputs that are still missing, or the component
|
||||
# won't run and will be put back in the waiting list, causing an infinite loop.
|
||||
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
|
||||
if input_socket.is_mandatory:
|
||||
continue
|
||||
if input_socket.name not in last_inputs[name]:
|
||||
last_inputs[name][input_socket.name] = input_socket.default_value
|
||||
|
||||
continue
|
||||
|
||||
before_last_waiting_for_input = (
|
||||
|
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Fix a bug when running a Pipeline that would cause it to get stuck in an infinite loop
|
@ -2,7 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -12,6 +12,7 @@ from haystack.components.builders import PromptBuilder
|
||||
from haystack.components.builders.answer_builder import AnswerBuilder
|
||||
from haystack.components.others import Multiplexer
|
||||
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
|
||||
from haystack.components.routers import ConditionalRouter
|
||||
from haystack.core.component import component
|
||||
from haystack.core.component.types import InputSocket, OutputSocket
|
||||
from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineMaxLoops, PipelineRuntimeError
|
||||
@ -916,3 +917,69 @@ def test_pipeline_is_not_stuck_with_components_with_only_defaults():
|
||||
answers = res["answer_builder"]["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "Paris"
|
||||
|
||||
|
||||
def test_pipeline_is_not_stuck_with_components_with_only_defaults_as_first_components():
|
||||
"""
|
||||
This tests verifies that a Pipeline doesn't get stuck running in a loop if
|
||||
it has all the following characterics:
|
||||
- The first Component has all defaults for its inputs
|
||||
- The first Component receives one input from the user
|
||||
- The first Component receives one input from a loop in the Pipeline
|
||||
- The second Component has at least one default input
|
||||
"""
|
||||
|
||||
def fake_generator_run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||
# Simple hack to simulate a model returning a different reply after the
|
||||
# the first time it's called
|
||||
if getattr(fake_generator_run, "called", False):
|
||||
return {"replies": ["Rome"]}
|
||||
fake_generator_run.called = True
|
||||
return {"replies": ["Paris"]}
|
||||
|
||||
FakeGenerator = component_class(
|
||||
"FakeGenerator",
|
||||
input_types={"prompt": str, "generation_kwargs": Optional[Dict[str, Any]]},
|
||||
output_types={"replies": List[str]},
|
||||
extra_fields={"run": fake_generator_run},
|
||||
)
|
||||
template = (
|
||||
"Answer the following question.\n"
|
||||
"{% if previous_replies %}\n"
|
||||
"Previously you replied incorrectly this:\n"
|
||||
"{% for reply in previous_replies %}\n"
|
||||
" - {{ reply }}\n"
|
||||
"{% endfor %}\n"
|
||||
"{% endif %}\n"
|
||||
"Question: {{ query }}"
|
||||
)
|
||||
router = ConditionalRouter(
|
||||
routes=[
|
||||
{
|
||||
"condition": "{{ replies == ['Rome'] }}",
|
||||
"output": "{{ replies }}",
|
||||
"output_name": "correct_replies",
|
||||
"output_type": List[int],
|
||||
},
|
||||
{
|
||||
"condition": "{{ replies == ['Paris'] }}",
|
||||
"output": "{{ replies }}",
|
||||
"output_name": "incorrect_replies",
|
||||
"output_type": List[int],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
pipe = Pipeline()
|
||||
|
||||
pipe.add_component("prompt_builder", PromptBuilder(template=template))
|
||||
pipe.add_component("generator", FakeGenerator())
|
||||
pipe.add_component("router", router)
|
||||
|
||||
pipe.connect("prompt_builder.prompt", "generator.prompt")
|
||||
pipe.connect("generator.replies", "router.replies")
|
||||
pipe.connect("router.incorrect_replies", "prompt_builder.previous_replies")
|
||||
|
||||
res = pipe.run({"prompt_builder": {"query": "What is the capital of Italy?"}})
|
||||
|
||||
assert res == {"router": {"correct_replies": ["Rome"]}}
|
||||
|
Loading…
x
Reference in New Issue
Block a user