mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-13 15:57:24 +00:00
feat: Rework Pipeline.run() to better handle cycles (#8431)
* draft * Enhance * Almost works * Simplify some parts and handle intermediate outputs * Handle connections with default * Handle cycles with multiple connections from two components * Update distributed outputs at the correct time * Remove Component inputs after it runs * Add agent pipeline test case * Fix infite loop test * Handle some corner cases with loops checking and inputs deletion * Fix tests * Add new behavioral test * Remove unused code in behavioural test * Fix behavioural test * Fix max run check * Simplify outputs distribution * Simplify subgraph run check * Remove unused _init_run_queue function * Remove commented code * Add some missing type hints * Simplify cycles breaking * Fix _distribute_output test * Fix _find_components_that_will_receive_no_input test * Fix validation test * Fix tracer losing Component inputs * Fix some linting issues * Remove ignore pylint rule * Rename method that break cycles and make it raise * Add docstring to _run_subgraph * Update Pipeline.run() docstring * Update comment to clarify cycles execution * Remove SelfLoop sample Component * Add behavioural test for unsupported cycles * Rename behavioural test to be more specific * Add new behavioural test * Add release notes * Remove commented out code and random pass * Use more efficient function to find cycles * Simplify _break_supported_cycles_in_graph by using defaultdict * Stop breaking edges as soon as we make the graph acyclic * Fix docstring and add some more comments * Fix _distribute_output docstring * Fix _find_receivers_from docstring * More detailed release notes * Minimize calls to networkx.is_directed_acyclic_graph * Add some more info on edges keys * Adjust components_in_cycles comment * Add new Pipeline behavioural test * Enhance _find_components_that_will_receive_no_input to cover more cases * Explain why run_queue is reset after running a subgraph cycle * Rename _init_inputs_state to _normalize_input_data * Better explain the subgraph output distribution * Remove for else * Fix some comments and docstrings * Fix linting * Add missing return type * Fix typo * Rename _normalize_input_data to _normalize_varidiac_input_data and add more documentation * Remove unused import --------- Co-authored-by: Sebastian Husch Lee <sjrl423@gmail.com>
This commit is contained in:
parent
d430833f8f
commit
8205724395
@ -5,7 +5,7 @@
|
||||
import importlib
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List, Optional, Set, TextIO, Tuple, Type, TypeVar, Union
|
||||
@ -19,6 +19,7 @@ from haystack.core.errors import (
|
||||
PipelineConnectError,
|
||||
PipelineDrawingError,
|
||||
PipelineError,
|
||||
PipelineRuntimeError,
|
||||
PipelineUnmarshalError,
|
||||
PipelineValidationError,
|
||||
)
|
||||
@ -765,7 +766,10 @@ class PipelineBase:
|
||||
|
||||
return data
|
||||
|
||||
def _init_inputs_state(self, data: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
def _normalize_varidiac_input_data(self, data: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Variadic inputs expect their value to be a list, this utility method creates that list from the user's input.
|
||||
"""
|
||||
for component_name, component_inputs in data.items():
|
||||
if component_name not in self.graph.nodes:
|
||||
# This is not a component name, it must be the name of one or more input sockets.
|
||||
@ -773,8 +777,6 @@ class PipelineBase:
|
||||
continue
|
||||
instance = self.graph.nodes[component_name]["instance"]
|
||||
for component_input, input_value in component_inputs.items():
|
||||
# Handle mutable input data
|
||||
data[component_name][component_input] = copy(input_value)
|
||||
if instance.__haystack_input__._sockets_dict[component_input].is_variadic:
|
||||
# Components that have variadic inputs need to receive lists as input.
|
||||
# We don't want to force the user to always pass lists, so we convert single values to lists here.
|
||||
@ -784,41 +786,6 @@ class PipelineBase:
|
||||
|
||||
return {**data}
|
||||
|
||||
def _init_run_queue(self, pipeline_inputs: Dict[str, Any]) -> List[Tuple[str, Component]]:
|
||||
run_queue: List[Tuple[str, Component]] = []
|
||||
|
||||
# HACK: Quick workaround for the issue of execution order not being
|
||||
# well-defined (NB - https://github.com/deepset-ai/haystack/issues/7985).
|
||||
# We should fix the original execution logic instead.
|
||||
if networkx.is_directed_acyclic_graph(self.graph):
|
||||
# If the Pipeline is linear we can easily determine the order of execution with
|
||||
# a topological sort.
|
||||
# So use that to get the run order.
|
||||
for node in networkx.topological_sort(self.graph):
|
||||
run_queue.append((node, self.graph.nodes[node]["instance"]))
|
||||
return run_queue
|
||||
|
||||
for node_name in self.graph.nodes:
|
||||
component = self.graph.nodes[node_name]["instance"]
|
||||
|
||||
if len(component.__haystack_input__._sockets_dict) == 0:
|
||||
# Component has no input, can run right away
|
||||
run_queue.append((node_name, component))
|
||||
continue
|
||||
|
||||
if node_name in pipeline_inputs:
|
||||
# This component is in the input data, if it has enough inputs it can run right away
|
||||
run_queue.append((node_name, component))
|
||||
continue
|
||||
|
||||
for socket in component.__haystack_input__._sockets_dict.values():
|
||||
if not socket.senders or socket.is_variadic:
|
||||
# Component has at least one input not connected or is variadic, can run right away.
|
||||
run_queue.append((node_name, component))
|
||||
break
|
||||
|
||||
return run_queue
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None
|
||||
@ -851,9 +818,27 @@ class PipelineBase:
|
||||
for node in self.graph.nodes:
|
||||
self.graph.nodes[node]["visits"] = 0
|
||||
|
||||
def _distribute_output(
|
||||
def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]:
|
||||
"""
|
||||
Utility function to find all Components that receive input form `component_name`.
|
||||
|
||||
:param component_name:
|
||||
Name of the sender Component
|
||||
|
||||
:returns:
|
||||
List of tuples containing name of the receiver Component and sender OutputSocket
|
||||
and receiver InputSocket instances
|
||||
"""
|
||||
res = []
|
||||
for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True):
|
||||
sender_socket: OutputSocket = connection["from_socket"]
|
||||
receiver_socket: InputSocket = connection["to_socket"]
|
||||
res.append((receiver_name, sender_socket, receiver_socket))
|
||||
return res
|
||||
|
||||
def _distribute_output( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
component_name: str,
|
||||
receiver_components: List[Tuple[str, OutputSocket, InputSocket]],
|
||||
component_result: Dict[str, Any],
|
||||
components_inputs: Dict[str, Dict[str, Any]],
|
||||
run_queue: List[Tuple[str, Component]],
|
||||
@ -865,23 +850,27 @@ class PipelineBase:
|
||||
This also updates the queues that keep track of which Components are ready to run and which are waiting for
|
||||
input.
|
||||
|
||||
:param component_name: Name of the Component that created the output
|
||||
:param component_result: The output of the Component
|
||||
:paramt components_inputs: The current state of the inputs divided by Component name
|
||||
:param run_queue: Queue of Components to run
|
||||
:param waiting_queue: Queue of Components waiting for input
|
||||
:param receiver_components:
|
||||
List of tuples containing name of receiver Components and relative sender OutputSocket
|
||||
and receiver InputSocket instances
|
||||
:param component_result:
|
||||
The output of the Component
|
||||
:param components_inputs:
|
||||
The current state of the inputs divided by Component name
|
||||
:param run_queue:
|
||||
Queue of Components to run
|
||||
:param waiting_queue:
|
||||
Queue of Components waiting for input
|
||||
|
||||
:returns: The updated output of the Component without the keys that were distributed to other Components
|
||||
:returns:
|
||||
The updated output of the Component without the keys that were distributed to other Components
|
||||
"""
|
||||
# We keep track of which keys to remove from component_result at the end of the loop.
|
||||
# This is done after the output has been distributed to the next components, so that
|
||||
# we're sure all components that need this output have received it.
|
||||
to_remove_from_component_result = set()
|
||||
|
||||
for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True):
|
||||
sender_socket: OutputSocket = connection["from_socket"]
|
||||
receiver_socket: InputSocket = connection["to_socket"]
|
||||
|
||||
for receiver_name, sender_socket, receiver_socket in receiver_components:
|
||||
if sender_socket.name not in component_result:
|
||||
# This output wasn't created by the sender, nothing we can do.
|
||||
#
|
||||
@ -929,7 +918,7 @@ class PipelineBase:
|
||||
run_queue.remove(pair)
|
||||
if pair in waiting_queue:
|
||||
waiting_queue.remove(pair)
|
||||
run_queue.append(pair)
|
||||
run_queue.insert(0, pair)
|
||||
else:
|
||||
# If the receiver Component has a variadic input that is not greedy
|
||||
# we put it in the waiting queue.
|
||||
@ -1048,16 +1037,33 @@ class PipelineBase:
|
||||
"""
|
||||
|
||||
# Simplifies the check if a Component is Variadic and received some input from other Components.
|
||||
def is_variadic_with_existing_inputs(comp: Component) -> bool:
|
||||
for receiver_socket in comp.__haystack_input__._sockets_dict.values(): # type: ignore
|
||||
if component_name not in receiver_socket.senders:
|
||||
def has_variadic_socket_with_existing_inputs(
|
||||
component: Component, component_name: str, sender_name: str, components_inputs: Dict[str, Dict[str, Any]]
|
||||
) -> bool:
|
||||
for socket in component.__haystack_input__._sockets_dict.values(): # type: ignore
|
||||
if sender_name not in socket.senders:
|
||||
continue
|
||||
if (
|
||||
receiver_socket.is_variadic
|
||||
and len(components_inputs.get(receiver, {}).get(receiver_socket.name, [])) > 0
|
||||
):
|
||||
# This Component already received some input to its Variadic socket from other Components.
|
||||
# It should be able to run even if it doesn't receive any input from component_name.
|
||||
if socket.is_variadic and len(components_inputs.get(component_name, {}).get(socket.name, [])) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
# Makes it easier to verify if all connections between two Components are optional
|
||||
def all_connections_are_optional(sender_name: str, receiver: Component) -> bool:
|
||||
for socket in receiver.__haystack_input__._sockets_dict.values(): # type: ignore
|
||||
if sender_name not in socket.senders:
|
||||
continue
|
||||
if socket.is_mandatory:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Eases checking if other connections that are not between sender_name and receiver_name
|
||||
# already received inputs
|
||||
def other_connections_received_input(sender_name: str, receiver_name: str) -> bool:
|
||||
receiver: Component = self.graph.nodes[receiver_name]["instance"]
|
||||
for receiver_socket in receiver.__haystack_input__._sockets_dict.values(): # type: ignore
|
||||
if sender_name in receiver_socket.senders:
|
||||
continue
|
||||
if components_inputs.get(receiver_name, {}).get(receiver_socket.name) is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -1069,7 +1075,21 @@ class PipelineBase:
|
||||
for receiver in socket.receivers:
|
||||
receiver_instance: Component = self.graph.nodes[receiver]["instance"]
|
||||
|
||||
if is_variadic_with_existing_inputs(receiver_instance):
|
||||
if has_variadic_socket_with_existing_inputs(
|
||||
receiver_instance, receiver, component_name, components_inputs
|
||||
):
|
||||
# Components with Variadic input that already received some input
|
||||
# can still run, even if branch is skipped.
|
||||
# If we remove them they won't run.
|
||||
continue
|
||||
|
||||
if all_connections_are_optional(component_name, receiver_instance) and other_connections_received_input(
|
||||
component_name, receiver
|
||||
):
|
||||
# If all the connections between component_name and receiver are optional
|
||||
# and receiver received other inputs already it still has enough inputs to run.
|
||||
# Even if it didn't receive input from component_name, so we can't remove it or its
|
||||
# descendants.
|
||||
continue
|
||||
|
||||
components.add((receiver, receiver_instance))
|
||||
@ -1078,7 +1098,18 @@ class PipelineBase:
|
||||
# This is fine even if the Pipeline will merge back into a single Component
|
||||
# at a certain point. The merging Component will be put back into the run
|
||||
# queue at a later stage.
|
||||
components |= {(d, self.graph.nodes[d]["instance"]) for d in networkx.descendants(self.graph, receiver)}
|
||||
for descendant_name in networkx.descendants(self.graph, receiver):
|
||||
descendant = self.graph.nodes[descendant_name]["instance"]
|
||||
|
||||
# Components with Variadic input that already received some input
|
||||
# can still run, even if branch is skipped.
|
||||
# If we remove them they won't run.
|
||||
if has_variadic_socket_with_existing_inputs(
|
||||
descendant, descendant_name, receiver, components_inputs
|
||||
):
|
||||
continue
|
||||
|
||||
components.add((descendant_name, descendant))
|
||||
|
||||
return components
|
||||
|
||||
@ -1127,6 +1158,90 @@ class PipelineBase:
|
||||
current_inputs = inputs[name].keys()
|
||||
return expected_inputs == current_inputs
|
||||
|
||||
def _break_supported_cycles_in_graph(self) -> Tuple[networkx.MultiDiGraph, Dict[str, List[List[str]]]]:
|
||||
"""
|
||||
Utility function to remove supported cycles in the Pipeline's graph.
|
||||
|
||||
Given that the Pipeline execution would wait to run a Component until it has received
|
||||
all its mandatory inputs, it doesn't make sense for us to try and break cycles by
|
||||
removing a connection to a mandatory input. The Pipeline would just get stuck at a later time.
|
||||
|
||||
So we can only break connections in cycles that have a Variadic or GreedyVariadic type or a default value.
|
||||
|
||||
This will raise a PipelineRuntimeError if we there are cycles that can't be broken.
|
||||
That is bound to happen when at least one of the inputs in a cycle is mandatory.
|
||||
|
||||
If the Pipeline's graph doesn't have any cycle it will just return that graph and an empty dictionary.
|
||||
|
||||
:returns:
|
||||
A tuple containing:
|
||||
* A copy of the Pipeline's graph without cycles
|
||||
* A dictionary of Component's names and a list of all the cycles they were part of.
|
||||
The cycles are a list of Component's names that create that cycle.
|
||||
"""
|
||||
if networkx.is_directed_acyclic_graph(self.graph):
|
||||
return self.graph, {}
|
||||
|
||||
temp_graph: networkx.MultiDiGraph = self.graph.copy()
|
||||
# A list of all the cycles that are found in the graph, each inner list contains
|
||||
# the Component names that create that cycle.
|
||||
cycles: List[List[str]] = list(networkx.simple_cycles(self.graph))
|
||||
# Maps a Component name to a list of its output socket names that have been broken
|
||||
edges_removed: Dict[str, List[str]] = defaultdict(list)
|
||||
# This keeps track of all the cycles that a component is part of.
|
||||
# Maps a Component name to a list of cycles, each inner list contains
|
||||
# the Component names that create that cycle (the key will also be
|
||||
# an element in each list). The last Component in each list is implicitly
|
||||
# connected to the first.
|
||||
components_in_cycles: Dict[str, List[List[str]]] = defaultdict(list)
|
||||
|
||||
# Used to minimize the number of time we check whether the graph has any more
|
||||
# cycles left to break or not.
|
||||
graph_has_cycles = True
|
||||
|
||||
# Iterate all the cycles to find the least amount of connections that we can remove
|
||||
# to make the Pipeline graph acyclic.
|
||||
# As soon as the graph is acyclic we stop breaking connections and return.
|
||||
for cycle in cycles:
|
||||
for comp in cycle:
|
||||
components_in_cycles[comp].append(cycle)
|
||||
|
||||
# Iterate this cycle, we zip the cycle with itself so that at the last iteration
|
||||
# sender_comp will be the last element of cycle and receiver_comp will be the first.
|
||||
# So if cycle is [1, 2, 3, 4] we would call zip([1, 2, 3, 4], [2, 3, 4, 1]).
|
||||
for sender_comp, receiver_comp in zip(cycle, cycle[1:] + cycle[:1]):
|
||||
# We get the key and iterate those as we want to edit the graph data while
|
||||
# iterating the edges and that would raise.
|
||||
# Even though the connection key set in Pipeline.connect() uses only the
|
||||
# sockets name we don't have clashes since it's only used to differentiate
|
||||
# multiple edges between two nodes.
|
||||
edge_keys = list(temp_graph.get_edge_data(sender_comp, receiver_comp).keys())
|
||||
for edge_key in edge_keys:
|
||||
edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)[edge_key]
|
||||
receiver_socket = edge_data["to_socket"]
|
||||
if not receiver_socket.is_variadic and receiver_socket.is_mandatory:
|
||||
continue
|
||||
|
||||
# We found a breakable edge
|
||||
sender_socket = edge_data["from_socket"]
|
||||
edges_removed[sender_comp].append(sender_socket.name)
|
||||
temp_graph.remove_edge(sender_comp, receiver_comp, edge_key)
|
||||
|
||||
graph_has_cycles = not networkx.is_directed_acyclic_graph(temp_graph)
|
||||
if not graph_has_cycles:
|
||||
# We removed all the cycles, we can stop
|
||||
break
|
||||
|
||||
if not graph_has_cycles:
|
||||
# We removed all the cycles, nice
|
||||
break
|
||||
|
||||
if graph_has_cycles:
|
||||
msg = "Pipeline contains a cycle that we can't execute"
|
||||
raise PipelineRuntimeError(msg)
|
||||
|
||||
return temp_graph, components_in_cycles
|
||||
|
||||
|
||||
def _connections_status(
|
||||
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
|
||||
|
||||
@ -6,6 +6,8 @@ from copy import deepcopy
|
||||
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple
|
||||
from warnings import warn
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from haystack import logging, tracing
|
||||
from haystack.core.component import Component
|
||||
from haystack.core.errors import PipelineMaxComponentRuns, PipelineRuntimeError
|
||||
@ -62,7 +64,9 @@ class Pipeline(PipelineBase):
|
||||
},
|
||||
},
|
||||
) as span:
|
||||
span.set_content_tag("haystack.component.input", inputs)
|
||||
# We deepcopy the inputs otherwise we might lose that information
|
||||
# when we delete them in case they're sent to other Components
|
||||
span.set_content_tag("haystack.component.input", deepcopy(inputs))
|
||||
logger.info("Running component {component_name}", component_name=name)
|
||||
res: Dict[str, Any] = instance.run(**inputs)
|
||||
self.graph.nodes[name]["visits"] += 1
|
||||
@ -84,11 +88,225 @@ class Pipeline(PipelineBase):
|
||||
|
||||
return res
|
||||
|
||||
def run( # noqa: PLR0915
|
||||
def _run_subgraph( # noqa: PLR0915
|
||||
self,
|
||||
cycle: List[str],
|
||||
component_name: str,
|
||||
components_inputs: Dict[str, Dict[str, Any]],
|
||||
include_outputs_from: Optional[Set[str]] = None,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
Runs a `cycle` in the Pipeline starting from `component_name`.
|
||||
|
||||
This will return once there are no inputs for the Components in `cycle`.
|
||||
|
||||
This is an internal method meant to be used in `Pipeline.run()` only.
|
||||
|
||||
:param cycle:
|
||||
List of Components that are part of the cycle being run
|
||||
:param component_name:
|
||||
Name of the Component that will start execution of the cycle
|
||||
:param components_inputs:
|
||||
Components inputs, this might include inputs for Components that are not part
|
||||
of the cycle but part of the wider Pipeline's graph
|
||||
:param include_outputs_from:
|
||||
Set of component names whose individual outputs are to be
|
||||
included in the cycle's output. In case a Component is executed multiple times
|
||||
only the last-produced output is included.
|
||||
:returns:
|
||||
Outputs of all the Components that are not connected to other Components in `cycle`.
|
||||
If `include_outputs_from` is set those Components' outputs will be included.
|
||||
:raises PipelineMaxComponentRuns:
|
||||
If a Component reaches the maximum number of times it can be run in this Pipeline
|
||||
"""
|
||||
waiting_queue: List[Tuple[str, Component]] = []
|
||||
run_queue: List[Tuple[str, Component]] = []
|
||||
|
||||
# Create the run queue starting with the component that needs to run first
|
||||
start_index = cycle.index(component_name)
|
||||
for node in cycle[start_index:]:
|
||||
run_queue.append((node, self.graph.nodes[node]["instance"]))
|
||||
|
||||
include_outputs_from = set() if include_outputs_from is None else include_outputs_from
|
||||
|
||||
before_last_waiting_queue: Optional[Set[str]] = None
|
||||
last_waiting_queue: Optional[Set[str]] = None
|
||||
|
||||
subgraph_outputs = {}
|
||||
# These are outputs that are sent to other Components but the user explicitly
|
||||
# asked to include them in the final output.
|
||||
extra_outputs = {}
|
||||
|
||||
# This variable is used to keep track if we still need to run the cycle or not.
|
||||
# When a Component doesn't send outputs to another Component
|
||||
# that's inside the subgraph, we stop running this subgraph.
|
||||
cycle_received_inputs = False
|
||||
|
||||
while not cycle_received_inputs:
|
||||
# Here we run the Components
|
||||
name, comp = run_queue.pop(0)
|
||||
if _is_lazy_variadic(comp) and not all(_is_lazy_variadic(comp) for _, comp in run_queue):
|
||||
# We run Components with lazy variadic inputs only if there only Components with
|
||||
# lazy variadic inputs left to run
|
||||
_enqueue_waiting_component((name, comp), waiting_queue)
|
||||
continue
|
||||
|
||||
# As soon as a Component returns only output that is not part of the cycle, we can stop
|
||||
if self._component_has_enough_inputs_to_run(name, components_inputs):
|
||||
if self.graph.nodes[name]["visits"] > self._max_runs_per_component:
|
||||
msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'"
|
||||
raise PipelineMaxComponentRuns(msg)
|
||||
|
||||
res: Dict[str, Any] = self._run_component(name, components_inputs[name])
|
||||
|
||||
# Delete the inputs that were consumed by the Component and are not received from
|
||||
# the user or from Components that are part of this cycle
|
||||
sockets = list(components_inputs[name].keys())
|
||||
for socket_name in sockets:
|
||||
senders = comp.__haystack_input__._sockets_dict[socket_name].senders # type: ignore
|
||||
if not senders:
|
||||
# We keep inputs that came from the user
|
||||
continue
|
||||
all_senders_in_cycle = all(sender in cycle for sender in senders)
|
||||
if all_senders_in_cycle:
|
||||
# All senders are in the cycle, we can remove the input.
|
||||
# We'll receive it later at a certain point.
|
||||
del components_inputs[name][socket_name]
|
||||
|
||||
if name in include_outputs_from:
|
||||
# Deepcopy the outputs to prevent downstream nodes from modifying them
|
||||
# We don't care about loops - Always store the last output.
|
||||
extra_outputs[name] = deepcopy(res)
|
||||
|
||||
# Reset the waiting for input previous states, we managed to run a component
|
||||
before_last_waiting_queue = None
|
||||
last_waiting_queue = None
|
||||
|
||||
# Check if a component doesn't send any output to components that are part of the cycle
|
||||
final_output_reached = False
|
||||
for output_socket in res.keys():
|
||||
for receiver in comp.__haystack_output__._sockets_dict[output_socket].receivers: # type: ignore
|
||||
if receiver in cycle:
|
||||
final_output_reached = True
|
||||
break
|
||||
if final_output_reached:
|
||||
break
|
||||
|
||||
if not final_output_reached:
|
||||
# We stop only if the Component we just ran doesn't send any output to sockets that
|
||||
# are part of the cycle
|
||||
cycle_received_inputs = True
|
||||
|
||||
# We manage to run this component that was in the waiting list, we can remove it.
|
||||
# This happens when a component was put in the waiting list but we reached it from another edge.
|
||||
_dequeue_waiting_component((name, comp), waiting_queue)
|
||||
for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs):
|
||||
_dequeue_component(pair, run_queue, waiting_queue)
|
||||
|
||||
receivers = [item for item in self._find_receivers_from(name) if item[0] in cycle]
|
||||
|
||||
res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue)
|
||||
|
||||
# We treat a cycle as a completely independent graph, so we keep track of output
|
||||
# that is not sent inside the cycle.
|
||||
# This output is going to get distributed to the wider graph after we finish running
|
||||
# a cycle.
|
||||
# All values that are left at this point go outside the cycle.
|
||||
if len(res) > 0:
|
||||
subgraph_outputs[name] = res
|
||||
else:
|
||||
# This component doesn't have enough inputs so we can't run it yet
|
||||
_enqueue_waiting_component((name, comp), waiting_queue)
|
||||
|
||||
if len(run_queue) == 0 and len(waiting_queue) > 0:
|
||||
# Check if we're stuck in a loop.
|
||||
# It's important to check whether previous waitings are None as it could be that no
|
||||
# Component has actually been run yet.
|
||||
if (
|
||||
before_last_waiting_queue is not None
|
||||
and last_waiting_queue is not None
|
||||
and before_last_waiting_queue == last_waiting_queue
|
||||
):
|
||||
if self._is_stuck_in_a_loop(waiting_queue):
|
||||
# We're stuck! We can't make any progress.
|
||||
msg = (
|
||||
"Pipeline is stuck running in a loop. Partial outputs will be returned. "
|
||||
"Check the Pipeline graph for possible issues."
|
||||
)
|
||||
warn(RuntimeWarning(msg))
|
||||
break
|
||||
|
||||
(name, comp) = self._find_next_runnable_lazy_variadic_or_default_component(waiting_queue)
|
||||
_add_missing_input_defaults(name, comp, components_inputs)
|
||||
_enqueue_component((name, comp), run_queue, waiting_queue)
|
||||
continue
|
||||
|
||||
before_last_waiting_queue = last_waiting_queue.copy() if last_waiting_queue is not None else None
|
||||
last_waiting_queue = {item[0] for item in waiting_queue}
|
||||
|
||||
(name, comp) = self._find_next_runnable_component(components_inputs, waiting_queue)
|
||||
_add_missing_input_defaults(name, comp, components_inputs)
|
||||
_enqueue_component((name, comp), run_queue, waiting_queue)
|
||||
|
||||
return subgraph_outputs, extra_outputs
|
||||
|
||||
def run( # noqa: PLR0915, PLR0912
|
||||
self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Runs the pipeline with given input data.
|
||||
Runs the Pipeline with given input data.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from haystack import Pipeline, Document
|
||||
from haystack.utils import Secret
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
from haystack.components.builders.answer_builder import AnswerBuilder
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
|
||||
# Write documents to InMemoryDocumentStore
|
||||
document_store = InMemoryDocumentStore()
|
||||
document_store.write_documents([
|
||||
Document(content="My name is Jean and I live in Paris."),
|
||||
Document(content="My name is Mark and I live in Berlin."),
|
||||
Document(content="My name is Giorgio and I live in Rome.")
|
||||
])
|
||||
|
||||
prompt_template = \"\"\"
|
||||
Given these documents, answer the question.
|
||||
Documents:
|
||||
{% for doc in documents %}
|
||||
{{ doc.content }}
|
||||
{% endfor %}
|
||||
Question: {{question}}
|
||||
Answer:
|
||||
\"\"\"
|
||||
|
||||
retriever = InMemoryBM25Retriever(document_store=document_store)
|
||||
prompt_builder = PromptBuilder(template=prompt_template)
|
||||
llm = OpenAIGenerator(api_key=Secret.from_token(api_key))
|
||||
|
||||
rag_pipeline = Pipeline()
|
||||
rag_pipeline.add_component("retriever", retriever)
|
||||
rag_pipeline.add_component("prompt_builder", prompt_builder)
|
||||
rag_pipeline.add_component("llm", llm)
|
||||
rag_pipeline.connect("retriever", "prompt_builder.documents")
|
||||
rag_pipeline.connect("prompt_builder", "llm")
|
||||
|
||||
# Ask a question
|
||||
question = "Who lives in Paris?"
|
||||
results = rag_pipeline.run(
|
||||
{
|
||||
"retriever": {"query": question},
|
||||
"prompt_builder": {"question": question},
|
||||
}
|
||||
)
|
||||
|
||||
print(results["llm"]["replies"])
|
||||
# Jean lives in Paris
|
||||
```
|
||||
|
||||
:param data:
|
||||
A dictionary of inputs for the pipeline's components. Each key is a component name
|
||||
@ -104,7 +322,6 @@ class Pipeline(PipelineBase):
|
||||
"input1": 1, "input2": 2,
|
||||
}
|
||||
```
|
||||
|
||||
:param include_outputs_from:
|
||||
Set of component names whose individual outputs are to be
|
||||
included in the pipeline's output. For components that are
|
||||
@ -117,41 +334,11 @@ class Pipeline(PipelineBase):
|
||||
without outgoing connections.
|
||||
|
||||
:raises PipelineRuntimeError:
|
||||
If a component fails or returns unexpected output.
|
||||
|
||||
Example a - Using named components:
|
||||
Consider a 'Hello' component that takes a 'word' input and outputs a greeting.
|
||||
|
||||
```python
|
||||
@component
|
||||
class Hello:
|
||||
@component.output_types(output=str)
|
||||
def run(self, word: str):
|
||||
return {"output": f"Hello, {word}!"}
|
||||
```
|
||||
|
||||
Create a pipeline with two 'Hello' components connected together:
|
||||
|
||||
```python
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("hello", Hello())
|
||||
pipeline.add_component("hello2", Hello())
|
||||
pipeline.connect("hello.output", "hello2.word")
|
||||
result = pipeline.run(data={"hello": {"word": "world"}})
|
||||
```
|
||||
|
||||
This runs the pipeline with the specified input for 'hello', yielding
|
||||
{'hello2': {'output': 'Hello, Hello, world!!'}}.
|
||||
|
||||
Example b - Using flat inputs:
|
||||
You can also pass inputs directly without specifying component names:
|
||||
|
||||
```python
|
||||
result = pipeline.run(data={"word": "world"})
|
||||
```
|
||||
|
||||
The pipeline resolves inputs to the correct components, returning
|
||||
{'hello2': {'output': 'Hello, Hello, world!!'}}.
|
||||
If the Pipeline contains cycles with unsupported connections that would cause
|
||||
it to get stuck and fail running.
|
||||
Or if a Component fails or returns output in an unsupported type.
|
||||
:raises PipelineMaxComponentRuns:
|
||||
If a Component reaches the maximum number of times it can be run in this Pipeline.
|
||||
"""
|
||||
pipeline_running(self)
|
||||
|
||||
@ -168,15 +355,8 @@ class Pipeline(PipelineBase):
|
||||
# Raise if input is malformed in some way
|
||||
self._validate_input(data)
|
||||
|
||||
# Initialize the inputs state
|
||||
components_inputs: Dict[str, Dict[str, Any]] = self._init_inputs_state(data)
|
||||
|
||||
# Take all components that:
|
||||
# - have no inputs
|
||||
# - receive input from the user
|
||||
# - have at least one input not connected
|
||||
# - have at least one input that is variadic
|
||||
run_queue: List[Tuple[str, Component]] = self._init_run_queue(data)
|
||||
# Normalize the input data
|
||||
components_inputs: Dict[str, Dict[str, Any]] = self._normalize_varidiac_input_data(data)
|
||||
|
||||
# These variables are used to detect when we're stuck in a loop.
|
||||
# Stuck loops can happen when one or more components are waiting for input but
|
||||
@ -199,6 +379,31 @@ class Pipeline(PipelineBase):
|
||||
# This is what we'll return at the end
|
||||
final_outputs: Dict[Any, Any] = {}
|
||||
|
||||
# Break cycles in case there are, this is a noop if no cycle is found.
|
||||
# This will raise if a cycle can't be broken.
|
||||
graph_without_cycles, components_in_cycles = self._break_supported_cycles_in_graph()
|
||||
|
||||
run_queue: List[Tuple[str, Component]] = []
|
||||
for node in nx.topological_sort(graph_without_cycles):
|
||||
run_queue.append((node, self.graph.nodes[node]["instance"]))
|
||||
|
||||
# Set defaults inputs for those sockets that don't receive input neither from the user
|
||||
# nor from other Components.
|
||||
# If they have no default nothing is done.
|
||||
# This is important to ensure correct order execution, otherwise some variadic
|
||||
# Components that receive input from the user might be run before than they should.
|
||||
for name, comp in self.graph.nodes(data="instance"):
|
||||
if name not in components_inputs:
|
||||
components_inputs[name] = {}
|
||||
for socket_name, socket in comp.__haystack_input__._sockets_dict.items():
|
||||
if socket_name in components_inputs[name]:
|
||||
continue
|
||||
if not socket.senders:
|
||||
value = socket.default_value
|
||||
if socket.is_variadic:
|
||||
value = [value]
|
||||
components_inputs[name][socket_name] = value
|
||||
|
||||
with tracing.tracer.trace(
|
||||
"haystack.pipeline.run",
|
||||
tags={
|
||||
@ -219,14 +424,56 @@ class Pipeline(PipelineBase):
|
||||
# lazy variadic inputs left to run
|
||||
_enqueue_waiting_component((name, comp), waiting_queue)
|
||||
continue
|
||||
if self._component_has_enough_inputs_to_run(name, components_inputs) and components_in_cycles.get(
|
||||
name, []
|
||||
):
|
||||
cycles = components_in_cycles.get(name, [])
|
||||
|
||||
if self._component_has_enough_inputs_to_run(name, components_inputs):
|
||||
# This component is part of one or more cycles, let's get the first one and run it.
|
||||
# We can reliably pick any of the cycles if there are multiple ones, the way cycles
|
||||
# are run doesn't make a different whether we pick the first or any of the others a
|
||||
# Component is part of.
|
||||
subgraph_output, subgraph_extra_output = self._run_subgraph(
|
||||
cycles[0], name, components_inputs, include_outputs_from
|
||||
)
|
||||
|
||||
# After a cycle is run the previous run_queue can't be correct anymore cause it's
|
||||
# not modified when running the subgraph.
|
||||
# So we reset it given the output returned by the subgraph.
|
||||
run_queue = []
|
||||
|
||||
# Reset the waiting for input previous states, we managed to run at least one component
|
||||
before_last_waiting_queue = None
|
||||
last_waiting_queue = None
|
||||
|
||||
# Merge the extra outputs
|
||||
extra_outputs.update(subgraph_extra_output)
|
||||
|
||||
for component_name, component_output in subgraph_output.items():
|
||||
receivers = self._find_receivers_from(component_name)
|
||||
component_output = self._distribute_output(
|
||||
receivers, component_output, components_inputs, run_queue, waiting_queue
|
||||
)
|
||||
|
||||
if len(component_output) > 0:
|
||||
final_outputs[component_name] = component_output
|
||||
|
||||
elif self._component_has_enough_inputs_to_run(name, components_inputs):
|
||||
if self.graph.nodes[name]["visits"] > self._max_runs_per_component:
|
||||
msg = f"Maximum run count {self._max_runs_per_component} reached for component '{name}'"
|
||||
raise PipelineMaxComponentRuns(msg)
|
||||
|
||||
res: Dict[str, Any] = self._run_component(name, components_inputs[name])
|
||||
|
||||
# Delete the inputs that were consumed by the Component and are not received from the user
|
||||
sockets = list(components_inputs[name].keys())
|
||||
for socket_name in sockets:
|
||||
senders = comp.__haystack_input__._sockets_dict[socket_name].senders
|
||||
if senders:
|
||||
# Delete all inputs that are received from other Components
|
||||
del components_inputs[name][socket_name]
|
||||
# We keep inputs that came from the user
|
||||
|
||||
if name in include_outputs_from:
|
||||
# Deepcopy the outputs to prevent downstream nodes from modifying them
|
||||
# We don't care about loops - Always store the last output.
|
||||
@ -242,7 +489,8 @@ class Pipeline(PipelineBase):
|
||||
|
||||
for pair in self._find_components_that_will_receive_no_input(name, res, components_inputs):
|
||||
_dequeue_component(pair, run_queue, waiting_queue)
|
||||
res = self._distribute_output(name, res, components_inputs, run_queue, waiting_queue)
|
||||
receivers = self._find_receivers_from(name)
|
||||
res = self._distribute_output(receivers, res, components_inputs, run_queue, waiting_queue)
|
||||
|
||||
if len(res) > 0:
|
||||
final_outputs[name] = res
|
||||
|
||||
@ -13,7 +13,6 @@ from haystack.testing.sample_components.joiner import StringJoiner, StringListJo
|
||||
from haystack.testing.sample_components.parity import Parity
|
||||
from haystack.testing.sample_components.remainder import Remainder
|
||||
from haystack.testing.sample_components.repeat import Repeat
|
||||
from haystack.testing.sample_components.self_loop import SelfLoop
|
||||
from haystack.testing.sample_components.subtract import Subtract
|
||||
from haystack.testing.sample_components.sum import Sum
|
||||
from haystack.testing.sample_components.text_splitter import TextSplitter
|
||||
@ -35,6 +34,5 @@ __all__ = [
|
||||
"Hello",
|
||||
"TextSplitter",
|
||||
"StringListJoiner",
|
||||
"SelfLoop",
|
||||
"FString",
|
||||
]
|
||||
|
||||
@ -1,27 +0,0 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from haystack.core.component import component
|
||||
from haystack.core.component.types import Variadic
|
||||
|
||||
|
||||
@component
|
||||
class SelfLoop:
|
||||
"""
|
||||
Decreases the initial value in steps of 1 until the target value is reached.
|
||||
|
||||
For no good reason it uses a self-loop to do so :)
|
||||
"""
|
||||
|
||||
def __init__(self, target: int = 0):
|
||||
self.target = target
|
||||
|
||||
@component.output_types(current_value=int, final_result=int)
|
||||
def run(self, values: Variadic[int]):
|
||||
"""Decreases the input value in steps of 1 until the target value is reached."""
|
||||
value = values[0] # type: ignore
|
||||
value -= 1
|
||||
if value == self.target:
|
||||
return {"final_result": value}
|
||||
return {"current_value": value}
|
||||
12
releasenotes/notes/pipeline-run-rework-23a972d83b792db2.yaml
Normal file
12
releasenotes/notes/pipeline-run-rework-23a972d83b792db2.yaml
Normal file
@ -0,0 +1,12 @@
|
||||
---
|
||||
highlights: >
|
||||
`Pipeline.run()` internal logic has been heavily reworked to be more robust and reliable
|
||||
than before.
|
||||
This new implementation makes it easier to run `Pipeline`s that have cycles in their graph.
|
||||
It also fixes some corner cases in `Pipeline`s that don't have any cycle.
|
||||
features:
|
||||
- |
|
||||
Fundamentally rework the internal logic of `Pipeline.run()`.
|
||||
The rework makes it more reliable and covers more use cases.
|
||||
We fixed some issues that made `Pipeline`s with cycles unpredictable
|
||||
and with unclear Components execution order.
|
||||
@ -26,7 +26,7 @@ Feature: Pipeline running
|
||||
| that has a greedy and variadic component after a component with default input |
|
||||
| that has components added in a different order from the order of execution |
|
||||
| that has a component with only default inputs |
|
||||
| that has a component with only default inputs as first to run |
|
||||
| that has a component with only default inputs as first to run and receives inputs from a loop |
|
||||
| that has multiple branches that merge into a component with a single variadic input |
|
||||
| that has multiple branches of different lengths that merge into a component with a single variadic input |
|
||||
| that is linear and returns intermediate outputs |
|
||||
@ -37,8 +37,12 @@ Feature: Pipeline running
|
||||
| that has a loop and a component with default inputs that doesn't receive anything from its sender but receives input from user |
|
||||
| that has multiple components with only default inputs and are added in a different order from the order of execution |
|
||||
| that is linear with conditional branching and multiple joins |
|
||||
| that is a simple agent |
|
||||
| that has a variadic component that receives partial inputs |
|
||||
| that has an answer joiner variadic component |
|
||||
| that is linear and a component in the middle receives optional input from other components and input from the user |
|
||||
| that has a loop in the middle |
|
||||
| that has variadic component that receives a conditional input |
|
||||
|
||||
Scenario Outline: Running a bad Pipeline
|
||||
Given a pipeline <kind>
|
||||
@ -49,3 +53,4 @@ Feature: Pipeline running
|
||||
| kind | exception |
|
||||
| that has an infinite loop | PipelineMaxComponentRuns |
|
||||
| that has a component that doesn't return a dictionary | PipelineRuntimeError |
|
||||
| that has a cycle that would get it stuck | PipelineRuntimeError |
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
import json
|
||||
from typing import List, Optional, Dict, Any
|
||||
import re
|
||||
|
||||
from pytest_bdd import scenarios, given
|
||||
import pytest
|
||||
|
||||
from haystack import Pipeline, Document, component
|
||||
from haystack.document_stores.types import DuplicatePolicy
|
||||
from haystack.dataclasses import ChatMessage, GeneratedAnswer
|
||||
from haystack.components.routers import ConditionalRouter
|
||||
from haystack.components.builders import PromptBuilder, AnswerBuilder
|
||||
from haystack.components.builders import PromptBuilder, AnswerBuilder, ChatPromptBuilder
|
||||
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
|
||||
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.components.joiners import BranchJoiner, DocumentJoiner, AnswerJoiner
|
||||
@ -25,7 +29,6 @@ from haystack.testing.sample_components import (
|
||||
Hello,
|
||||
TextSplitter,
|
||||
StringListJoiner,
|
||||
SelfLoop,
|
||||
)
|
||||
from haystack.testing.factory import component_class
|
||||
|
||||
@ -67,18 +70,25 @@ def pipeline_that_is_linear():
|
||||
|
||||
@given("a pipeline that has an infinite loop", target_fixture="pipeline_data")
|
||||
def pipeline_that_has_an_infinite_loop():
|
||||
def custom_init(self):
|
||||
component.set_input_type(self, "x", int)
|
||||
component.set_input_type(self, "y", int, 1)
|
||||
component.set_output_types(self, a=int, b=int)
|
||||
routes = [
|
||||
{"condition": "{{number > 2}}", "output": "{{number}}", "output_name": "big_number", "output_type": int},
|
||||
{"condition": "{{number <= 2}}", "output": "{{number + 2}}", "output_name": "small_number", "output_type": int},
|
||||
]
|
||||
|
||||
main_input = BranchJoiner(int)
|
||||
first_router = ConditionalRouter(routes=routes)
|
||||
second_router = ConditionalRouter(routes=routes)
|
||||
|
||||
FakeComponent = component_class("FakeComponent", output={"a": 1, "b": 1}, extra_fields={"__init__": custom_init})
|
||||
pipe = Pipeline(max_runs_per_component=1)
|
||||
pipe.add_component("first", FakeComponent())
|
||||
pipe.add_component("second", FakeComponent())
|
||||
pipe.connect("first.a", "second.x")
|
||||
pipe.connect("second.b", "first.y")
|
||||
return pipe, [PipelineRunData({"first": {"x": 1}})]
|
||||
pipe.add_component("main_input", main_input)
|
||||
pipe.add_component("first_router", first_router)
|
||||
pipe.add_component("second_router", second_router)
|
||||
|
||||
pipe.connect("main_input", "first_router.number")
|
||||
pipe.connect("first_router.big_number", "second_router.number")
|
||||
pipe.connect("second_router.big_number", "main_input")
|
||||
|
||||
return pipe, [PipelineRunData({"main_input": {"value": 3}})]
|
||||
|
||||
|
||||
@given("a pipeline that is really complex with lots of components, forks, and loops", target_fixture="pipeline_data")
|
||||
@ -146,8 +156,11 @@ def pipeline_complex():
|
||||
expected_outputs={"accumulate_3": {"value": -7}, "add_five": {"result": -6}},
|
||||
expected_run_order=[
|
||||
"greet_first",
|
||||
"greet_enumerator",
|
||||
"accumulate_1",
|
||||
"enumerate",
|
||||
"add_two",
|
||||
"add_three",
|
||||
"parity",
|
||||
"add_one",
|
||||
"branch_joiner",
|
||||
@ -159,9 +172,6 @@ def pipeline_complex():
|
||||
"branch_joiner",
|
||||
"below_10",
|
||||
"accumulate_2",
|
||||
"greet_enumerator",
|
||||
"enumerate",
|
||||
"add_three",
|
||||
"sum",
|
||||
"diff",
|
||||
"greet_one_last_time",
|
||||
@ -837,8 +847,11 @@ def pipeline_that_has_a_component_with_only_default_inputs():
|
||||
)
|
||||
|
||||
|
||||
@given("a pipeline that has a component with only default inputs as first to run", target_fixture="pipeline_data")
|
||||
def pipeline_that_has_a_component_with_only_default_inputs_as_first_to_run():
|
||||
@given(
|
||||
"a pipeline that has a component with only default inputs as first to run and receives inputs from a loop",
|
||||
target_fixture="pipeline_data",
|
||||
)
|
||||
def pipeline_that_has_a_component_with_only_default_inputs_as_first_to_run_and_receives_inputs_from_a_loop():
|
||||
"""
|
||||
This tests verifies that a Pipeline doesn't get stuck running in a loop if
|
||||
it has all the following characterics:
|
||||
@ -1529,6 +1542,217 @@ def that_is_linear_with_conditional_branching_and_multiple_joins():
|
||||
)
|
||||
|
||||
|
||||
@given("a pipeline that is a simple agent", target_fixture="pipeline_data")
|
||||
def that_is_a_simple_agent():
|
||||
search_message_template = """
|
||||
Given these web search results:
|
||||
|
||||
{% for doc in documents %}
|
||||
{{ doc.content }}
|
||||
{% endfor %}
|
||||
|
||||
Be as brief as possible, max one sentence.
|
||||
Answer the question: {{search_query}}
|
||||
"""
|
||||
|
||||
react_message_template = """
|
||||
Solve a question answering task with interleaving Thought, Action, Observation steps.
|
||||
|
||||
Thought reasons about the current situation
|
||||
|
||||
Action can be:
|
||||
google_search - Searches Google for the exact concept/entity (given in square brackets) and returns the results for you to use
|
||||
finish - Returns the final answer (given in square brackets) and finishes the task
|
||||
|
||||
Observation summarizes the Action outcome and helps in formulating the next
|
||||
Thought in Thought, Action, Observation interleaving triplet of steps.
|
||||
|
||||
After each Observation, provide the next Thought and next Action.
|
||||
Don't execute multiple steps even though you know the answer.
|
||||
Only generate Thought and Action, never Observation, you'll get Observation from Action.
|
||||
Follow the pattern in the example below.
|
||||
|
||||
Example:
|
||||
###########################
|
||||
Question: Which magazine was started first Arthur’s Magazine or First for Women?
|
||||
Thought: I need to search Arthur’s Magazine and First for Women, and find which was started
|
||||
first.
|
||||
Action: google_search[When was 'Arthur’s Magazine' started?]
|
||||
Observation: Arthur’s Magazine was an American literary periodical ˘
|
||||
published in Philadelphia and founded in 1844. Edited by Timothy Shay Arthur, it featured work by
|
||||
Edgar A. Poe, J.H. Ingraham, Sarah Josepha Hale, Thomas G. Spear, and others. In May 1846
|
||||
it was merged into Godey’s Lady’s Book.
|
||||
Thought: Arthur’s Magazine was started in 1844. I need to search First for Women founding date next
|
||||
Action: google_search[When was 'First for Women' magazine started?]
|
||||
Observation: First for Women is a woman’s magazine published by Bauer Media Group in the
|
||||
USA. The magazine was started in 1989. It is based in Englewood Cliffs, New Jersey. In 2011
|
||||
the circulation of the magazine was 1,310,696 copies.
|
||||
Thought: First for Women was started in 1989. 1844 (Arthur’s Magazine) ¡ 1989 (First for
|
||||
Women), so Arthur’s Magazine was started first.
|
||||
Action: finish[Arthur’s Magazine]
|
||||
############################
|
||||
|
||||
Let's start, the question is: {{query}}
|
||||
|
||||
Thought:
|
||||
"""
|
||||
|
||||
routes = [
|
||||
{
|
||||
"condition": "{{'search' in tool_id_and_param[0]}}",
|
||||
"output": "{{tool_id_and_param[1]}}",
|
||||
"output_name": "search",
|
||||
"output_type": str,
|
||||
},
|
||||
{
|
||||
"condition": "{{'finish' in tool_id_and_param[0]}}",
|
||||
"output": "{{tool_id_and_param[1]}}",
|
||||
"output_name": "finish",
|
||||
"output_type": str,
|
||||
},
|
||||
]
|
||||
|
||||
@component
|
||||
class FakeThoughtActionOpenAIChatGenerator:
|
||||
run_counter = 0
|
||||
|
||||
@component.output_types(replies=List[ChatMessage])
|
||||
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||
if self.run_counter == 0:
|
||||
self.run_counter += 1
|
||||
return {
|
||||
"replies": [
|
||||
ChatMessage.from_assistant(
|
||||
"thinking\n Action: google_search[What is taller, Eiffel Tower or Leaning Tower of Pisa]\n"
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
return {"replies": [ChatMessage.from_assistant("thinking\n Action: finish[Eiffel Tower]\n")]}
|
||||
|
||||
@component
|
||||
class FakeConclusionOpenAIChatGenerator:
|
||||
@component.output_types(replies=List[ChatMessage])
|
||||
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||
return {"replies": [ChatMessage.from_assistant("Tower of Pisa is 55 meters tall\n")]}
|
||||
|
||||
@component
|
||||
class FakeSerperDevWebSearch:
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, query: str):
|
||||
return {
|
||||
"documents": [
|
||||
Document(content="Eiffel Tower is 300 meters tall"),
|
||||
Document(content="Tower of Pisa is 55 meters tall"),
|
||||
]
|
||||
}
|
||||
|
||||
# main part
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("main_input", BranchJoiner(List[ChatMessage]))
|
||||
pipeline.add_component("prompt_builder", ChatPromptBuilder(variables=["query"]))
|
||||
pipeline.add_component("llm", FakeThoughtActionOpenAIChatGenerator())
|
||||
|
||||
@component
|
||||
class ToolExtractor:
|
||||
@component.output_types(output=List[str])
|
||||
def run(self, messages: List[ChatMessage]):
|
||||
prompt: str = messages[-1].content
|
||||
lines = prompt.strip().split("\n")
|
||||
for line in reversed(lines):
|
||||
pattern = r"Action:\s*(\w+)\[(.*?)\]"
|
||||
|
||||
match = re.search(pattern, line)
|
||||
if match:
|
||||
action_name = match.group(1)
|
||||
parameter = match.group(2)
|
||||
return {"output": [action_name, parameter]}
|
||||
return {"output": [None, None]}
|
||||
|
||||
pipeline.add_component("tool_extractor", ToolExtractor())
|
||||
|
||||
@component
|
||||
class PromptConcatenator:
|
||||
def __init__(self, suffix: str = ""):
|
||||
self._suffix = suffix
|
||||
|
||||
@component.output_types(output=List[ChatMessage])
|
||||
def run(self, replies: List[ChatMessage], current_prompt: List[ChatMessage]):
|
||||
content = current_prompt[-1].content + replies[-1].content + self._suffix
|
||||
return {"output": [ChatMessage.from_user(content)]}
|
||||
|
||||
@component
|
||||
class SearchOutputAdapter:
|
||||
@component.output_types(output=List[ChatMessage])
|
||||
def run(self, replies: List[ChatMessage]):
|
||||
content = f"Observation: {replies[-1].content}\n"
|
||||
return {"output": [ChatMessage.from_assistant(content)]}
|
||||
|
||||
pipeline.add_component("prompt_concatenator_after_action", PromptConcatenator())
|
||||
|
||||
pipeline.add_component("router", ConditionalRouter(routes))
|
||||
pipeline.add_component("router_search", FakeSerperDevWebSearch())
|
||||
pipeline.add_component("search_prompt_builder", ChatPromptBuilder(variables=["documents", "search_query"]))
|
||||
pipeline.add_component("search_llm", FakeConclusionOpenAIChatGenerator())
|
||||
|
||||
pipeline.add_component("search_output_adapter", SearchOutputAdapter())
|
||||
pipeline.add_component("prompt_concatenator_after_observation", PromptConcatenator(suffix="\nThought: "))
|
||||
|
||||
# main
|
||||
pipeline.connect("main_input", "prompt_builder.template")
|
||||
pipeline.connect("prompt_builder.prompt", "llm.messages")
|
||||
pipeline.connect("llm.replies", "prompt_concatenator_after_action.replies")
|
||||
|
||||
# tools
|
||||
pipeline.connect("prompt_builder.prompt", "prompt_concatenator_after_action.current_prompt")
|
||||
pipeline.connect("prompt_concatenator_after_action", "tool_extractor.messages")
|
||||
|
||||
pipeline.connect("tool_extractor", "router")
|
||||
pipeline.connect("router.search", "router_search.query")
|
||||
pipeline.connect("router_search.documents", "search_prompt_builder.documents")
|
||||
pipeline.connect("router.search", "search_prompt_builder.search_query")
|
||||
pipeline.connect("search_prompt_builder.prompt", "search_llm.messages")
|
||||
|
||||
pipeline.connect("search_llm.replies", "search_output_adapter.replies")
|
||||
pipeline.connect("search_output_adapter", "prompt_concatenator_after_observation.replies")
|
||||
pipeline.connect("prompt_concatenator_after_action", "prompt_concatenator_after_observation.current_prompt")
|
||||
pipeline.connect("prompt_concatenator_after_observation", "main_input")
|
||||
|
||||
search_message = [ChatMessage.from_user(search_message_template)]
|
||||
messages = [ChatMessage.from_user(react_message_template)]
|
||||
question = "which tower is taller: eiffel tower or tower of pisa?"
|
||||
|
||||
return pipeline, [
|
||||
PipelineRunData(
|
||||
inputs={
|
||||
"main_input": {"value": messages},
|
||||
"prompt_builder": {"query": question},
|
||||
"search_prompt_builder": {"template": search_message},
|
||||
},
|
||||
expected_outputs={"router": {"finish": "Eiffel Tower"}},
|
||||
expected_run_order=[
|
||||
"main_input",
|
||||
"prompt_builder",
|
||||
"llm",
|
||||
"prompt_concatenator_after_action",
|
||||
"tool_extractor",
|
||||
"router",
|
||||
"router_search",
|
||||
"search_prompt_builder",
|
||||
"search_llm",
|
||||
"search_output_adapter",
|
||||
"prompt_concatenator_after_observation",
|
||||
"main_input",
|
||||
"prompt_builder",
|
||||
"llm",
|
||||
"prompt_concatenator_after_action",
|
||||
"tool_extractor",
|
||||
"router",
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@given("a pipeline that has a variadic component that receives partial inputs", target_fixture="pipeline_data")
|
||||
def that_has_a_variadic_component_that_receives_partial_inputs():
|
||||
@component
|
||||
@ -1566,7 +1790,7 @@ def that_has_a_variadic_component_that_receives_partial_inputs():
|
||||
]
|
||||
},
|
||||
},
|
||||
expected_run_order=["first_creator", "third_creator", "second_creator", "documents_joiner"],
|
||||
expected_run_order=["first_creator", "second_creator", "third_creator", "documents_joiner"],
|
||||
),
|
||||
PipelineRunData(
|
||||
inputs={"first_creator": {"create_document": True}, "second_creator": {"create_document": True}},
|
||||
@ -1627,3 +1851,347 @@ def that_has_an_answer_joiner_variadic_component():
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@given(
|
||||
"a pipeline that is linear and a component in the middle receives optional input from other components and input from the user",
|
||||
target_fixture="pipeline_data",
|
||||
)
|
||||
def that_is_linear_and_a_component_in_the_middle_receives_optional_input_from_other_components_and_input_from_the_user():
|
||||
@component
|
||||
class QueryMetadataExtractor:
|
||||
@component.output_types(filters=Dict[str, str])
|
||||
def run(self, prompt: str):
|
||||
metadata = json.loads(prompt)
|
||||
filters = []
|
||||
for key, value in metadata.items():
|
||||
filters.append({"field": f"meta.{key}", "operator": "==", "value": value})
|
||||
|
||||
return {"filters": {"operator": "AND", "conditions": filters}}
|
||||
|
||||
documents = [
|
||||
Document(
|
||||
content="some publication about Alzheimer prevention research done over 2023 patients study",
|
||||
meta={"year": 2022, "disease": "Alzheimer", "author": "Michael Butter"},
|
||||
id="doc1",
|
||||
),
|
||||
Document(
|
||||
content="some text about investigation and treatment of Alzheimer disease",
|
||||
meta={"year": 2023, "disease": "Alzheimer", "author": "John Bread"},
|
||||
id="doc2",
|
||||
),
|
||||
Document(
|
||||
content="A study on the effectiveness of new therapies for Parkinson's disease",
|
||||
meta={"year": 2022, "disease": "Parkinson", "author": "Alice Smith"},
|
||||
id="doc3",
|
||||
),
|
||||
Document(
|
||||
content="An overview of the latest research on the genetics of Parkinson's disease and its implications for treatment",
|
||||
meta={"year": 2023, "disease": "Parkinson", "author": "David Jones"},
|
||||
id="doc4",
|
||||
),
|
||||
]
|
||||
document_store = InMemoryDocumentStore(bm25_algorithm="BM25Plus")
|
||||
document_store.write_documents(documents=documents, policy=DuplicatePolicy.OVERWRITE)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component(instance=PromptBuilder('{"disease": "Alzheimer", "year": 2023}'), name="builder")
|
||||
pipeline.add_component(instance=QueryMetadataExtractor(), name="metadata_extractor")
|
||||
pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever")
|
||||
pipeline.add_component(instance=DocumentJoiner(), name="document_joiner")
|
||||
|
||||
pipeline.connect("builder.prompt", "metadata_extractor.prompt")
|
||||
pipeline.connect("metadata_extractor.filters", "retriever.filters")
|
||||
pipeline.connect("retriever.documents", "document_joiner.documents")
|
||||
|
||||
query = "publications 2023 Alzheimer's disease"
|
||||
|
||||
return (
|
||||
pipeline,
|
||||
[
|
||||
PipelineRunData(
|
||||
inputs={"retriever": {"query": query}},
|
||||
expected_outputs={
|
||||
"document_joiner": {
|
||||
"documents": [
|
||||
Document(
|
||||
content="some text about investigation and treatment of Alzheimer disease",
|
||||
meta={"year": 2023, "disease": "Alzheimer", "author": "John Bread"},
|
||||
id="doc2",
|
||||
score=3.324112496100923,
|
||||
)
|
||||
]
|
||||
}
|
||||
},
|
||||
expected_run_order=["builder", "metadata_extractor", "retriever", "document_joiner"],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@given("a pipeline that has a cycle that would get it stuck", target_fixture="pipeline_data")
|
||||
def that_has_a_cycle_that_would_get_it_stuck():
|
||||
template = """
|
||||
You are an experienced and accurate Turkish CX speacialist that classifies customer comments into pre-defined categories below:\n
|
||||
Negative experience labels:
|
||||
- Late delivery
|
||||
- Rotten/spoilt item
|
||||
- Bad Courier behavior
|
||||
|
||||
Positive experience labels:
|
||||
- Good courier behavior
|
||||
- Thanks & appreciation
|
||||
- Love message to courier
|
||||
- Fast delivery
|
||||
- Quality of products
|
||||
|
||||
Create a JSON object as a response. The fields are: 'positive_experience', 'negative_experience'.
|
||||
Assign at least one of the pre-defined labels to the given customer comment under positive and negative experience fields.
|
||||
If the comment has a positive experience, list the label under 'positive_experience' field.
|
||||
If the comments has a negative_experience, list it under the 'negative_experience' field.
|
||||
Here is the comment:\n{{ comment }}\n. Just return the category names in the list. If there aren't any, return an empty list.
|
||||
|
||||
{% if invalid_replies and error_message %}
|
||||
You already created the following output in a previous attempt: {{ invalid_replies }}
|
||||
However, this doesn't comply with the format requirements from above and triggered this Python exception: {{ error_message }}
|
||||
Correct the output and try again. Just return the corrected output without any extra explanations.
|
||||
{% endif %}
|
||||
"""
|
||||
prompt_builder = PromptBuilder(
|
||||
template=template, required_variables=["comment", "invalid_replies", "error_message"]
|
||||
)
|
||||
|
||||
@component
|
||||
class FakeOutputValidator:
|
||||
@component.output_types(
|
||||
valid_replies=List[str], invalid_replies=Optional[List[str]], error_message=Optional[str]
|
||||
)
|
||||
def run(self, replies: List[str]):
|
||||
if not getattr(self, "called", False):
|
||||
self.called = True
|
||||
return {"invalid_replies": ["This is an invalid reply"], "error_message": "this is an error message"}
|
||||
return {"valid_replies": replies}
|
||||
|
||||
@component
|
||||
class FakeGenerator:
|
||||
@component.output_types(replies=List[str])
|
||||
def run(self, prompt: str):
|
||||
return {"replies": ["This is a valid reply"]}
|
||||
|
||||
llm = FakeGenerator()
|
||||
validator = FakeOutputValidator()
|
||||
|
||||
pipeline = Pipeline(max_runs_per_component=1)
|
||||
pipeline.add_component("prompt_builder", prompt_builder)
|
||||
|
||||
pipeline.add_component("llm", llm)
|
||||
pipeline.add_component("output_validator", validator)
|
||||
|
||||
pipeline.connect("prompt_builder.prompt", "llm.prompt")
|
||||
pipeline.connect("llm.replies", "output_validator.replies")
|
||||
pipeline.connect("output_validator.invalid_replies", "prompt_builder.invalid_replies")
|
||||
|
||||
pipeline.connect("output_validator.error_message", "prompt_builder.error_message")
|
||||
|
||||
comment = "I loved the quality of the meal but the courier was rude"
|
||||
return (pipeline, [PipelineRunData(inputs={"prompt_builder": {"comment": comment}})])
|
||||
|
||||
|
||||
@given("a pipeline that has a loop in the middle", target_fixture="pipeline_data")
|
||||
def that_has_a_loop_in_the_middle():
|
||||
@component
|
||||
class FakeGenerator:
|
||||
@component.output_types(replies=List[str])
|
||||
def run(self, prompt: str):
|
||||
replies = []
|
||||
if getattr(self, "first_run", True):
|
||||
self.first_run = False
|
||||
replies.append("No answer")
|
||||
else:
|
||||
replies.append("42")
|
||||
return {"replies": replies}
|
||||
|
||||
@component
|
||||
class PromptCleaner:
|
||||
@component.output_types(clean_prompt=str)
|
||||
def run(self, prompt: str):
|
||||
return {"clean_prompt": prompt.strip()}
|
||||
|
||||
routes = [
|
||||
{
|
||||
"condition": "{{ 'No answer' in replies }}",
|
||||
"output": "{{ replies }}",
|
||||
"output_name": "invalid_replies",
|
||||
"output_type": List[str],
|
||||
},
|
||||
{
|
||||
"condition": "{{ 'No answer' not in replies }}",
|
||||
"output": "{{ replies }}",
|
||||
"output_name": "valid_replies",
|
||||
"output_type": List[str],
|
||||
},
|
||||
]
|
||||
|
||||
pipeline = Pipeline(max_runs_per_component=20)
|
||||
pipeline.add_component("prompt_cleaner", PromptCleaner())
|
||||
pipeline.add_component("prompt_builder", PromptBuilder(template="", variables=["question", "invalid_replies"]))
|
||||
pipeline.add_component("llm", FakeGenerator())
|
||||
pipeline.add_component("answer_validator", ConditionalRouter(routes=routes))
|
||||
pipeline.add_component("answer_builder", AnswerBuilder())
|
||||
|
||||
pipeline.connect("prompt_cleaner.clean_prompt", "prompt_builder.template")
|
||||
pipeline.connect("prompt_builder.prompt", "llm.prompt")
|
||||
pipeline.connect("llm.replies", "answer_validator.replies")
|
||||
pipeline.connect("answer_validator.invalid_replies", "prompt_builder.invalid_replies")
|
||||
pipeline.connect("answer_validator.valid_replies", "answer_builder.replies")
|
||||
|
||||
question = "What is the answer?"
|
||||
return (
|
||||
pipeline,
|
||||
[
|
||||
PipelineRunData(
|
||||
inputs={
|
||||
"prompt_cleaner": {"prompt": "Random template"},
|
||||
"prompt_builder": {"question": question},
|
||||
"answer_builder": {"query": question},
|
||||
},
|
||||
expected_outputs={
|
||||
"answer_builder": {"answers": [GeneratedAnswer(data="42", query=question, documents=[])]}
|
||||
},
|
||||
expected_run_order=[
|
||||
"prompt_cleaner",
|
||||
"prompt_builder",
|
||||
"llm",
|
||||
"answer_validator",
|
||||
"prompt_builder",
|
||||
"llm",
|
||||
"answer_validator",
|
||||
"answer_builder",
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@given("a pipeline that has variadic component that receives a conditional input", target_fixture="pipeline_data")
|
||||
def that_has_variadic_component_that_receives_a_conditional_input():
|
||||
pipe = Pipeline(max_runs_per_component=1)
|
||||
routes = [
|
||||
{
|
||||
"condition": "{{ documents|length > 1 }}",
|
||||
"output": "{{ documents }}",
|
||||
"output_name": "long",
|
||||
"output_type": List[Document],
|
||||
},
|
||||
{
|
||||
"condition": "{{ documents|length <= 1 }}",
|
||||
"output": "{{ documents }}",
|
||||
"output_name": "short",
|
||||
"output_type": List[Document],
|
||||
},
|
||||
]
|
||||
|
||||
@component
|
||||
class NoOp:
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: List[Document]):
|
||||
return {"documents": documents}
|
||||
|
||||
@component
|
||||
class CommaSplitter:
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: List[Document]):
|
||||
res = []
|
||||
current_id = 0
|
||||
for doc in documents:
|
||||
for split in doc.content.split(","):
|
||||
res.append(Document(content=split, id=str(current_id)))
|
||||
current_id += 1
|
||||
return {"documents": res}
|
||||
|
||||
pipe.add_component("conditional_router", ConditionalRouter(routes, unsafe=True))
|
||||
pipe.add_component(
|
||||
"empty_lines_cleaner", DocumentCleaner(remove_empty_lines=True, remove_extra_whitespaces=False, keep_id=True)
|
||||
)
|
||||
pipe.add_component("comma_splitter", CommaSplitter())
|
||||
pipe.add_component("document_cleaner", DocumentCleaner(keep_id=True))
|
||||
pipe.add_component("document_joiner", DocumentJoiner())
|
||||
|
||||
pipe.add_component("noop2", NoOp())
|
||||
pipe.add_component("noop3", NoOp())
|
||||
|
||||
pipe.connect("noop2", "noop3")
|
||||
pipe.connect("noop3", "conditional_router")
|
||||
|
||||
pipe.connect("conditional_router.long", "empty_lines_cleaner")
|
||||
pipe.connect("empty_lines_cleaner", "document_joiner")
|
||||
|
||||
pipe.connect("comma_splitter", "document_cleaner")
|
||||
pipe.connect("document_cleaner", "document_joiner")
|
||||
pipe.connect("comma_splitter", "document_joiner")
|
||||
|
||||
document = Document(
|
||||
id="1000", content="This document has so many, sentences. Like this one, or this one. Or even this other one."
|
||||
)
|
||||
|
||||
return pipe, [
|
||||
PipelineRunData(
|
||||
inputs={"noop2": {"documents": [document]}, "comma_splitter": {"documents": [document]}},
|
||||
expected_outputs={
|
||||
"conditional_router": {
|
||||
"short": [
|
||||
Document(
|
||||
id="1000",
|
||||
content="This document has so many, sentences. Like this one, or this one. Or even this other one.",
|
||||
)
|
||||
]
|
||||
},
|
||||
"document_joiner": {
|
||||
"documents": [
|
||||
Document(id="0", content="This document has so many"),
|
||||
Document(id="1", content=" sentences. Like this one"),
|
||||
Document(id="2", content=" or this one. Or even this other one."),
|
||||
]
|
||||
},
|
||||
},
|
||||
expected_run_order=[
|
||||
"comma_splitter",
|
||||
"noop2",
|
||||
"document_cleaner",
|
||||
"noop3",
|
||||
"conditional_router",
|
||||
"document_joiner",
|
||||
],
|
||||
),
|
||||
PipelineRunData(
|
||||
inputs={
|
||||
"noop2": {"documents": [document, document]},
|
||||
"comma_splitter": {"documents": [document, document]},
|
||||
},
|
||||
expected_outputs={
|
||||
"document_joiner": {
|
||||
"documents": [
|
||||
Document(id="0", content="This document has so many"),
|
||||
Document(id="1", content=" sentences. Like this one"),
|
||||
Document(id="2", content=" or this one. Or even this other one."),
|
||||
Document(id="3", content="This document has so many"),
|
||||
Document(id="4", content=" sentences. Like this one"),
|
||||
Document(id="5", content=" or this one. Or even this other one."),
|
||||
Document(
|
||||
id="1000",
|
||||
content="This document has so many, sentences. Like this one, or this one. Or even this other one.",
|
||||
),
|
||||
]
|
||||
}
|
||||
},
|
||||
expected_run_order=[
|
||||
"comma_splitter",
|
||||
"noop2",
|
||||
"document_cleaner",
|
||||
"noop3",
|
||||
"conditional_router",
|
||||
"empty_lines_cleaner",
|
||||
"document_joiner",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
@ -11,7 +11,7 @@ from haystack import Document
|
||||
from haystack.components.builders import PromptBuilder, AnswerBuilder
|
||||
from haystack.components.joiners import BranchJoiner
|
||||
from haystack.core.component import component
|
||||
from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic
|
||||
from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic, _empty
|
||||
from haystack.core.errors import DeserializationError, PipelineConnectError, PipelineDrawingError, PipelineError
|
||||
from haystack.core.pipeline import Pipeline, PredefinedPipeline
|
||||
from haystack.core.pipeline.base import (
|
||||
@ -788,43 +788,7 @@ class TestPipeline:
|
||||
for node in pipe.graph.nodes:
|
||||
assert pipe.graph.nodes[node]["visits"] == 0
|
||||
|
||||
def test__init_run_queue(self):
|
||||
ComponentWithVariadic = component_class(
|
||||
"ComponentWithVariadic", input_types={"in": Variadic[int]}, output_types={"out": int}
|
||||
)
|
||||
ComponentWithNoInputs = component_class("ComponentWithNoInputs", input_types={}, output_types={"out": int})
|
||||
ComponentWithSingleInput = component_class(
|
||||
"ComponentWithSingleInput", input_types={"in": int}, output_types={"out": int}
|
||||
)
|
||||
ComponentWithMultipleInputs = component_class(
|
||||
"ComponentWithMultipleInputs", input_types={"in1": int, "in2": int}, output_types={"out": int}
|
||||
)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("with_variadic", ComponentWithVariadic())
|
||||
pipe.add_component("with_no_inputs", ComponentWithNoInputs())
|
||||
pipe.add_component("with_single_input", ComponentWithSingleInput())
|
||||
pipe.add_component("another_with_single_input", ComponentWithSingleInput())
|
||||
pipe.add_component("yet_another_with_single_input", ComponentWithSingleInput())
|
||||
pipe.add_component("with_multiple_inputs", ComponentWithMultipleInputs())
|
||||
|
||||
pipe.connect("yet_another_with_single_input.out", "with_variadic.in")
|
||||
pipe.connect("with_no_inputs.out", "with_variadic.in")
|
||||
pipe.connect("with_single_input.out", "another_with_single_input.in")
|
||||
pipe.connect("another_with_single_input.out", "with_multiple_inputs.in1")
|
||||
pipe.connect("with_multiple_inputs.out", "with_variadic.in")
|
||||
|
||||
data = {"yet_another_with_single_input": {"in": 1}}
|
||||
run_queue = pipe._init_run_queue(data)
|
||||
assert len(run_queue) == 6
|
||||
assert run_queue[0][0] == "with_no_inputs"
|
||||
assert run_queue[1][0] == "with_single_input"
|
||||
assert run_queue[2][0] == "yet_another_with_single_input"
|
||||
assert run_queue[3][0] == "another_with_single_input"
|
||||
assert run_queue[4][0] == "with_multiple_inputs"
|
||||
assert run_queue[5][0] == "with_variadic"
|
||||
|
||||
def test__init_inputs_state(self):
|
||||
def test__normalize_varidiac_input_data(self):
|
||||
pipe = Pipeline()
|
||||
template = """
|
||||
Answer the following questions:
|
||||
@ -838,13 +802,12 @@ class TestPipeline:
|
||||
"branch_joiner": {"value": 1},
|
||||
"not_a_component": "some input data",
|
||||
}
|
||||
res = pipe._init_inputs_state(data)
|
||||
res = pipe._normalize_varidiac_input_data(data)
|
||||
assert res == {
|
||||
"prompt_builder": {"questions": ["What is the capital of Italy?", "What is the capital of France?"]},
|
||||
"branch_joiner": {"value": [1]},
|
||||
"not_a_component": "some input data",
|
||||
}
|
||||
assert id(questions) != id(res["prompt_builder"]["questions"])
|
||||
|
||||
def test__prepare_component_input_data(self):
|
||||
MockComponent = component_class("MockComponent", input_types={"x": List[str], "y": str})
|
||||
@ -1165,6 +1128,30 @@ class TestPipeline:
|
||||
)
|
||||
assert res == set()
|
||||
|
||||
multiple_outputs = component_class("MultipleOutputs", output_types={"first": int, "second": int})()
|
||||
|
||||
def custom_init(self):
|
||||
component.set_input_type(self, "first", Optional[int], 1)
|
||||
component.set_input_type(self, "second", Optional[int], 2)
|
||||
|
||||
multiple_optional_inputs = component_class("MultipleOptionalInputs", extra_fields={"__init__": custom_init})()
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("multiple_outputs", multiple_outputs)
|
||||
pipe.add_component("multiple_optional_inputs", multiple_optional_inputs)
|
||||
pipe.connect("multiple_outputs.second", "multiple_optional_inputs.first")
|
||||
|
||||
res = pipe._find_components_that_will_receive_no_input("multiple_outputs", {"first": 1}, {})
|
||||
assert res == {("multiple_optional_inputs", multiple_optional_inputs)}
|
||||
|
||||
res = pipe._find_components_that_will_receive_no_input(
|
||||
"multiple_outputs", {"first": 1}, {"multiple_optional_inputs": {"second": 200}}
|
||||
)
|
||||
assert res == set()
|
||||
|
||||
res = pipe._find_components_that_will_receive_no_input("multiple_outputs", {"second": 1}, {})
|
||||
assert res == set()
|
||||
|
||||
def test__distribute_output(self):
|
||||
document_builder = component_class(
|
||||
"DocumentBuilder", input_types={"text": str}, output_types={"doc": Document, "another_doc": Document}
|
||||
@ -1184,12 +1171,20 @@ class TestPipeline:
|
||||
inputs = {"document_builder": {"text": "some text"}}
|
||||
run_queue = []
|
||||
waiting_queue = [("document_joiner", document_joiner)]
|
||||
receivers = [
|
||||
(
|
||||
"document_cleaner",
|
||||
OutputSocket("doc", Document, ["document_cleaner"]),
|
||||
InputSocket("doc", Document, _empty, ["document_builder"]),
|
||||
),
|
||||
(
|
||||
"document_joiner",
|
||||
OutputSocket("another_doc", Document, ["document_joiner"]),
|
||||
InputSocket("docs", Variadic[Document], _empty, ["document_builder"]),
|
||||
),
|
||||
]
|
||||
res = pipe._distribute_output(
|
||||
"document_builder",
|
||||
{"doc": Document("some text"), "another_doc": Document()},
|
||||
inputs,
|
||||
run_queue,
|
||||
waiting_queue,
|
||||
receivers, {"doc": Document("some text"), "another_doc": Document()}, inputs, run_queue, waiting_queue
|
||||
)
|
||||
|
||||
assert res == {}
|
||||
@ -1524,3 +1519,65 @@ class TestPipeline:
|
||||
assert not _is_lazy_variadic(NonVariadic())
|
||||
assert _is_lazy_variadic(VariadicNonGreedyVariadic())
|
||||
assert not _is_lazy_variadic(NonVariadicAndGreedyVariadic())
|
||||
|
||||
def test__find_receivers_from(self):
|
||||
sentence_builder = component_class(
|
||||
"SentenceBuilder", input_types={"words": List[str]}, output_types={"text": str}
|
||||
)()
|
||||
document_builder = component_class(
|
||||
"DocumentBuilder", input_types={"text": str}, output_types={"doc": Document}
|
||||
)()
|
||||
conditional_document_builder = component_class(
|
||||
"ConditionalDocumentBuilder", output_types={"doc": Document, "noop": None}
|
||||
)()
|
||||
|
||||
document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})()
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("sentence_builder", sentence_builder)
|
||||
pipe.add_component("document_builder", document_builder)
|
||||
pipe.add_component("document_joiner", document_joiner)
|
||||
pipe.add_component("conditional_document_builder", conditional_document_builder)
|
||||
pipe.connect("sentence_builder.text", "document_builder.text")
|
||||
pipe.connect("document_builder.doc", "document_joiner.docs")
|
||||
pipe.connect("conditional_document_builder.doc", "document_joiner.docs")
|
||||
|
||||
res = pipe._find_receivers_from("sentence_builder")
|
||||
assert res == [
|
||||
(
|
||||
"document_builder",
|
||||
OutputSocket(name="text", type=str, receivers=["document_builder"]),
|
||||
InputSocket(name="text", type=str, default_value=_empty, senders=["sentence_builder"]),
|
||||
)
|
||||
]
|
||||
|
||||
res = pipe._find_receivers_from("document_builder")
|
||||
assert res == [
|
||||
(
|
||||
"document_joiner",
|
||||
OutputSocket(name="doc", type=Document, receivers=["document_joiner"]),
|
||||
InputSocket(
|
||||
name="docs",
|
||||
type=Variadic[Document],
|
||||
default_value=_empty,
|
||||
senders=["document_builder", "conditional_document_builder"],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
res = pipe._find_receivers_from("document_joiner")
|
||||
assert res == []
|
||||
|
||||
res = pipe._find_receivers_from("conditional_document_builder")
|
||||
assert res == [
|
||||
(
|
||||
"document_joiner",
|
||||
OutputSocket(name="doc", type=Document, receivers=["document_joiner"]),
|
||||
InputSocket(
|
||||
name="docs",
|
||||
type=Variadic[Document],
|
||||
default_value=_empty,
|
||||
senders=["document_builder", "conditional_document_builder"],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@ -6,9 +6,9 @@ from typing import Optional
|
||||
import pytest
|
||||
|
||||
from haystack.core.component.types import InputSocket, OutputSocket, Variadic
|
||||
from haystack.core.errors import PipelineValidationError
|
||||
from haystack.core.pipeline import Pipeline
|
||||
from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
|
||||
from haystack.testing.factory import component_class
|
||||
from haystack.testing.sample_components import AddFixedValue, Double, Parity, Sum
|
||||
|
||||
|
||||
@ -119,10 +119,16 @@ def test_find_pipeline_some_outputs_different_components():
|
||||
|
||||
def test_validate_pipeline_input_pipeline_with_no_inputs():
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("comp1", Double())
|
||||
pipe.add_component("comp2", Double())
|
||||
pipe.connect("comp1", "comp2")
|
||||
pipe.connect("comp2", "comp1")
|
||||
NoInputs = component_class("NoInputs", input_types={}, output={"value": 10})
|
||||
pipe.add_component("no_inputs", NoInputs())
|
||||
res = pipe.run({})
|
||||
assert res == {"no_inputs": {"value": 10}}
|
||||
|
||||
|
||||
def test_validate_pipeline_input_pipeline_with_no_inputs_no_outputs():
|
||||
pipe = Pipeline()
|
||||
NoIO = component_class("NoIO", input_types={}, output={})
|
||||
pipe.add_component("no_inputs", NoIO())
|
||||
res = pipe.run({})
|
||||
assert res == {}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user