fix: Fix corner case when running Pipeline that causes it to get stuck in a loop (#7531)

* Fix corner case when running Pipeline that causes it to get stuck in a loop

* Update haystack/core/pipeline/pipeline.py

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>

---------

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
Silvano Cerza 2024-04-11 16:39:38 +02:00 committed by GitHub
parent b90a005b85
commit 6a8834e43e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 82 additions and 1 deletions

View File

@ -846,6 +846,7 @@ class Pipeline:
) as span: ) as span:
span.set_content_tag("haystack.component.input", last_inputs[name]) span.set_content_tag("haystack.component.input", last_inputs[name])
logger.info("Running component {name}", name=name)
res = comp.run(**last_inputs[name]) res = comp.run(**last_inputs[name])
self.graph.nodes[name]["visits"] += 1 self.graph.nodes[name]["visits"] += 1
@ -959,6 +960,15 @@ class Pipeline:
# There was a lazy variadic or a component with only default waiting for input, we can run it # There was a lazy variadic or a component with only default waiting for input, we can run it
waiting_for_input.remove((name, comp)) waiting_for_input.remove((name, comp))
to_run.append((name, comp)) to_run.append((name, comp))
# Let's use the default value for the inputs that are still missing, or the component
# won't run and will be put back in the waiting list, causing an infinite loop.
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.is_mandatory:
continue
if input_socket.name not in last_inputs[name]:
last_inputs[name][input_socket.name] = input_socket.default_value
continue continue
before_last_waiting_for_input = ( before_last_waiting_for_input = (

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fix a bug when running a Pipeline that would cause it to get stuck in an infinite loop

View File

@ -2,7 +2,7 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from typing import List, Optional from typing import Any, Dict, List, Optional
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -12,6 +12,7 @@ from haystack.components.builders import PromptBuilder
from haystack.components.builders.answer_builder import AnswerBuilder from haystack.components.builders.answer_builder import AnswerBuilder
from haystack.components.others import Multiplexer from haystack.components.others import Multiplexer
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.components.routers import ConditionalRouter
from haystack.core.component import component from haystack.core.component import component
from haystack.core.component.types import InputSocket, OutputSocket from haystack.core.component.types import InputSocket, OutputSocket
from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineMaxLoops, PipelineRuntimeError from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineMaxLoops, PipelineRuntimeError
@ -916,3 +917,69 @@ def test_pipeline_is_not_stuck_with_components_with_only_defaults():
answers = res["answer_builder"]["answers"] answers = res["answer_builder"]["answers"]
assert len(answers) == 1 assert len(answers) == 1
assert answers[0].data == "Paris" assert answers[0].data == "Paris"
def test_pipeline_is_not_stuck_with_components_with_only_defaults_as_first_components():
"""
This tests verifies that a Pipeline doesn't get stuck running in a loop if
it has all the following characterics:
- The first Component has all defaults for its inputs
- The first Component receives one input from the user
- The first Component receives one input from a loop in the Pipeline
- The second Component has at least one default input
"""
def fake_generator_run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
# Simple hack to simulate a model returning a different reply after the
# the first time it's called
if getattr(fake_generator_run, "called", False):
return {"replies": ["Rome"]}
fake_generator_run.called = True
return {"replies": ["Paris"]}
FakeGenerator = component_class(
"FakeGenerator",
input_types={"prompt": str, "generation_kwargs": Optional[Dict[str, Any]]},
output_types={"replies": List[str]},
extra_fields={"run": fake_generator_run},
)
template = (
"Answer the following question.\n"
"{% if previous_replies %}\n"
"Previously you replied incorrectly this:\n"
"{% for reply in previous_replies %}\n"
" - {{ reply }}\n"
"{% endfor %}\n"
"{% endif %}\n"
"Question: {{ query }}"
)
router = ConditionalRouter(
routes=[
{
"condition": "{{ replies == ['Rome'] }}",
"output": "{{ replies }}",
"output_name": "correct_replies",
"output_type": List[int],
},
{
"condition": "{{ replies == ['Paris'] }}",
"output": "{{ replies }}",
"output_name": "incorrect_replies",
"output_type": List[int],
},
]
)
pipe = Pipeline()
pipe.add_component("prompt_builder", PromptBuilder(template=template))
pipe.add_component("generator", FakeGenerator())
pipe.add_component("router", router)
pipe.connect("prompt_builder.prompt", "generator.prompt")
pipe.connect("generator.replies", "router.replies")
pipe.connect("router.incorrect_replies", "prompt_builder.previous_replies")
res = pipe.run({"prompt_builder": {"query": "What is the capital of Italy?"}})
assert res == {"router": {"correct_replies": ["Rome"]}}