diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 6c6113f69..9e2404da1 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -133,9 +133,7 @@ class PipelineBase: } @classmethod - def from_dict( - cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs - ) -> T: + def from_dict(cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs) -> T: """ Deserializes the pipeline from a dictionary. @@ -169,10 +167,7 @@ class PipelineBase: importlib.import_module(module) # ...then try again if component_data["type"] not in component.registry: - raise PipelineError( - f"Successfully imported module {module} but can't find it in the component registry." - "This is unexpected and most likely a bug." - ) + raise PipelineError(f"Successfully imported module {module} but can't find it in the component registry." "This is unexpected and most likely a bug.") except (ImportError, PipelineError) as e: raise PipelineError(f"Component '{component_data['type']}' not imported.") from e @@ -284,15 +279,10 @@ class PipelineBase: # Component instances must be components if not isinstance(instance, Component): - raise PipelineValidationError( - f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?" - ) + raise PipelineValidationError(f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?") if getattr(instance, "__haystack_added_to_pipeline__", None): - msg = ( - "Component has already been added in another Pipeline. " - "Components can't be shared between Pipelines. Create a new instance instead." - ) + msg = "Component has already been added in another Pipeline. Components can't be shared between Pipelines. Create a new instance instead." raise PipelineError(msg) setattr(instance, "__haystack_added_to_pipeline__", self) @@ -390,35 +380,21 @@ class PipelineBase: if sender_socket_name: sender_socket = from_sockets.get(sender_socket_name) if not sender_socket: - raise PipelineConnectError( - f"'{sender} does not exist. " - f"Output connections of {sender_component_name} are: " - + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()]) - ) + raise PipelineConnectError(f"'{sender} does not exist. " f"Output connections of {sender_component_name} are: " + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])) receiver_socket: Optional[InputSocket] = None if receiver_socket_name: receiver_socket = to_sockets.get(receiver_socket_name) if not receiver_socket: - raise PipelineConnectError( - f"'{receiver} does not exist. " - f"Input connections of {receiver_component_name} are: " - + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()]) - ) + raise PipelineConnectError(f"'{receiver} does not exist. " f"Input connections of {receiver_component_name} are: " + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])) # Look for a matching connection among the possible ones. # Note that if there is more than one possible connection but two sockets match by name, they're paired. sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values()) - receiver_socket_candidates: List[InputSocket] = ( - [receiver_socket] if receiver_socket else list(to_sockets.values()) - ) + receiver_socket_candidates: List[InputSocket] = [receiver_socket] if receiver_socket else list(to_sockets.values()) # Find all possible connections between these two components - possible_connections = [ - (sender_sock, receiver_sock) - for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates) - if _types_are_compatible(sender_sock.type, receiver_sock.type) - ] + possible_connections = [(sender_sock, receiver_sock) for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates) if _types_are_compatible(sender_sock.type, receiver_sock.type)] # We need this status for error messages, since we might need it in multiple places we calculate it here status = _connections_status( @@ -431,15 +407,9 @@ class PipelineBase: if not possible_connections: # There's no possible connection between these two components if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1: - msg = ( - f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with '{receiver_component_name}.{receiver_socket_candidates[0].name}': " - f"their declared input and output types do not match.\n{status}" - ) + msg = f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with '{receiver_component_name}.{receiver_socket_candidates[0].name}': " f"their declared input and output types do not match.\n{status}" else: - msg = ( - f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': " - f"no matching connections available.\n{status}" - ) + msg = f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': " f"no matching connections available.\n{status}" raise PipelineConnectError(msg) if len(possible_connections) == 1: @@ -449,9 +419,7 @@ class PipelineBase: if len(possible_connections) > 1: # There are multiple possible connection, let's try to match them by name - name_matches = [ - (out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name - ] + name_matches = [(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name] if len(name_matches) != 1: # There's are either no matches or more than one, we can't pick one reliably msg = ( @@ -494,10 +462,7 @@ class PipelineBase: if receiver_socket.senders and not receiver_socket.is_variadic: # Only variadic input sockets can receive from multiple senders - msg = ( - f"Cannot connect '{sender_component_name}.{sender_socket.name}' with '{receiver_component_name}.{receiver_socket.name}': " - f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n" - ) + msg = f"Cannot connect '{sender_component_name}.{sender_socket.name}' with '{receiver_component_name}.{receiver_socket.name}': " f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n" raise PipelineConnectError(msg) # Update the sockets with the new connection @@ -587,11 +552,7 @@ class PipelineBase: A dictionary where each key is a pipeline component name and each value is a dictionary of output sockets of that component. """ - outputs = { - comp: {socket.name: {"type": socket.type} for socket in data} - for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items() - if data - } + outputs = {comp: {socket.name: {"type": socket.type} for socket in data} for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items() if data} return outputs def show(self) -> None: @@ -679,9 +640,7 @@ class PipelineBase: if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs: raise ValueError(f"Missing input for component {component_name}: {socket_name}") if socket.senders and socket_name in component_inputs and not socket.is_variadic: - raise ValueError( - f"Input {socket_name} for component {component_name} is already sent by {socket.senders}." - ) + raise ValueError(f"Input {socket_name} for component {component_name} is already sent by {socket.senders}.") def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: """ @@ -785,9 +744,7 @@ class PipelineBase: return to_run @classmethod - def from_template( - cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None - ) -> "PipelineBase": + def from_template(cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None) -> "PipelineBase": """ Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options. @@ -932,10 +889,99 @@ class PipelineBase: # Returns the output without the keys that were distributed to other Components return {k: v for k, v in component_result.items() if k not in to_remove_from_component_result} + def _enqueue_next_runnable_component( + self, + inputs_by_component: Dict[str, Dict[str, Any]], + to_run: List[Tuple[str, Component]], + waiting_for_input: List[Tuple[str, Component]], + ): + """ + Finds the next Component that can be run and adds it to the queue of Components to run. -def _connections_status( - sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket] -): + :param inputs_by_component: The current state of the inputs divided by Component name + :param to_run: Queue of Components to run + :param waiting_for_input: Queue of Components waiting for input + """ + for name, comp in waiting_for_input: + if name not in inputs_by_component: + inputs_by_component[name] = {} + + # Small utility function to check if a Component has a Variadic input that is not greedy. + def is_lazy_variadic(c: Component) -> bool: + is_variadic = any( + socket.is_variadic + for socket in c.__haystack_input__._sockets_dict.values() # type: ignore + ) + if not is_variadic: + return False + return not getattr(c, "__haystack_is_greedy__", False) + + # Small utility function to check if a Component has all inputs with defaults. + def has_all_inputs_with_defaults(c: Component) -> bool: + return all( + not socket.is_mandatory + for socket in c.__haystack_input__._sockets_dict.values() # type: ignore + ) + + # Updates the inputs with the default values for the inputs that are missing + def add_missing_input_defaults(name: str, comp: Component, inputs_by_component: Dict[str, Dict[str, Any]]): + for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore + if input_socket.name not in inputs_by_component[name]: + inputs_by_component[name][input_socket.name] = input_socket.default_value + + all_lazy_variadic = True + all_with_default_inputs = True + + filtered_waiting_for_input = [] + + for name, comp in waiting_for_input: + if not is_lazy_variadic(comp): + # Components with variadic inputs that are not greedy must be removed only if there's nothing else to run at this stage. + # We need to wait as long as possible to run them, so we can collect as most inputs as we can. + all_lazy_variadic = False + + if not has_all_inputs_with_defaults(comp): + # Components that have defaults for all their inputs must be treated the same identical way as we treat + # lazy variadic components. If there are only components with defaults we can run them. + # If we don't do this the order of execution of the Pipeline's Components will be affected cause we + # enqueue the Components in `to_run` at the start using the order they are added in the Pipeline. + # If a Component A with defaults is added before a Component B that has no defaults, but in the Pipeline + # logic A must be executed after B it could run instead before if we don't do this check. + all_with_default_inputs = False + + if not is_lazy_variadic(comp) and not has_all_inputs_with_defaults(comp): + # Keep track of the Components that are not lazy variadic and don't have all inputs with defaults. + # We'll handle these later if necessary. + filtered_waiting_for_input.append((name, comp)) + + # If all Components are lazy variadic or all Components have all inputs with defaults we can get one to run + if all_lazy_variadic or all_with_default_inputs: + pair = waiting_for_input.pop(0) + to_run.append(pair) + # Add missing input defaults if needed, this is a no-op for Components with Variadic inputs + add_missing_input_defaults(name, comp, inputs_by_component) + return + + for name, comp in filtered_waiting_for_input: + # Find the first component that has all the inputs it needs to run + has_enough_inputs = True + for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore + if input_socket.name not in inputs_by_component[name]: + if input_socket.is_mandatory: + has_enough_inputs = False + break + + if input_socket.name not in inputs_by_component[name]: + inputs_by_component[name][input_socket.name] = input_socket.default_value + + if has_enough_inputs: + break + + waiting_for_input.remove((name, comp)) + to_run.append((name, comp)) + + +def _connections_status(sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]): """ Lists the status of the sockets, for error messages. """ @@ -950,9 +996,7 @@ def _connections_status( sender_status = f"sent by {','.join(receiver_socket.senders)}" else: sender_status = "available" - receiver_sockets_entries.append( - f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})" - ) + receiver_sockets_entries.append(f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})") receiver_sockets_list = "\n".join(receiver_sockets_entries) return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}" diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index fc0783a43..76db59f6d 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -86,10 +86,7 @@ class Pipeline(PipelineBase): inputs[socket.name] = [] if not isinstance(res, Mapping): - raise PipelineRuntimeError( - f"Component '{name}' didn't return a dictionary. " - "Components must always return dictionaries: check the the documentation." - ) + raise PipelineRuntimeError(f"Component '{name}' didn't return a dictionary. " "Components must always return dictionaries: check the the documentation.") span.set_tag("haystack.component.visits", self.graph.nodes[name]["visits"]) span.set_content_tag("haystack.component.output", res) @@ -274,18 +271,15 @@ class Pipeline(PipelineBase): # Check if we're stuck in a loop. # It's important to check whether previous waitings are None as it could be that no # Component has actually been run yet. - if ( - before_last_waiting_for_input is not None - and last_waiting_for_input is not None - and before_last_waiting_for_input == last_waiting_for_input - ): + if before_last_waiting_for_input is not None and last_waiting_for_input is not None and before_last_waiting_for_input == last_waiting_for_input: # Are we actually stuck or there's a lazy variadic or a component with has only default inputs waiting for input? # This is our last resort, if there's no lazy variadic or component with only default inputs waiting for input # we're stuck for real and we can't make any progress. for name, comp in waiting_for_input: is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore has_only_defaults = all( - not socket.is_mandatory for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore + not socket.is_mandatory + for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore ) if is_variadic and not comp.__haystack_is_greedy__ or has_only_defaults: # type: ignore[attr-defined] break @@ -315,69 +309,10 @@ class Pipeline(PipelineBase): continue - before_last_waiting_for_input = ( - last_waiting_for_input.copy() if last_waiting_for_input is not None else None - ) + before_last_waiting_for_input = last_waiting_for_input.copy() if last_waiting_for_input is not None else None last_waiting_for_input = {item[0] for item in waiting_for_input} - # Remove from waiting only if there is actually enough input to run - for name, comp in waiting_for_input: - if name not in last_inputs: - last_inputs[name] = {} - - # Lazy variadics must be removed only if there's nothing else to run at this stage - is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore - if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined] - there_are_only_lazy_variadics = True - for other_name, other_comp in waiting_for_input: - if name == other_name: - continue - there_are_only_lazy_variadics &= ( - any( - socket.is_variadic for socket in other_comp.__haystack_input__._sockets_dict.values() # type: ignore - ) - and not other_comp.__haystack_is_greedy__ # type: ignore[attr-defined] - ) - - if not there_are_only_lazy_variadics: - continue - - # Components that have defaults for all their inputs must be treated the same identical way as we treat - # lazy variadic components. If there are only components with defaults we can run them. - # If we don't do this the order of execution of the Pipeline's Components will be affected cause we - # enqueue the Components in `to_run` at the start using the order they are added in the Pipeline. - # If a Component A with defaults is added before a Component B that has no defaults, but in the Pipeline - # logic A must be executed after B it could run instead before if we don't do this check. - has_only_defaults = all( - not socket.is_mandatory for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore - ) - if has_only_defaults: - there_are_only_components_with_defaults = True - for other_name, other_comp in waiting_for_input: - if name == other_name: - continue - there_are_only_components_with_defaults &= all( - not s.is_mandatory for s in other_comp.__haystack_input__._sockets_dict.values() # type: ignore - ) - if not there_are_only_components_with_defaults: - continue - - # Find the first component that has all the inputs it needs to run - has_enough_inputs = True - for input_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore - if input_socket.is_mandatory and input_socket.name not in last_inputs[name]: - has_enough_inputs = False - break - if input_socket.is_mandatory: - continue - - if input_socket.name not in last_inputs[name]: - last_inputs[name][input_socket.name] = input_socket.default_value - if has_enough_inputs: - break - - waiting_for_input.remove((name, comp)) - to_run.append((name, comp)) + self._enqueue_next_runnable_component(last_inputs, to_run, waiting_for_input) if len(include_outputs_from) > 0: for name, output in extra_outputs.items(): diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 1e228f2d5..b3c98be0b 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -77,9 +77,7 @@ class TestPipeline: @patch("haystack.core.pipeline.base.is_in_jupyter") @patch("IPython.display.Image") @patch("IPython.display.display") - def test_show_in_notebook( - self, mock_ipython_display, mock_ipython_image, mock_is_in_jupyter, mock_to_mermaid_image - ): + def test_show_in_notebook(self, mock_ipython_display, mock_ipython_image, mock_is_in_jupyter, mock_to_mermaid_image): pipe = Pipeline() mock_to_mermaid_image.return_value = b"some_image_data" @@ -163,9 +161,7 @@ class TestPipeline: assert component_2.__haystack_added_to_pipeline__ is None assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=[])} - assert component_2.__haystack_output__._sockets_dict == { - "out": OutputSocket(name="out", type=int, receivers=[]) - } + assert component_2.__haystack_output__._sockets_dict == {"out": OutputSocket(name="out", type=int, receivers=[])} pipe2 = Pipeline() pipe2.add_component("component_4", Some()) @@ -175,12 +171,8 @@ class TestPipeline: pipe2.connect("component_4", "component_2") pipe2.connect("component_2", "component_5") assert component_2.__haystack_added_to_pipeline__ is pipe2 - assert component_2.__haystack_input__._sockets_dict == { - "in": InputSocket(name="in", type=int, senders=["component_4"]) - } - assert component_2.__haystack_output__._sockets_dict == { - "out": OutputSocket(name="out", type=int, receivers=["component_5"]) - } + assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=["component_4"])} + assert component_2.__haystack_output__._sockets_dict == {"out": OutputSocket(name="out", type=int, receivers=["component_5"])} # instance = pipe2.get_component("some") # assert instance == component @@ -210,16 +202,7 @@ class TestPipeline: pipe.connect("double", "add_default") expected_repr = ( - f"{object.__repr__(pipe)}\n" - "🧱 Metadata\n" - " - test: test\n" - "🚅 Components\n" - " - add_two: AddFixedValue\n" - " - add_default: AddFixedValue\n" - " - double: Double\n" - "🛤️ Connections\n" - " - add_two.result -> double.value (int)\n" - " - double.value -> add_default.value (int)\n" + f"{object.__repr__(pipe)}\n" "🧱 Metadata\n" " - test: test\n" "🚅 Components\n" " - add_two: AddFixedValue\n" " - add_default: AddFixedValue\n" " - double: Double\n" "🛤️ Connections\n" " - add_two.result -> double.value (int)\n" " - double.value -> add_default.value (int)\n" ) assert repr(pipe) == expected_repr @@ -379,9 +362,7 @@ class TestPipeline: components_seen_in_callback.append(name) - pipe = Pipeline.from_dict( - data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback) - ) + pipe = Pipeline.from_dict(data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback)) assert components_seen_in_callback == ["add_two", "add_default", "double", "greet"] add_two = pipe.graph.nodes["add_two"]["instance"] assert add_two.add == 2 @@ -403,9 +384,7 @@ class TestPipeline: init_params["message"] = "modified test" init_params["log_level"] = "DEBUG" - pipe = Pipeline.from_dict( - data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_modify) - ) + pipe = Pipeline.from_dict(data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_modify)) add_two = pipe.graph.nodes["add_two"]["instance"] assert add_two.add == 3 add_default = pipe.graph.nodes["add_default"]["instance"] @@ -558,9 +537,7 @@ class TestPipeline: p.connect("a.x", "c.x") p.connect("b.y", "c.y") assert p.inputs() == {} - assert p.inputs(include_components_with_connected_inputs=True) == { - "c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}} - } + assert p.inputs(include_components_with_connected_inputs=True) == {"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}} def test_describe_input_some_components_with_no_inputs(self): A = component_class("A", input_types={}, output={"x": 0}) @@ -742,16 +719,10 @@ class TestPipeline: assert pipe.graph.nodes[node]["visits"] == 0 def test__init_to_run(self): - ComponentWithVariadic = component_class( - "ComponentWithVariadic", input_types={"in": Variadic[int]}, output_types={"out": int} - ) + ComponentWithVariadic = component_class("ComponentWithVariadic", input_types={"in": Variadic[int]}, output_types={"out": int}) ComponentWithNoInputs = component_class("ComponentWithNoInputs", input_types={}, output_types={"out": int}) - ComponentWithSingleInput = component_class( - "ComponentWithSingleInput", input_types={"in": int}, output_types={"out": int} - ) - ComponentWithMultipleInputs = component_class( - "ComponentWithMultipleInputs", input_types={"in1": int, "in2": int}, output_types={"out": int} - ) + ComponentWithSingleInput = component_class("ComponentWithSingleInput", input_types={"in": int}, output_types={"out": int}) + ComponentWithMultipleInputs = component_class("ComponentWithMultipleInputs", input_types={"in1": int, "in2": int}, output_types={"out": int}) pipe = Pipeline() pipe.add_component("with_variadic", ComponentWithVariadic()) @@ -812,9 +783,7 @@ class TestPipeline: assert id(res["first_mock"]["x"]) != id(res["second_mock"]["x"]) def test__prepare_component_input_data_with_connected_inputs(self): - MockComponent = component_class( - "MockComponent", input_types={"x": List[str], "y": str}, output_types={"z": str} - ) + MockComponent = component_class("MockComponent", input_types={"x": List[str], "y": str}, output_types={"z": str}) pipe = Pipeline() pipe.add_component("first_mock", MockComponent()) pipe.add_component("second_mock", MockComponent()) @@ -828,10 +797,7 @@ class TestPipeline: pipe = Pipeline() res = pipe._prepare_component_input_data({"input_name": 1}) assert res == {} - assert ( - "Inputs ['input_name'] were not matched to any component inputs, " - "please check your run parameters." in caplog.text - ) + assert "Inputs ['input_name'] were not matched to any component inputs, " "please check your run parameters." in caplog.text def test_connect(self): comp1 = component_class("Comp1", output_types={"value": int})() @@ -1007,12 +973,8 @@ class TestPipeline: def test__run_component(self, spying_tracer, caplog): caplog.set_level(logging.INFO) - sentence_builder = component_class( - "SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"} - )() - document_builder = component_class( - "DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")} - )() + sentence_builder = component_class("SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"})() + document_builder = component_class("DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")})() document_cleaner = component_class( "DocumentCleaner", input_types={"doc": Document}, @@ -1058,20 +1020,12 @@ class TestPipeline: pipe.add_component("sentence_builder", sentence_builder) assert not pipe._component_has_enough_inputs_to_run("sentence_builder", {}) - assert not pipe._component_has_enough_inputs_to_run( - "sentence_builder", {"sentence_builder": {"wrong_input_name": "blah blah"}} - ) - assert pipe._component_has_enough_inputs_to_run( - "sentence_builder", {"sentence_builder": {"words": ["blah blah"]}} - ) + assert not pipe._component_has_enough_inputs_to_run("sentence_builder", {"sentence_builder": {"wrong_input_name": "blah blah"}}) + assert pipe._component_has_enough_inputs_to_run("sentence_builder", {"sentence_builder": {"words": ["blah blah"]}}) def test__dequeue_components_that_received_no_input(self): - sentence_builder = component_class( - "SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"} - )() - document_builder = component_class( - "DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")} - )() + sentence_builder = component_class("SentenceBuilder", input_types={"words": List[str]}, output={"text": "some words"})() + document_builder = component_class("DocumentBuilder", input_types={"text": str}, output={"doc": Document(content="some words")})() pipe = Pipeline() pipe.add_component("sentence_builder", sentence_builder) @@ -1085,12 +1039,8 @@ class TestPipeline: assert waiting_for_input == [] def test__distribute_output(self): - document_builder = component_class( - "DocumentBuilder", input_types={"text": str}, output_types={"doc": Document, "another_doc": Document} - )() - document_cleaner = component_class( - "DocumentCleaner", input_types={"doc": Document}, output_types={"cleaned_doc": Document} - )() + document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document, "another_doc": Document})() + document_cleaner = component_class("DocumentCleaner", input_types={"doc": Document}, output_types={"cleaned_doc": Document})() document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})() pipe = Pipeline() @@ -1119,3 +1069,93 @@ class TestPipeline: } assert to_run == [("document_cleaner", document_cleaner)] assert waiting_for_input == [("document_joiner", document_joiner)] + + def test__enqueue_next_runnable_component(self): + document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})() + pipe = Pipeline() + inputs_by_component = {"document_builder": {"text": "some text"}} + to_run = [] + waiting_for_input = [("document_builder", document_builder)] + pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input) + + assert to_run == [("document_builder", document_builder)] + assert waiting_for_input == [] + + def test__enqueue_next_runnable_component_without_component_inputs(self): + document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})() + pipe = Pipeline() + inputs_by_component = {} + to_run = [] + waiting_for_input = [("document_builder", document_builder)] + pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input) + + assert to_run == [("document_builder", document_builder)] + assert waiting_for_input == [] + + def test__enqueue_next_runnable_component_with_component_with_only_variadic_non_greedy_input(self): + document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})() + + pipe = Pipeline() + inputs_by_component = {} + to_run = [] + waiting_for_input = [("document_joiner", document_joiner)] + pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input) + + assert to_run == [("document_joiner", document_joiner)] + assert waiting_for_input == [] + + def test__enqueue_next_runnable_component_with_component_with_only_default_input(self): + prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}") + + pipe = Pipeline() + inputs_by_component = {} + to_run = [] + waiting_for_input = [("prompt_builder", prompt_builder)] + pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input) + + assert to_run == [("prompt_builder", prompt_builder)] + assert waiting_for_input == [] + + def test__enqueue_next_runnable_component_with_component_with_variadic_non_greedy_and_default_input(self): + document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})() + prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}") + + pipe = Pipeline() + inputs_by_component = {} + to_run = [] + waiting_for_input = [("prompt_builder", prompt_builder), ("document_joiner", document_joiner)] + pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input) + + assert to_run == [("document_joiner", document_joiner)] + assert waiting_for_input == [("prompt_builder", prompt_builder)] + + def test__enqueue_next_runnable_component_with_different_components_inputs(self): + document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})() + document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})() + prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}") + + pipe = Pipeline() + inputs_by_component = {"document_builder": {"text": "some text"}} + to_run = [] + waiting_for_input = [("prompt_builder", prompt_builder), ("document_builder", document_builder), ("document_joiner", document_joiner)] + pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input) + + assert to_run == [("document_builder", document_builder)] + assert waiting_for_input == [("prompt_builder", prompt_builder), ("document_joiner", document_joiner)] + + def test__enqueue_next_runnable_component_with_different_components_without_any_input(self): + document_builder = component_class("DocumentBuilder", input_types={"text": str}, output_types={"doc": Document})() + document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})() + prompt_builder = PromptBuilder(template="{{ questions | join('\n') }}") + + pipe = Pipeline() + inputs_by_component = {} + to_run = [] + waiting_for_input = [("prompt_builder", prompt_builder), ("document_builder", document_builder), ("document_joiner", document_joiner)] + pipe._enqueue_next_runnable_component(inputs_by_component, to_run, waiting_for_input) + + assert to_run == [("document_builder", document_builder)] + assert waiting_for_input == [ + ("prompt_builder", prompt_builder), + ("document_joiner", document_joiner), + ]