2025-05-26 17:22:51 +01:00
|
|
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
|
|
#
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-06-03 15:59:42 +02:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from typing import Tuple, List, Dict, Any, Set, Union
|
2024-08-13 14:49:13 +02:00
|
|
|
from pathlib import Path
|
|
|
|
import re
|
2025-02-07 16:37:29 +01:00
|
|
|
import pytest
|
|
|
|
import asyncio
|
2025-03-04 12:06:07 +01:00
|
|
|
from pandas import DataFrame
|
2024-05-31 11:58:49 +02:00
|
|
|
|
2024-05-28 15:42:47 +02:00
|
|
|
from pytest_bdd import when, then, parsers
|
|
|
|
|
2025-02-07 16:37:29 +01:00
|
|
|
from haystack import Pipeline, AsyncPipeline
|
2024-05-31 11:58:49 +02:00
|
|
|
|
2024-08-13 14:49:13 +02:00
|
|
|
PIPELINE_NAME_REGEX = re.compile(r"\[(.*)\]")
|
|
|
|
|
2024-05-31 11:58:49 +02:00
|
|
|
|
2025-02-07 16:37:29 +01:00
|
|
|
@pytest.fixture(params=[AsyncPipeline, Pipeline])
|
|
|
|
def pipeline_class(request):
|
|
|
|
"""
|
|
|
|
A parametrized fixture that will yield AsyncPipeline for one test run
|
|
|
|
and Pipeline for the next test run.
|
|
|
|
"""
|
|
|
|
return request.param
|
|
|
|
|
|
|
|
|
2024-06-03 15:59:42 +02:00
|
|
|
@dataclass
|
|
|
|
class PipelineRunData:
|
|
|
|
"""
|
|
|
|
Holds the inputs and expected outputs for a single Pipeline run.
|
|
|
|
"""
|
|
|
|
|
|
|
|
inputs: Dict[str, Any]
|
|
|
|
include_outputs_from: Set[str] = field(default_factory=set)
|
|
|
|
expected_outputs: Dict[str, Any] = field(default_factory=dict)
|
2025-02-06 15:19:47 +01:00
|
|
|
expected_component_calls: Dict[Tuple[str, int], Dict[str, Any]] = field(default_factory=dict)
|
2024-06-03 15:59:42 +02:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class _PipelineResult:
|
|
|
|
"""
|
|
|
|
Holds the outputs and the run order of a single Pipeline run.
|
|
|
|
"""
|
|
|
|
|
|
|
|
outputs: Dict[str, Any]
|
2025-02-06 15:19:47 +01:00
|
|
|
component_calls: Dict[Tuple[str, int], Dict[str, Any]] = field(default_factory=dict)
|
2024-05-31 11:58:49 +02:00
|
|
|
|
2024-05-28 15:42:47 +02:00
|
|
|
|
|
|
|
@when("I run the Pipeline", target_fixture="pipeline_result")
|
2024-06-03 15:59:42 +02:00
|
|
|
def run_pipeline(
|
2025-02-07 16:37:29 +01:00
|
|
|
pipeline_data: Tuple[Union[AsyncPipeline, Pipeline], List[PipelineRunData]], spying_tracer
|
|
|
|
) -> Union[List[Tuple[_PipelineResult, PipelineRunData]], Exception]:
|
|
|
|
if isinstance(pipeline_data[0], AsyncPipeline):
|
|
|
|
return run_async_pipeline(pipeline_data, spying_tracer)
|
|
|
|
else:
|
|
|
|
return run_sync_pipeline(pipeline_data, spying_tracer)
|
|
|
|
|
|
|
|
|
|
|
|
def run_async_pipeline(
|
|
|
|
pipeline_data: Tuple[Union[AsyncPipeline], List[PipelineRunData]], spying_tracer
|
|
|
|
) -> Union[List[Tuple[_PipelineResult, PipelineRunData]], Exception]:
|
|
|
|
"""
|
|
|
|
Attempts to run a pipeline with the given inputs.
|
|
|
|
`pipeline_data` is a tuple that must contain:
|
|
|
|
* A Pipeline instance
|
|
|
|
* The data to run the pipeline with
|
|
|
|
|
|
|
|
If successful returns a tuple of the run outputs and the expected outputs.
|
|
|
|
In case an exceptions is raised returns that.
|
|
|
|
"""
|
|
|
|
pipeline, pipeline_run_data = pipeline_data[0], pipeline_data[1]
|
|
|
|
|
|
|
|
results: List[_PipelineResult] = []
|
|
|
|
|
|
|
|
async def run_inner(data, include_outputs_from):
|
|
|
|
"""Wrapper function to call pipeline.run_async method with required params."""
|
|
|
|
return await pipeline.run_async(data=data.inputs, include_outputs_from=include_outputs_from)
|
|
|
|
|
|
|
|
for data in pipeline_run_data:
|
|
|
|
try:
|
|
|
|
outputs = asyncio.run(run_inner(data, data.include_outputs_from))
|
|
|
|
|
|
|
|
component_calls = {
|
|
|
|
(span.tags["haystack.component.name"], span.tags["haystack.component.visits"]): span.tags[
|
|
|
|
"haystack.component.input"
|
|
|
|
]
|
|
|
|
for span in spying_tracer.spans
|
|
|
|
if "haystack.component.name" in span.tags and "haystack.component.visits" in span.tags
|
|
|
|
}
|
|
|
|
results.append(_PipelineResult(outputs=outputs, component_calls=component_calls))
|
|
|
|
spying_tracer.spans.clear()
|
|
|
|
except Exception as e:
|
|
|
|
return e
|
|
|
|
|
|
|
|
return [e for e in zip(results, pipeline_run_data)]
|
|
|
|
|
|
|
|
|
|
|
|
def run_sync_pipeline(
|
2024-06-03 15:59:42 +02:00
|
|
|
pipeline_data: Tuple[Pipeline, List[PipelineRunData]], spying_tracer
|
|
|
|
) -> Union[List[Tuple[_PipelineResult, PipelineRunData]], Exception]:
|
2024-05-28 15:42:47 +02:00
|
|
|
"""
|
|
|
|
Attempts to run a pipeline with the given inputs.
|
|
|
|
`pipeline_data` is a tuple that must contain:
|
|
|
|
* A Pipeline instance
|
2024-06-03 15:59:42 +02:00
|
|
|
* The data to run the pipeline with
|
2024-05-28 15:42:47 +02:00
|
|
|
|
|
|
|
If successful returns a tuple of the run outputs and the expected outputs.
|
|
|
|
In case an exceptions is raised returns that.
|
|
|
|
"""
|
2024-06-03 15:59:42 +02:00
|
|
|
pipeline, pipeline_run_data = pipeline_data[0], pipeline_data[1]
|
2024-05-31 11:58:49 +02:00
|
|
|
|
2024-06-03 15:59:42 +02:00
|
|
|
results: List[_PipelineResult] = []
|
2024-05-31 11:58:49 +02:00
|
|
|
|
2024-06-03 15:59:42 +02:00
|
|
|
for data in pipeline_run_data:
|
2024-05-31 11:58:49 +02:00
|
|
|
try:
|
2024-06-03 15:59:42 +02:00
|
|
|
outputs = pipeline.run(data=data.inputs, include_outputs_from=data.include_outputs_from)
|
2025-02-06 15:19:47 +01:00
|
|
|
|
|
|
|
component_calls = {
|
|
|
|
(span.tags["haystack.component.name"], span.tags["haystack.component.visits"]): span.tags[
|
|
|
|
"haystack.component.input"
|
|
|
|
]
|
2024-05-31 11:58:49 +02:00
|
|
|
for span in spying_tracer.spans
|
2025-02-06 15:19:47 +01:00
|
|
|
if "haystack.component.name" in span.tags and "haystack.component.visits" in span.tags
|
|
|
|
}
|
|
|
|
results.append(_PipelineResult(outputs=outputs, component_calls=component_calls))
|
2024-05-31 11:58:49 +02:00
|
|
|
spying_tracer.spans.clear()
|
|
|
|
except Exception as e:
|
|
|
|
return e
|
2024-06-03 15:59:42 +02:00
|
|
|
return [e for e in zip(results, pipeline_run_data)]
|
2024-05-28 15:42:47 +02:00
|
|
|
|
|
|
|
|
2024-08-13 14:49:13 +02:00
|
|
|
@then("draw it to file")
|
|
|
|
def draw_pipeline(pipeline_data: Tuple[Pipeline, List[PipelineRunData]], request):
|
|
|
|
"""
|
|
|
|
Draw the pipeline to a file with the same name as the test.
|
|
|
|
"""
|
|
|
|
if m := PIPELINE_NAME_REGEX.search(request.node.name):
|
|
|
|
name = m.group(1).replace(" ", "_")
|
|
|
|
pipeline = pipeline_data[0]
|
|
|
|
graphs_dir = Path(request.config.rootpath) / "test_pipeline_graphs"
|
|
|
|
graphs_dir.mkdir(exist_ok=True)
|
|
|
|
pipeline.draw(graphs_dir / f"{name}.png")
|
|
|
|
|
|
|
|
|
2024-05-28 15:42:47 +02:00
|
|
|
@then("it should return the expected result")
|
2024-06-03 15:59:42 +02:00
|
|
|
def check_pipeline_result(pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]]):
|
|
|
|
for res, data in pipeline_result:
|
2025-02-19 10:10:48 +01:00
|
|
|
compare_outputs_with_dataframes(res.outputs, data.expected_outputs)
|
2024-05-28 15:42:47 +02:00
|
|
|
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
@then("components are called with the expected inputs")
|
|
|
|
def check_component_calls(pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]]):
|
2024-06-03 15:59:42 +02:00
|
|
|
for res, data in pipeline_result:
|
2025-02-19 10:10:48 +01:00
|
|
|
assert compare_outputs_with_dataframes(res.component_calls, data.expected_component_calls)
|
2024-05-28 15:42:47 +02:00
|
|
|
|
|
|
|
|
|
|
|
@then(parsers.parse("it must have raised {exception_class_name}"))
|
2024-05-31 11:58:49 +02:00
|
|
|
def check_pipeline_raised(pipeline_result: Exception, exception_class_name: str):
|
2024-05-28 15:42:47 +02:00
|
|
|
assert pipeline_result.__class__.__name__ == exception_class_name
|
2025-02-19 10:10:48 +01:00
|
|
|
|
|
|
|
|
|
|
|
def compare_outputs_with_dataframes(actual: Dict, expected: Dict) -> bool:
|
|
|
|
"""
|
|
|
|
Compare two component_calls or pipeline outputs dictionaries where values may contain DataFrames.
|
|
|
|
"""
|
|
|
|
assert actual.keys() == expected.keys()
|
|
|
|
|
|
|
|
for key in actual:
|
|
|
|
actual_data = actual[key]
|
|
|
|
expected_data = expected[key]
|
|
|
|
|
|
|
|
assert actual_data.keys() == expected_data.keys()
|
|
|
|
|
|
|
|
for data_key in actual_data:
|
|
|
|
actual_value = actual_data[data_key]
|
|
|
|
expected_value = expected_data[data_key]
|
|
|
|
|
2025-03-04 12:06:07 +01:00
|
|
|
if isinstance(actual_value, DataFrame) and isinstance(expected_value, DataFrame):
|
2025-02-19 10:10:48 +01:00
|
|
|
assert actual_value.equals(expected_value)
|
|
|
|
else:
|
|
|
|
assert actual_value == expected_value
|
|
|
|
|
|
|
|
return True
|