2023-11-27 15:16:35 +01:00
|
|
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
|
|
#
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-02-06 15:19:47 +01:00
|
|
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
2023-11-27 15:16:35 +01:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
2024-07-12 10:35:23 +02:00
|
|
|
from haystack.components.joiners import BranchJoiner
|
2024-02-12 18:25:28 +01:00
|
|
|
from haystack.core.component import component
|
2025-02-06 15:19:47 +01:00
|
|
|
from haystack.core.errors import PipelineRuntimeError
|
|
|
|
from haystack.core.pipeline import Pipeline
|
2024-06-21 10:29:37 +02:00
|
|
|
|
|
|
|
|
2024-05-21 16:12:28 +02:00
|
|
|
class TestPipeline:
|
|
|
|
"""
|
|
|
|
This class contains only unit tests for the Pipeline class.
|
|
|
|
It doesn't test Pipeline.run(), that is done separately in a different way.
|
|
|
|
"""
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
def test_pipeline_thread_safety(self, waiting_component, spying_tracer):
|
|
|
|
# Initialize pipeline with synchronous components
|
|
|
|
pp = Pipeline()
|
|
|
|
pp.add_component("wait", waiting_component())
|
2024-06-10 14:54:07 +02:00
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
run_data = [{"wait_for": 1}, {"wait_for": 2}]
|
2024-06-10 14:54:07 +02:00
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
# Use ThreadPoolExecutor to run pipeline calls in parallel
|
|
|
|
with ThreadPoolExecutor(max_workers=len(run_data)) as executor:
|
|
|
|
# Submit pipeline runs to the executor
|
|
|
|
futures = [executor.submit(pp.run, data) for data in run_data]
|
2024-06-10 14:54:07 +02:00
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
# Wait for all futures to complete
|
|
|
|
for future in futures:
|
|
|
|
future.result()
|
2024-06-10 14:54:07 +02:00
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
# Verify component visits using tracer
|
|
|
|
component_spans = [sp for sp in spying_tracer.spans if sp.operation_name == "haystack.component.run"]
|
2024-06-10 14:54:07 +02:00
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
for span in component_spans:
|
|
|
|
assert span.tags["haystack.component.visits"] == 1
|
2024-06-10 14:54:07 +02:00
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
def test__run_component_success(self):
|
|
|
|
"""Test successful component execution"""
|
|
|
|
joiner_1 = BranchJoiner(type_=str)
|
|
|
|
joiner_2 = BranchJoiner(type_=str)
|
|
|
|
pp = Pipeline()
|
|
|
|
pp.add_component("joiner_1", joiner_1)
|
|
|
|
pp.add_component("joiner_2", joiner_2)
|
|
|
|
pp.connect("joiner_1", "joiner_2")
|
|
|
|
inputs = {"joiner_1": {"value": [{"sender": None, "value": "test_value"}]}}
|
2024-06-10 14:54:07 +02:00
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
outputs = pp._run_component(
|
|
|
|
component=pp._get_component_with_graph_metadata_and_visits("joiner_1", 0),
|
|
|
|
inputs=inputs,
|
|
|
|
component_visits={"joiner_1": 0, "joiner_2": 0},
|
2024-05-21 16:12:28 +02:00
|
|
|
)
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
assert outputs == {"value": "test_value"}
|
|
|
|
# We remove input in greedy variadic sockets, even if they are from the user
|
|
|
|
assert "value" not in inputs["joiner_1"]
|
2024-05-21 16:12:28 +02:00
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
def test__run_component_fail(self):
|
|
|
|
"""Test error when component doesn't return a dictionary"""
|
2024-05-21 16:12:28 +02:00
|
|
|
|
|
|
|
@component
|
2025-02-06 15:19:47 +01:00
|
|
|
class WrongOutput:
|
2024-05-21 16:12:28 +02:00
|
|
|
@component.output_types(output=str)
|
2025-02-06 15:19:47 +01:00
|
|
|
def run(self, value: str):
|
|
|
|
return "not_a_dict"
|
2024-05-24 16:41:38 +02:00
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
wrong = WrongOutput()
|
|
|
|
pp = Pipeline()
|
|
|
|
pp.add_component("wrong", wrong)
|
2024-05-24 16:41:38 +02:00
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
inputs = {"wrong": {"value": [{"sender": None, "value": "test_value"}]}}
|
2024-05-24 16:41:38 +02:00
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
with pytest.raises(PipelineRuntimeError) as exc_info:
|
|
|
|
pp._run_component(
|
|
|
|
component=pp._get_component_with_graph_metadata_and_visits("wrong", 0),
|
|
|
|
inputs=inputs,
|
|
|
|
component_visits={"wrong": 0},
|
2024-10-29 15:43:16 +01:00
|
|
|
)
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
assert "didn't return a dictionary" in str(exc_info.value)
|