refactor: Isolate logic that finds next runnable component waiting for input (#7880)

* Fix formatting

* Isolate logic that finds next runnable component waiting for input

* Explain more lazy variadics

* Enhance logic following review suggestions

* Simplify code to use a single for

* Fix test
This commit is contained in:
Silvano Cerza 2024-06-18 16:43:19 +02:00 committed by GitHub
parent ff79da5f55
commit 15ee622b3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 225 additions and 206 deletions

View File

@ -133,9 +133,7 @@ class PipelineBase:
}
@classmethod
def from_dict(
cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs
) -> T:
def from_dict(cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs) -> T:
"""
Deserializes the pipeline from a dictionary.
@ -169,10 +167,7 @@ class PipelineBase:
importlib.import_module(module)
# ...then try again
if component_data["type"] not in component.registry:
raise PipelineError(
f"Successfully imported module {module} but can't find it in the component registry."
"This is unexpected and most likely a bug."
)
raise PipelineError(f"Successfully imported module {module} but can't find it in the component registry." "This is unexpected and most likely a bug.")
except (ImportError, PipelineError) as e:
raise PipelineError(f"Component '{component_data['type']}' not imported.") from e
@ -284,15 +279,10 @@ class PipelineBase:
# Component instances must be components
if not isinstance(instance, Component):
raise PipelineValidationError(
f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
)
raise PipelineValidationError(f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?")
if getattr(instance, "__haystack_added_to_pipeline__", None):
msg = (
"Component has already been added in another Pipeline. "
"Components can't be shared between Pipelines. Create a new instance instead."
)
msg = "Component has already been added in another Pipeline. Components can't be shared between Pipelines. Create a new instance instead."
raise PipelineError(msg)
setattr(instance, "__haystack_added_to_pipeline__", self)
@ -390,35 +380,21 @@ class PipelineBase:
if sender_socket_name:
sender_socket = from_sockets.get(sender_socket_name)
if not sender_socket:
raise PipelineConnectError(
f"'{sender} does not exist. "
f"Output connections of {sender_component_name} are: "
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
)
raise PipelineConnectError(f"'{sender} does not exist. " f"Output connections of {sender_component_name} are: " + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()]))
receiver_socket: Optional[InputSocket] = None
if receiver_socket_name:
receiver_socket = to_sockets.get(receiver_socket_name)
if not receiver_socket:
raise PipelineConnectError(
f"'{receiver} does not exist. "
f"Input connections of {receiver_component_name} are: "
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
)
raise PipelineConnectError(f"'{receiver} does not exist. " f"Input connections of {receiver_component_name} are: " + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()]))
# Look for a matching connection among the possible ones.
# Note that if there is more than one possible connection but two sockets match by name, they're paired.
sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values())
receiver_socket_candidates: List[InputSocket] = (
[receiver_socket] if receiver_socket else list(to_sockets.values())
)
receiver_socket_candidates: List[InputSocket] = [receiver_socket] if receiver_socket else list(to_sockets.values())
# Find all possible connections between these two components
possible_connections = [
(sender_sock, receiver_sock)
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates)
if _types_are_compatible(sender_sock.type, receiver_sock.type)
]
possible_connections = [(sender_sock, receiver_sock) for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates) if _types_are_compatible(sender_sock.type, receiver_sock.type)]
# We need this status for error messages, since we might need it in multiple places we calculate it here
status = _connections_status(
@ -431,15 +407,9 @@ class PipelineBase:
if not possible_connections:
# There's no possible connection between these two components
if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1:
msg = (
f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with '{receiver_component_name}.{receiver_socket_candidates[0].name}': "
f"their declared input and output types do not match.\n{status}"
)
msg = f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with '{receiver_component_name}.{receiver_socket_candidates[0].name}': " f"their declared input and output types do not match.\n{status}"
else:
msg = (
f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': "
f"no matching connections available.\n{status}"
)
msg = f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': " f"no matching connections available.\n{status}"
raise PipelineConnectError(msg)
if len(possible_connections) == 1:
@ -449,9 +419,7 @@ class PipelineBase:
if len(possible_connections) > 1:
# There are multiple possible connection, let's try to match them by name
name_matches = [
(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
]
name_matches = [(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name]
if len(name_matches) != 1:
# There's are either no matches or more than one, we can't pick one reliably
msg = (
@ -494,10 +462,7 @@ class PipelineBase:
if receiver_socket.senders and not receiver_socket.is_variadic:
# Only variadic input sockets can receive from multiple senders
msg = (
f"Cannot connect '{sender_component_name}.{sender_socket.name}' with '{receiver_component_name}.{receiver_socket.name}': "
f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
)
msg = f"Cannot connect '{sender_component_name}.{sender_socket.name}' with '{receiver_component_name}.{receiver_socket.name}': " f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
raise PipelineConnectError(msg)
# Update the sockets with the new connection
@ -587,11 +552,7 @@ class PipelineBase:
A dictionary where each key is a pipeline component name and each value is a dictionary of
output sockets of that component.
"""
outputs = {
comp: {socket.name: {"type": socket.type} for socket in data}
for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items()
if data
}
outputs = {comp: {socket.name: {"type": socket.type} for socket in data} for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items() if data}
return outputs
def show(self) -> None:
@ -679,9 +640,7 @@ class PipelineBase:
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
if socket.senders and socket_name in component_inputs and not socket.is_variadic:
raise ValueError(
f"Input {socket_name} for component {component_name} is already sent by {socket.senders}."
)
raise ValueError(f"Input {socket_name} for component {component_name} is already sent by {socket.senders}.")
def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
"""
@ -785,9 +744,7 @@ class PipelineBase:
return to_run
@classmethod
def from_template(
cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None
) -> "PipelineBase":
def from_template(cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None) -> "PipelineBase":
"""
Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options.
@ -932,10 +889,99 @@ class PipelineBase:
# Returns the output without the keys that were distributed to other Components
return {k: v for k, v in component_result.items() if k not in to_remove_from_component_result}
def _connections_status(
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
def _enqueue_next_runnable_component(
self,
inputs_by_component: Dict[str, Dict[str, Any]],
to_run: List[Tuple[str, Component]],
waiting_for_input: List[Tuple[str, Component]],
):
"""
Finds the next Component that can be run and adds it to the queue of Components to run.
:param inputs_by_component: The current state of the inputs divided by Component name
:param to_run: Queue of Components to run
:param waiting_for_input: Queue of Components waiting for input
"""
for name, comp in waiting_for_input:
if name not in inputs_by_component:
inputs_by_component[name] = {}
# Small utility function to check if a Component has a Variadic input that is not greedy.
def is_lazy_variadic(c: Component) -> bool:
is_variadic = any(
socket.is_variadic
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
)
if not is_variadic:
return False
return not getattr(c, "__haystack_is_greedy__", False)
# Small utility function to check if a Component has all inputs with defaults.
def has_all_inputs_with_defaults(c: Component) -> bool:
return all(
not socket.is_mandatory
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
)
# Updates the inputs with the default values for the inputs that are missing
def add_missing_input_defaults(name: str, comp: Component, inputs_by_component: Dict[str, Dict[str, Any]]):
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.name not in inputs_by_component[name]:
inputs_by_component[name][input_socket.name] = input_socket.default_value
all_lazy_variadic = True
all_with_default_inputs = True
filtered_waiting_for_input = []
for name, comp in waiting_for_input:
if not is_lazy_variadic(comp):
# Components with variadic inputs that are not greedy must be removed only if there's nothing else to run at this stage.
# We need to wait as long as possible to run them, so we can collect as most inputs as we can.
all_lazy_variadic = False
if not has_all_inputs_with_defaults(comp):
# Components that have defaults for all their inputs must be treated the same identical way as we treat
# lazy variadic components. If there are only components with defaults we can run them.
# If we don't do this the order of execution of the Pipeline's Components will be affected cause we
# enqueue the Components in `to_run` at the start using the order they are added in the Pipeline.
# If a Component A with defaults is added before a Component B that has no defaults, but in the Pipeline
# logic A must be executed after B it could run instead before if we don't do this check.
all_with_default_inputs = False
if not is_lazy_variadic(comp) and not has_all_inputs_with_defaults(comp):
# Keep track of the Components that are not lazy variadic and don't have all inputs with defaults.
# We'll handle these later if necessary.
filtered_waiting_for_input.append((name, comp))
# If all Components are lazy variadic or all Components have all inputs with defaults we can get one to run
if all_lazy_variadic or all_with_default_inputs:
pair = waiting_for_input.pop(0)
to_run.append(pair)
# Add missing input defaults if needed, this is a no-op for Components with Variadic inputs
add_missing_input_defaults(name, comp, inputs_by_component)
return
for name, comp in filtered_waiting_for_input:
# Find the first component that has all the inputs it needs to run
has_enough_inputs = True
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.name not in inputs_by_component[name]:
if input_socket.is_mandatory:
has_enough_inputs = False
break
if input_socket.name not in inputs_by_component[name]:
inputs_by_component[name][input_socket.name] = input_socket.default_value
if has_enough_inputs:
break
waiting_for_input.remove((name, comp))
to_run.append((name, comp))
def _connections_status(sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]):
"""
Lists the status of the sockets, for error messages.
"""
@ -950,9 +996,7 @@ def _connections_status(
sender_status = f"sent by {','.join(receiver_socket.senders)}"
else:
sender_status = "available"
receiver_sockets_entries.append(
f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
)
receiver_sockets_entries.append(f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})")
receiver_sockets_list = "\n".join(receiver_sockets_entries)
return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"

View File

@ -86,10 +86,7 @@ class Pipeline(PipelineBase):
inputs[socket.name] = []
if not isinstance(res, Mapping):
raise PipelineRuntimeError(
f"Component '{name}' didn't return a dictionary. "
"Components must always return dictionaries: check the the documentation."
)
raise PipelineRuntimeError(f"Component '{name}' didn't return a dictionary. " "Components must always return dictionaries: check the the documentation.")
span.set_tag("haystack.component.visits", self.graph.nodes[name]["visits"])
span.set_content_tag("haystack.component.output", res)
@ -274,18 +271,15 @@ class Pipeline(PipelineBase):
# Check if we're stuck in a loop.
# It's important to check whether previous waitings are None as it could be that no
# Component has actually been run yet.
if (
before_last_waiting_for_input is not None
and last_waiting_for_input is not None
and before_last_waiting_for_input == last_waiting_for_input
):
if before_last_waiting_for_input is not None and last_waiting_for_input is not None and before_last_waiting_for_input == last_waiting_for_input:
# Are we actually stuck or there's a lazy variadic or a component with has only default inputs waiting for input?
# This is our last resort, if there's no lazy variadic or component with only default inputs waiting for input
# we're stuck for real and we can't make any progress.
for name, comp in waiting_for_input:
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
has_only_defaults = all(
not socket.is_mandatory for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore
not socket.is_mandatory
for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore
)
if is_variadic and not comp.__haystack_is_greedy__ or has_only_defaults: # type: ignore[attr-defined]
break
@ -315,69 +309,10 @@ class Pipeline(PipelineBase):
continue
before_last_waiting_for_input = (
last_waiting_for_input.copy() if last_waiting_for_input is not None else None
)
before_last_waiting_for_input = last_waiting_for_input.copy() if last_waiting_for_input is not None else None
last_waiting_for_input = {item[0] for item in waiting_for_input}
# Remove from waiting only if there is actually enough input to run
for name, comp in waiting_for_input:
if name not in last_inputs:
last_inputs[name] = {}
# Lazy variadics must be removed only if there's nothing else to run at this stage
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
there_are_only_lazy_variadics = True
for other_name, other_comp in waiting_for_input:
if name == other_name:
continue
there_are_only_lazy_variadics &= (
any(
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
)
and not other_comp.__haystack_is_greedy__ # type: ignore[attr-defined]
)
if not there_are_only_lazy_variadics:
continue
# Components that have defaults for all their inputs must be treated the same identical way as we treat
# lazy variadic components. If there are only components with defaults we can run them.
# If we don't do this the order of execution of the Pipeline's Components will be affected cause we
# enqueue the Components in `to_run` at the start using the order they are added in the Pipeline.
# If a Component A with defaults is added before a Component B that has no defaults, but in the Pipeline
# logic A must be executed after B it could run instead before if we don't do this check.
has_only_defaults = all(
not socket.is_mandatory for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore
)
if has_only_defaults:
there_are_only_components_with_defaults = True
for other_name, other_comp in waiting_for_input:
if name == other_name:
continue
there_are_only_components_with_defaults &= all(
not s.is_mandatory for s in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
)
if not there_are_only_components_with_defaults:
continue
# Find the first component that has all the inputs it needs to run
has_enough_inputs = True
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.is_mandatory and input_socket.name not in last_inputs[name]:
has_enough_inputs = False
break
if input_socket.is_mandatory:
continue
if input_socket.name not in last_inputs[name]:
last_inputs[name][input_socket.name] = input_socket.default_value
if has_enough_inputs:
break
waiting_for_input.remove((name, comp))
to_run.append((name, comp))
self._enqueue_next_runnable_component(last_inputs, to_run, waiting_for_input)
if len(include_outputs_from) > 0:
for name, output in extra_outputs.items():

View File

@ -77,9 +77,7 @@ class TestPipeline:
@patch("haystack.core.pipeline.base.is_in_jupyter")
@patch("IPython.display.Image")
@patch("IPython.display.display")
def test_show_in_notebook(
self, mock_ipython_display, mock_ipython_image, mock_is_in_jupyter, mock_to_mermaid_image
):
def test_show_in_notebook(self, mock_ipython_display, mock_ipython_image, mock_is_in_jupyter, mock_to_mermaid_image):
pipe = Pipeline()
mock_to_mermaid_image.return_value = b"some_image_data"
@ -163,9 +161,7 @@ class TestPipeline:
assert component_2.__haystack_added_to_pipeline__ is None
assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=[])}
assert component_2.__haystack_output__._sockets_dict == {
"out": OutputSocket(name="out", type=int, receivers=[])
}
assert component_2.__haystack_output__._sockets_dict == {"out": OutputSocket(name="out", type=int, receivers=[])}
pipe2 = Pipeline()
pipe2.add_component("component_4", Some())
@ -175,12 +171,8 @@ class TestPipeline:
pipe2.connect("component_4", "component_2")
pipe2.connect("component_2", "component_5")
assert component_2.__haystack_added_to_pipeline__ is pipe2
assert component_2.__haystack_input__._sockets_dict == {
"in": InputSocket(name="in", type=int, senders=["component_4"])
}
assert component_2.__haystack_output__._sockets_dict == {
"out": OutputSocket(name="out", type=int, receivers=["component_5"])
}
assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=["component_4"])}
assert component_2.__haystack_output__._sockets_dict == {"out": OutputSocket(name="out", type=int, receivers=["component_5"])}
# instance = pipe2.get_component("some")
# assert instance == component
@ -210,16 +202,7 @@ class TestPipeline:
pipe.connect("double", "add_default")
expected_repr = (
f"{object.__repr__(pipe)}\n"
"🧱 Metadata\n"
" - test: test\n"
"🚅 Components\n"
" - add_two: AddFixedValue\n"
" - add_default: AddFixedValue\n"
" - double: Double\n"
"🛤️ Connections\n"
" - add_two.result -> double.value (int)\n"
" - double.value -> add_default.value (int)\n"
f"{object.__repr__(pipe)}\n" "🧱 Metadata\n" " - test: test\n" "🚅 Components\n" " - add_two: AddFixedValue\n" " - add_default: AddFixedValue\n" " - double: Double\n" "🛤️ Connections\n" " - add_two.result -> double.value (int)\n" " - double.value -> add_default.value (int)\n"
)
assert repr(pipe) == expected_repr
@ -379,9 +362,7 @@ class TestPipeline:
components_seen_in_callback.append(name)
pipe = Pipeline.from_dict(
data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback)
)
pipe = Pipeline.from_dict(data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback))
assert components_seen_in_callback == ["add_two", "add_default", "double", "greet"]
add_two = pipe.graph.nodes["add_two"]["instance"]
assert add_two.add == 2
@ -403,9 +384,7 @@ class TestPipeline:
init_params["message"] = "modified test"
init_params["log_level"] = "DEBUG"
pipe = Pipeline.from_dict(
data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_modify)
)
pipe = Pipeline.from_dict(data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_modify))
add_two = pipe.graph.nodes["add_two"]["instance"]
assert add_two.add == 3
add_default = pipe.graph.nodes["add_default"]["instance"]
@ -558,9 +537,7 @@ class TestPipeline:
p.connect("a.x", "c.x")
p.connect("b.y", "c.y")
assert p.inputs() == {}
assert p.inputs(include_components_with_connected_inputs=True) == {
"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}
}
assert p.inputs(include_components_with_connected_inputs=True) == {"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}}
def test_describe_input_some_components_with_no_inputs(self):
A = component_class("A", input_types={}, output={"x": 0})
@ -742,16 +719,10 @@ class TestPipeline:
assert pipe.graph.nodes[node]["visits"] == 0
def test__init_to_run(self):
ComponentWithVariadic = component_class(
"ComponentWithVariadic", input_types={"in": Variadic[int]}, output_types={"out": int}
)
ComponentWithVariadic = component_class("ComponentWithVariadic", input_types={"in": Variadic[int]}, output_types={"out": int})
ComponentWithNoInputs = component_class("ComponentWithNoInputs", input_types={}, output_types={"out": int})
ComponentWithSingleInput = component_class(
"ComponentWithSingleInput", input_types={"in": int}, output_types={"out": int}
)
ComponentWithMultipleInputs = component_class(
"ComponentWithMultipleInputs", input_types={"in1": int, "in2": int}, output_types={"out": int}
)
ComponentWithSingleInput = component_class("ComponentWithSingleInput", input_types={"in": int}, output_types={"out": int})
ComponentWithMultipleInputs = component_class("ComponentWithMultipleInputs", input_types={"in1": int, "in2": int}, output_types={"out": int})
pipe = Pipeline()
pipe.add_component("with_variadic", ComponentWithVariadic())
@ -812,9 +783,7 @@ class TestPipeline:
assert id(res["first_mock"]["x"]) != id(res["second_mock"]["x"])
def test__prepare_component_input_data_with_connected_inputs(self):
MockComponent = component_class(
"MockComponent", input_types={"x": List[str], "y": str}, output_types={"z": str}
)
MockComponent = component_class("MockComponent", input_types={"x": List[str], "y": str}, output_types={"z": str})
pipe = Pipeline()
pipe.add_component("first_mock", MockComponent())
pipe.add_component("second_mock", MockComponent())
@ -828,10 +797,7 @@ class TestPipeline:
pipe = Pipeline()
res = pipe._prepare_component_input_data({"input_name": 1})
assert res == {}
assert (
"Inputs ['input_name'] were not matched to any component inputs, "
"please check your run parameters." in caplog.text
)
assert "Inputs ['input_name'] were not matched to any component inputs, " "please check your run parameters." in caplog.text
def test_connect(self):
comp1 = component_class("Comp1", output_types={"value": int})()
@ -1007,12 +973,8 @@ class TestPipeline:
def test__run_component(self, spying_tracer, caplog):
caplog.set_level(logging.INFO)
sentence_builder = component_class(
"SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"}
)()
document_builder = component_class(
"DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")}
)()
sentence_builder = component_class("SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"})()
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")})()
document_cleaner = component_class(
"DocumentCleaner",
input_types={"doc": Document},
@ -1058,20 +1020,12 @@ class TestPipeline:
pipe.add_component("sentence_builder", sentence_builder)
assert not pipe._component_has_enough_inputs_to_run("sentence_builder", {})
assert not pipe._component_has_enough_inputs_to_run(
"sentence_builder", {"sentence_builder": {"wrong_input_name": "blah blah"}}
)
assert pipe._component_has_enough_inputs_to_run(
"sentence_builder", {"sentence_builder": {"words": ["blah blah"]}}
)
assert not pipe._component_has_enough_inputs_to_run("sentence_builder", {"sentence_builder": {"wrong_input_name": "blah blah"}})
assert pipe._component_has_enough_inputs_to_run("sentence_builder", {"sentence_builder": {"words": ["blah blah"]}})
def test__dequeue_components_that_received_no_input(self):
sentence_builder = component_class(
"SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"}
)()
document_builder = component_class(
"DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")}
)()
sentence_builder = component_class("SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"})()
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")})()
pipe = Pipeline()
pipe.add_component("sentence_builder", sentence_builder)
@ -1085,12 +1039,8 @@ class TestPipeline:
assert waiting_for_input == []
def test__distribute_output(self):
document_builder = component_class(
"DocumentBuilder", input_types={"text": str}, output_types={"doc": Document, "another_doc": Document}
)()
document_cleaner = component_class(
"DocumentCleaner", input_types={"doc": Document}, output_types={"cleaned_doc": Document}
)()
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document, "another_doc": Document})()
document_cleaner = component_class("DocumentCleaner", input_types={"doc": Document}, output_types={"cleaned_doc": Document})()
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
pipe = Pipeline()
@ -1119,3 +1069,93 @@ class TestPipeline:
}
assert to_run == [("document_cleaner", document_cleaner)]
assert waiting_for_input == [("document_joiner", document_joiner)]
def test__enqueue_next_runnable_component(self):
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
pipe = Pipeline()
inputs_by_component = {"document_builder": {"text": "some text"}}
to_run = []
waiting_for_input = [("document_builder", document_builder)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_builder", document_builder)]
assert waiting_for_input == []
def test__enqueue_next_runnable_component_without_component_inputs(self):
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
pipe = Pipeline()
inputs_by_component = {}
to_run = []
waiting_for_input = [("document_builder", document_builder)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_builder", document_builder)]
assert waiting_for_input == []
def test__enqueue_next_runnable_component_with_component_with_only_variadic_non_greedy_input(self):
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
pipe = Pipeline()
inputs_by_component = {}
to_run = []
waiting_for_input = [("document_joiner", document_joiner)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_joiner", document_joiner)]
assert waiting_for_input == []
def test__enqueue_next_runnable_component_with_component_with_only_default_input(self):
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
pipe = Pipeline()
inputs_by_component = {}
to_run = []
waiting_for_input = [("prompt_builder", prompt_builder)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("prompt_builder", prompt_builder)]
assert waiting_for_input == []
def test__enqueue_next_runnable_component_with_component_with_variadic_non_greedy_and_default_input(self):
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
pipe = Pipeline()
inputs_by_component = {}
to_run = []
waiting_for_input = [("prompt_builder", prompt_builder), ("document_joiner", document_joiner)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_joiner", document_joiner)]
assert waiting_for_input == [("prompt_builder", prompt_builder)]
def test__enqueue_next_runnable_component_with_different_components_inputs(self):
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
pipe = Pipeline()
inputs_by_component = {"document_builder": {"text": "some text"}}
to_run = []
waiting_for_input = [("prompt_builder", prompt_builder), ("document_builder", document_builder), ("document_joiner", document_joiner)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_builder", document_builder)]
assert waiting_for_input == [("prompt_builder", prompt_builder), ("document_joiner", document_joiner)]
def test__enqueue_next_runnable_component_with_different_components_without_any_input(self):
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
pipe = Pipeline()
inputs_by_component = {}
to_run = []
waiting_for_input = [("prompt_builder", prompt_builder), ("document_builder", document_builder), ("document_joiner", document_joiner)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_builder", document_builder)]
assert waiting_for_input == [
("prompt_builder", prompt_builder),
("document_joiner", document_joiner),
]