David S. Batista 8af4cf8b01
chore: removing Pipeline.draw() deprecation warnings (#9651)
* cleaning up tests

* adding release notes
2025-07-24 12:35:19 +02:00

1485 lines
67 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# pylint: disable=too-many-lines
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import itertools
from collections import defaultdict
from datetime import datetime
from enum import IntEnum
from pathlib import Path
from typing import (
Any,
ContextManager,
Dict,
Iterator,
List,
Mapping,
Optional,
Set,
TextIO,
Tuple,
Type,
TypeVar,
Union,
)
import networkx # type:ignore
from haystack import logging, tracing
from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.errors import (
DeserializationError,
PipelineComponentsBlockedError,
PipelineConnectError,
PipelineDrawingError,
PipelineError,
PipelineMaxComponentRuns,
PipelineUnmarshalError,
PipelineValidationError,
)
from haystack.core.pipeline.component_checks import (
_NO_OUTPUT_PRODUCED,
all_predecessors_executed,
are_all_lazy_variadic_sockets_resolved,
are_all_sockets_ready,
can_component_run,
is_any_greedy_socket_ready,
is_socket_lazy_variadic,
)
from haystack.core.pipeline.utils import FIFOPriorityQueue, _deepcopy_with_exceptions, parse_connect_string
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
from haystack.core.type_utils import _type_name, _types_are_compatible
from haystack.marshal import Marshaller, YamlMarshaller
from haystack.utils import is_in_jupyter, type_serialization
from .descriptions import find_pipeline_inputs, find_pipeline_outputs
from .draw import _to_mermaid_image
from .template import PipelineTemplate, PredefinedPipeline
DEFAULT_MARSHALLER = YamlMarshaller()
# We use a generic type to annotate the return value of class methods,
# so that static analyzers won't be confused when derived classes
# use those methods.
T = TypeVar("T", bound="PipelineBase")
logger = logging.getLogger(__name__)
# Constants for tracing tags
_COMPONENT_INPUT = "haystack.component.input"
_COMPONENT_OUTPUT = "haystack.component.output"
_COMPONENT_VISITS = "haystack.component.visits"
class ComponentPriority(IntEnum):
HIGHEST = 1
READY = 2
DEFER = 3
DEFER_LAST = 4
BLOCKED = 5
class PipelineBase: # noqa: PLW1641
"""
Components orchestration engine.
Builds a graph of components and orchestrates their execution according to the execution graph.
"""
def __init__(
self,
metadata: Optional[Dict[str, Any]] = None,
max_runs_per_component: int = 100,
connection_type_validation: bool = True,
):
"""
Creates the Pipeline.
:param metadata:
Arbitrary dictionary to store metadata about this `Pipeline`. Make sure all the values contained in
this dictionary can be serialized and deserialized if you wish to save this `Pipeline` to file.
:param max_runs_per_component:
How many times the `Pipeline` can run the same Component.
If this limit is reached a `PipelineMaxComponentRuns` exception is raised.
If not set defaults to 100 runs per Component.
:param connection_type_validation: Whether the pipeline will validate the types of the connections.
Defaults to True.
"""
self._telemetry_runs = 0
self._last_telemetry_sent: Optional[datetime] = None
self.metadata = metadata or {}
self.graph = networkx.MultiDiGraph()
self._max_runs_per_component = max_runs_per_component
self._connection_type_validation = connection_type_validation
def __eq__(self, other: object) -> bool:
"""
Pipeline equality is defined by their type and the equality of their serialized form.
Pipelines of the same type share every metadata, node and edge, but they're not required to use
the same node instances: this allows pipeline saved and then loaded back to be equal to themselves.
"""
if not isinstance(self, type(other)):
return False
assert isinstance(other, PipelineBase)
return self.to_dict() == other.to_dict()
def __repr__(self) -> str:
"""
Returns a text representation of the Pipeline.
"""
res = f"{object.__repr__(self)}\n"
if self.metadata:
res += "🧱 Metadata\n"
for k, v in self.metadata.items():
res += f" - {k}: {v}\n"
res += "🚅 Components\n"
for name, instance in self.graph.nodes(data="instance"): # type: ignore # type wrongly defined in networkx
res += f" - {name}: {instance.__class__.__name__}\n"
res += "🛤️ Connections\n"
for sender, receiver, edge_data in self.graph.edges(data=True):
sender_socket = edge_data["from_socket"].name
receiver_socket = edge_data["to_socket"].name
res += f" - {sender}.{sender_socket} -> {receiver}.{receiver_socket} ({edge_data['conn_type']})\n"
return res
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the pipeline to a dictionary.
This is meant to be an intermediate representation but it can be also used to save a pipeline to file.
:returns:
Dictionary with serialized data.
"""
components = {}
for name, instance in self.graph.nodes(data="instance"): # type:ignore
components[name] = component_to_dict(instance, name)
connections = []
for sender, receiver, edge_data in self.graph.edges.data():
sender_socket = edge_data["from_socket"].name
receiver_socket = edge_data["to_socket"].name
connections.append({"sender": f"{sender}.{sender_socket}", "receiver": f"{receiver}.{receiver_socket}"})
return {
"metadata": self.metadata,
"max_runs_per_component": self._max_runs_per_component,
"components": components,
"connections": connections,
"connection_type_validation": self._connection_type_validation,
}
@classmethod
def from_dict(
cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs: Any
) -> T:
"""
Deserializes the pipeline from a dictionary.
:param data:
Dictionary to deserialize from.
:param callbacks:
Callbacks to invoke during deserialization.
:param kwargs:
`components`: a dictionary of `{name: instance}` to reuse instances of components instead of creating new
ones.
:returns:
Deserialized component.
"""
data_copy = _deepcopy_with_exceptions(data) # to prevent modification of original data
metadata = data_copy.get("metadata", {})
max_runs_per_component = data_copy.get("max_runs_per_component", 100)
connection_type_validation = data_copy.get("connection_type_validation", True)
pipe = cls(
metadata=metadata,
max_runs_per_component=max_runs_per_component,
connection_type_validation=connection_type_validation,
)
components_to_reuse = kwargs.get("components", {})
for name, component_data in data_copy.get("components", {}).items():
if name in components_to_reuse:
# Reuse an instance
instance = components_to_reuse[name]
else:
if "type" not in component_data:
raise PipelineError(f"Missing 'type' in component '{name}'")
if component_data["type"] not in component.registry:
try:
# Import the module first...
module, _ = component_data["type"].rsplit(".", 1)
logger.debug("Trying to import module {module_name}", module_name=module)
type_serialization.thread_safe_import(module)
# ...then try again
if component_data["type"] not in component.registry:
raise PipelineError(
f"Successfully imported module '{module}' but couldn't find "
f"'{component_data['type']}' in the component registry.\n"
f"The component might be registered under a different path. "
f"Here are the registered components:\n {list(component.registry.keys())}\n"
)
except (ImportError, PipelineError, ValueError) as e:
raise PipelineError(
f"Component '{component_data['type']}' (name: '{name}') not imported. Please "
f"check that the package is installed and the component path is correct."
) from e
# Create a new one
component_class = component.registry[component_data["type"]]
try:
instance = component_from_dict(component_class, component_data, name, callbacks)
except Exception as e:
msg = (
f"Couldn't deserialize component '{name}' of class '{component_class.__name__}' "
f"with the following data: {str(component_data)}. Possible reasons include "
"malformed serialized data, mismatch between the serialized component and the "
"loaded one (due to a breaking change, see "
"https://github.com/deepset-ai/haystack/releases), etc."
)
raise DeserializationError(msg) from e
pipe.add_component(name=name, instance=instance)
for connection in data.get("connections", []):
if "sender" not in connection:
raise PipelineError(f"Missing sender in connection: {connection}")
if "receiver" not in connection:
raise PipelineError(f"Missing receiver in connection: {connection}")
pipe.connect(sender=connection["sender"], receiver=connection["receiver"])
return pipe
def dumps(self, marshaller: Marshaller = DEFAULT_MARSHALLER) -> str:
"""
Returns the string representation of this pipeline according to the format dictated by the `Marshaller` in use.
:param marshaller:
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
:returns:
A string representing the pipeline.
"""
return marshaller.marshal(self.to_dict())
def dump(self, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER) -> None:
"""
Writes the string representation of this pipeline to the file-like object passed in the `fp` argument.
:param fp:
A file-like object ready to be written to.
:param marshaller:
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
"""
fp.write(marshaller.marshal(self.to_dict()))
@classmethod
def loads(
cls: Type[T],
data: Union[str, bytes, bytearray],
marshaller: Marshaller = DEFAULT_MARSHALLER,
callbacks: Optional[DeserializationCallbacks] = None,
) -> T:
"""
Creates a `Pipeline` object from the string representation passed in the `data` argument.
:param data:
The string representation of the pipeline, can be `str`, `bytes` or `bytearray`.
:param marshaller:
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
:param callbacks:
Callbacks to invoke during deserialization.
:raises DeserializationError:
If an error occurs during deserialization.
:returns:
A `Pipeline` object.
"""
try:
deserialized_data = marshaller.unmarshal(data)
except Exception as e:
raise DeserializationError(
"Error while unmarshalling serialized pipeline data. This is usually "
"caused by malformed or invalid syntax in the serialized representation."
) from e
return cls.from_dict(deserialized_data, callbacks)
@classmethod
def load(
cls: Type[T],
fp: TextIO,
marshaller: Marshaller = DEFAULT_MARSHALLER,
callbacks: Optional[DeserializationCallbacks] = None,
) -> T:
"""
Creates a `Pipeline` object a string representation.
The string representation is read from the file-like object passed in the `fp` argument.
:param fp:
A file-like object ready to be read from.
:param marshaller:
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
:param callbacks:
Callbacks to invoke during deserialization.
:raises DeserializationError:
If an error occurs during deserialization.
:returns:
A `Pipeline` object.
"""
return cls.loads(fp.read(), marshaller, callbacks)
def add_component(self, name: str, instance: Component) -> None:
"""
Add the given component to the pipeline.
Components are not connected to anything by default: use `Pipeline.connect()` to connect components together.
Component names must be unique, but component instances can be reused if needed.
:param name:
The name of the component to add.
:param instance:
The component instance to add.
:raises ValueError:
If a component with the same name already exists.
:raises PipelineValidationError:
If the given instance is not a component.
"""
# Component names are unique
if name in self.graph.nodes:
raise ValueError(f"A component named '{name}' already exists in this pipeline: choose another name.")
# Components can't be named `_debug`
if name == "_debug":
raise ValueError("'_debug' is a reserved name for debug output. Choose another name.")
# Component names can't have "."
if "." in name:
raise ValueError(f"{name} is an invalid component name, cannot contain '.' (dot) characters.")
# Component instances must be components
if not isinstance(instance, Component):
raise PipelineValidationError(
f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
)
if getattr(instance, "__haystack_added_to_pipeline__", None):
msg = (
"Component has already been added in another Pipeline. Components can't be shared between Pipelines. "
"Create a new instance instead."
)
raise PipelineError(msg)
setattr(instance, "__haystack_added_to_pipeline__", self)
setattr(instance, "__component_name__", name)
# Add component to the graph, disconnected
logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance)
# We're completely sure the fields exist so we ignore the type error
self.graph.add_node(
name,
instance=instance,
input_sockets=instance.__haystack_input__._sockets_dict, # type: ignore[attr-defined]
output_sockets=instance.__haystack_output__._sockets_dict, # type: ignore[attr-defined]
visits=0,
)
def remove_component(self, name: str) -> Component:
"""
Remove and returns component from the pipeline.
Remove an existing component from the pipeline by providing its name.
All edges that connect to the component will also be deleted.
:param name:
The name of the component to remove.
:returns:
The removed Component instance.
:raises ValueError:
If there is no component with that name already in the Pipeline.
"""
# Check that a component with that name is in the Pipeline
try:
instance = self.get_component(name)
except ValueError as exc:
raise ValueError(
f"There is no component named '{name}' in the pipeline. The valid component names are: ",
", ".join(n for n in self.graph.nodes),
) from exc
# Delete component from the graph, deleting all its connections
self.graph.remove_node(name)
# Reset the Component sockets' senders and receivers
input_sockets = instance.__haystack_input__._sockets_dict # type: ignore[attr-defined]
for socket in input_sockets.values():
socket.senders = []
output_sockets = instance.__haystack_output__._sockets_dict # type: ignore[attr-defined]
for socket in output_sockets.values():
socket.receivers = []
# Reset the Component's pipeline reference
setattr(instance, "__haystack_added_to_pipeline__", None)
return instance
def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR0915 PLR0912
"""
Connects two components together.
All components to connect must exist in the pipeline.
If connecting to a component that has several output connections, specify the inputs and output names as
'component_name.connections_name'.
:param sender:
The component that delivers the value. This can be either just a component name or can be
in the format `component_name.connection_name` if the component has multiple outputs.
:param receiver:
The component that receives the value. This can be either just a component name or can be
in the format `component_name.connection_name` if the component has multiple inputs.
:returns:
The Pipeline instance.
:raises PipelineConnectError:
If the two components cannot be connected (for example if one of the components is
not present in the pipeline, or the connections don't match by type, and so on).
"""
# Edges may be named explicitly by passing 'node_name.edge_name' to connect().
sender_component_name, sender_socket_name = parse_connect_string(sender)
receiver_component_name, receiver_socket_name = parse_connect_string(receiver)
if sender_component_name == receiver_component_name:
raise PipelineConnectError("Connecting a Component to itself is not supported.")
# Get the nodes data.
try:
sender_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
except KeyError as exc:
raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc
try:
receiver_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
except KeyError as exc:
raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc
# If the name of either socket is given, get the socket
sender_socket: Optional[OutputSocket] = None
if sender_socket_name:
sender_socket = sender_sockets.get(sender_socket_name)
if not sender_socket:
raise PipelineConnectError(
f"'{sender} does not exist. "
f"Output connections of {sender_component_name} are: "
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in sender_sockets.items()])
)
receiver_socket: Optional[InputSocket] = None
if receiver_socket_name:
receiver_socket = receiver_sockets.get(receiver_socket_name)
if not receiver_socket:
raise PipelineConnectError(
f"'{receiver} does not exist. "
f"Input connections of {receiver_component_name} are: "
+ ", ".join(
[f"{name} (type {_type_name(socket.type)})" for name, socket in receiver_sockets.items()]
)
)
# Look for a matching connection among the possible ones.
# Note that if there is more than one possible connection but two sockets match by name, they're paired.
sender_socket_candidates: List[OutputSocket] = (
[sender_socket] if sender_socket else list(sender_sockets.values())
)
receiver_socket_candidates: List[InputSocket] = (
[receiver_socket] if receiver_socket else list(receiver_sockets.values())
)
# Find all possible connections between these two components
possible_connections = []
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates):
if _types_are_compatible(sender_sock.type, receiver_sock.type, self._connection_type_validation):
possible_connections.append((sender_sock, receiver_sock))
# We need this status for error messages, since we might need it in multiple places we calculate it here
status = _connections_status(
sender_node=sender_component_name,
sender_sockets=sender_socket_candidates,
receiver_node=receiver_component_name,
receiver_sockets=receiver_socket_candidates,
)
if not possible_connections:
# There's no possible connection between these two components
if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1:
msg = (
f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with "
f"'{receiver_component_name}.{receiver_socket_candidates[0].name}': "
f"their declared input and output types do not match.\n{status}"
)
else:
msg = (
f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': "
f"no matching connections available.\n{status}"
)
raise PipelineConnectError(msg)
if len(possible_connections) == 1:
# There's only one possible connection, use it
sender_socket = possible_connections[0][0]
receiver_socket = possible_connections[0][1]
if len(possible_connections) > 1:
# There are multiple possible connection, let's try to match them by name
name_matches = [
(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
]
if len(name_matches) != 1:
# There's are either no matches or more than one, we can't pick one reliably
msg = (
f"Cannot connect '{sender_component_name}' with "
f"'{receiver_component_name}': more than one connection is possible "
"between these components. Please specify the connection name, like: "
f"pipeline.connect('{sender_component_name}.{possible_connections[0][0].name}', "
f"'{receiver_component_name}.{possible_connections[0][1].name}').\n{status}"
)
raise PipelineConnectError(msg)
# Get the only possible match
sender_socket = name_matches[0][0]
receiver_socket = name_matches[0][1]
# Connection must be valid on both sender/receiver sides
if not sender_socket or not receiver_socket or not sender_component_name or not receiver_component_name:
if sender_component_name and sender_socket:
sender_repr = f"{sender_component_name}.{sender_socket.name} ({_type_name(sender_socket.type)})"
else:
sender_repr = "input needed"
if receiver_component_name and receiver_socket:
receiver_repr = f"({_type_name(receiver_socket.type)}) {receiver_component_name}.{receiver_socket.name}"
else:
receiver_repr = "output"
msg = f"Connection must have both sender and receiver: {sender_repr} -> {receiver_repr}"
raise PipelineConnectError(msg)
logger.debug(
"Connecting '{sender_component}.{sender_socket_name}' to '{receiver_component}.{receiver_socket_name}'",
sender_component=sender_component_name,
sender_socket_name=sender_socket.name,
receiver_component=receiver_component_name,
receiver_socket_name=receiver_socket.name,
)
if receiver_component_name in sender_socket.receivers and sender_component_name in receiver_socket.senders:
# This is already connected, nothing to do
return self
if receiver_socket.senders and not receiver_socket.is_variadic:
# Only variadic input sockets can receive from multiple senders
msg = (
f"Cannot connect '{sender_component_name}.{sender_socket.name}' with "
f"'{receiver_component_name}.{receiver_socket.name}': "
f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
)
raise PipelineConnectError(msg)
# Update the sockets with the new connection
sender_socket.receivers.append(receiver_component_name)
receiver_socket.senders.append(sender_component_name)
# Create the new connection
self.graph.add_edge(
sender_component_name,
receiver_component_name,
key=f"{sender_socket.name}/{receiver_socket.name}",
conn_type=_type_name(sender_socket.type),
from_socket=sender_socket,
to_socket=receiver_socket,
mandatory=receiver_socket.is_mandatory,
)
return self
def get_component(self, name: str) -> Component:
"""
Get the component with the specified name from the pipeline.
:param name:
The name of the component.
:returns:
The instance of that component.
:raises ValueError:
If a component with that name is not present in the pipeline.
"""
try:
return self.graph.nodes[name]["instance"]
except KeyError as exc:
raise ValueError(f"Component named {name} not found in the pipeline.") from exc
def get_component_name(self, instance: Component) -> str:
"""
Returns the name of the Component instance if it has been added to this Pipeline or an empty string otherwise.
:param instance:
The Component instance to look for.
:returns:
The name of the Component instance.
"""
for name, inst in self.graph.nodes(data="instance"): # type: ignore # type wrongly defined in networkx
if inst == instance:
return name
return ""
def inputs(self, include_components_with_connected_inputs: bool = False) -> Dict[str, Dict[str, Any]]:
"""
Returns a dictionary containing the inputs of a pipeline.
Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
the input sockets of that component, including their types and whether they are optional.
:param include_components_with_connected_inputs:
If `False`, only components that have disconnected input edges are
included in the output.
:returns:
A dictionary where each key is a pipeline component name and each value is a dictionary of
inputs sockets of that component.
"""
inputs: Dict[str, Dict[str, Any]] = {}
for component_name, data in find_pipeline_inputs(self.graph, include_components_with_connected_inputs).items():
sockets_description = {}
for socket in data:
sockets_description[socket.name] = {"type": socket.type, "is_mandatory": socket.is_mandatory}
if not socket.is_mandatory:
sockets_description[socket.name]["default_value"] = socket.default_value
if sockets_description:
inputs[component_name] = sockets_description
return inputs
def outputs(self, include_components_with_connected_outputs: bool = False) -> Dict[str, Dict[str, Any]]:
"""
Returns a dictionary containing the outputs of a pipeline.
Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
the output sockets of that component.
:param include_components_with_connected_outputs:
If `False`, only components that have disconnected output edges are
included in the output.
:returns:
A dictionary where each key is a pipeline component name and each value is a dictionary of
output sockets of that component.
"""
outputs = {
comp: {socket.name: {"type": socket.type} for socket in data}
for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items()
if data
}
return outputs
def show(
self,
*,
server_url: str = "https://mermaid.ink",
params: Optional[dict] = None,
timeout: int = 30,
super_component_expansion: bool = False,
) -> None:
"""
Display an image representing this `Pipeline` in a Jupyter notebook.
This function generates a diagram of the `Pipeline` using a Mermaid server and displays it directly in
the notebook.
:param server_url:
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
info on how to set up your own Mermaid server.
:param params:
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
Supported keys:
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
- width: Width of the output image (integer).
- height: Height of the output image (integer).
- scale: Scaling factor (13). Only applicable if 'width' or 'height' is specified.
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
:param timeout:
Timeout in seconds for the request to the Mermaid server.
:param super_component_expansion:
If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
super-components as if they were components part of the pipeline instead of a "black-box".
Otherwise, only the super-component itself will be displayed.
:raises PipelineDrawingError:
If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
"""
if is_in_jupyter():
from IPython.display import Image, display # type: ignore
if super_component_expansion:
graph, super_component_mapping = self._merge_super_component_pipelines()
else:
graph = self.graph
super_component_mapping = None
image_data = _to_mermaid_image(
graph,
server_url=server_url,
params=params,
timeout=timeout,
super_component_mapping=super_component_mapping,
)
display(Image(image_data))
else:
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
raise PipelineDrawingError(msg)
def draw(
self,
*,
path: Path,
server_url: str = "https://mermaid.ink",
params: Optional[dict] = None,
timeout: int = 30,
super_component_expansion: bool = False,
) -> None:
"""
Save an image representing this `Pipeline` to the specified file path.
This function generates a diagram of the `Pipeline` using the Mermaid server and saves it to the provided path.
:param path:
The file path where the generated image will be saved.
:param server_url:
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
info on how to set up your own Mermaid server.
:param params:
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
Supported keys:
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
- width: Width of the output image (integer).
- height: Height of the output image (integer).
- scale: Scaling factor (13). Only applicable if 'width' or 'height' is specified.
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
:param timeout:
Timeout in seconds for the request to the Mermaid server.
:param super_component_expansion:
If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
super-components as if they were components part of the pipeline instead of a "black-box".
Otherwise, only the super-component itself will be displayed.
:raises PipelineDrawingError:
If there is an issue with rendering or saving the image.
"""
# Before drawing we edit a bit the graph, to avoid modifying the original that is
# used for running the pipeline we copy it.
if super_component_expansion:
graph, super_component_mapping = self._merge_super_component_pipelines()
else:
graph = self.graph
super_component_mapping = None
image_data = _to_mermaid_image(
graph,
server_url=server_url,
params=params,
timeout=timeout,
super_component_mapping=super_component_mapping,
)
Path(path).write_bytes(image_data)
def walk(self) -> Iterator[Tuple[str, Component]]:
"""
Visits each component in the pipeline exactly once and yields its name and instance.
No guarantees are provided on the visiting order.
:returns:
An iterator of tuples of component name and component instance.
"""
for component_name, instance in self.graph.nodes(data="instance"): # type: ignore # type is wrong in networkx
yield component_name, instance
def warm_up(self) -> None:
"""
Make sure all nodes are warm.
It's the node's responsibility to make sure this method can be called at every `Pipeline.run()`
without re-initializing everything.
"""
for node in self.graph.nodes:
if hasattr(self.graph.nodes[node]["instance"], "warm_up"):
logger.info("Warming up component {node}...", node=node)
self.graph.nodes[node]["instance"].warm_up()
@staticmethod
def _create_component_span(
component_name: str, instance: Component, inputs: Dict[str, Any], parent_span: Optional[tracing.Span] = None
) -> ContextManager[tracing.Span]:
return tracing.tracer.trace(
"haystack.component.run",
tags={
"haystack.component.name": component_name,
"haystack.component.type": instance.__class__.__name__,
"haystack.component.input_types": {k: type(v).__name__ for k, v in inputs.items()},
"haystack.component.input_spec": {
key: {
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
"senders": value.senders,
}
for key, value in instance.__haystack_input__._sockets_dict.items() # type: ignore
},
"haystack.component.output_spec": {
key: {
"type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
"receivers": value.receivers,
}
for key, value in instance.__haystack_output__._sockets_dict.items() # type: ignore
},
},
parent_span=parent_span,
)
def validate_input(self, data: Dict[str, Any]) -> None:
"""
Validates pipeline input data.
Validates that data:
* Each Component name actually exists in the Pipeline
* Each Component is not missing any input
* Each Component has only one input per input socket, if not variadic
* Each Component doesn't receive inputs that are already sent by another Component
:param data:
A dictionary of inputs for the pipeline's components. Each key is a component name.
:raises ValueError:
If inputs are invalid according to the above.
"""
for component_name, component_inputs in data.items():
if component_name not in self.graph.nodes:
raise ValueError(f"Component named {component_name} not found in the pipeline.")
instance = self.graph.nodes[component_name]["instance"]
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
for input_name in component_inputs.keys():
if input_name not in instance.__haystack_input__._sockets_dict:
raise ValueError(f"Input {input_name} not found in component {component_name}.")
for component_name in self.graph.nodes:
instance = self.graph.nodes[component_name]["instance"]
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
component_inputs = data.get(component_name, {})
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
if socket.senders and socket_name in component_inputs and not socket.is_variadic:
raise ValueError(
f"Input {socket_name} for component {component_name} is already sent by {socket.senders}."
)
def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
"""
Prepares input data for pipeline components.
Organizes input data for pipeline components and identifies any inputs that are not matched to any
component's input slots. Deep-copies data items to avoid sharing mutables across multiple components.
This method processes a flat dictionary of input data, where each key-value pair represents an input name
and its corresponding value. It distributes these inputs to the appropriate pipeline components based on
their input requirements. Inputs that don't match any component's input slots are classified as unresolved.
:param data:
A dictionary potentially having input names as keys and input values as values.
:returns:
A dictionary mapping component names to their respective matched inputs.
"""
# check whether the data is a nested dictionary of component inputs where each key is a component name
# and each value is a dictionary of input parameters for that component
is_nested_component_input = all(isinstance(value, dict) for value in data.values())
if not is_nested_component_input:
# flat input, a dict where keys are input names and values are the corresponding values
# we need to convert it to a nested dictionary of component inputs and then run the pipeline
# just like in the previous case
pipeline_input_data: Dict[str, Dict[str, Any]] = defaultdict(dict)
unresolved_kwargs = {}
# Retrieve the input slots for each component in the pipeline
available_inputs: Dict[str, Dict[str, Any]] = self.inputs()
# Go through all provided to distribute them to the appropriate component inputs
for input_name, input_value in data.items():
resolved_at_least_once = False
# Check each component to see if it has a slot for the current kwarg
for component_name, component_inputs in available_inputs.items():
if input_name in component_inputs:
# If a match is found, add the kwarg to the component's input data
pipeline_input_data[component_name][input_name] = input_value
resolved_at_least_once = True
if not resolved_at_least_once:
unresolved_kwargs[input_name] = input_value
if unresolved_kwargs:
logger.warning(
"Inputs {input_keys} were not matched to any component inputs, please check your run parameters.",
input_keys=list(unresolved_kwargs.keys()),
)
data = dict(pipeline_input_data)
# deepcopying the inputs prevents the Pipeline run logic from being altered unexpectedly
# when the same input reference is passed to multiple components.
for component_name, component_inputs in data.items():
data[component_name] = {k: _deepcopy_with_exceptions(v) for k, v in component_inputs.items()}
return data
@classmethod
def from_template(
cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None
) -> "PipelineBase":
"""
Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options.
:param predefined_pipeline:
The predefined pipeline to use.
:param template_params:
An optional dictionary of parameters to use when rendering the pipeline template.
:returns:
An instance of `Pipeline`.
"""
tpl = PipelineTemplate.from_predefined(predefined_pipeline)
# If tpl.render() fails, we let bubble up the original error
rendered = tpl.render(template_params)
# If there was a problem with the rendered version of the
# template, we add it to the error stack for debugging
try:
return cls.loads(rendered)
except Exception as e:
msg = f"Error unmarshalling pipeline: {e}\n"
msg += f"Source:\n{rendered}"
raise PipelineUnmarshalError(msg)
def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]:
"""
Utility function to find all Components that receive input from `component_name`.
:param component_name:
Name of the sender Component
:returns:
List of tuples containing name of the receiver Component and sender OutputSocket
and receiver InputSocket instances
"""
res = []
for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True):
sender_socket: OutputSocket = connection["from_socket"]
receiver_socket: InputSocket = connection["to_socket"]
res.append((receiver_name, sender_socket, receiver_socket))
return res
@staticmethod
def _convert_to_internal_format(pipeline_inputs: Dict[str, Any]) -> Dict[str, Dict[str, List]]:
"""
Converts the inputs to the pipeline to the format that is needed for the internal `Pipeline.run` logic.
Example Input:
{'prompt_builder': {'question': 'Who lives in Paris?'}, 'retriever': {'query': 'Who lives in Paris?'}}
Example Output:
{'prompt_builder': {'question': [{'sender': None, 'value': 'Who lives in Paris?'}]},
'retriever': {'query': [{'sender': None, 'value': 'Who lives in Paris?'}]}}
:param pipeline_inputs: Inputs to the pipeline.
:returns: Converted inputs that can be used by the internal `Pipeline.run` logic.
"""
inputs: Dict[str, Dict[str, List[Dict[str, Any]]]] = {}
for component_name, socket_dict in pipeline_inputs.items():
inputs[component_name] = {}
for socket_name, value in socket_dict.items():
inputs[component_name][socket_name] = [{"sender": None, "value": value}]
return inputs
@staticmethod
def _consume_component_inputs(
component_name: str, component: Dict, inputs: Dict, is_resume: bool = False
) -> Dict[str, Any]:
"""
Extracts the inputs needed to run for the component and removes them from the global inputs state.
:param component_name: The name of a component.
:param component: Component with component metadata.
:param inputs: Global inputs state.
:returns: The inputs for the component.
"""
component_inputs = inputs.get(component_name, {})
consumed_inputs = {}
greedy_inputs_to_remove = set()
for socket_name, socket in component["input_sockets"].items():
socket_inputs = component_inputs.get(socket_name, [])
socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED]
# if we are resuming a component, the inputs are already consumed, so we just return the first input
if is_resume:
consumed_inputs[socket_name] = socket_inputs[0]
continue
if socket_inputs:
if not socket.is_variadic:
# We only care about the first input provided to the socket.
consumed_inputs[socket_name] = socket_inputs[0]
elif socket.is_greedy:
# We need to keep track of greedy inputs because we always remove them, even if they come from
# outside the pipeline. Otherwise, a greedy input from the user would trigger a pipeline to run
# indefinitely.
greedy_inputs_to_remove.add(socket_name)
consumed_inputs[socket_name] = [socket_inputs[0]]
elif is_socket_lazy_variadic(socket):
# We use all inputs provided to the socket on a lazy variadic socket.
consumed_inputs[socket_name] = socket_inputs
# We prune all inputs except for those that were provided from outside the pipeline (e.g. user inputs).
pruned_inputs = {
socket_name: [
sock for sock in socket if sock["sender"] is None and not socket_name in greedy_inputs_to_remove
]
for socket_name, socket in component_inputs.items()
}
pruned_inputs = {socket_name: socket for socket_name, socket in pruned_inputs.items() if len(socket) > 0}
inputs[component_name] = pruned_inputs
return consumed_inputs
def _fill_queue(
self, component_names: List[str], inputs: Dict[str, Any], component_visits: Dict[str, int]
) -> FIFOPriorityQueue:
"""
Calculates the execution priority for each component and inserts it into the priority queue.
:param component_names: Names of the components to put into the queue.
:param inputs: Inputs to the components.
:param component_visits: Current state of component visits.
:returns: A prioritized queue of component names.
"""
priority_queue = FIFOPriorityQueue()
for component_name in component_names:
component = self._get_component_with_graph_metadata_and_visits(
component_name, component_visits[component_name]
)
priority = self._calculate_priority(component, inputs.get(component_name, {}))
priority_queue.push(component_name, priority)
return priority_queue
@staticmethod
def _calculate_priority(component: Dict, inputs: Dict) -> ComponentPriority:
"""
Calculates the execution priority for a component depending on the component's inputs.
:param component: Component metadata and component instance.
:param inputs: Inputs to the component.
:returns: Priority value for the component.
"""
if not can_component_run(component, inputs):
return ComponentPriority.BLOCKED
elif is_any_greedy_socket_ready(component, inputs) and are_all_sockets_ready(component, inputs):
return ComponentPriority.HIGHEST
elif all_predecessors_executed(component, inputs):
return ComponentPriority.READY
elif are_all_lazy_variadic_sockets_resolved(component, inputs):
return ComponentPriority.DEFER
else:
return ComponentPriority.DEFER_LAST
def _get_component_with_graph_metadata_and_visits(self, component_name: str, visits: int) -> Dict[str, Any]:
"""
Returns the component instance alongside input/output-socket metadata from the graph and adds current visits.
We can't store visits in the pipeline graph because this would prevent reentrance / thread-safe execution.
:param component_name: The name of the component.
:param visits: Number of visits for the component.
:returns: Dict including component instance, input/output-sockets and visits.
"""
comp_dict = self.graph.nodes[component_name]
comp_dict = {**comp_dict, "visits": visits}
return comp_dict
def _get_next_runnable_component(
self, priority_queue: FIFOPriorityQueue, component_visits: Dict[str, int]
) -> Union[Tuple[ComponentPriority, str, Dict[str, Any]], None]:
"""
Returns the next runnable component alongside its metadata from the priority queue.
:param priority_queue: Priority queue of component names.
:param component_visits: Current state of component visits.
:returns: The next runnable component, the component name, and its priority
or None if no component in the queue can run.
:raises: PipelineMaxComponentRuns if the next runnable component has exceeded the maximum number of runs.
"""
priority_and_component_name: Union[Tuple[ComponentPriority, str], None] = (
None if (item := priority_queue.get()) is None else (ComponentPriority(item[0]), str(item[1]))
)
if priority_and_component_name is None:
return None
priority, component_name = priority_and_component_name
comp = self._get_component_with_graph_metadata_and_visits(component_name, component_visits[component_name])
if comp["visits"] > self._max_runs_per_component:
msg = f"Maximum run count {self._max_runs_per_component} reached for component '{component_name}'"
raise PipelineMaxComponentRuns(msg)
return priority, component_name, comp
@staticmethod
def _add_missing_input_defaults(
component_inputs: Dict[str, Any], component_input_sockets: Dict[str, InputSocket]
) -> Dict[str, Any]:
"""
Updates the inputs with the default values for the inputs that are missing
:param component_inputs: Inputs for the component.
:param component_input_sockets: Input sockets of the component.
"""
for name, socket in component_input_sockets.items():
if not socket.is_mandatory and name not in component_inputs:
if socket.is_variadic:
component_inputs[name] = [socket.default_value]
else:
component_inputs[name] = socket.default_value
return component_inputs
def _tiebreak_waiting_components(
self,
component_name: str,
priority: ComponentPriority,
priority_queue: FIFOPriorityQueue,
topological_sort: Union[Dict[str, int], None],
) -> Tuple[str, Union[Dict[str, int], None]]:
"""
Decides which component to run when multiple components are waiting for inputs with the same priority.
:param component_name: The name of the component.
:param priority: Priority of the component.
:param priority_queue: Priority queue of component names.
:param topological_sort: Cached topological sort of all components in the pipeline.
"""
components_with_same_priority = [component_name]
while len(priority_queue) > 0:
next_priority, next_component_name = priority_queue.peek()
if next_priority == priority:
priority_queue.pop() # actually remove the component
components_with_same_priority.append(next_component_name)
else:
break
if len(components_with_same_priority) > 1:
if topological_sort is None:
if networkx.is_directed_acyclic_graph(self.graph):
topological_sort = networkx.lexicographical_topological_sort(self.graph)
topological_sort = {node: idx for idx, node in enumerate(topological_sort)}
else:
condensed = networkx.condensation(self.graph)
condensed_sorted = {node: idx for idx, node in enumerate(networkx.topological_sort(condensed))}
topological_sort = {
component_name: condensed_sorted[node]
for component_name, node in condensed.graph["mapping"].items()
}
components_with_same_priority = sorted(
components_with_same_priority, key=lambda comp_name: (topological_sort[comp_name], comp_name.lower())
)
component_name = components_with_same_priority[0]
return component_name, topological_sort
@staticmethod
def _write_component_outputs(
component_name: str,
component_outputs: Mapping[str, Any],
inputs: Dict[str, Any],
receivers: List[Tuple],
include_outputs_from: Set[str],
) -> Mapping[str, Any]:
"""
Distributes the outputs of a component to the input sockets that it is connected to.
:param component_name: The name of the component.
:param component_outputs: The outputs of the component.
:param inputs: The current global input state.
:param receivers: List of components that receive inputs from the component.
:param include_outputs_from: List of component names that should always return an output from the pipeline.
"""
for receiver_name, sender_socket, receiver_socket in receivers:
# We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
# that the sender did not produce an output for this socket.
# This allows us to track if a predecessor already ran but did not produce an output.
value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
if receiver_name not in inputs:
inputs[receiver_name] = {}
if is_socket_lazy_variadic(receiver_socket):
# If the receiver socket is lazy variadic, we append the new input.
# Lazy variadic sockets can collect multiple inputs.
_write_to_lazy_variadic_socket(
inputs=inputs,
receiver_name=receiver_name,
receiver_socket_name=receiver_socket.name,
component_name=component_name,
value=value,
)
else:
# If the receiver socket is not lazy variadic, it is greedy variadic or non-variadic.
# We overwrite with the new input if it's not _NO_OUTPUT_PRODUCED or if the current value is None.
_write_to_standard_socket(
inputs=inputs,
receiver_name=receiver_name,
receiver_socket_name=receiver_socket.name,
component_name=component_name,
value=value,
)
# If we want to include all outputs from this actor in the final outputs, we don't need to prune any consumed
# outputs
if component_name in include_outputs_from:
return component_outputs
# We prune outputs that were consumed by any receiving sockets.
# All remaining outputs will be added to the final outputs of the pipeline.
consumed_outputs = {sender_socket.name for _, sender_socket, __ in receivers}
pruned_outputs = {key: value for key, value in component_outputs.items() if key not in consumed_outputs}
return pruned_outputs
@staticmethod
def _is_queue_stale(priority_queue: FIFOPriorityQueue) -> bool:
"""
Checks if the priority queue needs to be recomputed because the priorities might have changed.
:param priority_queue: Priority queue of component names.
"""
return len(priority_queue) == 0 or priority_queue.peek()[0] > ComponentPriority.READY
@staticmethod
def validate_pipeline(priority_queue: FIFOPriorityQueue) -> None:
"""
Validate the pipeline to check if it is blocked or has no valid entry point.
:param priority_queue: Priority queue of component names.
:raises PipelineRuntimeError:
If the pipeline is blocked or has no valid entry point.
"""
if len(priority_queue) == 0:
return
candidate = priority_queue.peek()
if candidate is not None and candidate[0] == ComponentPriority.BLOCKED:
raise PipelineComponentsBlockedError()
def _find_super_components(self) -> list[tuple[str, Component]]:
"""
Find all SuperComponents in the pipeline.
:returns:
List of tuples containing (component_name, component_instance) representing a SuperComponent.
"""
super_components = []
for comp_name, comp in self.walk():
# a SuperComponent has a "pipeline" attribute which itself a Pipeline instance
# we don't test against SuperComponent because doing so always lead to circular imports
if hasattr(comp, "pipeline") and isinstance(comp.pipeline, self.__class__):
super_components.append((comp_name, comp))
return super_components
def _merge_super_component_pipelines(self) -> Tuple["networkx.MultiDiGraph", Dict[str, str]]:
"""
Merge the internal pipelines of SuperComponents into the main pipeline graph structure.
This creates a new networkx.MultiDiGraph containing all the components from both the main pipeline
and all the internal SuperComponents' pipelines. The SuperComponents are removed and their internal
components are connected to corresponding input and output sockets of the main pipeline.
:returns:
A tuple containing:
- A networkx.MultiDiGraph with the expanded structure of the main pipeline and all it's SuperComponents
- A dictionary mapping component names to boolean indicating that this component was part of a
SuperComponent
- A dictionary mapping component names to their SuperComponent name
"""
merged_graph = self.graph.copy()
super_component_mapping: Dict[str, str] = {}
for super_name, super_component in self._find_super_components():
internal_pipeline = super_component.pipeline # type: ignore
internal_graph = internal_pipeline.graph.copy()
# Mark all components in the internal pipeline as being part of a SuperComponent
for node in internal_graph.nodes():
super_component_mapping[node] = super_name
# edges connected to the super component
incoming_edges = list(merged_graph.in_edges(super_name, data=True))
outgoing_edges = list(merged_graph.out_edges(super_name, data=True))
# merge the SuperComponent graph into the main graph and remove the super component node
# since its components are now part of the main graph
merged_graph = networkx.compose(merged_graph, internal_graph)
merged_graph.remove_node(super_name)
# get the entry and exit points of the SuperComponent internal pipeline
entry_points = [n for n in internal_graph.nodes() if internal_graph.in_degree(n) == 0]
exit_points = [n for n in internal_graph.nodes() if internal_graph.out_degree(n) == 0]
# connect the incoming edges to entry points
for sender, _, edge_data in incoming_edges:
sender_socket = edge_data["from_socket"]
for entry_point in entry_points:
# find a matching input socket in the entry point
entry_point_sockets = internal_graph.nodes[entry_point]["input_sockets"]
for socket_name, socket in entry_point_sockets.items():
if _types_are_compatible(sender_socket.type, socket.type, self._connection_type_validation):
merged_graph.add_edge(
sender,
entry_point,
key=f"{sender_socket.name}/{socket_name}",
conn_type=_type_name(sender_socket.type),
from_socket=sender_socket,
to_socket=socket,
mandatory=socket.is_mandatory,
)
# connect outgoing edges from exit points
for _, receiver, edge_data in outgoing_edges:
receiver_socket = edge_data["to_socket"]
for exit_point in exit_points:
# find a matching output socket in the exit point
exit_point_sockets = internal_graph.nodes[exit_point]["output_sockets"]
for socket_name, socket in exit_point_sockets.items():
if _types_are_compatible(socket.type, receiver_socket.type, self._connection_type_validation):
merged_graph.add_edge(
exit_point,
receiver,
key=f"{socket_name}/{receiver_socket.name}",
conn_type=_type_name(socket.type),
from_socket=socket,
to_socket=receiver_socket,
mandatory=receiver_socket.is_mandatory,
)
return merged_graph, super_component_mapping
def _is_pipeline_possibly_blocked(self, current_pipeline_outputs: Dict[str, Any]) -> bool:
"""
Heuristically determines whether the pipeline is possibly blocked based on its current outputs.
This method checks if the pipeline has produced any of the expected outputs.
- If no outputs are expected (i.e., `self.outputs()` returns an empty list), the method assumes the pipeline
is not blocked.
- If at least one expected output is present in `current_pipeline_outputs`, the pipeline is also assumed to not
be blocked.
- If none of the expected outputs are present, the pipeline is considered to be possibly blocked.
Note: This check is not definitive—it is intended as a best-effort guess to detect a stalled or misconfigured
pipeline when there are no more runnable components.
:param current_pipeline_outputs: A dictionary of outputs currently produced by the pipeline.
:returns:
bool: True if the pipeline is possibly blocked (i.e., expected outputs are missing), False otherwise.
"""
expected_outputs = self.outputs()
return bool(expected_outputs) and not any(k in current_pipeline_outputs for k in expected_outputs)
def _connections_status(
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
) -> str:
"""
Lists the status of the sockets, for error messages.
"""
sender_sockets_entries = []
for sender_socket in sender_sockets:
sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}")
sender_sockets_list = "\n".join(sender_sockets_entries)
receiver_sockets_entries = []
for receiver_socket in receiver_sockets:
if receiver_socket.senders:
sender_status = f"sent by {','.join(receiver_socket.senders)}"
else:
sender_status = "available"
receiver_sockets_entries.append(
f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
)
receiver_sockets_list = "\n".join(receiver_sockets_entries)
return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"
# Utility functions for writing to sockets
def _write_to_lazy_variadic_socket(
inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any
) -> None:
"""
Write to a lazy variadic socket.
Mutates inputs in place.
"""
if not inputs[receiver_name].get(receiver_socket_name):
inputs[receiver_name][receiver_socket_name] = []
inputs[receiver_name][receiver_socket_name].append({"sender": component_name, "value": value})
def _write_to_standard_socket(
inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any
) -> None:
"""
Write to a greedy variadic or non-variadic socket.
Mutates inputs in place.
"""
current_value = inputs[receiver_name].get(receiver_socket_name)
# Only overwrite if there's no existing value, or we have a new value to provide
if current_value is None or value is not _NO_OUTPUT_PRODUCED:
inputs[receiver_name][receiver_socket_name] = [{"sender": component_name, "value": value}]