mathislucka 8c54f06a19
fix: component checks failing for components that return dataframes (#8873)
* fix: use is not to compare to sentinel value

* chore: release notes

* Update releasenotes/notes/fix-component-checks-with-ambiguous-truth-values-949c447b3702e427.yaml

Co-authored-by: David S. Batista <dsbatista@gmail.com>

* fix: another sentinel value

* test: also test base class

* add pandas as test dependency

* format

* Trigger CI

* mark test with xfail strict=False

---------

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
Co-authored-by: David S. Batista <dsbatista@gmail.com>
Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
2025-02-19 09:10:48 +00:00

182 lines
6.4 KiB
Python

from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Any, Set, Union
from pathlib import Path
import re
import pytest
import asyncio
import pandas as pd
from pytest_bdd import when, then, parsers
from haystack import Pipeline, AsyncPipeline
PIPELINE_NAME_REGEX = re.compile(r"\[(.*)\]")
@pytest.fixture(params=[AsyncPipeline, Pipeline])
def pipeline_class(request):
"""
A parametrized fixture that will yield AsyncPipeline for one test run
and Pipeline for the next test run.
"""
return request.param
@dataclass
class PipelineRunData:
"""
Holds the inputs and expected outputs for a single Pipeline run.
"""
inputs: Dict[str, Any]
include_outputs_from: Set[str] = field(default_factory=set)
expected_outputs: Dict[str, Any] = field(default_factory=dict)
expected_component_calls: Dict[Tuple[str, int], Dict[str, Any]] = field(default_factory=dict)
@dataclass
class _PipelineResult:
"""
Holds the outputs and the run order of a single Pipeline run.
"""
outputs: Dict[str, Any]
component_calls: Dict[Tuple[str, int], Dict[str, Any]] = field(default_factory=dict)
@when("I run the Pipeline", target_fixture="pipeline_result")
def run_pipeline(
pipeline_data: Tuple[Union[AsyncPipeline, Pipeline], List[PipelineRunData]], spying_tracer
) -> Union[List[Tuple[_PipelineResult, PipelineRunData]], Exception]:
if isinstance(pipeline_data[0], AsyncPipeline):
return run_async_pipeline(pipeline_data, spying_tracer)
else:
return run_sync_pipeline(pipeline_data, spying_tracer)
def run_async_pipeline(
pipeline_data: Tuple[Union[AsyncPipeline], List[PipelineRunData]], spying_tracer
) -> Union[List[Tuple[_PipelineResult, PipelineRunData]], Exception]:
"""
Attempts to run a pipeline with the given inputs.
`pipeline_data` is a tuple that must contain:
* A Pipeline instance
* The data to run the pipeline with
If successful returns a tuple of the run outputs and the expected outputs.
In case an exceptions is raised returns that.
"""
pipeline, pipeline_run_data = pipeline_data[0], pipeline_data[1]
results: List[_PipelineResult] = []
async def run_inner(data, include_outputs_from):
"""Wrapper function to call pipeline.run_async method with required params."""
return await pipeline.run_async(data=data.inputs, include_outputs_from=include_outputs_from)
for data in pipeline_run_data:
try:
outputs = asyncio.run(run_inner(data, data.include_outputs_from))
component_calls = {
(span.tags["haystack.component.name"], span.tags["haystack.component.visits"]): span.tags[
"haystack.component.input"
]
for span in spying_tracer.spans
if "haystack.component.name" in span.tags and "haystack.component.visits" in span.tags
}
results.append(_PipelineResult(outputs=outputs, component_calls=component_calls))
spying_tracer.spans.clear()
except Exception as e:
return e
return [e for e in zip(results, pipeline_run_data)]
def run_sync_pipeline(
pipeline_data: Tuple[Pipeline, List[PipelineRunData]], spying_tracer
) -> Union[List[Tuple[_PipelineResult, PipelineRunData]], Exception]:
"""
Attempts to run a pipeline with the given inputs.
`pipeline_data` is a tuple that must contain:
* A Pipeline instance
* The data to run the pipeline with
If successful returns a tuple of the run outputs and the expected outputs.
In case an exceptions is raised returns that.
"""
pipeline, pipeline_run_data = pipeline_data[0], pipeline_data[1]
results: List[_PipelineResult] = []
for data in pipeline_run_data:
try:
outputs = pipeline.run(data=data.inputs, include_outputs_from=data.include_outputs_from)
component_calls = {
(span.tags["haystack.component.name"], span.tags["haystack.component.visits"]): span.tags[
"haystack.component.input"
]
for span in spying_tracer.spans
if "haystack.component.name" in span.tags and "haystack.component.visits" in span.tags
}
results.append(_PipelineResult(outputs=outputs, component_calls=component_calls))
spying_tracer.spans.clear()
except Exception as e:
return e
return [e for e in zip(results, pipeline_run_data)]
@then("draw it to file")
def draw_pipeline(pipeline_data: Tuple[Pipeline, List[PipelineRunData]], request):
"""
Draw the pipeline to a file with the same name as the test.
"""
if m := PIPELINE_NAME_REGEX.search(request.node.name):
name = m.group(1).replace(" ", "_")
pipeline = pipeline_data[0]
graphs_dir = Path(request.config.rootpath) / "test_pipeline_graphs"
graphs_dir.mkdir(exist_ok=True)
pipeline.draw(graphs_dir / f"{name}.png")
@then("it should return the expected result")
def check_pipeline_result(pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]]):
for res, data in pipeline_result:
compare_outputs_with_dataframes(res.outputs, data.expected_outputs)
@then("components are called with the expected inputs")
def check_component_calls(pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]]):
for res, data in pipeline_result:
assert compare_outputs_with_dataframes(res.component_calls, data.expected_component_calls)
@then(parsers.parse("it must have raised {exception_class_name}"))
def check_pipeline_raised(pipeline_result: Exception, exception_class_name: str):
assert pipeline_result.__class__.__name__ == exception_class_name
def compare_outputs_with_dataframes(actual: Dict, expected: Dict) -> bool:
"""
Compare two component_calls or pipeline outputs dictionaries where values may contain DataFrames.
"""
assert actual.keys() == expected.keys()
for key in actual:
actual_data = actual[key]
expected_data = expected[key]
assert actual_data.keys() == expected_data.keys()
for data_key in actual_data:
actual_value = actual_data[data_key]
expected_value = expected_data[data_key]
if isinstance(actual_value, pd.DataFrame) and isinstance(expected_value, pd.DataFrame):
assert actual_value.equals(expected_value)
else:
assert actual_value == expected_value
return True