2023-11-27 15:16:35 +01:00
|
|
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
|
|
#
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from pathlib import Path
|
|
|
|
from pprint import pprint
|
|
|
|
import logging
|
|
|
|
|
2023-11-28 09:58:56 +01:00
|
|
|
from haystack.core.pipeline import Pipeline
|
|
|
|
from haystack.testing.sample_components import (
|
2023-11-27 15:16:35 +01:00
|
|
|
Accumulate,
|
|
|
|
AddFixedValue,
|
|
|
|
Greet,
|
|
|
|
Parity,
|
|
|
|
Threshold,
|
|
|
|
Double,
|
|
|
|
Sum,
|
|
|
|
Repeat,
|
|
|
|
Subtract,
|
|
|
|
MergeLoop,
|
|
|
|
)
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
|
|
|
|
|
|
|
|
def test_complex_pipeline(tmp_path):
|
|
|
|
loop_merger = MergeLoop(expected_type=int, inputs=["in_1", "in_2"])
|
|
|
|
summer = Sum()
|
|
|
|
|
|
|
|
pipeline = Pipeline(max_loops_allowed=2)
|
|
|
|
pipeline.add_component("greet_first", Greet(message="Hello, the value is {value}."))
|
|
|
|
pipeline.add_component("accumulate_1", Accumulate())
|
|
|
|
pipeline.add_component("add_two", AddFixedValue(add=2))
|
|
|
|
pipeline.add_component("parity", Parity())
|
|
|
|
pipeline.add_component("add_one", AddFixedValue(add=1))
|
|
|
|
pipeline.add_component("accumulate_2", Accumulate())
|
|
|
|
|
|
|
|
pipeline.add_component("loop_merger", loop_merger)
|
|
|
|
pipeline.add_component("below_10", Threshold(threshold=10))
|
|
|
|
pipeline.add_component("double", Double())
|
|
|
|
|
|
|
|
pipeline.add_component("greet_again", Greet(message="Hello again, now the value is {value}."))
|
|
|
|
pipeline.add_component("sum", summer)
|
|
|
|
|
|
|
|
pipeline.add_component("greet_enumerator", Greet(message="Hello from enumerator, here the value became {value}."))
|
|
|
|
pipeline.add_component("enumerate", Repeat(outputs=["first", "second"]))
|
|
|
|
pipeline.add_component("add_three", AddFixedValue(add=3))
|
|
|
|
|
|
|
|
pipeline.add_component("diff", Subtract())
|
|
|
|
pipeline.add_component("greet_one_last_time", Greet(message="Bye bye! The value here is {value}!"))
|
|
|
|
pipeline.add_component("replicate", Repeat(outputs=["first", "second"]))
|
|
|
|
pipeline.add_component("add_five", AddFixedValue(add=5))
|
|
|
|
pipeline.add_component("add_four", AddFixedValue(add=4))
|
|
|
|
pipeline.add_component("accumulate_3", Accumulate())
|
|
|
|
|
|
|
|
pipeline.connect("greet_first", "accumulate_1")
|
|
|
|
pipeline.connect("accumulate_1", "add_two")
|
|
|
|
pipeline.connect("add_two", "parity")
|
|
|
|
|
|
|
|
pipeline.connect("parity.even", "greet_again")
|
|
|
|
pipeline.connect("greet_again", "sum.values")
|
|
|
|
pipeline.connect("sum", "diff.first_value")
|
|
|
|
pipeline.connect("diff", "greet_one_last_time")
|
|
|
|
pipeline.connect("greet_one_last_time", "replicate")
|
|
|
|
pipeline.connect("replicate.first", "add_five.value")
|
|
|
|
pipeline.connect("replicate.second", "add_four.value")
|
|
|
|
pipeline.connect("add_four", "accumulate_3")
|
|
|
|
|
|
|
|
pipeline.connect("parity.odd", "add_one.value")
|
|
|
|
pipeline.connect("add_one", "loop_merger.in_1")
|
|
|
|
pipeline.connect("loop_merger", "below_10")
|
|
|
|
|
|
|
|
pipeline.connect("below_10.below", "double")
|
|
|
|
pipeline.connect("double", "loop_merger.in_2")
|
|
|
|
|
|
|
|
pipeline.connect("below_10.above", "accumulate_2")
|
|
|
|
pipeline.connect("accumulate_2", "diff.second_value")
|
|
|
|
|
|
|
|
pipeline.connect("greet_enumerator", "enumerate")
|
|
|
|
pipeline.connect("enumerate.second", "sum.values")
|
|
|
|
|
|
|
|
pipeline.connect("enumerate.first", "add_three.value")
|
|
|
|
pipeline.connect("add_three", "sum.values")
|
|
|
|
|
|
|
|
pipeline.draw(tmp_path / "complex_pipeline.png")
|
|
|
|
|
|
|
|
results = pipeline.run({"greet_first": {"value": 1}, "greet_enumerator": {"value": 1}})
|
|
|
|
assert results == {"accumulate_3": {"value": -7}, "add_five": {"result": -6}}
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
test_complex_pipeline(Path(__file__).parent)
|