haystack/test/core/pipeline/test_pipeline.py
Sebastian Husch Lee 85258f0654
fix: Fix types and formatting pipeline test_run.py (#9575)
* Fix types in test_run.py

* Get test_run.py to pass fmt-check

* Add test_run to mypy checks

* Update test folder to pass ruff linting

* Fix merge

* Fix HF tests

* Fix hf test

* Try to fix tests

* Another attempt

* minor fix

* fix SentenceTransformersDiversityRanker

* skip integrations tests due to model unavailable on HF inference

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
2025-07-03 09:49:09 +02:00

126 lines
4.6 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from concurrent.futures import ThreadPoolExecutor
import pytest
from haystack.components.joiners import BranchJoiner
from haystack.core.component import component
from haystack.core.errors import PipelineRuntimeError
from haystack.core.pipeline import Pipeline
class TestPipeline:
"""
This class contains only unit tests for the Pipeline class.
It doesn't test Pipeline.run(), that is done separately in a different way.
"""
def test_pipeline_thread_safety(self, waiting_component, spying_tracer):
# Initialize pipeline with synchronous components
pp = Pipeline()
pp.add_component("wait", waiting_component())
run_data = [{"wait_for": 0.001}, {"wait_for": 0.002}]
# Use ThreadPoolExecutor to run pipeline calls in parallel
with ThreadPoolExecutor(max_workers=len(run_data)) as executor:
# Submit pipeline runs to the executor
futures = [executor.submit(pp.run, data) for data in run_data]
# Wait for all futures to complete
for future in futures:
future.result()
# Verify component visits using tracer
component_spans = [sp for sp in spying_tracer.spans if sp.operation_name == "haystack.component.run"]
for span in component_spans:
assert span.tags["haystack.component.visits"] == 1
def test_prepare_component_inputs(self):
joiner_1 = BranchJoiner(type_=str)
joiner_2 = BranchJoiner(type_=str)
pp = Pipeline()
component_name = "joiner_1"
pp.add_component(component_name, joiner_1)
pp.add_component("joiner_2", joiner_2)
pp.connect(component_name, "joiner_2")
inputs = {"joiner_1": {"value": [{"sender": None, "value": "test_value"}]}}
comp_dict = pp._get_component_with_graph_metadata_and_visits(component_name, 0)
_ = pp._consume_component_inputs(component_name=component_name, component=comp_dict, inputs=inputs)
# We remove input in greedy variadic sockets, even if they are from the user
assert inputs == {"joiner_1": {}}
def test__run_component_success(self):
"""Test successful component execution"""
joiner_1 = BranchJoiner(type_=str)
joiner_2 = BranchJoiner(type_=str)
pp = Pipeline()
component_name = "joiner_1"
pp.add_component(component_name, joiner_1)
pp.add_component("joiner_2", joiner_2)
pp.connect(component_name, "joiner_2")
inputs = {"value": ["test_value"]}
outputs = pp._run_component(
component_name=component_name,
component=pp._get_component_with_graph_metadata_and_visits(component_name, 0),
inputs=inputs,
component_visits={component_name: 0, "joiner_2": 0},
)
assert outputs == {"value": "test_value"}
def test__run_component_fail(self):
"""Test error when component doesn't return a dictionary"""
@component
class WrongOutput:
@component.output_types(output=str)
def run(self, value: str):
return "not_a_dict"
wrong = WrongOutput()
pp = Pipeline()
pp.add_component("wrong", wrong)
inputs = {"value": "test_value"}
with pytest.raises(PipelineRuntimeError) as exc_info:
pp._run_component(
component_name="wrong",
component=pp._get_component_with_graph_metadata_and_visits("wrong", 0),
inputs=inputs,
component_visits={"wrong": 0},
)
assert "Expected a dict" in str(exc_info.value)
def test_run_component_error(self):
"""Test error when component fails to run"""
@component
class ErroringComponent:
@component.output_types(output=str)
def run(self):
raise ValueError("Test error")
erroring_component = ErroringComponent()
pp = Pipeline()
pp.add_component("erroring_component", erroring_component)
inputs = {"wrong": {"value": [{"sender": None, "value": "test_value"}]}}
with pytest.raises(PipelineRuntimeError) as exc_info:
pp._run_component(
component_name="erroring_component",
component=pp._get_component_with_graph_metadata_and_visits("erroring_component", 0),
inputs=inputs,
component_visits={"erroring_component": 0},
)
assert "Component name: 'erroring_component'" in str(exc_info.value)