from dataclasses import dataclass, field from typing import Tuple, List, Dict, Any, Set, Union from pathlib import Path import re from pytest_bdd import when, then, parsers from haystack import Pipeline PIPELINE_NAME_REGEX = re.compile(r"\[(.*)\]") @dataclass class PipelineRunData: """ Holds the inputs and expected outputs for a single Pipeline run. """ inputs: Dict[str, Any] include_outputs_from: Set[str] = field(default_factory=set) expected_outputs: Dict[str, Any] = field(default_factory=dict) expected_component_calls: Dict[Tuple[str, int], Dict[str, Any]] = field(default_factory=dict) @dataclass class _PipelineResult: """ Holds the outputs and the run order of a single Pipeline run. """ outputs: Dict[str, Any] component_calls: Dict[Tuple[str, int], Dict[str, Any]] = field(default_factory=dict) @when("I run the Pipeline", target_fixture="pipeline_result") def run_pipeline( pipeline_data: Tuple[Pipeline, List[PipelineRunData]], spying_tracer ) -> Union[List[Tuple[_PipelineResult, PipelineRunData]], Exception]: """ Attempts to run a pipeline with the given inputs. `pipeline_data` is a tuple that must contain: * A Pipeline instance * The data to run the pipeline with If successful returns a tuple of the run outputs and the expected outputs. In case an exceptions is raised returns that. """ pipeline, pipeline_run_data = pipeline_data[0], pipeline_data[1] results: List[_PipelineResult] = [] for data in pipeline_run_data: try: outputs = pipeline.run(data=data.inputs, include_outputs_from=data.include_outputs_from) component_calls = { (span.tags["haystack.component.name"], span.tags["haystack.component.visits"]): span.tags[ "haystack.component.input" ] for span in spying_tracer.spans if "haystack.component.name" in span.tags and "haystack.component.visits" in span.tags } results.append(_PipelineResult(outputs=outputs, component_calls=component_calls)) spying_tracer.spans.clear() except Exception as e: return e return [e for e in zip(results, pipeline_run_data)] @then("draw it to file") def draw_pipeline(pipeline_data: Tuple[Pipeline, List[PipelineRunData]], request): """ Draw the pipeline to a file with the same name as the test. """ if m := PIPELINE_NAME_REGEX.search(request.node.name): name = m.group(1).replace(" ", "_") pipeline = pipeline_data[0] graphs_dir = Path(request.config.rootpath) / "test_pipeline_graphs" graphs_dir.mkdir(exist_ok=True) pipeline.draw(graphs_dir / f"{name}.png") @then("it should return the expected result") def check_pipeline_result(pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]]): for res, data in pipeline_result: assert res.outputs == data.expected_outputs @then("components are called with the expected inputs") def check_component_calls(pipeline_result: List[Tuple[_PipelineResult, PipelineRunData]]): for res, data in pipeline_result: assert res.component_calls == data.expected_component_calls @then(parsers.parse("it must have raised {exception_class_name}")) def check_pipeline_raised(pipeline_result: Exception, exception_class_name: str): assert pipeline_result.__class__.__name__ == exception_class_name