fix: component checks failing for components that return dataframes (#8873)

* fix: use is not to compare to sentinel value

* chore: release notes

* Update releasenotes/notes/fix-component-checks-with-ambiguous-truth-values-949c447b3702e427.yaml

Co-authored-by: David S. Batista <dsbatista@gmail.com>

* fix: another sentinel value

* test: also test base class

* add pandas as test dependency

* format

* Trigger CI

* mark test with xfail strict=False

---------

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
Co-authored-by: David S. Batista <dsbatista@gmail.com>
Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
mathislucka 2025-02-19 10:10:48 +01:00 committed by GitHub
parent 93f361e1e1
commit 8c54f06a19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 115 additions and 9 deletions

View File

@ -913,7 +913,7 @@ class PipelineBase:
greedy_inputs_to_remove = set()
for socket_name, socket in component["input_sockets"].items():
socket_inputs = component_inputs.get(socket_name, [])
socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] != _NO_OUTPUT_PRODUCED]
socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED]
if socket_inputs:
if not socket.is_variadic:
# We only care about the first input provided to the socket.

View File

@ -103,7 +103,7 @@ def any_socket_value_from_predecessor_received(socket_inputs: List[Dict[str, Any
:param socket_inputs: Inputs for the component's socket.
"""
# When sender is None, the input was provided from outside the pipeline.
return any(inp["value"] != _NO_OUTPUT_PRODUCED and inp["sender"] is not None for inp in socket_inputs)
return any(inp["value"] is not _NO_OUTPUT_PRODUCED and inp["sender"] is not None for inp in socket_inputs)
def has_user_input(inputs: Dict) -> bool:
@ -143,7 +143,7 @@ def any_socket_input_received(socket_inputs: List[Dict]) -> bool:
:param socket_inputs: Inputs for the socket.
"""
return any(inp["value"] != _NO_OUTPUT_PRODUCED for inp in socket_inputs)
return any(inp["value"] is not _NO_OUTPUT_PRODUCED for inp in socket_inputs)
def has_lazy_variadic_socket_received_all_inputs(socket: InputSocket, socket_inputs: List[Dict]) -> bool:
@ -155,7 +155,9 @@ def has_lazy_variadic_socket_received_all_inputs(socket: InputSocket, socket_inp
"""
expected_senders = set(socket.senders)
actual_senders = {
sock["sender"] for sock in socket_inputs if sock["value"] != _NO_OUTPUT_PRODUCED and sock["sender"] is not None
sock["sender"]
for sock in socket_inputs
if sock["value"] is not _NO_OUTPUT_PRODUCED and sock["sender"] is not None
}
return expected_senders == actual_senders
@ -182,7 +184,11 @@ def has_socket_received_all_inputs(socket: InputSocket, socket_inputs: List[Dict
return False
# The socket is greedy variadic and at least one input was produced, it is complete.
if socket.is_variadic and socket.is_greedy and any(sock["value"] != _NO_OUTPUT_PRODUCED for sock in socket_inputs):
if (
socket.is_variadic
and socket.is_greedy
and any(sock["value"] is not _NO_OUTPUT_PRODUCED for sock in socket_inputs)
):
return True
# The socket is lazy variadic and all expected inputs were produced.
@ -190,7 +196,7 @@ def has_socket_received_all_inputs(socket: InputSocket, socket_inputs: List[Dict
return True
# The socket is not variadic and the only expected input is complete.
return not socket.is_variadic and socket_inputs[0]["value"] != _NO_OUTPUT_PRODUCED
return not socket.is_variadic and socket_inputs[0]["value"] is not _NO_OUTPUT_PRODUCED
def all_predecessors_executed(component: Dict, inputs: Dict) -> bool:

View File

@ -93,6 +93,7 @@ extra-dependencies = [
"langdetect", # TextLanguageRouter and DocumentLanguageClassifier
"openai-whisper>=20231106", # LocalWhisperTranscriber
"arrow>=1.3.0", # Jinja2TimeExtension
"pandas", # Needed for pipeline tests with components that return dataframes
# NamedEntityExtractor
"spacy>=3.8,<3.9",

View File

@ -0,0 +1,5 @@
---
fixes:
- |
Pipelines with components that return plain pandas dataframes failed.
The comparison of socket values is now 'is not' instead of '!=' to avoid errors with dataframes.

View File

@ -583,7 +583,9 @@ class TestHuggingFaceAPIChatGenerator:
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
@pytest.mark.flaky(reruns=3, reruns_delay=10)
@pytest.mark.xfail(
reason="The Hugging Face API can be unstable and this test may fail intermittently", strict=False
)
def test_live_run_with_tools(self, tools):
"""
We test the round trip: generate tool call, pass tool message, generate response.

View File

@ -4,6 +4,7 @@ from pathlib import Path
import re
import pytest
import asyncio
import pandas as pd
from pytest_bdd import when, then, parsers
@ -142,15 +143,39 @@ def draw_pipeline(pipeline_data: Tuple[Pipeline, List[PipelineRunData]], request
@then("it should return the expected result")
def check_pipeline_result(pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]]):
for res, data in pipeline_result:
assert res.outputs == data.expected_outputs
compare_outputs_with_dataframes(res.outputs, data.expected_outputs)
@then("components are called with the expected inputs")
def check_component_calls(pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]]):
for res, data in pipeline_result:
assert res.component_calls == data.expected_component_calls
assert compare_outputs_with_dataframes(res.component_calls, data.expected_component_calls)
@then(parsers.parse("it must have raised {exception_class_name}"))
def check_pipeline_raised(pipeline_result: Exception, exception_class_name: str):
assert pipeline_result.__class__.__name__ == exception_class_name
def compare_outputs_with_dataframes(actual: Dict, expected: Dict) -> bool:
"""
Compare two component_calls or pipeline outputs dictionaries where values may contain DataFrames.
"""
assert actual.keys() == expected.keys()
for key in actual:
actual_data = actual[key]
expected_data = expected[key]
assert actual_data.keys() == expected_data.keys()
for data_key in actual_data:
actual_value = actual_data[data_key]
expected_value = expected_data[data_key]
if isinstance(actual_value, pd.DataFrame) and isinstance(expected_value, pd.DataFrame):
assert actual_value.equals(expected_value)
else:
assert actual_value == expected_value
return True

View File

@ -52,6 +52,7 @@ Feature: Pipeline running
| with a component that has dynamic default inputs |
| with a component that has variadic dynamic default inputs |
| that is a file conversion pipeline with two joiners |
| that has components returning dataframes |
Scenario Outline: Running a bad Pipeline
Given a pipeline <kind>

View File

@ -4,6 +4,7 @@ import re
from pytest_bdd import scenarios, given
import pytest
import pandas as pd
from haystack import Document, component
from haystack.document_stores.types import DuplicatePolicy
@ -5079,3 +5080,33 @@ some,header,row
)
],
)
@given("a pipeline that has components returning dataframes", target_fixture="pipeline_data")
def pipeline_has_components_returning_dataframes(pipeline_class):
def get_df():
return pd.DataFrame({"a": [1, 2], "b": [1, 2]})
@component
class DataFramer:
@component.output_types(dataframe=pd.DataFrame)
def run(self, dataframe: pd.DataFrame) -> Dict[str, Any]:
return {"dataframe": get_df()}
pp = pipeline_class(max_runs_per_component=1)
pp.add_component("df_1", DataFramer())
pp.add_component("df_2", DataFramer())
pp.connect("df_1", "df_2")
return (
pp,
[
PipelineRunData(
inputs={"df_1": {"dataframe": get_df()}},
expected_outputs={"df_2": {"dataframe": get_df()}},
expected_component_calls={("df_1", 1): {"dataframe": get_df()}, ("df_2", 1): {"dataframe": get_df()}},
)
],
)

View File

@ -9,6 +9,9 @@ from haystack.core.pipeline.component_checks import _NO_OUTPUT_PRODUCED
from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic
import pandas as pd
@pytest.fixture
def basic_component():
"""Basic component with one mandatory and one optional input."""
@ -130,6 +133,26 @@ class TestCanComponentRun:
inputs = {"optional_input": [{"sender": "previous_component", "value": "test"}]}
assert can_component_run(basic_component, inputs) is False
# We added these tests because a component that returned a pandas dataframe caused the pipeline to fail.
# Previously, we compared the value of the socket using '!=' which leads to an error with dataframes.
# Instead, we use 'is not' to compare with the sentinel value.
def test_sockets_with_ambiguous_truth_value(self, basic_component, greedy_variadic_socket, regular_socket):
inputs = {
"mandatory_input": [{"sender": "previous_component", "value": pd.DataFrame.from_dict([{"value": 42}])}]
}
assert are_all_sockets_ready(basic_component, inputs, only_check_mandatory=True) is True
assert any_socket_value_from_predecessor_received(inputs["mandatory_input"]) is True
assert any_socket_input_received(inputs["mandatory_input"]) is True
assert (
has_lazy_variadic_socket_received_all_inputs(
basic_component["input_sockets"]["mandatory_input"], inputs["mandatory_input"]
)
is True
)
assert has_socket_received_all_inputs(greedy_variadic_socket, inputs["mandatory_input"]) is True
assert has_socket_received_all_inputs(regular_socket, inputs["mandatory_input"]) is True
def test_component_with_no_trigger_but_all_inputs(self, basic_component):
"""
Test case where all mandatory inputs are present with valid values,

View File

@ -8,6 +8,8 @@ from unittest.mock import patch
import pytest
import pandas as pd
from haystack import Document
from haystack.core.component import component
from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic, _empty
@ -1625,3 +1627,13 @@ class TestPipelineBase:
# Verify
assert consumed == expected_consumed
assert inputs["test_component"] == expected_remaining
def test__consume_component_inputs_with_df(self, regular_input_socket):
component = {"input_sockets": {"input1": regular_input_socket}}
inputs = {
"test_component": {"input1": [{"sender": "sender1", "value": pd.DataFrame({"a": [1, 2], "b": [1, 2]})}]}
}
consumed = PipelineBase._consume_component_inputs("test_component", component, inputs)
assert consumed["input1"].equals(pd.DataFrame({"a": [1, 2], "b": [1, 2]}))