Fix Pipeline skipping a Component with Variadic input (#8347)

* Fix Pipeline skipping a Component with Variadic input

* Simplify _find_components_that_will_receive_no_input
This commit is contained in:
Silvano Cerza 2024-09-10 14:59:53 +02:00 committed by GitHub
parent 145ca89a3f
commit 4d67b552e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 109 additions and 6 deletions

View File

@ -1037,18 +1037,37 @@ class PipelineBase:
return name, comp
def _find_components_that_will_receive_no_input(
self, component_name: str, component_result: Dict[str, Any]
self, component_name: str, component_result: Dict[str, Any], components_inputs: Dict[str, Dict[str, Any]]
) -> Set[Tuple[str, Component]]:
"""
Find all the Components that are connected to component_name and didn't receive any input from it.
Components that have a Variadic input and received already some input from other Components
but not from component_name won't be returned as they have enough inputs to run.
This includes the descendants of the Components that didn't receive any input from component_name.
That is necessary to avoid getting stuck into infinite loops waiting for inputs that will never arrive.
:param component_name: Name of the Component that created the output
:param component_result: Output of the Component
:param components_inputs: The current state of the inputs divided by Component name
:return: A set of Components that didn't receive any input from component_name
"""
# Simplifies the check if a Component is Variadic and received some input from other Components.
def is_variadic_with_existing_inputs(comp: Component) -> bool:
for receiver_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if component_name not in receiver_socket.senders:
continue
if (
receiver_socket.is_variadic
and len(components_inputs.get(receiver, {}).get(receiver_socket.name, [])) > 0
):
# This Component already received some input to its Variadic socket from other Components.
# It should be able to run even if it doesn't receive any input from component_name.
return True
return False
components = set()
instance: Component = self.graph.nodes[component_name]["instance"]
for socket_name, socket in instance.__haystack_output__._sockets_dict.items(): # type: ignore
@ -1056,6 +1075,10 @@ class PipelineBase:
continue
for receiver in socket.receivers:
receiver_instance: Component = self.graph.nodes[receiver]["instance"]
if is_variadic_with_existing_inputs(receiver_instance):
continue
components.add((receiver, receiver_instance))
# Get the descendants too. When we remove a Component that received no input
# it's extremely likely that its descendants will receive no input as well.

View File

@ -240,7 +240,7 @@ class Pipeline(PipelineBase):
# This happens when a component was put in the waiting list but we reached it from another edge.
_dequeue_waiting_component((name, comp), waiting_queue)
for pair in self._find_components_that_will_receive_no_input(name, res):
for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs):
_dequeue_component(pair, run_queue, waiting_queue)
res = self._distribute_output(name, res, components_inputs, run_queue, waiting_queue)

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fix `Pipeline` not running Components with Variadic input even if it received inputs only from a subset of its senders

View File

@ -39,6 +39,7 @@ Feature: Pipeline running
| that has a loop and a component with default inputs that doesn't receive anything from its sender but receives input from user |
| that has multiple components with only default inputs and are added in a different order from the order of execution |
| that is linear with conditional branching and multiple joins |
| that has a variadic component that receives partial inputs |
Scenario Outline: Running a bad Pipeline
Given a pipeline <kind>

View File

@ -1579,3 +1579,59 @@ def that_is_linear_with_conditional_branching_and_multiple_joins():
),
],
)
@given("a pipeline that has a variadic component that receives partial inputs", target_fixture="pipeline_data")
def that_has_a_variadic_component_that_receives_partial_inputs():
@component
class ConditionalDocumentCreator:
def __init__(self, content: str):
self._content = content
@component.output_types(documents=List[Document], noop=None)
def run(self, create_document: bool = False):
if create_document:
return {"documents": [Document(id=self._content, content=self._content)]}
return {"noop": None}
pipeline = Pipeline()
pipeline.add_component("first_creator", ConditionalDocumentCreator(content="First document"))
pipeline.add_component("second_creator", ConditionalDocumentCreator(content="Second document"))
pipeline.add_component("third_creator", ConditionalDocumentCreator(content="Third document"))
pipeline.add_component("documents_joiner", DocumentJoiner())
pipeline.connect("first_creator.documents", "documents_joiner.documents")
pipeline.connect("second_creator.documents", "documents_joiner.documents")
pipeline.connect("third_creator.documents", "documents_joiner.documents")
return (
pipeline,
[
PipelineRunData(
inputs={"first_creator": {"create_document": True}, "third_creator": {"create_document": True}},
expected_outputs={
"second_creator": {"noop": None},
"documents_joiner": {
"documents": [
Document(id="First document", content="First document"),
Document(id="Third document", content="Third document"),
]
},
},
expected_run_order=["first_creator", "third_creator", "second_creator", "documents_joiner"],
),
PipelineRunData(
inputs={"first_creator": {"create_document": True}, "second_creator": {"create_document": True}},
expected_outputs={
"third_creator": {"noop": None},
"documents_joiner": {
"documents": [
Document(id="First document", content="First document"),
Document(id="Second document", content="Second document"),
]
},
},
expected_run_order=["first_creator", "second_creator", "third_creator", "documents_joiner"],
),
],
)

View File

@ -1138,24 +1138,43 @@ class TestPipeline:
def test__find_components_that_will_receive_no_input(self):
sentence_builder = component_class(
"SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"}
"SentenceBuilder", input_types={"words": List[str]}, output_types={"text": str}
)()
document_builder = component_class(
"DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")}
"DocumentBuilder", input_types={"text": str}, output_types={"doc": Document}
)()
conditional_document_builder = component_class(
"ConditionalDocumentBuilder", output_types={"doc": Document, "noop": None}
)()
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
pipe = Pipeline()
pipe.add_component("sentence_builder", sentence_builder)
pipe.add_component("document_builder", document_builder)
pipe.add_component("document_joiner", document_joiner)
pipe.add_component("conditional_document_builder", conditional_document_builder)
pipe.connect("sentence_builder.text", "document_builder.text")
pipe.connect("document_builder.doc", "document_joiner.docs")
pipe.connect("conditional_document_builder.doc", "document_joiner.docs")
res = pipe._find_components_that_will_receive_no_input("sentence_builder", {})
res = pipe._find_components_that_will_receive_no_input("sentence_builder", {}, {})
assert res == {("document_builder", document_builder), ("document_joiner", document_joiner)}
res = pipe._find_components_that_will_receive_no_input("sentence_builder", {"text": "some text"})
res = pipe._find_components_that_will_receive_no_input("sentence_builder", {"text": "some text"}, {})
assert res == set()
res = pipe._find_components_that_will_receive_no_input("conditional_document_builder", {"noop": None}, {})
assert res == {("document_joiner", document_joiner)}
res = pipe._find_components_that_will_receive_no_input(
"conditional_document_builder", {"noop": None}, {"document_joiner": {"docs": []}}
)
assert res == {("document_joiner", document_joiner)}
res = pipe._find_components_that_will_receive_no_input(
"conditional_document_builder", {"noop": None}, {"document_joiner": {"docs": [Document("some text")]}}
)
assert res == set()
def test__distribute_output(self):