refactor: Isolate logic that finds next runnable component waiting for input (#7880)

* Fix formatting

* Isolate logic that finds next runnable component waiting for input

* Explain more lazy variadics

* Enhance logic following review suggestions

* Simplify code to use a single for

* Fix test
This commit is contained in:
Silvano Cerza 2024-06-18 16:43:19 +02:00 committed by GitHub
parent ff79da5f55
commit 15ee622b3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 225 additions and 206 deletions

View File

@ -133,9 +133,7 @@ class PipelineBase:
} }
@classmethod @classmethod
def from_dict( def from_dict(cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs) -> T:
cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs
) -> T:
""" """
Deserializes the pipeline from a dictionary. Deserializes the pipeline from a dictionary.
@ -169,10 +167,7 @@ class PipelineBase:
importlib.import_module(module) importlib.import_module(module)
# ...then try again # ...then try again
if component_data["type"] not in component.registry: if component_data["type"] not in component.registry:
raise PipelineError( raise PipelineError(f"Successfully imported module {module} but can't find it in the component registry." "This is unexpected and most likely a bug.")
f"Successfully imported module {module} but can't find it in the component registry."
"This is unexpected and most likely a bug."
)
except (ImportError, PipelineError) as e: except (ImportError, PipelineError) as e:
raise PipelineError(f"Component '{component_data['type']}' not imported.") from e raise PipelineError(f"Component '{component_data['type']}' not imported.") from e
@ -284,15 +279,10 @@ class PipelineBase:
# Component instances must be components # Component instances must be components
if not isinstance(instance, Component): if not isinstance(instance, Component):
raise PipelineValidationError( raise PipelineValidationError(f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?")
f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
)
if getattr(instance, "__haystack_added_to_pipeline__", None): if getattr(instance, "__haystack_added_to_pipeline__", None):
msg = ( msg = "Component has already been added in another Pipeline. Components can't be shared between Pipelines. Create a new instance instead."
"Component has already been added in another Pipeline. "
"Components can't be shared between Pipelines. Create a new instance instead."
)
raise PipelineError(msg) raise PipelineError(msg)
setattr(instance, "__haystack_added_to_pipeline__", self) setattr(instance, "__haystack_added_to_pipeline__", self)
@ -390,35 +380,21 @@ class PipelineBase:
if sender_socket_name: if sender_socket_name:
sender_socket = from_sockets.get(sender_socket_name) sender_socket = from_sockets.get(sender_socket_name)
if not sender_socket: if not sender_socket:
raise PipelineConnectError( raise PipelineConnectError(f"'{sender} does not exist. " f"Output connections of {sender_component_name} are: " + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()]))
f"'{sender} does not exist. "
f"Output connections of {sender_component_name} are: "
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
)
receiver_socket: Optional[InputSocket] = None receiver_socket: Optional[InputSocket] = None
if receiver_socket_name: if receiver_socket_name:
receiver_socket = to_sockets.get(receiver_socket_name) receiver_socket = to_sockets.get(receiver_socket_name)
if not receiver_socket: if not receiver_socket:
raise PipelineConnectError( raise PipelineConnectError(f"'{receiver} does not exist. " f"Input connections of {receiver_component_name} are: " + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()]))
f"'{receiver} does not exist. "
f"Input connections of {receiver_component_name} are: "
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
)
# Look for a matching connection among the possible ones. # Look for a matching connection among the possible ones.
# Note that if there is more than one possible connection but two sockets match by name, they're paired. # Note that if there is more than one possible connection but two sockets match by name, they're paired.
sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values()) sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values())
receiver_socket_candidates: List[InputSocket] = ( receiver_socket_candidates: List[InputSocket] = [receiver_socket] if receiver_socket else list(to_sockets.values())
[receiver_socket] if receiver_socket else list(to_sockets.values())
)
# Find all possible connections between these two components # Find all possible connections between these two components
possible_connections = [ possible_connections = [(sender_sock, receiver_sock) for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates) if _types_are_compatible(sender_sock.type, receiver_sock.type)]
(sender_sock, receiver_sock)
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates)
if _types_are_compatible(sender_sock.type, receiver_sock.type)
]
# We need this status for error messages, since we might need it in multiple places we calculate it here # We need this status for error messages, since we might need it in multiple places we calculate it here
status = _connections_status( status = _connections_status(
@ -431,15 +407,9 @@ class PipelineBase:
if not possible_connections: if not possible_connections:
# There's no possible connection between these two components # There's no possible connection between these two components
if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1: if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1:
msg = ( msg = f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with '{receiver_component_name}.{receiver_socket_candidates[0].name}': " f"their declared input and output types do not match.\n{status}"
f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with '{receiver_component_name}.{receiver_socket_candidates[0].name}': "
f"their declared input and output types do not match.\n{status}"
)
else: else:
msg = ( msg = f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': " f"no matching connections available.\n{status}"
f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': "
f"no matching connections available.\n{status}"
)
raise PipelineConnectError(msg) raise PipelineConnectError(msg)
if len(possible_connections) == 1: if len(possible_connections) == 1:
@ -449,9 +419,7 @@ class PipelineBase:
if len(possible_connections) > 1: if len(possible_connections) > 1:
# There are multiple possible connection, let's try to match them by name # There are multiple possible connection, let's try to match them by name
name_matches = [ name_matches = [(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name]
(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
]
if len(name_matches) != 1: if len(name_matches) != 1:
# There's are either no matches or more than one, we can't pick one reliably # There's are either no matches or more than one, we can't pick one reliably
msg = ( msg = (
@ -494,10 +462,7 @@ class PipelineBase:
if receiver_socket.senders and not receiver_socket.is_variadic: if receiver_socket.senders and not receiver_socket.is_variadic:
# Only variadic input sockets can receive from multiple senders # Only variadic input sockets can receive from multiple senders
msg = ( msg = f"Cannot connect '{sender_component_name}.{sender_socket.name}' with '{receiver_component_name}.{receiver_socket.name}': " f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
f"Cannot connect '{sender_component_name}.{sender_socket.name}' with '{receiver_component_name}.{receiver_socket.name}': "
f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
)
raise PipelineConnectError(msg) raise PipelineConnectError(msg)
# Update the sockets with the new connection # Update the sockets with the new connection
@ -587,11 +552,7 @@ class PipelineBase:
A dictionary where each key is a pipeline component name and each value is a dictionary of A dictionary where each key is a pipeline component name and each value is a dictionary of
output sockets of that component. output sockets of that component.
""" """
outputs = { outputs = {comp: {socket.name: {"type": socket.type} for socket in data} for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items() if data}
comp: {socket.name: {"type": socket.type} for socket in data}
for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items()
if data
}
return outputs return outputs
def show(self) -> None: def show(self) -> None:
@ -679,9 +640,7 @@ class PipelineBase:
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs: if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
raise ValueError(f"Missing input for component {component_name}: {socket_name}") raise ValueError(f"Missing input for component {component_name}: {socket_name}")
if socket.senders and socket_name in component_inputs and not socket.is_variadic: if socket.senders and socket_name in component_inputs and not socket.is_variadic:
raise ValueError( raise ValueError(f"Input {socket_name} for component {component_name} is already sent by {socket.senders}.")
f"Input {socket_name} for component {component_name} is already sent by {socket.senders}."
)
def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
""" """
@ -785,9 +744,7 @@ class PipelineBase:
return to_run return to_run
@classmethod @classmethod
def from_template( def from_template(cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None) -> "PipelineBase":
cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None
) -> "PipelineBase":
""" """
Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options. Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options.
@ -932,10 +889,99 @@ class PipelineBase:
# Returns the output without the keys that were distributed to other Components # Returns the output without the keys that were distributed to other Components
return {k: v for k, v in component_result.items() if k not in to_remove_from_component_result} return {k: v for k, v in component_result.items() if k not in to_remove_from_component_result}
def _enqueue_next_runnable_component(
self,
inputs_by_component: Dict[str, Dict[str, Any]],
to_run: List[Tuple[str, Component]],
waiting_for_input: List[Tuple[str, Component]],
):
"""
Finds the next Component that can be run and adds it to the queue of Components to run.
def _connections_status( :param inputs_by_component: The current state of the inputs divided by Component name
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket] :param to_run: Queue of Components to run
): :param waiting_for_input: Queue of Components waiting for input
"""
for name, comp in waiting_for_input:
if name not in inputs_by_component:
inputs_by_component[name] = {}
# Small utility function to check if a Component has a Variadic input that is not greedy.
def is_lazy_variadic(c: Component) -> bool:
is_variadic = any(
socket.is_variadic
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
)
if not is_variadic:
return False
return not getattr(c, "__haystack_is_greedy__", False)
# Small utility function to check if a Component has all inputs with defaults.
def has_all_inputs_with_defaults(c: Component) -> bool:
return all(
not socket.is_mandatory
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
)
# Updates the inputs with the default values for the inputs that are missing
def add_missing_input_defaults(name: str, comp: Component, inputs_by_component: Dict[str, Dict[str, Any]]):
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.name not in inputs_by_component[name]:
inputs_by_component[name][input_socket.name] = input_socket.default_value
all_lazy_variadic = True
all_with_default_inputs = True
filtered_waiting_for_input = []
for name, comp in waiting_for_input:
if not is_lazy_variadic(comp):
# Components with variadic inputs that are not greedy must be removed only if there's nothing else to run at this stage.
# We need to wait as long as possible to run them, so we can collect as most inputs as we can.
all_lazy_variadic = False
if not has_all_inputs_with_defaults(comp):
# Components that have defaults for all their inputs must be treated the same identical way as we treat
# lazy variadic components. If there are only components with defaults we can run them.
# If we don't do this the order of execution of the Pipeline's Components will be affected cause we
# enqueue the Components in `to_run` at the start using the order they are added in the Pipeline.
# If a Component A with defaults is added before a Component B that has no defaults, but in the Pipeline
# logic A must be executed after B it could run instead before if we don't do this check.
all_with_default_inputs = False
if not is_lazy_variadic(comp) and not has_all_inputs_with_defaults(comp):
# Keep track of the Components that are not lazy variadic and don't have all inputs with defaults.
# We'll handle these later if necessary.
filtered_waiting_for_input.append((name, comp))
# If all Components are lazy variadic or all Components have all inputs with defaults we can get one to run
if all_lazy_variadic or all_with_default_inputs:
pair = waiting_for_input.pop(0)
to_run.append(pair)
# Add missing input defaults if needed, this is a no-op for Components with Variadic inputs
add_missing_input_defaults(name, comp, inputs_by_component)
return
for name, comp in filtered_waiting_for_input:
# Find the first component that has all the inputs it needs to run
has_enough_inputs = True
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.name not in inputs_by_component[name]:
if input_socket.is_mandatory:
has_enough_inputs = False
break
if input_socket.name not in inputs_by_component[name]:
inputs_by_component[name][input_socket.name] = input_socket.default_value
if has_enough_inputs:
break
waiting_for_input.remove((name, comp))
to_run.append((name, comp))
def _connections_status(sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]):
""" """
Lists the status of the sockets, for error messages. Lists the status of the sockets, for error messages.
""" """
@ -950,9 +996,7 @@ def _connections_status(
sender_status = f"sent by {','.join(receiver_socket.senders)}" sender_status = f"sent by {','.join(receiver_socket.senders)}"
else: else:
sender_status = "available" sender_status = "available"
receiver_sockets_entries.append( receiver_sockets_entries.append(f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})")
f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
)
receiver_sockets_list = "\n".join(receiver_sockets_entries) receiver_sockets_list = "\n".join(receiver_sockets_entries)
return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}" return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"

View File

@ -86,10 +86,7 @@ class Pipeline(PipelineBase):
inputs[socket.name] = [] inputs[socket.name] = []
if not isinstance(res, Mapping): if not isinstance(res, Mapping):
raise PipelineRuntimeError( raise PipelineRuntimeError(f"Component '{name}' didn't return a dictionary. " "Components must always return dictionaries: check the the documentation.")
f"Component '{name}' didn't return a dictionary. "
"Components must always return dictionaries: check the the documentation."
)
span.set_tag("haystack.component.visits", self.graph.nodes[name]["visits"]) span.set_tag("haystack.component.visits", self.graph.nodes[name]["visits"])
span.set_content_tag("haystack.component.output", res) span.set_content_tag("haystack.component.output", res)
@ -274,18 +271,15 @@ class Pipeline(PipelineBase):
# Check if we're stuck in a loop. # Check if we're stuck in a loop.
# It's important to check whether previous waitings are None as it could be that no # It's important to check whether previous waitings are None as it could be that no
# Component has actually been run yet. # Component has actually been run yet.
if ( if before_last_waiting_for_input is not None and last_waiting_for_input is not None and before_last_waiting_for_input == last_waiting_for_input:
before_last_waiting_for_input is not None
and last_waiting_for_input is not None
and before_last_waiting_for_input == last_waiting_for_input
):
# Are we actually stuck or there's a lazy variadic or a component with has only default inputs waiting for input? # Are we actually stuck or there's a lazy variadic or a component with has only default inputs waiting for input?
# This is our last resort, if there's no lazy variadic or component with only default inputs waiting for input # This is our last resort, if there's no lazy variadic or component with only default inputs waiting for input
# we're stuck for real and we can't make any progress. # we're stuck for real and we can't make any progress.
for name, comp in waiting_for_input: for name, comp in waiting_for_input:
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
has_only_defaults = all( has_only_defaults = all(
not socket.is_mandatory for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore not socket.is_mandatory
for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore
) )
if is_variadic and not comp.__haystack_is_greedy__ or has_only_defaults: # type: ignore[attr-defined] if is_variadic and not comp.__haystack_is_greedy__ or has_only_defaults: # type: ignore[attr-defined]
break break
@ -315,69 +309,10 @@ class Pipeline(PipelineBase):
continue continue
before_last_waiting_for_input = ( before_last_waiting_for_input = last_waiting_for_input.copy() if last_waiting_for_input is not None else None
last_waiting_for_input.copy() if last_waiting_for_input is not None else None
)
last_waiting_for_input = {item[0] for item in waiting_for_input} last_waiting_for_input = {item[0] for item in waiting_for_input}
# Remove from waiting only if there is actually enough input to run self._enqueue_next_runnable_component(last_inputs, to_run, waiting_for_input)
for name, comp in waiting_for_input:
if name not in last_inputs:
last_inputs[name] = {}
# Lazy variadics must be removed only if there's nothing else to run at this stage
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
there_are_only_lazy_variadics = True
for other_name, other_comp in waiting_for_input:
if name == other_name:
continue
there_are_only_lazy_variadics &= (
any(
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
)
and not other_comp.__haystack_is_greedy__ # type: ignore[attr-defined]
)
if not there_are_only_lazy_variadics:
continue
# Components that have defaults for all their inputs must be treated the same identical way as we treat
# lazy variadic components. If there are only components with defaults we can run them.
# If we don't do this the order of execution of the Pipeline's Components will be affected cause we
# enqueue the Components in `to_run` at the start using the order they are added in the Pipeline.
# If a Component A with defaults is added before a Component B that has no defaults, but in the Pipeline
# logic A must be executed after B it could run instead before if we don't do this check.
has_only_defaults = all(
not socket.is_mandatory for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore
)
if has_only_defaults:
there_are_only_components_with_defaults = True
for other_name, other_comp in waiting_for_input:
if name == other_name:
continue
there_are_only_components_with_defaults &= all(
not s.is_mandatory for s in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
)
if not there_are_only_components_with_defaults:
continue
# Find the first component that has all the inputs it needs to run
has_enough_inputs = True
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
if input_socket.is_mandatory and input_socket.name not in last_inputs[name]:
has_enough_inputs = False
break
if input_socket.is_mandatory:
continue
if input_socket.name not in last_inputs[name]:
last_inputs[name][input_socket.name] = input_socket.default_value
if has_enough_inputs:
break
waiting_for_input.remove((name, comp))
to_run.append((name, comp))
if len(include_outputs_from) > 0: if len(include_outputs_from) > 0:
for name, output in extra_outputs.items(): for name, output in extra_outputs.items():

View File

@ -77,9 +77,7 @@ class TestPipeline:
@patch("haystack.core.pipeline.base.is_in_jupyter") @patch("haystack.core.pipeline.base.is_in_jupyter")
@patch("IPython.display.Image") @patch("IPython.display.Image")
@patch("IPython.display.display") @patch("IPython.display.display")
def test_show_in_notebook( def test_show_in_notebook(self, mock_ipython_display, mock_ipython_image, mock_is_in_jupyter, mock_to_mermaid_image):
self, mock_ipython_display, mock_ipython_image, mock_is_in_jupyter, mock_to_mermaid_image
):
pipe = Pipeline() pipe = Pipeline()
mock_to_mermaid_image.return_value = b"some_image_data" mock_to_mermaid_image.return_value = b"some_image_data"
@ -163,9 +161,7 @@ class TestPipeline:
assert component_2.__haystack_added_to_pipeline__ is None assert component_2.__haystack_added_to_pipeline__ is None
assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=[])} assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=[])}
assert component_2.__haystack_output__._sockets_dict == { assert component_2.__haystack_output__._sockets_dict == {"out": OutputSocket(name="out", type=int, receivers=[])}
"out": OutputSocket(name="out", type=int, receivers=[])
}
pipe2 = Pipeline() pipe2 = Pipeline()
pipe2.add_component("component_4", Some()) pipe2.add_component("component_4", Some())
@ -175,12 +171,8 @@ class TestPipeline:
pipe2.connect("component_4", "component_2") pipe2.connect("component_4", "component_2")
pipe2.connect("component_2", "component_5") pipe2.connect("component_2", "component_5")
assert component_2.__haystack_added_to_pipeline__ is pipe2 assert component_2.__haystack_added_to_pipeline__ is pipe2
assert component_2.__haystack_input__._sockets_dict == { assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=["component_4"])}
"in": InputSocket(name="in", type=int, senders=["component_4"]) assert component_2.__haystack_output__._sockets_dict == {"out": OutputSocket(name="out", type=int, receivers=["component_5"])}
}
assert component_2.__haystack_output__._sockets_dict == {
"out": OutputSocket(name="out", type=int, receivers=["component_5"])
}
# instance = pipe2.get_component("some") # instance = pipe2.get_component("some")
# assert instance == component # assert instance == component
@ -210,16 +202,7 @@ class TestPipeline:
pipe.connect("double", "add_default") pipe.connect("double", "add_default")
expected_repr = ( expected_repr = (
f"{object.__repr__(pipe)}\n" f"{object.__repr__(pipe)}\n" "🧱 Metadata\n" " - test: test\n" "🚅 Components\n" " - add_two: AddFixedValue\n" " - add_default: AddFixedValue\n" " - double: Double\n" "🛤️ Connections\n" " - add_two.result -> double.value (int)\n" " - double.value -> add_default.value (int)\n"
"🧱 Metadata\n"
" - test: test\n"
"🚅 Components\n"
" - add_two: AddFixedValue\n"
" - add_default: AddFixedValue\n"
" - double: Double\n"
"🛤️ Connections\n"
" - add_two.result -> double.value (int)\n"
" - double.value -> add_default.value (int)\n"
) )
assert repr(pipe) == expected_repr assert repr(pipe) == expected_repr
@ -379,9 +362,7 @@ class TestPipeline:
components_seen_in_callback.append(name) components_seen_in_callback.append(name)
pipe = Pipeline.from_dict( pipe = Pipeline.from_dict(data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback))
data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback)
)
assert components_seen_in_callback == ["add_two", "add_default", "double", "greet"] assert components_seen_in_callback == ["add_two", "add_default", "double", "greet"]
add_two = pipe.graph.nodes["add_two"]["instance"] add_two = pipe.graph.nodes["add_two"]["instance"]
assert add_two.add == 2 assert add_two.add == 2
@ -403,9 +384,7 @@ class TestPipeline:
init_params["message"] = "modified test" init_params["message"] = "modified test"
init_params["log_level"] = "DEBUG" init_params["log_level"] = "DEBUG"
pipe = Pipeline.from_dict( pipe = Pipeline.from_dict(data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_modify))
data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_modify)
)
add_two = pipe.graph.nodes["add_two"]["instance"] add_two = pipe.graph.nodes["add_two"]["instance"]
assert add_two.add == 3 assert add_two.add == 3
add_default = pipe.graph.nodes["add_default"]["instance"] add_default = pipe.graph.nodes["add_default"]["instance"]
@ -558,9 +537,7 @@ class TestPipeline:
p.connect("a.x", "c.x") p.connect("a.x", "c.x")
p.connect("b.y", "c.y") p.connect("b.y", "c.y")
assert p.inputs() == {} assert p.inputs() == {}
assert p.inputs(include_components_with_connected_inputs=True) == { assert p.inputs(include_components_with_connected_inputs=True) == {"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}}
"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}
}
def test_describe_input_some_components_with_no_inputs(self): def test_describe_input_some_components_with_no_inputs(self):
A = component_class("A", input_types={}, output={"x": 0}) A = component_class("A", input_types={}, output={"x": 0})
@ -742,16 +719,10 @@ class TestPipeline:
assert pipe.graph.nodes[node]["visits"] == 0 assert pipe.graph.nodes[node]["visits"] == 0
def test__init_to_run(self): def test__init_to_run(self):
ComponentWithVariadic = component_class( ComponentWithVariadic = component_class("ComponentWithVariadic", input_types={"in": Variadic[int]}, output_types={"out": int})
"ComponentWithVariadic", input_types={"in": Variadic[int]}, output_types={"out": int}
)
ComponentWithNoInputs = component_class("ComponentWithNoInputs", input_types={}, output_types={"out": int}) ComponentWithNoInputs = component_class("ComponentWithNoInputs", input_types={}, output_types={"out": int})
ComponentWithSingleInput = component_class( ComponentWithSingleInput = component_class("ComponentWithSingleInput", input_types={"in": int}, output_types={"out": int})
"ComponentWithSingleInput", input_types={"in": int}, output_types={"out": int} ComponentWithMultipleInputs = component_class("ComponentWithMultipleInputs", input_types={"in1": int, "in2": int}, output_types={"out": int})
)
ComponentWithMultipleInputs = component_class(
"ComponentWithMultipleInputs", input_types={"in1": int, "in2": int}, output_types={"out": int}
)
pipe = Pipeline() pipe = Pipeline()
pipe.add_component("with_variadic", ComponentWithVariadic()) pipe.add_component("with_variadic", ComponentWithVariadic())
@ -812,9 +783,7 @@ class TestPipeline:
assert id(res["first_mock"]["x"]) != id(res["second_mock"]["x"]) assert id(res["first_mock"]["x"]) != id(res["second_mock"]["x"])
def test__prepare_component_input_data_with_connected_inputs(self): def test__prepare_component_input_data_with_connected_inputs(self):
MockComponent = component_class( MockComponent = component_class("MockComponent", input_types={"x": List[str], "y": str}, output_types={"z": str})
"MockComponent", input_types={"x": List[str], "y": str}, output_types={"z": str}
)
pipe = Pipeline() pipe = Pipeline()
pipe.add_component("first_mock", MockComponent()) pipe.add_component("first_mock", MockComponent())
pipe.add_component("second_mock", MockComponent()) pipe.add_component("second_mock", MockComponent())
@ -828,10 +797,7 @@ class TestPipeline:
pipe = Pipeline() pipe = Pipeline()
res = pipe._prepare_component_input_data({"input_name": 1}) res = pipe._prepare_component_input_data({"input_name": 1})
assert res == {} assert res == {}
assert ( assert "Inputs ['input_name'] were not matched to any component inputs, " "please check your run parameters." in caplog.text
"Inputs ['input_name'] were not matched to any component inputs, "
"please check your run parameters." in caplog.text
)
def test_connect(self): def test_connect(self):
comp1 = component_class("Comp1", output_types={"value": int})() comp1 = component_class("Comp1", output_types={"value": int})()
@ -1007,12 +973,8 @@ class TestPipeline:
def test__run_component(self, spying_tracer, caplog): def test__run_component(self, spying_tracer, caplog):
caplog.set_level(logging.INFO) caplog.set_level(logging.INFO)
sentence_builder = component_class( sentence_builder = component_class("SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"})()
"SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"} document_builder = component_class("DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")})()
)()
document_builder = component_class(
"DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")}
)()
document_cleaner = component_class( document_cleaner = component_class(
"DocumentCleaner", "DocumentCleaner",
input_types={"doc": Document}, input_types={"doc": Document},
@ -1058,20 +1020,12 @@ class TestPipeline:
pipe.add_component("sentence_builder", sentence_builder) pipe.add_component("sentence_builder", sentence_builder)
assert not pipe._component_has_enough_inputs_to_run("sentence_builder", {}) assert not pipe._component_has_enough_inputs_to_run("sentence_builder", {})
assert not pipe._component_has_enough_inputs_to_run( assert not pipe._component_has_enough_inputs_to_run("sentence_builder", {"sentence_builder": {"wrong_input_name": "blah blah"}})
"sentence_builder", {"sentence_builder": {"wrong_input_name": "blah blah"}} assert pipe._component_has_enough_inputs_to_run("sentence_builder", {"sentence_builder": {"words": ["blah blah"]}})
)
assert pipe._component_has_enough_inputs_to_run(
"sentence_builder", {"sentence_builder": {"words": ["blah blah"]}}
)
def test__dequeue_components_that_received_no_input(self): def test__dequeue_components_that_received_no_input(self):
sentence_builder = component_class( sentence_builder = component_class("SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"})()
"SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"} document_builder = component_class("DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")})()
)()
document_builder = component_class(
"DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")}
)()
pipe = Pipeline() pipe = Pipeline()
pipe.add_component("sentence_builder", sentence_builder) pipe.add_component("sentence_builder", sentence_builder)
@ -1085,12 +1039,8 @@ class TestPipeline:
assert waiting_for_input == [] assert waiting_for_input == []
def test__distribute_output(self): def test__distribute_output(self):
document_builder = component_class( document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document, "another_doc": Document})()
"DocumentBuilder", input_types={"text": str}, output_types={"doc": Document, "another_doc": Document} document_cleaner = component_class("DocumentCleaner", input_types={"doc": Document}, output_types={"cleaned_doc": Document})()
)()
document_cleaner = component_class(
"DocumentCleaner", input_types={"doc": Document}, output_types={"cleaned_doc": Document}
)()
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})() document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
pipe = Pipeline() pipe = Pipeline()
@ -1119,3 +1069,93 @@ class TestPipeline:
} }
assert to_run == [("document_cleaner", document_cleaner)] assert to_run == [("document_cleaner", document_cleaner)]
assert waiting_for_input == [("document_joiner", document_joiner)] assert waiting_for_input == [("document_joiner", document_joiner)]
def test__enqueue_next_runnable_component(self):
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
pipe = Pipeline()
inputs_by_component = {"document_builder": {"text": "some text"}}
to_run = []
waiting_for_input = [("document_builder", document_builder)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_builder", document_builder)]
assert waiting_for_input == []
def test__enqueue_next_runnable_component_without_component_inputs(self):
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
pipe = Pipeline()
inputs_by_component = {}
to_run = []
waiting_for_input = [("document_builder", document_builder)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_builder", document_builder)]
assert waiting_for_input == []
def test__enqueue_next_runnable_component_with_component_with_only_variadic_non_greedy_input(self):
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
pipe = Pipeline()
inputs_by_component = {}
to_run = []
waiting_for_input = [("document_joiner", document_joiner)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_joiner", document_joiner)]
assert waiting_for_input == []
def test__enqueue_next_runnable_component_with_component_with_only_default_input(self):
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
pipe = Pipeline()
inputs_by_component = {}
to_run = []
waiting_for_input = [("prompt_builder", prompt_builder)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("prompt_builder", prompt_builder)]
assert waiting_for_input == []
def test__enqueue_next_runnable_component_with_component_with_variadic_non_greedy_and_default_input(self):
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
pipe = Pipeline()
inputs_by_component = {}
to_run = []
waiting_for_input = [("prompt_builder", prompt_builder), ("document_joiner", document_joiner)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_joiner", document_joiner)]
assert waiting_for_input == [("prompt_builder", prompt_builder)]
def test__enqueue_next_runnable_component_with_different_components_inputs(self):
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
pipe = Pipeline()
inputs_by_component = {"document_builder": {"text": "some text"}}
to_run = []
waiting_for_input = [("prompt_builder", prompt_builder), ("document_builder", document_builder), ("document_joiner", document_joiner)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_builder", document_builder)]
assert waiting_for_input == [("prompt_builder", prompt_builder), ("document_joiner", document_joiner)]
def test__enqueue_next_runnable_component_with_different_components_without_any_input(self):
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
pipe = Pipeline()
inputs_by_component = {}
to_run = []
waiting_for_input = [("prompt_builder", prompt_builder), ("document_builder", document_builder), ("document_joiner", document_joiner)]
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
assert to_run == [("document_builder", document_builder)]
assert waiting_for_input == [
("prompt_builder", prompt_builder),
("document_joiner", document_joiner),
]