mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-22 22:53:41 +00:00
refactor: Isolate logic that finds next runnable component waiting for input (#7880)
* Fix formatting * Isolate logic that finds next runnable component waiting for input * Explain more lazy variadics * Enhance logic following review suggestions * Simplify code to use a single for * Fix test
This commit is contained in:
parent
ff79da5f55
commit
15ee622b3c
@ -133,9 +133,7 @@ class PipelineBase:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs
|
||||
) -> T:
|
||||
def from_dict(cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs) -> T:
|
||||
"""
|
||||
Deserializes the pipeline from a dictionary.
|
||||
|
||||
@ -169,10 +167,7 @@ class PipelineBase:
|
||||
importlib.import_module(module)
|
||||
# ...then try again
|
||||
if component_data["type"] not in component.registry:
|
||||
raise PipelineError(
|
||||
f"Successfully imported module {module} but can't find it in the component registry."
|
||||
"This is unexpected and most likely a bug."
|
||||
)
|
||||
raise PipelineError(f"Successfully imported module {module} but can't find it in the component registry." "This is unexpected and most likely a bug.")
|
||||
except (ImportError, PipelineError) as e:
|
||||
raise PipelineError(f"Component '{component_data['type']}' not imported.") from e
|
||||
|
||||
@ -284,15 +279,10 @@ class PipelineBase:
|
||||
|
||||
# Component instances must be components
|
||||
if not isinstance(instance, Component):
|
||||
raise PipelineValidationError(
|
||||
f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
|
||||
)
|
||||
raise PipelineValidationError(f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?")
|
||||
|
||||
if getattr(instance, "__haystack_added_to_pipeline__", None):
|
||||
msg = (
|
||||
"Component has already been added in another Pipeline. "
|
||||
"Components can't be shared between Pipelines. Create a new instance instead."
|
||||
)
|
||||
msg = "Component has already been added in another Pipeline. Components can't be shared between Pipelines. Create a new instance instead."
|
||||
raise PipelineError(msg)
|
||||
|
||||
setattr(instance, "__haystack_added_to_pipeline__", self)
|
||||
@ -390,35 +380,21 @@ class PipelineBase:
|
||||
if sender_socket_name:
|
||||
sender_socket = from_sockets.get(sender_socket_name)
|
||||
if not sender_socket:
|
||||
raise PipelineConnectError(
|
||||
f"'{sender} does not exist. "
|
||||
f"Output connections of {sender_component_name} are: "
|
||||
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
|
||||
)
|
||||
raise PipelineConnectError(f"'{sender} does not exist. " f"Output connections of {sender_component_name} are: " + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()]))
|
||||
|
||||
receiver_socket: Optional[InputSocket] = None
|
||||
if receiver_socket_name:
|
||||
receiver_socket = to_sockets.get(receiver_socket_name)
|
||||
if not receiver_socket:
|
||||
raise PipelineConnectError(
|
||||
f"'{receiver} does not exist. "
|
||||
f"Input connections of {receiver_component_name} are: "
|
||||
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
|
||||
)
|
||||
raise PipelineConnectError(f"'{receiver} does not exist. " f"Input connections of {receiver_component_name} are: " + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()]))
|
||||
|
||||
# Look for a matching connection among the possible ones.
|
||||
# Note that if there is more than one possible connection but two sockets match by name, they're paired.
|
||||
sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values())
|
||||
receiver_socket_candidates: List[InputSocket] = (
|
||||
[receiver_socket] if receiver_socket else list(to_sockets.values())
|
||||
)
|
||||
receiver_socket_candidates: List[InputSocket] = [receiver_socket] if receiver_socket else list(to_sockets.values())
|
||||
|
||||
# Find all possible connections between these two components
|
||||
possible_connections = [
|
||||
(sender_sock, receiver_sock)
|
||||
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates)
|
||||
if _types_are_compatible(sender_sock.type, receiver_sock.type)
|
||||
]
|
||||
possible_connections = [(sender_sock, receiver_sock) for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates) if _types_are_compatible(sender_sock.type, receiver_sock.type)]
|
||||
|
||||
# We need this status for error messages, since we might need it in multiple places we calculate it here
|
||||
status = _connections_status(
|
||||
@ -431,15 +407,9 @@ class PipelineBase:
|
||||
if not possible_connections:
|
||||
# There's no possible connection between these two components
|
||||
if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1:
|
||||
msg = (
|
||||
f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with '{receiver_component_name}.{receiver_socket_candidates[0].name}': "
|
||||
f"their declared input and output types do not match.\n{status}"
|
||||
)
|
||||
msg = f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with '{receiver_component_name}.{receiver_socket_candidates[0].name}': " f"their declared input and output types do not match.\n{status}"
|
||||
else:
|
||||
msg = (
|
||||
f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': "
|
||||
f"no matching connections available.\n{status}"
|
||||
)
|
||||
msg = f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': " f"no matching connections available.\n{status}"
|
||||
raise PipelineConnectError(msg)
|
||||
|
||||
if len(possible_connections) == 1:
|
||||
@ -449,9 +419,7 @@ class PipelineBase:
|
||||
|
||||
if len(possible_connections) > 1:
|
||||
# There are multiple possible connection, let's try to match them by name
|
||||
name_matches = [
|
||||
(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
|
||||
]
|
||||
name_matches = [(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name]
|
||||
if len(name_matches) != 1:
|
||||
# There's are either no matches or more than one, we can't pick one reliably
|
||||
msg = (
|
||||
@ -494,10 +462,7 @@ class PipelineBase:
|
||||
|
||||
if receiver_socket.senders and not receiver_socket.is_variadic:
|
||||
# Only variadic input sockets can receive from multiple senders
|
||||
msg = (
|
||||
f"Cannot connect '{sender_component_name}.{sender_socket.name}' with '{receiver_component_name}.{receiver_socket.name}': "
|
||||
f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
|
||||
)
|
||||
msg = f"Cannot connect '{sender_component_name}.{sender_socket.name}' with '{receiver_component_name}.{receiver_socket.name}': " f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
|
||||
raise PipelineConnectError(msg)
|
||||
|
||||
# Update the sockets with the new connection
|
||||
@ -587,11 +552,7 @@ class PipelineBase:
|
||||
A dictionary where each key is a pipeline component name and each value is a dictionary of
|
||||
output sockets of that component.
|
||||
"""
|
||||
outputs = {
|
||||
comp: {socket.name: {"type": socket.type} for socket in data}
|
||||
for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items()
|
||||
if data
|
||||
}
|
||||
outputs = {comp: {socket.name: {"type": socket.type} for socket in data} for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items() if data}
|
||||
return outputs
|
||||
|
||||
def show(self) -> None:
|
||||
@ -679,9 +640,7 @@ class PipelineBase:
|
||||
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
|
||||
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
|
||||
if socket.senders and socket_name in component_inputs and not socket.is_variadic:
|
||||
raise ValueError(
|
||||
f"Input {socket_name} for component {component_name} is already sent by {socket.senders}."
|
||||
)
|
||||
raise ValueError(f"Input {socket_name} for component {component_name} is already sent by {socket.senders}.")
|
||||
|
||||
def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
@ -785,9 +744,7 @@ class PipelineBase:
|
||||
return to_run
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None
|
||||
) -> "PipelineBase":
|
||||
def from_template(cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None) -> "PipelineBase":
|
||||
"""
|
||||
Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options.
|
||||
|
||||
@ -932,10 +889,99 @@ class PipelineBase:
|
||||
# Returns the output without the keys that were distributed to other Components
|
||||
return {k: v for k, v in component_result.items() if k not in to_remove_from_component_result}
|
||||
|
||||
def _enqueue_next_runnable_component(
|
||||
self,
|
||||
inputs_by_component: Dict[str, Dict[str, Any]],
|
||||
to_run: List[Tuple[str, Component]],
|
||||
waiting_for_input: List[Tuple[str, Component]],
|
||||
):
|
||||
"""
|
||||
Finds the next Component that can be run and adds it to the queue of Components to run.
|
||||
|
||||
def _connections_status(
|
||||
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
|
||||
):
|
||||
:param inputs_by_component: The current state of the inputs divided by Component name
|
||||
:param to_run: Queue of Components to run
|
||||
:param waiting_for_input: Queue of Components waiting for input
|
||||
"""
|
||||
for name, comp in waiting_for_input:
|
||||
if name not in inputs_by_component:
|
||||
inputs_by_component[name] = {}
|
||||
|
||||
# Small utility function to check if a Component has a Variadic input that is not greedy.
|
||||
def is_lazy_variadic(c: Component) -> bool:
|
||||
is_variadic = any(
|
||||
socket.is_variadic
|
||||
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
|
||||
)
|
||||
if not is_variadic:
|
||||
return False
|
||||
return not getattr(c, "__haystack_is_greedy__", False)
|
||||
|
||||
# Small utility function to check if a Component has all inputs with defaults.
|
||||
def has_all_inputs_with_defaults(c: Component) -> bool:
|
||||
return all(
|
||||
not socket.is_mandatory
|
||||
for socket in c.__haystack_input__._sockets_dict.values() # type: ignore
|
||||
)
|
||||
|
||||
# Updates the inputs with the default values for the inputs that are missing
|
||||
def add_missing_input_defaults(name: str, comp: Component, inputs_by_component: Dict[str, Dict[str, Any]]):
|
||||
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
|
||||
if input_socket.name not in inputs_by_component[name]:
|
||||
inputs_by_component[name][input_socket.name] = input_socket.default_value
|
||||
|
||||
all_lazy_variadic = True
|
||||
all_with_default_inputs = True
|
||||
|
||||
filtered_waiting_for_input = []
|
||||
|
||||
for name, comp in waiting_for_input:
|
||||
if not is_lazy_variadic(comp):
|
||||
# Components with variadic inputs that are not greedy must be removed only if there's nothing else to run at this stage.
|
||||
# We need to wait as long as possible to run them, so we can collect as most inputs as we can.
|
||||
all_lazy_variadic = False
|
||||
|
||||
if not has_all_inputs_with_defaults(comp):
|
||||
# Components that have defaults for all their inputs must be treated the same identical way as we treat
|
||||
# lazy variadic components. If there are only components with defaults we can run them.
|
||||
# If we don't do this the order of execution of the Pipeline's Components will be affected cause we
|
||||
# enqueue the Components in `to_run` at the start using the order they are added in the Pipeline.
|
||||
# If a Component A with defaults is added before a Component B that has no defaults, but in the Pipeline
|
||||
# logic A must be executed after B it could run instead before if we don't do this check.
|
||||
all_with_default_inputs = False
|
||||
|
||||
if not is_lazy_variadic(comp) and not has_all_inputs_with_defaults(comp):
|
||||
# Keep track of the Components that are not lazy variadic and don't have all inputs with defaults.
|
||||
# We'll handle these later if necessary.
|
||||
filtered_waiting_for_input.append((name, comp))
|
||||
|
||||
# If all Components are lazy variadic or all Components have all inputs with defaults we can get one to run
|
||||
if all_lazy_variadic or all_with_default_inputs:
|
||||
pair = waiting_for_input.pop(0)
|
||||
to_run.append(pair)
|
||||
# Add missing input defaults if needed, this is a no-op for Components with Variadic inputs
|
||||
add_missing_input_defaults(name, comp, inputs_by_component)
|
||||
return
|
||||
|
||||
for name, comp in filtered_waiting_for_input:
|
||||
# Find the first component that has all the inputs it needs to run
|
||||
has_enough_inputs = True
|
||||
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
|
||||
if input_socket.name not in inputs_by_component[name]:
|
||||
if input_socket.is_mandatory:
|
||||
has_enough_inputs = False
|
||||
break
|
||||
|
||||
if input_socket.name not in inputs_by_component[name]:
|
||||
inputs_by_component[name][input_socket.name] = input_socket.default_value
|
||||
|
||||
if has_enough_inputs:
|
||||
break
|
||||
|
||||
waiting_for_input.remove((name, comp))
|
||||
to_run.append((name, comp))
|
||||
|
||||
|
||||
def _connections_status(sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]):
|
||||
"""
|
||||
Lists the status of the sockets, for error messages.
|
||||
"""
|
||||
@ -950,9 +996,7 @@ def _connections_status(
|
||||
sender_status = f"sent by {','.join(receiver_socket.senders)}"
|
||||
else:
|
||||
sender_status = "available"
|
||||
receiver_sockets_entries.append(
|
||||
f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
|
||||
)
|
||||
receiver_sockets_entries.append(f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})")
|
||||
receiver_sockets_list = "\n".join(receiver_sockets_entries)
|
||||
|
||||
return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"
|
||||
|
@ -86,10 +86,7 @@ class Pipeline(PipelineBase):
|
||||
inputs[socket.name] = []
|
||||
|
||||
if not isinstance(res, Mapping):
|
||||
raise PipelineRuntimeError(
|
||||
f"Component '{name}' didn't return a dictionary. "
|
||||
"Components must always return dictionaries: check the the documentation."
|
||||
)
|
||||
raise PipelineRuntimeError(f"Component '{name}' didn't return a dictionary. " "Components must always return dictionaries: check the the documentation.")
|
||||
span.set_tag("haystack.component.visits", self.graph.nodes[name]["visits"])
|
||||
span.set_content_tag("haystack.component.output", res)
|
||||
|
||||
@ -274,18 +271,15 @@ class Pipeline(PipelineBase):
|
||||
# Check if we're stuck in a loop.
|
||||
# It's important to check whether previous waitings are None as it could be that no
|
||||
# Component has actually been run yet.
|
||||
if (
|
||||
before_last_waiting_for_input is not None
|
||||
and last_waiting_for_input is not None
|
||||
and before_last_waiting_for_input == last_waiting_for_input
|
||||
):
|
||||
if before_last_waiting_for_input is not None and last_waiting_for_input is not None and before_last_waiting_for_input == last_waiting_for_input:
|
||||
# Are we actually stuck or there's a lazy variadic or a component with has only default inputs waiting for input?
|
||||
# This is our last resort, if there's no lazy variadic or component with only default inputs waiting for input
|
||||
# we're stuck for real and we can't make any progress.
|
||||
for name, comp in waiting_for_input:
|
||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
||||
has_only_defaults = all(
|
||||
not socket.is_mandatory for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore
|
||||
not socket.is_mandatory
|
||||
for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore
|
||||
)
|
||||
if is_variadic and not comp.__haystack_is_greedy__ or has_only_defaults: # type: ignore[attr-defined]
|
||||
break
|
||||
@ -315,69 +309,10 @@ class Pipeline(PipelineBase):
|
||||
|
||||
continue
|
||||
|
||||
before_last_waiting_for_input = (
|
||||
last_waiting_for_input.copy() if last_waiting_for_input is not None else None
|
||||
)
|
||||
before_last_waiting_for_input = last_waiting_for_input.copy() if last_waiting_for_input is not None else None
|
||||
last_waiting_for_input = {item[0] for item in waiting_for_input}
|
||||
|
||||
# Remove from waiting only if there is actually enough input to run
|
||||
for name, comp in waiting_for_input:
|
||||
if name not in last_inputs:
|
||||
last_inputs[name] = {}
|
||||
|
||||
# Lazy variadics must be removed only if there's nothing else to run at this stage
|
||||
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
|
||||
if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
|
||||
there_are_only_lazy_variadics = True
|
||||
for other_name, other_comp in waiting_for_input:
|
||||
if name == other_name:
|
||||
continue
|
||||
there_are_only_lazy_variadics &= (
|
||||
any(
|
||||
socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
|
||||
)
|
||||
and not other_comp.__haystack_is_greedy__ # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if not there_are_only_lazy_variadics:
|
||||
continue
|
||||
|
||||
# Components that have defaults for all their inputs must be treated the same identical way as we treat
|
||||
# lazy variadic components. If there are only components with defaults we can run them.
|
||||
# If we don't do this the order of execution of the Pipeline's Components will be affected cause we
|
||||
# enqueue the Components in `to_run` at the start using the order they are added in the Pipeline.
|
||||
# If a Component A with defaults is added before a Component B that has no defaults, but in the Pipeline
|
||||
# logic A must be executed after B it could run instead before if we don't do this check.
|
||||
has_only_defaults = all(
|
||||
not socket.is_mandatory for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore
|
||||
)
|
||||
if has_only_defaults:
|
||||
there_are_only_components_with_defaults = True
|
||||
for other_name, other_comp in waiting_for_input:
|
||||
if name == other_name:
|
||||
continue
|
||||
there_are_only_components_with_defaults &= all(
|
||||
not s.is_mandatory for s in other_comp.__haystack_input__._sockets_dict.values() # type: ignore
|
||||
)
|
||||
if not there_are_only_components_with_defaults:
|
||||
continue
|
||||
|
||||
# Find the first component that has all the inputs it needs to run
|
||||
has_enough_inputs = True
|
||||
for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
|
||||
if input_socket.is_mandatory and input_socket.name not in last_inputs[name]:
|
||||
has_enough_inputs = False
|
||||
break
|
||||
if input_socket.is_mandatory:
|
||||
continue
|
||||
|
||||
if input_socket.name not in last_inputs[name]:
|
||||
last_inputs[name][input_socket.name] = input_socket.default_value
|
||||
if has_enough_inputs:
|
||||
break
|
||||
|
||||
waiting_for_input.remove((name, comp))
|
||||
to_run.append((name, comp))
|
||||
self._enqueue_next_runnable_component(last_inputs, to_run, waiting_for_input)
|
||||
|
||||
if len(include_outputs_from) > 0:
|
||||
for name, output in extra_outputs.items():
|
||||
|
@ -77,9 +77,7 @@ class TestPipeline:
|
||||
@patch("haystack.core.pipeline.base.is_in_jupyter")
|
||||
@patch("IPython.display.Image")
|
||||
@patch("IPython.display.display")
|
||||
def test_show_in_notebook(
|
||||
self, mock_ipython_display, mock_ipython_image, mock_is_in_jupyter, mock_to_mermaid_image
|
||||
):
|
||||
def test_show_in_notebook(self, mock_ipython_display, mock_ipython_image, mock_is_in_jupyter, mock_to_mermaid_image):
|
||||
pipe = Pipeline()
|
||||
|
||||
mock_to_mermaid_image.return_value = b"some_image_data"
|
||||
@ -163,9 +161,7 @@ class TestPipeline:
|
||||
|
||||
assert component_2.__haystack_added_to_pipeline__ is None
|
||||
assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=[])}
|
||||
assert component_2.__haystack_output__._sockets_dict == {
|
||||
"out": OutputSocket(name="out", type=int, receivers=[])
|
||||
}
|
||||
assert component_2.__haystack_output__._sockets_dict == {"out": OutputSocket(name="out", type=int, receivers=[])}
|
||||
|
||||
pipe2 = Pipeline()
|
||||
pipe2.add_component("component_4", Some())
|
||||
@ -175,12 +171,8 @@ class TestPipeline:
|
||||
pipe2.connect("component_4", "component_2")
|
||||
pipe2.connect("component_2", "component_5")
|
||||
assert component_2.__haystack_added_to_pipeline__ is pipe2
|
||||
assert component_2.__haystack_input__._sockets_dict == {
|
||||
"in": InputSocket(name="in", type=int, senders=["component_4"])
|
||||
}
|
||||
assert component_2.__haystack_output__._sockets_dict == {
|
||||
"out": OutputSocket(name="out", type=int, receivers=["component_5"])
|
||||
}
|
||||
assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=["component_4"])}
|
||||
assert component_2.__haystack_output__._sockets_dict == {"out": OutputSocket(name="out", type=int, receivers=["component_5"])}
|
||||
|
||||
# instance = pipe2.get_component("some")
|
||||
# assert instance == component
|
||||
@ -210,16 +202,7 @@ class TestPipeline:
|
||||
pipe.connect("double", "add_default")
|
||||
|
||||
expected_repr = (
|
||||
f"{object.__repr__(pipe)}\n"
|
||||
"🧱 Metadata\n"
|
||||
" - test: test\n"
|
||||
"🚅 Components\n"
|
||||
" - add_two: AddFixedValue\n"
|
||||
" - add_default: AddFixedValue\n"
|
||||
" - double: Double\n"
|
||||
"🛤️ Connections\n"
|
||||
" - add_two.result -> double.value (int)\n"
|
||||
" - double.value -> add_default.value (int)\n"
|
||||
f"{object.__repr__(pipe)}\n" "🧱 Metadata\n" " - test: test\n" "🚅 Components\n" " - add_two: AddFixedValue\n" " - add_default: AddFixedValue\n" " - double: Double\n" "🛤️ Connections\n" " - add_two.result -> double.value (int)\n" " - double.value -> add_default.value (int)\n"
|
||||
)
|
||||
|
||||
assert repr(pipe) == expected_repr
|
||||
@ -379,9 +362,7 @@ class TestPipeline:
|
||||
|
||||
components_seen_in_callback.append(name)
|
||||
|
||||
pipe = Pipeline.from_dict(
|
||||
data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback)
|
||||
)
|
||||
pipe = Pipeline.from_dict(data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback))
|
||||
assert components_seen_in_callback == ["add_two", "add_default", "double", "greet"]
|
||||
add_two = pipe.graph.nodes["add_two"]["instance"]
|
||||
assert add_two.add == 2
|
||||
@ -403,9 +384,7 @@ class TestPipeline:
|
||||
init_params["message"] = "modified test"
|
||||
init_params["log_level"] = "DEBUG"
|
||||
|
||||
pipe = Pipeline.from_dict(
|
||||
data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_modify)
|
||||
)
|
||||
pipe = Pipeline.from_dict(data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_modify))
|
||||
add_two = pipe.graph.nodes["add_two"]["instance"]
|
||||
assert add_two.add == 3
|
||||
add_default = pipe.graph.nodes["add_default"]["instance"]
|
||||
@ -558,9 +537,7 @@ class TestPipeline:
|
||||
p.connect("a.x", "c.x")
|
||||
p.connect("b.y", "c.y")
|
||||
assert p.inputs() == {}
|
||||
assert p.inputs(include_components_with_connected_inputs=True) == {
|
||||
"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}
|
||||
}
|
||||
assert p.inputs(include_components_with_connected_inputs=True) == {"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}}
|
||||
|
||||
def test_describe_input_some_components_with_no_inputs(self):
|
||||
A = component_class("A", input_types={}, output={"x": 0})
|
||||
@ -742,16 +719,10 @@ class TestPipeline:
|
||||
assert pipe.graph.nodes[node]["visits"] == 0
|
||||
|
||||
def test__init_to_run(self):
|
||||
ComponentWithVariadic = component_class(
|
||||
"ComponentWithVariadic", input_types={"in": Variadic[int]}, output_types={"out": int}
|
||||
)
|
||||
ComponentWithVariadic = component_class("ComponentWithVariadic", input_types={"in": Variadic[int]}, output_types={"out": int})
|
||||
ComponentWithNoInputs = component_class("ComponentWithNoInputs", input_types={}, output_types={"out": int})
|
||||
ComponentWithSingleInput = component_class(
|
||||
"ComponentWithSingleInput", input_types={"in": int}, output_types={"out": int}
|
||||
)
|
||||
ComponentWithMultipleInputs = component_class(
|
||||
"ComponentWithMultipleInputs", input_types={"in1": int, "in2": int}, output_types={"out": int}
|
||||
)
|
||||
ComponentWithSingleInput = component_class("ComponentWithSingleInput", input_types={"in": int}, output_types={"out": int})
|
||||
ComponentWithMultipleInputs = component_class("ComponentWithMultipleInputs", input_types={"in1": int, "in2": int}, output_types={"out": int})
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("with_variadic", ComponentWithVariadic())
|
||||
@ -812,9 +783,7 @@ class TestPipeline:
|
||||
assert id(res["first_mock"]["x"]) != id(res["second_mock"]["x"])
|
||||
|
||||
def test__prepare_component_input_data_with_connected_inputs(self):
|
||||
MockComponent = component_class(
|
||||
"MockComponent", input_types={"x": List[str], "y": str}, output_types={"z": str}
|
||||
)
|
||||
MockComponent = component_class("MockComponent", input_types={"x": List[str], "y": str}, output_types={"z": str})
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("first_mock", MockComponent())
|
||||
pipe.add_component("second_mock", MockComponent())
|
||||
@ -828,10 +797,7 @@ class TestPipeline:
|
||||
pipe = Pipeline()
|
||||
res = pipe._prepare_component_input_data({"input_name": 1})
|
||||
assert res == {}
|
||||
assert (
|
||||
"Inputs ['input_name'] were not matched to any component inputs, "
|
||||
"please check your run parameters." in caplog.text
|
||||
)
|
||||
assert "Inputs ['input_name'] were not matched to any component inputs, " "please check your run parameters." in caplog.text
|
||||
|
||||
def test_connect(self):
|
||||
comp1 = component_class("Comp1", output_types={"value": int})()
|
||||
@ -1007,12 +973,8 @@ class TestPipeline:
|
||||
|
||||
def test__run_component(self, spying_tracer, caplog):
|
||||
caplog.set_level(logging.INFO)
|
||||
sentence_builder = component_class(
|
||||
"SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"}
|
||||
)()
|
||||
document_builder = component_class(
|
||||
"DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")}
|
||||
)()
|
||||
sentence_builder = component_class("SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"})()
|
||||
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")})()
|
||||
document_cleaner = component_class(
|
||||
"DocumentCleaner",
|
||||
input_types={"doc": Document},
|
||||
@ -1058,20 +1020,12 @@ class TestPipeline:
|
||||
pipe.add_component("sentence_builder", sentence_builder)
|
||||
|
||||
assert not pipe._component_has_enough_inputs_to_run("sentence_builder", {})
|
||||
assert not pipe._component_has_enough_inputs_to_run(
|
||||
"sentence_builder", {"sentence_builder": {"wrong_input_name": "blah blah"}}
|
||||
)
|
||||
assert pipe._component_has_enough_inputs_to_run(
|
||||
"sentence_builder", {"sentence_builder": {"words": ["blah blah"]}}
|
||||
)
|
||||
assert not pipe._component_has_enough_inputs_to_run("sentence_builder", {"sentence_builder": {"wrong_input_name": "blah blah"}})
|
||||
assert pipe._component_has_enough_inputs_to_run("sentence_builder", {"sentence_builder": {"words": ["blah blah"]}})
|
||||
|
||||
def test__dequeue_components_that_received_no_input(self):
|
||||
sentence_builder = component_class(
|
||||
"SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"}
|
||||
)()
|
||||
document_builder = component_class(
|
||||
"DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")}
|
||||
)()
|
||||
sentence_builder = component_class("SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"})()
|
||||
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")})()
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("sentence_builder", sentence_builder)
|
||||
@ -1085,12 +1039,8 @@ class TestPipeline:
|
||||
assert waiting_for_input == []
|
||||
|
||||
def test__distribute_output(self):
|
||||
document_builder = component_class(
|
||||
"DocumentBuilder", input_types={"text": str}, output_types={"doc": Document, "another_doc": Document}
|
||||
)()
|
||||
document_cleaner = component_class(
|
||||
"DocumentCleaner", input_types={"doc": Document}, output_types={"cleaned_doc": Document}
|
||||
)()
|
||||
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document, "another_doc": Document})()
|
||||
document_cleaner = component_class("DocumentCleaner", input_types={"doc": Document}, output_types={"cleaned_doc": Document})()
|
||||
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
|
||||
|
||||
pipe = Pipeline()
|
||||
@ -1119,3 +1069,93 @@ class TestPipeline:
|
||||
}
|
||||
assert to_run == [("document_cleaner", document_cleaner)]
|
||||
assert waiting_for_input == [("document_joiner", document_joiner)]
|
||||
|
||||
def test__enqueue_next_runnable_component(self):
|
||||
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
|
||||
pipe = Pipeline()
|
||||
inputs_by_component = {"document_builder": {"text": "some text"}}
|
||||
to_run = []
|
||||
waiting_for_input = [("document_builder", document_builder)]
|
||||
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
|
||||
|
||||
assert to_run == [("document_builder", document_builder)]
|
||||
assert waiting_for_input == []
|
||||
|
||||
def test__enqueue_next_runnable_component_without_component_inputs(self):
|
||||
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
|
||||
pipe = Pipeline()
|
||||
inputs_by_component = {}
|
||||
to_run = []
|
||||
waiting_for_input = [("document_builder", document_builder)]
|
||||
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
|
||||
|
||||
assert to_run == [("document_builder", document_builder)]
|
||||
assert waiting_for_input == []
|
||||
|
||||
def test__enqueue_next_runnable_component_with_component_with_only_variadic_non_greedy_input(self):
|
||||
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
|
||||
|
||||
pipe = Pipeline()
|
||||
inputs_by_component = {}
|
||||
to_run = []
|
||||
waiting_for_input = [("document_joiner", document_joiner)]
|
||||
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
|
||||
|
||||
assert to_run == [("document_joiner", document_joiner)]
|
||||
assert waiting_for_input == []
|
||||
|
||||
def test__enqueue_next_runnable_component_with_component_with_only_default_input(self):
|
||||
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
|
||||
|
||||
pipe = Pipeline()
|
||||
inputs_by_component = {}
|
||||
to_run = []
|
||||
waiting_for_input = [("prompt_builder", prompt_builder)]
|
||||
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
|
||||
|
||||
assert to_run == [("prompt_builder", prompt_builder)]
|
||||
assert waiting_for_input == []
|
||||
|
||||
def test__enqueue_next_runnable_component_with_component_with_variadic_non_greedy_and_default_input(self):
|
||||
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
|
||||
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
|
||||
|
||||
pipe = Pipeline()
|
||||
inputs_by_component = {}
|
||||
to_run = []
|
||||
waiting_for_input = [("prompt_builder", prompt_builder), ("document_joiner", document_joiner)]
|
||||
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
|
||||
|
||||
assert to_run == [("document_joiner", document_joiner)]
|
||||
assert waiting_for_input == [("prompt_builder", prompt_builder)]
|
||||
|
||||
def test__enqueue_next_runnable_component_with_different_components_inputs(self):
|
||||
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
|
||||
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
|
||||
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
|
||||
|
||||
pipe = Pipeline()
|
||||
inputs_by_component = {"document_builder": {"text": "some text"}}
|
||||
to_run = []
|
||||
waiting_for_input = [("prompt_builder", prompt_builder), ("document_builder", document_builder), ("document_joiner", document_joiner)]
|
||||
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
|
||||
|
||||
assert to_run == [("document_builder", document_builder)]
|
||||
assert waiting_for_input == [("prompt_builder", prompt_builder), ("document_joiner", document_joiner)]
|
||||
|
||||
def test__enqueue_next_runnable_component_with_different_components_without_any_input(self):
|
||||
document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})()
|
||||
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
|
||||
prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}")
|
||||
|
||||
pipe = Pipeline()
|
||||
inputs_by_component = {}
|
||||
to_run = []
|
||||
waiting_for_input = [("prompt_builder", prompt_builder), ("document_builder", document_builder), ("document_joiner", document_joiner)]
|
||||
pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input)
|
||||
|
||||
assert to_run == [("document_builder", document_builder)]
|
||||
assert waiting_for_input == [
|
||||
("prompt_builder", prompt_builder),
|
||||
("document_joiner", document_joiner),
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user