fix: Fix running Pipeline with conditional branch and Component with default inputs (#7799)

* Fix running Pipeline with conditional branch and Component with default inputs

* Add release notes

* Change arg name of _init_to_run so it's clearer

* Enhance release note
This commit is contained in:
Silvano Cerza 2024-06-06 15:19:07 +02:00 committed by GitHub
parent ce9b0ecb19
commit 3c8569e12c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 249 additions and 9 deletions

View File

@ -719,7 +719,7 @@ class PipelineBase:
return {**data}
def _init_to_run(self) -> List[Tuple[str, Component]]:
def _init_to_run(self, pipeline_inputs: Dict[str, Any]) -> List[Tuple[str, Component]]:
to_run: List[Tuple[str, Component]] = []
for node_name in self.graph.nodes:
component = self.graph.nodes[node_name]["instance"]
@ -729,6 +729,11 @@ class PipelineBase:
to_run.append((node_name, component))
continue
if node_name in pipeline_inputs:
# This component is in the input data, if it has enough inputs it can run right away
to_run.append((node_name, component))
continue
for socket in component.__haystack_input__._sockets_dict.values():
if not socket.senders or socket.is_variadic:
# Component has at least one input not connected or is variadic, can run right away.

View File

@ -112,9 +112,12 @@ class Pipeline(PipelineBase):
# Initialize the inputs state
last_inputs: Dict[str, Dict[str, Any]] = self._init_inputs_state(data)
# Take all components that have at least 1 input not connected or is variadic,
# and all components that have no inputs at all
to_run: List[Tuple[str, Component]] = self._init_to_run()
# Take all components that:
# - have no inputs
# - receive input from the user
# - have at least one input not connected
# - have at least one input that is variadic
to_run: List[Tuple[str, Component]] = self._init_to_run(data)
# These variables are used to detect when we're stuck in a loop.
# Stuck loops can happen when one or more components are waiting for input but
@ -232,8 +235,15 @@ class Pipeline(PipelineBase):
if name != sender_component_name:
continue
pair = (receiver_component_name, self.graph.nodes[receiver_component_name]["instance"])
if edge_data["from_socket"].name not in res:
# This output has not been produced by the component, skip it
# The component didn't produce any output for this socket.
# We can't run the receiver, let's remove it from the list of components to run
# or we risk running it if it's in those lists.
if pair in to_run:
to_run.remove(pair)
if pair in waiting_for_input:
waiting_for_input.remove(pair)
continue
if receiver_component_name not in last_inputs:
@ -249,7 +259,6 @@ class Pipeline(PipelineBase):
else:
last_inputs[receiver_component_name][edge_data["to_socket"].name] = value
pair = (receiver_component_name, self.graph.nodes[receiver_component_name]["instance"])
is_greedy = pair[1].__haystack_is_greedy__
is_variadic = edge_data["to_socket"].is_variadic
if is_variadic and is_greedy:

View File

@ -0,0 +1,7 @@
---
fixes:
- |
Fix some bugs running a Pipeline that has Components with conditional outputs.
Some branches that were expected not to run would run anyway, even if they received no inputs.
Some branches instead would cause the Pipeline to get stuck waiting to run that branch, even if they received no inputs.
The behaviour would depend whether the Component not receiving the input has a optional input or not.

View File

@ -34,6 +34,9 @@ Feature: Pipeline running
| that is linear and returns intermediate outputs |
| that has a loop and returns intermediate outputs from it |
| that is linear and returns intermediate outputs from multiple sockets |
| that has a component with default inputs that doesn't receive anything from its sender |
| that has a component with default inputs that doesn't receive anything from its sender but receives input from user |
| that has a loop and a component with default inputs that doesn't receive anything from its sender but receives input from user |
Scenario Outline: Running a bad Pipeline
Given a pipeline <kind>

View File

@ -1161,3 +1161,215 @@ def pipeline_that_is_linear_and_returns_intermediate_outputs_from_multiple_socke
),
],
)
@given(
"a pipeline that has a component with default inputs that doesn't receive anything from its sender",
target_fixture="pipeline_data",
)
def pipeline_that_has_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender():
routes = [
{"condition": "{{'reisen' in sentence}}", "output": "German", "output_name": "language_1", "output_type": str},
{"condition": "{{'viajar' in sentence}}", "output": "Spanish", "output_name": "language_2", "output_type": str},
]
router = ConditionalRouter(routes)
pipeline = Pipeline()
pipeline.add_component("router", router)
pipeline.add_component("pb", PromptBuilder(template="Ok, I know, that's {{language}}"))
pipeline.connect("router.language_2", "pb.language")
return (
pipeline,
[
PipelineRunData(
inputs={"router": {"sentence": "Wir mussen reisen"}},
expected_outputs={"router": {"language_1": "German"}},
expected_run_order=["router"],
),
PipelineRunData(
inputs={"router": {"sentence": "Yo tengo que viajar"}},
expected_outputs={"pb": {"prompt": "Ok, I know, that's Spanish"}},
expected_run_order=["router", "pb"],
),
],
)
@given(
"a pipeline that has a component with default inputs that doesn't receive anything from its sender but receives input from user",
target_fixture="pipeline_data",
)
def pipeline_that_has_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender_but_receives_input_from_user():
prompt = PromptBuilder(
template="""Please generate an SQL query. The query should answer the following Question: {{ question }};
If the question cannot be answered given the provided table and columns, return 'no_answer'
The query is to be answered for the table is called 'absenteeism' with the following
Columns: {{ columns }};
Answer:"""
)
@component
class FakeGenerator:
@component.output_types(replies=List[str])
def run(self, prompt: str):
if "no_answer" in prompt:
return {"replies": ["There's simply no_answer to this question"]}
return {"replies": ["Some SQL query"]}
@component
class FakeSQLQuerier:
@component.output_types(results=str)
def run(self, query: str):
return {"results": "This is the query result", "query": query}
llm = FakeGenerator()
sql_querier = FakeSQLQuerier()
routes = [
{
"condition": "{{'no_answer' not in replies[0]}}",
"output": "{{replies[0]}}",
"output_name": "sql",
"output_type": str,
},
{
"condition": "{{'no_answer' in replies[0]}}",
"output": "{{question}}",
"output_name": "go_to_fallback",
"output_type": str,
},
]
router = ConditionalRouter(routes)
fallback_prompt = PromptBuilder(
template="""User entered a query that cannot be answered with the given table.
The query was: {{ question }} and the table had columns: {{ columns }}.
Let the user know why the question cannot be answered"""
)
fallback_llm = FakeGenerator()
pipeline = Pipeline()
pipeline.add_component("prompt", prompt)
pipeline.add_component("llm", llm)
pipeline.add_component("router", router)
pipeline.add_component("fallback_prompt", fallback_prompt)
pipeline.add_component("fallback_llm", fallback_llm)
pipeline.add_component("sql_querier", sql_querier)
pipeline.connect("prompt", "llm")
pipeline.connect("llm.replies", "router.replies")
pipeline.connect("router.sql", "sql_querier.query")
pipeline.connect("router.go_to_fallback", "fallback_prompt.question")
pipeline.connect("fallback_prompt", "fallback_llm")
columns = "Age, Absenteeism_time_in_hours, Days, Disciplinary_failure"
return (
pipeline,
[
PipelineRunData(
inputs={
"prompt": {"question": "This is a question with no_answer", "columns": columns},
"router": {"question": "This is a question with no_answer"},
},
expected_outputs={"fallback_llm": {"replies": ["There's simply no_answer to this question"]}},
expected_run_order=["prompt", "llm", "router", "fallback_prompt", "fallback_llm"],
)
],
[
PipelineRunData(
inputs={
"prompt": {"question": "This is a question that has an answer", "columns": columns},
"router": {"question": "This is a question that has an answer"},
},
expected_outputs={"sql_querier": {"results": "This is the query result", "query": "Some SQL query"}},
expected_run_order=["prompt", "llm", "router", "sql_querier"],
)
],
)
@given(
"a pipeline that has a loop and a component with default inputs that doesn't receive anything from its sender but receives input from user",
target_fixture="pipeline_data",
)
def pipeline_that_has_a_loop_and_a_component_with_default_inputs_that_doesnt_receive_anything_from_its_sender_but_receives_input_from_user():
template = """
You are an experienced and accurate Turkish CX speacialist that classifies customer comments into pre-defined categories below:\n
Negative experience labels:
- Late delivery
- Rotten/spoilt item
- Bad Courier behavior
Positive experience labels:
- Good courier behavior
- Thanks & appreciation
- Love message to courier
- Fast delivery
- Quality of products
Create a JSON object as a response. The fields are: 'positive_experience', 'negative_experience'.
Assign at least one of the pre-defined labels to the given customer comment under positive and negative experience fields.
If the comment has a positive experience, list the label under 'positive_experience' field.
If the comments has a negative_experience, list it under the 'negative_experience' field.
Here is the comment:\n{{ comment }}\n. Just return the category names in the list. If there aren't any, return an empty list.
{% if invalid_replies and error_message %}
You already created the following output in a previous attempt: {{ invalid_replies }}
However, this doesn't comply with the format requirements from above and triggered this Python exception: {{ error_message }}
Correct the output and try again. Just return the corrected output without any extra explanations.
{% endif %}
"""
prompt_builder = PromptBuilder(template=template)
@component
class FakeOutputValidator:
@component.output_types(
valid_replies=List[str], invalid_replies=Optional[List[str]], error_message=Optional[str]
)
def run(self, replies: List[str]):
if not getattr(self, "called", False):
self.called = True
return {"invalid_replies": ["This is an invalid reply"], "error_message": "this is an error message"}
return {"valid_replies": replies}
@component
class FakeGenerator:
@component.output_types(replies=List[str])
def run(self, prompt: str):
return {"replies": ["This is a valid reply"]}
llm = FakeGenerator()
validator = FakeOutputValidator()
pipeline = Pipeline()
pipeline.add_component("prompt_builder", prompt_builder)
pipeline.add_component("llm", llm)
pipeline.add_component("output_validator", validator)
pipeline.connect("prompt_builder.prompt", "llm.prompt")
pipeline.connect("llm.replies", "output_validator.replies")
pipeline.connect("output_validator.invalid_replies", "prompt_builder.invalid_replies")
pipeline.connect("output_validator.error_message", "prompt_builder.error_message")
comment = "I loved the quality of the meal but the courier was rude"
return (
pipeline,
[
PipelineRunData(
inputs={"prompt_builder": {"template_variables": {"comment": comment}}},
expected_outputs={"output_validator": {"valid_replies": ["This is a valid reply"]}},
expected_run_order=[
"prompt_builder",
"llm",
"output_validator",
"prompt_builder",
"llm",
"output_validator",
],
)
],
)

View File

@ -692,19 +692,23 @@ class TestPipeline:
pipe.add_component("with_no_inputs", ComponentWithNoInputs())
pipe.add_component("with_single_input", ComponentWithSingleInput())
pipe.add_component("another_with_single_input", ComponentWithSingleInput())
pipe.add_component("yet_another_with_single_input", ComponentWithSingleInput())
pipe.add_component("with_multiple_inputs", ComponentWithMultipleInputs())
pipe.connect("yet_another_with_single_input.out", "with_variadic.in")
pipe.connect("with_no_inputs.out", "with_variadic.in")
pipe.connect("with_single_input.out", "another_with_single_input.in")
pipe.connect("another_with_single_input.out", "with_multiple_inputs.in1")
pipe.connect("with_multiple_inputs.out", "with_variadic.in")
to_run = pipe._init_to_run()
assert len(to_run) == 4
data = {"yet_another_with_single_input": {"in": 1}}
to_run = pipe._init_to_run(data)
assert len(to_run) == 5
assert to_run[0][0] == "with_variadic"
assert to_run[1][0] == "with_no_inputs"
assert to_run[2][0] == "with_single_input"
assert to_run[3][0] == "with_multiple_inputs"
assert to_run[3][0] == "yet_another_with_single_input"
assert to_run[4][0] == "with_multiple_inputs"
def test__init_inputs_state(self):
pipe = Pipeline()