2024-05-10 11:35:15 +02:00
|
|
|
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
|
|
|
|
#
|
|
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
import itertools
|
|
|
|
|
|
from collections import defaultdict
|
2024-10-29 15:43:16 +01:00
|
|
|
|
from copy import deepcopy
|
2024-05-10 11:35:15 +02:00
|
|
|
|
from datetime import datetime
|
2025-02-06 15:19:47 +01:00
|
|
|
|
from enum import IntEnum
|
2024-05-10 11:35:15 +02:00
|
|
|
|
from pathlib import Path
|
2025-02-06 15:19:47 +01:00
|
|
|
|
from typing import Any, Dict, Iterator, List, Optional, TextIO, Tuple, Type, TypeVar, Union
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
import networkx # type:ignore
|
|
|
|
|
|
|
|
|
|
|
|
from haystack import logging
|
|
|
|
|
|
from haystack.core.component import Component, InputSocket, OutputSocket, component
|
|
|
|
|
|
from haystack.core.errors import (
|
2024-07-12 16:47:00 +02:00
|
|
|
|
DeserializationError,
|
2024-05-10 11:35:15 +02:00
|
|
|
|
PipelineConnectError,
|
|
|
|
|
|
PipelineDrawingError,
|
|
|
|
|
|
PipelineError,
|
2025-02-06 15:19:47 +01:00
|
|
|
|
PipelineMaxComponentRuns,
|
2024-10-29 15:43:16 +01:00
|
|
|
|
PipelineRuntimeError,
|
2024-05-10 11:35:15 +02:00
|
|
|
|
PipelineUnmarshalError,
|
|
|
|
|
|
PipelineValidationError,
|
|
|
|
|
|
)
|
2025-02-06 15:19:47 +01:00
|
|
|
|
from haystack.core.pipeline.component_checks import (
|
|
|
|
|
|
_NO_OUTPUT_PRODUCED,
|
|
|
|
|
|
all_predecessors_executed,
|
|
|
|
|
|
are_all_lazy_variadic_sockets_resolved,
|
|
|
|
|
|
are_all_sockets_ready,
|
|
|
|
|
|
can_component_run,
|
|
|
|
|
|
is_any_greedy_socket_ready,
|
|
|
|
|
|
is_socket_lazy_variadic,
|
|
|
|
|
|
)
|
|
|
|
|
|
from haystack.core.pipeline.utils import FIFOPriorityQueue, parse_connect_string
|
2024-05-10 11:35:15 +02:00
|
|
|
|
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
|
|
|
|
|
|
from haystack.core.type_utils import _type_name, _types_are_compatible
|
|
|
|
|
|
from haystack.marshal import Marshaller, YamlMarshaller
|
2024-12-18 21:34:57 +01:00
|
|
|
|
from haystack.utils import is_in_jupyter, type_serialization
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
from .descriptions import find_pipeline_inputs, find_pipeline_outputs
|
|
|
|
|
|
from .draw import _to_mermaid_image
|
|
|
|
|
|
from .template import PipelineTemplate, PredefinedPipeline
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_MARSHALLER = YamlMarshaller()
|
|
|
|
|
|
|
2025-02-03 15:55:29 +01:00
|
|
|
|
# We use a generic type to annotate the return value of class methods,
|
2024-05-10 11:35:15 +02:00
|
|
|
|
# so that static analyzers won't be confused when derived classes
|
|
|
|
|
|
# use those methods.
|
|
|
|
|
|
T = TypeVar("T", bound="PipelineBase")
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
class ComponentPriority(IntEnum):
|
|
|
|
|
|
HIGHEST = 1
|
|
|
|
|
|
READY = 2
|
|
|
|
|
|
DEFER = 3
|
|
|
|
|
|
DEFER_LAST = 4
|
|
|
|
|
|
BLOCKED = 5
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-05-10 11:35:15 +02:00
|
|
|
|
class PipelineBase:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Components orchestration engine.
|
|
|
|
|
|
|
|
|
|
|
|
Builds a graph of components and orchestrates their execution according to the execution graph.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2024-09-30 17:11:49 +02:00
|
|
|
|
def __init__(self, metadata: Optional[Dict[str, Any]] = None, max_runs_per_component: int = 100):
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"""
|
|
|
|
|
|
Creates the Pipeline.
|
|
|
|
|
|
|
|
|
|
|
|
:param metadata:
|
2024-09-12 11:00:12 +02:00
|
|
|
|
Arbitrary dictionary to store metadata about this `Pipeline`. Make sure all the values contained in
|
|
|
|
|
|
this dictionary can be serialized and deserialized if you wish to save this `Pipeline` to file.
|
|
|
|
|
|
:param max_runs_per_component:
|
|
|
|
|
|
How many times the `Pipeline` can run the same Component.
|
|
|
|
|
|
If this limit is reached a `PipelineMaxComponentRuns` exception is raised.
|
|
|
|
|
|
If not set defaults to 100 runs per Component.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"""
|
|
|
|
|
|
self._telemetry_runs = 0
|
|
|
|
|
|
self._last_telemetry_sent: Optional[datetime] = None
|
|
|
|
|
|
self.metadata = metadata or {}
|
|
|
|
|
|
self.graph = networkx.MultiDiGraph()
|
2024-09-30 19:28:05 +05:30
|
|
|
|
self._max_runs_per_component = max_runs_per_component
|
2024-09-12 11:00:12 +02:00
|
|
|
|
|
2024-05-10 11:35:15 +02:00
|
|
|
|
def __eq__(self, other) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Pipeline equality is defined by their type and the equality of their serialized form.
|
|
|
|
|
|
|
|
|
|
|
|
Pipelines of the same type share every metadata, node and edge, but they're not required to use
|
|
|
|
|
|
the same node instances: this allows pipeline saved and then loaded back to be equal to themselves.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not isinstance(self, type(other)):
|
|
|
|
|
|
return False
|
|
|
|
|
|
return self.to_dict() == other.to_dict()
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Returns a text representation of the Pipeline.
|
|
|
|
|
|
"""
|
|
|
|
|
|
res = f"{object.__repr__(self)}\n"
|
|
|
|
|
|
if self.metadata:
|
|
|
|
|
|
res += "🧱 Metadata\n"
|
|
|
|
|
|
for k, v in self.metadata.items():
|
|
|
|
|
|
res += f" - {k}: {v}\n"
|
|
|
|
|
|
|
|
|
|
|
|
res += "🚅 Components\n"
|
|
|
|
|
|
for name, instance in self.graph.nodes(data="instance"): # type: ignore # type wrongly defined in networkx
|
|
|
|
|
|
res += f" - {name}: {instance.__class__.__name__}\n"
|
|
|
|
|
|
|
|
|
|
|
|
res += "🛤️ Connections\n"
|
|
|
|
|
|
for sender, receiver, edge_data in self.graph.edges(data=True):
|
|
|
|
|
|
sender_socket = edge_data["from_socket"].name
|
|
|
|
|
|
receiver_socket = edge_data["to_socket"].name
|
|
|
|
|
|
res += f" - {sender}.{sender_socket} -> {receiver}.{receiver_socket} ({edge_data['conn_type']})\n"
|
|
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Serializes the pipeline to a dictionary.
|
|
|
|
|
|
|
|
|
|
|
|
This is meant to be an intermediate representation but it can be also used to save a pipeline to file.
|
|
|
|
|
|
|
|
|
|
|
|
:returns:
|
|
|
|
|
|
Dictionary with serialized data.
|
|
|
|
|
|
"""
|
|
|
|
|
|
components = {}
|
|
|
|
|
|
for name, instance in self.graph.nodes(data="instance"): # type:ignore
|
2024-10-22 17:08:36 +02:00
|
|
|
|
components[name] = component_to_dict(instance, name)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
connections = []
|
|
|
|
|
|
for sender, receiver, edge_data in self.graph.edges.data():
|
|
|
|
|
|
sender_socket = edge_data["from_socket"].name
|
|
|
|
|
|
receiver_socket = edge_data["to_socket"].name
|
|
|
|
|
|
connections.append({"sender": f"{sender}.{sender_socket}", "receiver": f"{receiver}.{receiver_socket}"})
|
|
|
|
|
|
return {
|
|
|
|
|
|
"metadata": self.metadata,
|
2024-09-12 11:00:12 +02:00
|
|
|
|
"max_runs_per_component": self._max_runs_per_component,
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"components": components,
|
|
|
|
|
|
"connections": connections,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
2024-06-18 17:52:46 +02:00
|
|
|
|
def from_dict(
|
|
|
|
|
|
cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs
|
|
|
|
|
|
) -> T:
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"""
|
|
|
|
|
|
Deserializes the pipeline from a dictionary.
|
|
|
|
|
|
|
|
|
|
|
|
:param data:
|
|
|
|
|
|
Dictionary to deserialize from.
|
|
|
|
|
|
:param callbacks:
|
|
|
|
|
|
Callbacks to invoke during deserialization.
|
|
|
|
|
|
:param kwargs:
|
2024-06-18 17:52:46 +02:00
|
|
|
|
`components`: a dictionary of {name: instance} to reuse instances of components instead of creating new
|
|
|
|
|
|
ones.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
:returns:
|
|
|
|
|
|
Deserialized component.
|
|
|
|
|
|
"""
|
2024-07-17 10:28:29 +02:00
|
|
|
|
data_copy = deepcopy(data) # to prevent modification of original data
|
|
|
|
|
|
metadata = data_copy.get("metadata", {})
|
2024-09-12 11:00:12 +02:00
|
|
|
|
max_runs_per_component = data_copy.get("max_runs_per_component", 100)
|
2024-09-30 17:11:49 +02:00
|
|
|
|
pipe = cls(metadata=metadata, max_runs_per_component=max_runs_per_component)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
components_to_reuse = kwargs.get("components", {})
|
2024-07-17 10:28:29 +02:00
|
|
|
|
for name, component_data in data_copy.get("components", {}).items():
|
2024-05-10 11:35:15 +02:00
|
|
|
|
if name in components_to_reuse:
|
|
|
|
|
|
# Reuse an instance
|
|
|
|
|
|
instance = components_to_reuse[name]
|
|
|
|
|
|
else:
|
|
|
|
|
|
if "type" not in component_data:
|
|
|
|
|
|
raise PipelineError(f"Missing 'type' in component '{name}'")
|
|
|
|
|
|
|
|
|
|
|
|
if component_data["type"] not in component.registry:
|
|
|
|
|
|
try:
|
|
|
|
|
|
# Import the module first...
|
|
|
|
|
|
module, _ = component_data["type"].rsplit(".", 1)
|
|
|
|
|
|
logger.debug("Trying to import module {module_name}", module_name=module)
|
2024-12-18 21:34:57 +01:00
|
|
|
|
type_serialization.thread_safe_import(module)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
# ...then try again
|
|
|
|
|
|
if component_data["type"] not in component.registry:
|
2024-06-18 17:52:46 +02:00
|
|
|
|
raise PipelineError(
|
|
|
|
|
|
f"Successfully imported module {module} but can't find it in the component registry."
|
|
|
|
|
|
"This is unexpected and most likely a bug."
|
|
|
|
|
|
)
|
2025-01-23 12:40:19 +01:00
|
|
|
|
except (ImportError, PipelineError, ValueError) as e:
|
|
|
|
|
|
raise PipelineError(
|
|
|
|
|
|
f"Component '{component_data['type']}' (name: '{name}') not imported."
|
|
|
|
|
|
) from e
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
# Create a new one
|
|
|
|
|
|
component_class = component.registry[component_data["type"]]
|
2024-07-12 16:47:00 +02:00
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
instance = component_from_dict(component_class, component_data, name, callbacks)
|
2024-11-28 13:28:52 +01:00
|
|
|
|
except Exception as e:
|
2024-07-12 16:47:00 +02:00
|
|
|
|
msg = (
|
2024-11-28 13:28:52 +01:00
|
|
|
|
f"Couldn't deserialize component '{name}' of class '{component_class.__name__}' "
|
|
|
|
|
|
f"with the following data: {str(component_data)}. Possible reasons include "
|
|
|
|
|
|
"malformed serialized data, mismatch between the serialized component and the "
|
|
|
|
|
|
"loaded one (due to a breaking change, see "
|
2024-11-28 10:35:24 +01:00
|
|
|
|
"https://github.com/deepset-ai/haystack/releases), etc."
|
2024-07-12 16:47:00 +02:00
|
|
|
|
)
|
2024-11-28 13:28:52 +01:00
|
|
|
|
raise DeserializationError(msg) from e
|
2024-05-10 11:35:15 +02:00
|
|
|
|
pipe.add_component(name=name, instance=instance)
|
|
|
|
|
|
|
|
|
|
|
|
for connection in data.get("connections", []):
|
|
|
|
|
|
if "sender" not in connection:
|
|
|
|
|
|
raise PipelineError(f"Missing sender in connection: {connection}")
|
|
|
|
|
|
if "receiver" not in connection:
|
|
|
|
|
|
raise PipelineError(f"Missing receiver in connection: {connection}")
|
|
|
|
|
|
pipe.connect(sender=connection["sender"], receiver=connection["receiver"])
|
|
|
|
|
|
|
|
|
|
|
|
return pipe
|
|
|
|
|
|
|
|
|
|
|
|
def dumps(self, marshaller: Marshaller = DEFAULT_MARSHALLER) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Returns the string representation of this pipeline according to the format dictated by the `Marshaller` in use.
|
|
|
|
|
|
|
|
|
|
|
|
:param marshaller:
|
|
|
|
|
|
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
|
|
|
|
|
|
:returns:
|
|
|
|
|
|
A string representing the pipeline.
|
|
|
|
|
|
"""
|
|
|
|
|
|
return marshaller.marshal(self.to_dict())
|
|
|
|
|
|
|
|
|
|
|
|
def dump(self, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Writes the string representation of this pipeline to the file-like object passed in the `fp` argument.
|
|
|
|
|
|
|
|
|
|
|
|
:param fp:
|
|
|
|
|
|
A file-like object ready to be written to.
|
|
|
|
|
|
:param marshaller:
|
|
|
|
|
|
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
|
|
|
|
|
|
"""
|
|
|
|
|
|
fp.write(marshaller.marshal(self.to_dict()))
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def loads(
|
|
|
|
|
|
cls: Type[T],
|
|
|
|
|
|
data: Union[str, bytes, bytearray],
|
|
|
|
|
|
marshaller: Marshaller = DEFAULT_MARSHALLER,
|
|
|
|
|
|
callbacks: Optional[DeserializationCallbacks] = None,
|
|
|
|
|
|
) -> T:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Creates a `Pipeline` object from the string representation passed in the `data` argument.
|
|
|
|
|
|
|
|
|
|
|
|
:param data:
|
|
|
|
|
|
The string representation of the pipeline, can be `str`, `bytes` or `bytearray`.
|
|
|
|
|
|
:param marshaller:
|
|
|
|
|
|
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
|
|
|
|
|
|
:param callbacks:
|
|
|
|
|
|
Callbacks to invoke during deserialization.
|
2024-07-12 16:47:00 +02:00
|
|
|
|
:raises DeserializationError:
|
|
|
|
|
|
If an error occurs during deserialization.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
:returns:
|
|
|
|
|
|
A `Pipeline` object.
|
|
|
|
|
|
"""
|
2024-07-12 16:47:00 +02:00
|
|
|
|
try:
|
|
|
|
|
|
deserialized_data = marshaller.unmarshal(data)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise DeserializationError(
|
|
|
|
|
|
"Error while unmarshalling serialized pipeline data. This is usually "
|
|
|
|
|
|
"caused by malformed or invalid syntax in the serialized representation."
|
|
|
|
|
|
) from e
|
|
|
|
|
|
|
|
|
|
|
|
return cls.from_dict(deserialized_data, callbacks)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def load(
|
|
|
|
|
|
cls: Type[T],
|
|
|
|
|
|
fp: TextIO,
|
|
|
|
|
|
marshaller: Marshaller = DEFAULT_MARSHALLER,
|
|
|
|
|
|
callbacks: Optional[DeserializationCallbacks] = None,
|
|
|
|
|
|
) -> T:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Creates a `Pipeline` object a string representation.
|
|
|
|
|
|
|
|
|
|
|
|
The string representation is read from the file-like object passed in the `fp` argument.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param fp:
|
|
|
|
|
|
A file-like object ready to be read from.
|
|
|
|
|
|
:param marshaller:
|
|
|
|
|
|
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
|
|
|
|
|
|
:param callbacks:
|
|
|
|
|
|
Callbacks to invoke during deserialization.
|
2024-07-12 16:47:00 +02:00
|
|
|
|
:raises DeserializationError:
|
|
|
|
|
|
If an error occurs during deserialization.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
:returns:
|
|
|
|
|
|
A `Pipeline` object.
|
|
|
|
|
|
"""
|
2024-07-12 16:47:00 +02:00
|
|
|
|
return cls.loads(fp.read(), marshaller, callbacks)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
def add_component(self, name: str, instance: Component) -> None:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Add the given component to the pipeline.
|
|
|
|
|
|
|
|
|
|
|
|
Components are not connected to anything by default: use `Pipeline.connect()` to connect components together.
|
|
|
|
|
|
Component names must be unique, but component instances can be reused if needed.
|
|
|
|
|
|
|
|
|
|
|
|
:param name:
|
|
|
|
|
|
The name of the component to add.
|
|
|
|
|
|
:param instance:
|
|
|
|
|
|
The component instance to add.
|
|
|
|
|
|
|
|
|
|
|
|
:raises ValueError:
|
|
|
|
|
|
If a component with the same name already exists.
|
|
|
|
|
|
:raises PipelineValidationError:
|
|
|
|
|
|
If the given instance is not a Canals component.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# Component names are unique
|
|
|
|
|
|
if name in self.graph.nodes:
|
|
|
|
|
|
raise ValueError(f"A component named '{name}' already exists in this pipeline: choose another name.")
|
|
|
|
|
|
|
|
|
|
|
|
# Components can't be named `_debug`
|
|
|
|
|
|
if name == "_debug":
|
|
|
|
|
|
raise ValueError("'_debug' is a reserved name for debug output. Choose another name.")
|
|
|
|
|
|
|
|
|
|
|
|
# Component instances must be components
|
|
|
|
|
|
if not isinstance(instance, Component):
|
2024-06-18 17:52:46 +02:00
|
|
|
|
raise PipelineValidationError(
|
|
|
|
|
|
f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
|
|
|
|
|
|
)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
if getattr(instance, "__haystack_added_to_pipeline__", None):
|
2024-06-18 17:52:46 +02:00
|
|
|
|
msg = (
|
|
|
|
|
|
"Component has already been added in another Pipeline. Components can't be shared between Pipelines. "
|
|
|
|
|
|
"Create a new instance instead."
|
|
|
|
|
|
)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
raise PipelineError(msg)
|
|
|
|
|
|
|
|
|
|
|
|
setattr(instance, "__haystack_added_to_pipeline__", self)
|
|
|
|
|
|
|
|
|
|
|
|
# Add component to the graph, disconnected
|
|
|
|
|
|
logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance)
|
|
|
|
|
|
# We're completely sure the fields exist so we ignore the type error
|
|
|
|
|
|
self.graph.add_node(
|
|
|
|
|
|
name,
|
|
|
|
|
|
instance=instance,
|
|
|
|
|
|
input_sockets=instance.__haystack_input__._sockets_dict, # type: ignore[attr-defined]
|
|
|
|
|
|
output_sockets=instance.__haystack_output__._sockets_dict, # type: ignore[attr-defined]
|
|
|
|
|
|
visits=0,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2024-06-10 14:54:07 +02:00
|
|
|
|
def remove_component(self, name: str) -> Component:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Remove and returns component from the pipeline.
|
|
|
|
|
|
|
|
|
|
|
|
Remove an existing component from the pipeline by providing its name.
|
|
|
|
|
|
All edges that connect to the component will also be deleted.
|
|
|
|
|
|
|
|
|
|
|
|
:param name:
|
|
|
|
|
|
The name of the component to remove.
|
|
|
|
|
|
:returns:
|
|
|
|
|
|
The removed Component instance.
|
|
|
|
|
|
|
|
|
|
|
|
:raises ValueError:
|
|
|
|
|
|
If there is no component with that name already in the Pipeline.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# Check that a component with that name is in the Pipeline
|
|
|
|
|
|
try:
|
|
|
|
|
|
instance = self.get_component(name)
|
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
f"There is no component named '{name}' in the pipeline. The valid component names are: ",
|
|
|
|
|
|
", ".join(n for n in self.graph.nodes),
|
|
|
|
|
|
) from exc
|
|
|
|
|
|
|
|
|
|
|
|
# Delete component from the graph, deleting all its connections
|
|
|
|
|
|
self.graph.remove_node(name)
|
|
|
|
|
|
|
|
|
|
|
|
# Reset the Component sockets' senders and receivers
|
|
|
|
|
|
input_sockets = instance.__haystack_input__._sockets_dict # type: ignore[attr-defined]
|
|
|
|
|
|
for socket in input_sockets.values():
|
|
|
|
|
|
socket.senders = []
|
|
|
|
|
|
|
|
|
|
|
|
output_sockets = instance.__haystack_output__._sockets_dict # type: ignore[attr-defined]
|
|
|
|
|
|
for socket in output_sockets.values():
|
|
|
|
|
|
socket.receivers = []
|
|
|
|
|
|
|
|
|
|
|
|
# Reset the Component's pipeline reference
|
|
|
|
|
|
setattr(instance, "__haystack_added_to_pipeline__", None)
|
|
|
|
|
|
|
|
|
|
|
|
return instance
|
|
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR0915 PLR0912
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"""
|
|
|
|
|
|
Connects two components together.
|
|
|
|
|
|
|
|
|
|
|
|
All components to connect must exist in the pipeline.
|
|
|
|
|
|
If connecting to a component that has several output connections, specify the inputs and output names as
|
|
|
|
|
|
'component_name.connections_name'.
|
|
|
|
|
|
|
|
|
|
|
|
:param sender:
|
|
|
|
|
|
The component that delivers the value. This can be either just a component name or can be
|
|
|
|
|
|
in the format `component_name.connection_name` if the component has multiple outputs.
|
|
|
|
|
|
:param receiver:
|
|
|
|
|
|
The component that receives the value. This can be either just a component name or can be
|
|
|
|
|
|
in the format `component_name.connection_name` if the component has multiple inputs.
|
|
|
|
|
|
:returns:
|
|
|
|
|
|
The Pipeline instance.
|
|
|
|
|
|
|
|
|
|
|
|
:raises PipelineConnectError:
|
|
|
|
|
|
If the two components cannot be connected (for example if one of the components is
|
|
|
|
|
|
not present in the pipeline, or the connections don't match by type, and so on).
|
|
|
|
|
|
"""
|
|
|
|
|
|
# Edges may be named explicitly by passing 'node_name.edge_name' to connect().
|
|
|
|
|
|
sender_component_name, sender_socket_name = parse_connect_string(sender)
|
|
|
|
|
|
receiver_component_name, receiver_socket_name = parse_connect_string(receiver)
|
|
|
|
|
|
|
2024-09-25 11:29:31 +02:00
|
|
|
|
if sender_component_name == receiver_component_name:
|
2024-09-30 19:22:36 +05:30
|
|
|
|
raise PipelineConnectError("Connecting a Component to itself is not supported.")
|
2024-09-25 11:29:31 +02:00
|
|
|
|
|
2024-05-10 11:35:15 +02:00
|
|
|
|
# Get the nodes data.
|
|
|
|
|
|
try:
|
|
|
|
|
|
from_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
|
|
|
|
|
|
except KeyError as exc:
|
|
|
|
|
|
raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc
|
|
|
|
|
|
try:
|
|
|
|
|
|
to_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
|
|
|
|
|
|
except KeyError as exc:
|
|
|
|
|
|
raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc
|
|
|
|
|
|
|
|
|
|
|
|
# If the name of either socket is given, get the socket
|
|
|
|
|
|
sender_socket: Optional[OutputSocket] = None
|
|
|
|
|
|
if sender_socket_name:
|
|
|
|
|
|
sender_socket = from_sockets.get(sender_socket_name)
|
|
|
|
|
|
if not sender_socket:
|
2024-06-18 17:52:46 +02:00
|
|
|
|
raise PipelineConnectError(
|
|
|
|
|
|
f"'{sender} does not exist. "
|
|
|
|
|
|
f"Output connections of {sender_component_name} are: "
|
|
|
|
|
|
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
|
|
|
|
|
|
)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
receiver_socket: Optional[InputSocket] = None
|
|
|
|
|
|
if receiver_socket_name:
|
|
|
|
|
|
receiver_socket = to_sockets.get(receiver_socket_name)
|
|
|
|
|
|
if not receiver_socket:
|
2024-06-18 17:52:46 +02:00
|
|
|
|
raise PipelineConnectError(
|
|
|
|
|
|
f"'{receiver} does not exist. "
|
|
|
|
|
|
f"Input connections of {receiver_component_name} are: "
|
|
|
|
|
|
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
|
|
|
|
|
|
)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
# Look for a matching connection among the possible ones.
|
|
|
|
|
|
# Note that if there is more than one possible connection but two sockets match by name, they're paired.
|
|
|
|
|
|
sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values())
|
2024-06-18 17:52:46 +02:00
|
|
|
|
receiver_socket_candidates: List[InputSocket] = (
|
|
|
|
|
|
[receiver_socket] if receiver_socket else list(to_sockets.values())
|
|
|
|
|
|
)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
# Find all possible connections between these two components
|
2024-06-18 17:52:46 +02:00
|
|
|
|
possible_connections = [
|
|
|
|
|
|
(sender_sock, receiver_sock)
|
|
|
|
|
|
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates)
|
|
|
|
|
|
if _types_are_compatible(sender_sock.type, receiver_sock.type)
|
|
|
|
|
|
]
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
# We need this status for error messages, since we might need it in multiple places we calculate it here
|
|
|
|
|
|
status = _connections_status(
|
|
|
|
|
|
sender_node=sender_component_name,
|
|
|
|
|
|
sender_sockets=sender_socket_candidates,
|
|
|
|
|
|
receiver_node=receiver_component_name,
|
|
|
|
|
|
receiver_sockets=receiver_socket_candidates,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not possible_connections:
|
|
|
|
|
|
# There's no possible connection between these two components
|
|
|
|
|
|
if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1:
|
2024-06-18 17:52:46 +02:00
|
|
|
|
msg = (
|
|
|
|
|
|
f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with "
|
|
|
|
|
|
f"'{receiver_component_name}.{receiver_socket_candidates[0].name}': "
|
|
|
|
|
|
f"their declared input and output types do not match.\n{status}"
|
|
|
|
|
|
)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
else:
|
2024-06-18 17:52:46 +02:00
|
|
|
|
msg = (
|
|
|
|
|
|
f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': "
|
|
|
|
|
|
f"no matching connections available.\n{status}"
|
|
|
|
|
|
)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
raise PipelineConnectError(msg)
|
|
|
|
|
|
|
|
|
|
|
|
if len(possible_connections) == 1:
|
|
|
|
|
|
# There's only one possible connection, use it
|
|
|
|
|
|
sender_socket = possible_connections[0][0]
|
|
|
|
|
|
receiver_socket = possible_connections[0][1]
|
|
|
|
|
|
|
|
|
|
|
|
if len(possible_connections) > 1:
|
|
|
|
|
|
# There are multiple possible connection, let's try to match them by name
|
2024-06-18 17:52:46 +02:00
|
|
|
|
name_matches = [
|
|
|
|
|
|
(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
|
|
|
|
|
|
]
|
2024-05-10 11:35:15 +02:00
|
|
|
|
if len(name_matches) != 1:
|
|
|
|
|
|
# There's are either no matches or more than one, we can't pick one reliably
|
|
|
|
|
|
msg = (
|
2024-06-18 17:52:46 +02:00
|
|
|
|
f"Cannot connect '{sender_component_name}' with "
|
|
|
|
|
|
f"'{receiver_component_name}': more than one connection is possible "
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"between these components. Please specify the connection name, like: "
|
|
|
|
|
|
f"pipeline.connect('{sender_component_name}.{possible_connections[0][0].name}', "
|
|
|
|
|
|
f"'{receiver_component_name}.{possible_connections[0][1].name}').\n{status}"
|
|
|
|
|
|
)
|
|
|
|
|
|
raise PipelineConnectError(msg)
|
|
|
|
|
|
|
|
|
|
|
|
# Get the only possible match
|
|
|
|
|
|
sender_socket = name_matches[0][0]
|
|
|
|
|
|
receiver_socket = name_matches[0][1]
|
|
|
|
|
|
|
|
|
|
|
|
# Connection must be valid on both sender/receiver sides
|
|
|
|
|
|
if not sender_socket or not receiver_socket or not sender_component_name or not receiver_component_name:
|
|
|
|
|
|
if sender_component_name and sender_socket:
|
|
|
|
|
|
sender_repr = f"{sender_component_name}.{sender_socket.name} ({_type_name(sender_socket.type)})"
|
|
|
|
|
|
else:
|
|
|
|
|
|
sender_repr = "input needed"
|
|
|
|
|
|
|
|
|
|
|
|
if receiver_component_name and receiver_socket:
|
|
|
|
|
|
receiver_repr = f"({_type_name(receiver_socket.type)}) {receiver_component_name}.{receiver_socket.name}"
|
|
|
|
|
|
else:
|
|
|
|
|
|
receiver_repr = "output"
|
|
|
|
|
|
msg = f"Connection must have both sender and receiver: {sender_repr} -> {receiver_repr}"
|
|
|
|
|
|
raise PipelineConnectError(msg)
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(
|
|
|
|
|
|
"Connecting '{sender_component}.{sender_socket_name}' to '{receiver_component}.{receiver_socket_name}'",
|
|
|
|
|
|
sender_component=sender_component_name,
|
|
|
|
|
|
sender_socket_name=sender_socket.name,
|
|
|
|
|
|
receiver_component=receiver_component_name,
|
|
|
|
|
|
receiver_socket_name=receiver_socket.name,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if receiver_component_name in sender_socket.receivers and sender_component_name in receiver_socket.senders:
|
|
|
|
|
|
# This is already connected, nothing to do
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
if receiver_socket.senders and not receiver_socket.is_variadic:
|
|
|
|
|
|
# Only variadic input sockets can receive from multiple senders
|
2024-06-18 17:52:46 +02:00
|
|
|
|
msg = (
|
|
|
|
|
|
f"Cannot connect '{sender_component_name}.{sender_socket.name}' with "
|
|
|
|
|
|
f"'{receiver_component_name}.{receiver_socket.name}': "
|
|
|
|
|
|
f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
|
|
|
|
|
|
)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
raise PipelineConnectError(msg)
|
|
|
|
|
|
|
|
|
|
|
|
# Update the sockets with the new connection
|
|
|
|
|
|
sender_socket.receivers.append(receiver_component_name)
|
|
|
|
|
|
receiver_socket.senders.append(sender_component_name)
|
|
|
|
|
|
|
|
|
|
|
|
# Create the new connection
|
|
|
|
|
|
self.graph.add_edge(
|
|
|
|
|
|
sender_component_name,
|
|
|
|
|
|
receiver_component_name,
|
|
|
|
|
|
key=f"{sender_socket.name}/{receiver_socket.name}",
|
|
|
|
|
|
conn_type=_type_name(sender_socket.type),
|
|
|
|
|
|
from_socket=sender_socket,
|
|
|
|
|
|
to_socket=receiver_socket,
|
|
|
|
|
|
mandatory=receiver_socket.is_mandatory,
|
|
|
|
|
|
)
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def get_component(self, name: str) -> Component:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Get the component with the specified name from the pipeline.
|
|
|
|
|
|
|
|
|
|
|
|
:param name:
|
|
|
|
|
|
The name of the component.
|
|
|
|
|
|
:returns:
|
|
|
|
|
|
The instance of that component.
|
|
|
|
|
|
|
|
|
|
|
|
:raises ValueError:
|
|
|
|
|
|
If a component with that name is not present in the pipeline.
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
return self.graph.nodes[name]["instance"]
|
|
|
|
|
|
except KeyError as exc:
|
|
|
|
|
|
raise ValueError(f"Component named {name} not found in the pipeline.") from exc
|
|
|
|
|
|
|
|
|
|
|
|
def get_component_name(self, instance: Component) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Returns the name of the Component instance if it has been added to this Pipeline or an empty string otherwise.
|
|
|
|
|
|
|
|
|
|
|
|
:param instance:
|
|
|
|
|
|
The Component instance to look for.
|
|
|
|
|
|
:returns:
|
|
|
|
|
|
The name of the Component instance.
|
|
|
|
|
|
"""
|
|
|
|
|
|
for name, inst in self.graph.nodes(data="instance"): # type: ignore # type wrongly defined in networkx
|
|
|
|
|
|
if inst == instance:
|
|
|
|
|
|
return name
|
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
def inputs(self, include_components_with_connected_inputs: bool = False) -> Dict[str, Dict[str, Any]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Returns a dictionary containing the inputs of a pipeline.
|
|
|
|
|
|
|
|
|
|
|
|
Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
|
|
|
|
|
|
the input sockets of that component, including their types and whether they are optional.
|
|
|
|
|
|
|
|
|
|
|
|
:param include_components_with_connected_inputs:
|
|
|
|
|
|
If `False`, only components that have disconnected input edges are
|
|
|
|
|
|
included in the output.
|
|
|
|
|
|
:returns:
|
|
|
|
|
|
A dictionary where each key is a pipeline component name and each value is a dictionary of
|
|
|
|
|
|
inputs sockets of that component.
|
|
|
|
|
|
"""
|
|
|
|
|
|
inputs: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
|
for component_name, data in find_pipeline_inputs(self.graph, include_components_with_connected_inputs).items():
|
|
|
|
|
|
sockets_description = {}
|
|
|
|
|
|
for socket in data:
|
|
|
|
|
|
sockets_description[socket.name] = {"type": socket.type, "is_mandatory": socket.is_mandatory}
|
|
|
|
|
|
if not socket.is_mandatory:
|
|
|
|
|
|
sockets_description[socket.name]["default_value"] = socket.default_value
|
|
|
|
|
|
|
|
|
|
|
|
if sockets_description:
|
|
|
|
|
|
inputs[component_name] = sockets_description
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
def outputs(self, include_components_with_connected_outputs: bool = False) -> Dict[str, Dict[str, Any]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Returns a dictionary containing the outputs of a pipeline.
|
|
|
|
|
|
|
|
|
|
|
|
Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
|
|
|
|
|
|
the output sockets of that component.
|
|
|
|
|
|
|
|
|
|
|
|
:param include_components_with_connected_outputs:
|
|
|
|
|
|
If `False`, only components that have disconnected output edges are
|
|
|
|
|
|
included in the output.
|
|
|
|
|
|
:returns:
|
|
|
|
|
|
A dictionary where each key is a pipeline component name and each value is a dictionary of
|
|
|
|
|
|
output sockets of that component.
|
|
|
|
|
|
"""
|
2024-06-18 17:52:46 +02:00
|
|
|
|
outputs = {
|
|
|
|
|
|
comp: {socket.name: {"type": socket.type} for socket in data}
|
|
|
|
|
|
for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items()
|
|
|
|
|
|
if data
|
|
|
|
|
|
}
|
2024-05-10 11:35:15 +02:00
|
|
|
|
return outputs
|
|
|
|
|
|
|
2025-02-03 15:55:29 +01:00
|
|
|
|
def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None:
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"""
|
2025-02-03 15:55:29 +01:00
|
|
|
|
Display an image representing this `Pipeline` in a Jupyter notebook.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
2025-02-03 15:55:29 +01:00
|
|
|
|
This function generates a diagram of the `Pipeline` using a Mermaid server and displays it directly in
|
|
|
|
|
|
the notebook.
|
|
|
|
|
|
|
|
|
|
|
|
:param server_url:
|
|
|
|
|
|
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
|
|
|
|
|
|
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
|
|
|
|
|
|
info on how to set up your own Mermaid server.
|
|
|
|
|
|
|
|
|
|
|
|
:param params:
|
|
|
|
|
|
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
|
|
|
|
|
|
Supported keys:
|
|
|
|
|
|
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
|
|
|
|
|
|
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
|
|
|
|
|
|
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
|
|
|
|
|
|
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
|
|
|
|
|
|
- width: Width of the output image (integer).
|
|
|
|
|
|
- height: Height of the output image (integer).
|
|
|
|
|
|
- scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
|
|
|
|
|
|
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
|
|
|
|
|
|
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
|
|
|
|
|
|
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
|
|
|
|
|
|
|
|
|
|
|
|
:raises PipelineDrawingError:
|
|
|
|
|
|
If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"""
|
|
|
|
|
|
if is_in_jupyter():
|
|
|
|
|
|
from IPython.display import Image, display # type: ignore
|
|
|
|
|
|
|
2025-02-03 15:55:29 +01:00
|
|
|
|
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
display(Image(image_data))
|
|
|
|
|
|
else:
|
|
|
|
|
|
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
|
|
|
|
|
|
raise PipelineDrawingError(msg)
|
|
|
|
|
|
|
2025-02-03 15:55:29 +01:00
|
|
|
|
def draw(self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None:
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"""
|
2025-02-03 15:55:29 +01:00
|
|
|
|
Save an image representing this `Pipeline` to the specified file path.
|
|
|
|
|
|
|
|
|
|
|
|
This function generates a diagram of the `Pipeline` using the Mermaid server and saves it to the provided path.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
:param path:
|
2025-02-03 15:55:29 +01:00
|
|
|
|
The file path where the generated image will be saved.
|
|
|
|
|
|
:param server_url:
|
|
|
|
|
|
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
|
|
|
|
|
|
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
|
|
|
|
|
|
info on how to set up your own Mermaid server.
|
|
|
|
|
|
:param params:
|
|
|
|
|
|
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
|
|
|
|
|
|
Supported keys:
|
|
|
|
|
|
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
|
|
|
|
|
|
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
|
|
|
|
|
|
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
|
|
|
|
|
|
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
|
|
|
|
|
|
- width: Width of the output image (integer).
|
|
|
|
|
|
- height: Height of the output image (integer).
|
|
|
|
|
|
- scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
|
|
|
|
|
|
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
|
|
|
|
|
|
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
|
|
|
|
|
|
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
|
|
|
|
|
|
|
|
|
|
|
|
:raises PipelineDrawingError:
|
|
|
|
|
|
If there is an issue with rendering or saving the image.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"""
|
|
|
|
|
|
# Before drawing we edit a bit the graph, to avoid modifying the original that is
|
|
|
|
|
|
# used for running the pipeline we copy it.
|
2025-02-03 15:55:29 +01:00
|
|
|
|
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
Path(path).write_bytes(image_data)
|
|
|
|
|
|
|
|
|
|
|
|
def walk(self) -> Iterator[Tuple[str, Component]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Visits each component in the pipeline exactly once and yields its name and instance.
|
|
|
|
|
|
|
|
|
|
|
|
No guarantees are provided on the visiting order.
|
|
|
|
|
|
|
|
|
|
|
|
:returns:
|
|
|
|
|
|
An iterator of tuples of component name and component instance.
|
|
|
|
|
|
"""
|
|
|
|
|
|
for component_name, instance in self.graph.nodes(data="instance"): # type: ignore # type is wrong in networkx
|
|
|
|
|
|
yield component_name, instance
|
|
|
|
|
|
|
|
|
|
|
|
def warm_up(self):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Make sure all nodes are warm.
|
|
|
|
|
|
|
|
|
|
|
|
It's the node's responsibility to make sure this method can be called at every `Pipeline.run()`
|
|
|
|
|
|
without re-initializing everything.
|
|
|
|
|
|
"""
|
|
|
|
|
|
for node in self.graph.nodes:
|
|
|
|
|
|
if hasattr(self.graph.nodes[node]["instance"], "warm_up"):
|
|
|
|
|
|
logger.info("Warming up component {node}...", node=node)
|
|
|
|
|
|
self.graph.nodes[node]["instance"].warm_up()
|
|
|
|
|
|
|
|
|
|
|
|
def _validate_input(self, data: Dict[str, Any]):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Validates pipeline input data.
|
|
|
|
|
|
|
|
|
|
|
|
Validates that data:
|
|
|
|
|
|
* Each Component name actually exists in the Pipeline
|
|
|
|
|
|
* Each Component is not missing any input
|
|
|
|
|
|
* Each Component has only one input per input socket, if not variadic
|
|
|
|
|
|
* Each Component doesn't receive inputs that are already sent by another Component
|
|
|
|
|
|
|
|
|
|
|
|
:param data:
|
|
|
|
|
|
A dictionary of inputs for the pipeline's components. Each key is a component name.
|
|
|
|
|
|
|
|
|
|
|
|
:raises ValueError:
|
|
|
|
|
|
If inputs are invalid according to the above.
|
|
|
|
|
|
"""
|
|
|
|
|
|
for component_name, component_inputs in data.items():
|
|
|
|
|
|
if component_name not in self.graph.nodes:
|
|
|
|
|
|
raise ValueError(f"Component named {component_name} not found in the pipeline.")
|
|
|
|
|
|
instance = self.graph.nodes[component_name]["instance"]
|
|
|
|
|
|
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
|
|
|
|
|
|
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
|
|
|
|
|
|
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
|
|
|
|
|
|
for input_name in component_inputs.keys():
|
|
|
|
|
|
if input_name not in instance.__haystack_input__._sockets_dict:
|
|
|
|
|
|
raise ValueError(f"Input {input_name} not found in component {component_name}.")
|
|
|
|
|
|
|
|
|
|
|
|
for component_name in self.graph.nodes:
|
|
|
|
|
|
instance = self.graph.nodes[component_name]["instance"]
|
|
|
|
|
|
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
|
|
|
|
|
|
component_inputs = data.get(component_name, {})
|
|
|
|
|
|
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
|
|
|
|
|
|
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
|
|
|
|
|
|
if socket.senders and socket_name in component_inputs and not socket.is_variadic:
|
2024-06-18 17:52:46 +02:00
|
|
|
|
raise ValueError(
|
|
|
|
|
|
f"Input {socket_name} for component {component_name} is already sent by {socket.senders}."
|
|
|
|
|
|
)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
|
def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"""
|
|
|
|
|
|
Prepares input data for pipeline components.
|
|
|
|
|
|
|
|
|
|
|
|
Organizes input data for pipeline components and identifies any inputs that are not matched to any
|
2024-05-14 23:25:46 +02:00
|
|
|
|
component's input slots. Deep-copies data items to avoid sharing mutables across multiple components.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
This method processes a flat dictionary of input data, where each key-value pair represents an input name
|
|
|
|
|
|
and its corresponding value. It distributes these inputs to the appropriate pipeline components based on
|
|
|
|
|
|
their input requirements. Inputs that don't match any component's input slots are classified as unresolved.
|
|
|
|
|
|
|
|
|
|
|
|
:param data:
|
2024-05-14 23:25:46 +02:00
|
|
|
|
A dictionary potentially having input names as keys and input values as values.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
|
:returns:
|
|
|
|
|
|
A dictionary mapping component names to their respective matched inputs.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# check whether the data is a nested dictionary of component inputs where each key is a component name
|
|
|
|
|
|
# and each value is a dictionary of input parameters for that component
|
|
|
|
|
|
is_nested_component_input = all(isinstance(value, dict) for value in data.values())
|
|
|
|
|
|
if not is_nested_component_input:
|
|
|
|
|
|
# flat input, a dict where keys are input names and values are the corresponding values
|
|
|
|
|
|
# we need to convert it to a nested dictionary of component inputs and then run the pipeline
|
|
|
|
|
|
# just like in the previous case
|
|
|
|
|
|
pipeline_input_data: Dict[str, Dict[str, Any]] = defaultdict(dict)
|
|
|
|
|
|
unresolved_kwargs = {}
|
|
|
|
|
|
|
|
|
|
|
|
# Retrieve the input slots for each component in the pipeline
|
|
|
|
|
|
available_inputs: Dict[str, Dict[str, Any]] = self.inputs()
|
|
|
|
|
|
|
|
|
|
|
|
# Go through all provided to distribute them to the appropriate component inputs
|
|
|
|
|
|
for input_name, input_value in data.items():
|
|
|
|
|
|
resolved_at_least_once = False
|
|
|
|
|
|
|
|
|
|
|
|
# Check each component to see if it has a slot for the current kwarg
|
|
|
|
|
|
for component_name, component_inputs in available_inputs.items():
|
|
|
|
|
|
if input_name in component_inputs:
|
|
|
|
|
|
# If a match is found, add the kwarg to the component's input data
|
|
|
|
|
|
pipeline_input_data[component_name][input_name] = input_value
|
|
|
|
|
|
resolved_at_least_once = True
|
|
|
|
|
|
|
|
|
|
|
|
if not resolved_at_least_once:
|
|
|
|
|
|
unresolved_kwargs[input_name] = input_value
|
|
|
|
|
|
|
|
|
|
|
|
if unresolved_kwargs:
|
|
|
|
|
|
logger.warning(
|
|
|
|
|
|
"Inputs {input_keys} were not matched to any component inputs, please check your run parameters.",
|
|
|
|
|
|
input_keys=list(unresolved_kwargs.keys()),
|
|
|
|
|
|
)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
|
data = dict(pipeline_input_data)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
|
# deepcopying the inputs prevents the Pipeline run logic from being altered unexpectedly
|
|
|
|
|
|
# when the same input reference is passed to multiple components.
|
|
|
|
|
|
for component_name, component_inputs in data.items():
|
|
|
|
|
|
data[component_name] = {k: deepcopy(v) for k, v in component_inputs.items()}
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
|
return data
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
|
|
@classmethod
|
2024-06-18 17:52:46 +02:00
|
|
|
|
def from_template(
|
|
|
|
|
|
cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
) -> "PipelineBase":
|
2024-05-10 11:35:15 +02:00
|
|
|
|
"""
|
|
|
|
|
|
Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options.
|
|
|
|
|
|
|
|
|
|
|
|
:param predefined_pipeline:
|
|
|
|
|
|
The predefined pipeline to use.
|
|
|
|
|
|
:param template_params:
|
|
|
|
|
|
An optional dictionary of parameters to use when rendering the pipeline template.
|
|
|
|
|
|
:returns:
|
|
|
|
|
|
An instance of `Pipeline`.
|
|
|
|
|
|
"""
|
|
|
|
|
|
tpl = PipelineTemplate.from_predefined(predefined_pipeline)
|
|
|
|
|
|
# If tpl.render() fails, we let bubble up the original error
|
|
|
|
|
|
rendered = tpl.render(template_params)
|
|
|
|
|
|
|
|
|
|
|
|
# If there was a problem with the rendered version of the
|
|
|
|
|
|
# template, we add it to the error stack for debugging
|
|
|
|
|
|
try:
|
|
|
|
|
|
return cls.loads(rendered)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
msg = f"Error unmarshalling pipeline: {e}\n"
|
|
|
|
|
|
msg += f"Source:\n{rendered}"
|
|
|
|
|
|
raise PipelineUnmarshalError(msg)
|
|
|
|
|
|
|
2024-10-29 15:43:16 +01:00
|
|
|
|
def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Utility function to find all Components that receive input form `component_name`.
|
|
|
|
|
|
|
|
|
|
|
|
:param component_name:
|
|
|
|
|
|
Name of the sender Component
|
|
|
|
|
|
|
|
|
|
|
|
:returns:
|
|
|
|
|
|
List of tuples containing name of the receiver Component and sender OutputSocket
|
|
|
|
|
|
and receiver InputSocket instances
|
|
|
|
|
|
"""
|
|
|
|
|
|
res = []
|
|
|
|
|
|
for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True):
|
|
|
|
|
|
sender_socket: OutputSocket = connection["from_socket"]
|
|
|
|
|
|
receiver_socket: InputSocket = connection["to_socket"]
|
|
|
|
|
|
res.append((receiver_name, sender_socket, receiver_socket))
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _convert_to_internal_format(pipeline_inputs: Dict[str, Any]) -> Dict[str, Dict[str, List]]:
|
2024-06-14 15:53:28 +02:00
|
|
|
|
"""
|
2025-02-06 15:19:47 +01:00
|
|
|
|
Converts the inputs to the pipeline to the format that is needed for the internal `Pipeline.run` logic.
|
2024-06-14 15:53:28 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
Example Input:
|
|
|
|
|
|
{'prompt_builder': {'question': 'Who lives in Paris?'}, 'retriever': {'query': 'Who lives in Paris?'}}
|
|
|
|
|
|
Example Output:
|
|
|
|
|
|
{'prompt_builder': {'question': [{'sender': None, 'value': 'Who lives in Paris?'}]},
|
|
|
|
|
|
'retriever': {'query': [{'sender': None, 'value': 'Who lives in Paris?'}]}}
|
2024-06-14 15:53:28 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
:param pipeline_inputs: Inputs to the pipeline.
|
|
|
|
|
|
:returns: Converted inputs that can be used by the internal `Pipeline.run` logic.
|
|
|
|
|
|
"""
|
|
|
|
|
|
inputs: Dict[str, Dict[str, List[Dict[str, Any]]]] = {}
|
|
|
|
|
|
for component_name, socket_dict in pipeline_inputs.items():
|
|
|
|
|
|
inputs[component_name] = {}
|
|
|
|
|
|
for socket_name, value in socket_dict.items():
|
|
|
|
|
|
inputs[component_name][socket_name] = [{"sender": None, "value": value}]
|
2024-06-14 15:53:28 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
return inputs
|
2024-10-29 15:43:16 +01:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Extracts the inputs needed to run for the component and removes them from the global inputs state.
|
|
|
|
|
|
|
|
|
|
|
|
:param component_name: The name of a component.
|
|
|
|
|
|
:param component: Component with component metadata.
|
|
|
|
|
|
:param inputs: Global inputs state.
|
|
|
|
|
|
:returns: The inputs for the component.
|
|
|
|
|
|
"""
|
|
|
|
|
|
component_inputs = inputs.get(component_name, {})
|
|
|
|
|
|
consumed_inputs = {}
|
|
|
|
|
|
greedy_inputs_to_remove = set()
|
|
|
|
|
|
for socket_name, socket in component["input_sockets"].items():
|
|
|
|
|
|
socket_inputs = component_inputs.get(socket_name, [])
|
|
|
|
|
|
socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] != _NO_OUTPUT_PRODUCED]
|
|
|
|
|
|
if socket_inputs:
|
|
|
|
|
|
if not socket.is_variadic:
|
|
|
|
|
|
# We only care about the first input provided to the socket.
|
|
|
|
|
|
consumed_inputs[socket_name] = socket_inputs[0]
|
|
|
|
|
|
elif socket.is_greedy:
|
|
|
|
|
|
# We need to keep track of greedy inputs because we always remove them, even if they come from
|
|
|
|
|
|
# outside the pipeline. Otherwise, a greedy input from the user would trigger a pipeline to run
|
|
|
|
|
|
# indefinitely.
|
|
|
|
|
|
greedy_inputs_to_remove.add(socket_name)
|
|
|
|
|
|
consumed_inputs[socket_name] = [socket_inputs[0]]
|
|
|
|
|
|
elif is_socket_lazy_variadic(socket):
|
|
|
|
|
|
# We use all inputs provided to the socket on a lazy variadic socket.
|
|
|
|
|
|
consumed_inputs[socket_name] = socket_inputs
|
|
|
|
|
|
|
|
|
|
|
|
# We prune all inputs except for those that were provided from outside the pipeline (e.g. user inputs).
|
|
|
|
|
|
pruned_inputs = {
|
|
|
|
|
|
socket_name: [
|
|
|
|
|
|
sock for sock in socket if sock["sender"] is None and not socket_name in greedy_inputs_to_remove
|
|
|
|
|
|
]
|
|
|
|
|
|
for socket_name, socket in component_inputs.items()
|
|
|
|
|
|
}
|
|
|
|
|
|
pruned_inputs = {socket_name: socket for socket_name, socket in pruned_inputs.items() if len(socket) > 0}
|
2024-09-10 14:59:53 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
inputs[component_name] = pruned_inputs
|
|
|
|
|
|
|
|
|
|
|
|
return consumed_inputs
|
2024-06-18 16:43:19 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
def _fill_queue(
|
|
|
|
|
|
self, component_names: List[str], inputs: Dict[str, Any], component_visits: Dict[str, int]
|
|
|
|
|
|
) -> FIFOPriorityQueue:
|
2024-10-29 15:43:16 +01:00
|
|
|
|
"""
|
2025-02-06 15:19:47 +01:00
|
|
|
|
Calculates the execution priority for each component and inserts it into the priority queue.
|
2024-10-29 15:43:16 +01:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
:param component_names: Names of the components to put into the queue.
|
|
|
|
|
|
:param inputs: Inputs to the components.
|
|
|
|
|
|
:param component_visits: Current state of component visits.
|
|
|
|
|
|
:returns: A prioritized queue of component names.
|
|
|
|
|
|
"""
|
|
|
|
|
|
priority_queue = FIFOPriorityQueue()
|
|
|
|
|
|
for component_name in component_names:
|
|
|
|
|
|
component = self._get_component_with_graph_metadata_and_visits(
|
|
|
|
|
|
component_name, component_visits[component_name]
|
|
|
|
|
|
)
|
|
|
|
|
|
priority = self._calculate_priority(component, inputs.get(component_name, {}))
|
|
|
|
|
|
priority_queue.push(component_name, priority)
|
2024-10-29 15:43:16 +01:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
return priority_queue
|
2024-10-29 15:43:16 +01:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _calculate_priority(component: Dict, inputs: Dict) -> ComponentPriority:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Calculates the execution priority for a component depending on the component's inputs.
|
2024-10-29 15:43:16 +01:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
:param component: Component metadata and component instance.
|
|
|
|
|
|
:param inputs: Inputs to the component.
|
|
|
|
|
|
:returns: Priority value for the component.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not can_component_run(component, inputs):
|
|
|
|
|
|
return ComponentPriority.BLOCKED
|
|
|
|
|
|
elif is_any_greedy_socket_ready(component, inputs) and are_all_sockets_ready(component, inputs):
|
|
|
|
|
|
return ComponentPriority.HIGHEST
|
|
|
|
|
|
elif all_predecessors_executed(component, inputs):
|
|
|
|
|
|
return ComponentPriority.READY
|
|
|
|
|
|
elif are_all_lazy_variadic_sockets_resolved(component, inputs):
|
|
|
|
|
|
return ComponentPriority.DEFER
|
|
|
|
|
|
else:
|
|
|
|
|
|
return ComponentPriority.DEFER_LAST
|
2024-10-29 15:43:16 +01:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
def _get_component_with_graph_metadata_and_visits(self, component_name: str, visits: int) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Returns the component instance alongside input/output-socket metadata from the graph and adds current visits.
|
2024-10-29 15:43:16 +01:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
We can't store visits in the pipeline graph because this would prevent reentrance / thread-safe execution.
|
2024-06-18 16:43:19 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
:param component_name: The name of the component.
|
|
|
|
|
|
:param visits: Number of visits for the component.
|
|
|
|
|
|
:returns: Dict including component instance, input/output-sockets and visits.
|
|
|
|
|
|
"""
|
|
|
|
|
|
comp_dict = self.graph.nodes[component_name]
|
|
|
|
|
|
comp_dict = {**comp_dict, "visits": visits}
|
|
|
|
|
|
return comp_dict
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
def _get_next_runnable_component(
|
|
|
|
|
|
self, priority_queue: FIFOPriorityQueue, component_visits: Dict[str, int]
|
|
|
|
|
|
) -> Union[Tuple[ComponentPriority, str, Dict[str, Any]], None]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Returns the next runnable component alongside its metadata from the priority queue.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
:param priority_queue: Priority queue of component names.
|
|
|
|
|
|
:param component_visits: Current state of component visits.
|
|
|
|
|
|
:returns: The next runnable component, the component name, and its priority
|
|
|
|
|
|
or None if no component in the queue can run.
|
|
|
|
|
|
:raises: PipelineMaxComponentRuns if the next runnable component has exceeded the maximum number of runs.
|
|
|
|
|
|
"""
|
|
|
|
|
|
priority_and_component_name: Union[Tuple[ComponentPriority, str], None] = (
|
|
|
|
|
|
None if (item := priority_queue.get()) is None else (ComponentPriority(item[0]), str(item[1]))
|
|
|
|
|
|
)
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
if priority_and_component_name is not None and priority_and_component_name[0] != ComponentPriority.BLOCKED:
|
|
|
|
|
|
priority, component_name = priority_and_component_name
|
|
|
|
|
|
component = self._get_component_with_graph_metadata_and_visits(
|
|
|
|
|
|
component_name, component_visits[component_name]
|
|
|
|
|
|
)
|
|
|
|
|
|
if component["visits"] > self._max_runs_per_component:
|
|
|
|
|
|
msg = f"Maximum run count {self._max_runs_per_component} reached for component '{component_name}'"
|
|
|
|
|
|
raise PipelineMaxComponentRuns(msg)
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
return priority, component_name, component
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
return None
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _add_missing_input_defaults(component_inputs: Dict[str, Any], component_input_sockets: Dict[str, InputSocket]):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Updates the inputs with the default values for the inputs that are missing
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
:param component_inputs: Inputs for the component.
|
|
|
|
|
|
:param component_input_sockets: Input sockets of the component.
|
|
|
|
|
|
"""
|
|
|
|
|
|
for name, socket in component_input_sockets.items():
|
|
|
|
|
|
if not socket.is_mandatory and name not in component_inputs:
|
|
|
|
|
|
if socket.is_variadic:
|
|
|
|
|
|
component_inputs[name] = [socket.default_value]
|
|
|
|
|
|
else:
|
|
|
|
|
|
component_inputs[name] = socket.default_value
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
return component_inputs
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _write_component_outputs(
|
|
|
|
|
|
component_name, component_outputs, inputs, receivers, include_outputs_from
|
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Distributes the outputs of a component to the input sockets that it is connected to.
|
|
|
|
|
|
|
|
|
|
|
|
:param component_name: The name of the component.
|
|
|
|
|
|
:param component_outputs: The outputs of the component.
|
|
|
|
|
|
:param inputs: The current global input state.
|
|
|
|
|
|
:param receivers: List of receiver_name, sender_socket, receiver_socket for connected components.
|
|
|
|
|
|
:param include_outputs_from: List of component names that should always return an output from the pipeline.
|
|
|
|
|
|
"""
|
|
|
|
|
|
for receiver_name, sender_socket, receiver_socket in receivers:
|
|
|
|
|
|
# We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
|
|
|
|
|
|
# that the sender did not produce an output for this socket.
|
|
|
|
|
|
# This allows us to track if a pre-decessor already ran but did not produce an output.
|
|
|
|
|
|
value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
|
|
|
|
|
|
if receiver_name not in inputs:
|
|
|
|
|
|
inputs[receiver_name] = {}
|
|
|
|
|
|
|
|
|
|
|
|
# If we have a non-variadic or a greedy variadic receiver socket, we can just overwrite any inputs
|
|
|
|
|
|
# that might already exist (to be reconsidered but mirrors current behavior).
|
|
|
|
|
|
if not is_socket_lazy_variadic(receiver_socket):
|
|
|
|
|
|
inputs[receiver_name][receiver_socket.name] = [{"sender": component_name, "value": value}]
|
|
|
|
|
|
|
|
|
|
|
|
# If the receiver socket is lazy variadic, and it already has an input, we need to append the new input.
|
|
|
|
|
|
# Lazy variadic sockets can collect multiple inputs.
|
|
|
|
|
|
else:
|
|
|
|
|
|
if not inputs[receiver_name].get(receiver_socket.name):
|
|
|
|
|
|
inputs[receiver_name][receiver_socket.name] = []
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
inputs[receiver_name][receiver_socket.name].append({"sender": component_name, "value": value})
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
# If we want to include all outputs from this actor in the final outputs, we don't need to prune any consumed
|
|
|
|
|
|
# outputs
|
|
|
|
|
|
if component_name in include_outputs_from:
|
|
|
|
|
|
return component_outputs
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
# We prune outputs that were consumed by any receiving sockets.
|
|
|
|
|
|
# All remaining outputs will be added to the final outputs of the pipeline.
|
|
|
|
|
|
consumed_outputs = {sender_socket.name for _, sender_socket, __ in receivers}
|
|
|
|
|
|
pruned_outputs = {key: value for key, value in component_outputs.items() if key not in consumed_outputs}
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
return pruned_outputs
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _is_queue_stale(priority_queue: FIFOPriorityQueue) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Checks if the priority queue needs to be recomputed because the priorities might have changed.
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
:param priority_queue: Priority queue of component names.
|
|
|
|
|
|
"""
|
|
|
|
|
|
return len(priority_queue) == 0 or priority_queue.peek()[0] > ComponentPriority.READY
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def validate_pipeline(priority_queue: FIFOPriorityQueue) -> None:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Validate the pipeline to check if it is blocked or has no valid entry point.
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
:param priority_queue: Priority queue of component names.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if len(priority_queue) == 0:
|
|
|
|
|
|
return
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
candidate = priority_queue.peek()
|
|
|
|
|
|
if candidate is not None and candidate[0] == ComponentPriority.BLOCKED:
|
|
|
|
|
|
raise PipelineRuntimeError(
|
|
|
|
|
|
"Cannot run pipeline - all components are blocked. "
|
|
|
|
|
|
"This typically happens when:\n"
|
|
|
|
|
|
"1. There is no valid entry point for the pipeline\n"
|
|
|
|
|
|
"2. There is a circular dependency preventing the pipeline from running\n"
|
|
|
|
|
|
"Check the connections between these components and ensure all required inputs are provided."
|
|
|
|
|
|
)
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
|
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
def _connections_status(
|
|
|
|
|
|
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
|
|
|
|
|
|
):
|
2024-07-12 10:35:23 +02:00
|
|
|
|
"""
|
2025-02-06 15:19:47 +01:00
|
|
|
|
Lists the status of the sockets, for error messages.
|
2024-07-12 10:35:23 +02:00
|
|
|
|
"""
|
2025-02-06 15:19:47 +01:00
|
|
|
|
sender_sockets_entries = []
|
|
|
|
|
|
for sender_socket in sender_sockets:
|
|
|
|
|
|
sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}")
|
|
|
|
|
|
sender_sockets_list = "\n".join(sender_sockets_entries)
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
receiver_sockets_entries = []
|
|
|
|
|
|
for receiver_socket in receiver_sockets:
|
|
|
|
|
|
if receiver_socket.senders:
|
|
|
|
|
|
sender_status = f"sent by {','.join(receiver_socket.senders)}"
|
|
|
|
|
|
else:
|
|
|
|
|
|
sender_status = "available"
|
|
|
|
|
|
receiver_sockets_entries.append(
|
|
|
|
|
|
f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
|
|
|
|
|
|
)
|
|
|
|
|
|
receiver_sockets_list = "\n".join(receiver_sockets_entries)
|
2024-07-12 10:35:23 +02:00
|
|
|
|
|
2025-02-06 15:19:47 +01:00
|
|
|
|
return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"
|