2024-05-10 11:35:15 +02:00
|
|
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
|
|
|
#
|
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
|
|
|
|
import importlib
|
|
|
|
|
import itertools
|
|
|
|
|
from collections import defaultdict
|
2024-05-14 23:25:46 +02:00
|
|
|
from copy import copy, deepcopy
|
2024-05-10 11:35:15 +02:00
|
|
|
from datetime import datetime
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any, Dict, Iterator, List, Optional, TextIO, Tuple, Type, TypeVar, Union
|
|
|
|
|
|
|
|
|
|
import networkx # type:ignore
|
|
|
|
|
|
|
|
|
|
from haystack import logging
|
|
|
|
|
from haystack.core.component import Component, InputSocket, OutputSocket, component
|
|
|
|
|
from haystack.core.errors import (
|
|
|
|
|
PipelineConnectError,
|
|
|
|
|
PipelineDrawingError,
|
|
|
|
|
PipelineError,
|
|
|
|
|
PipelineUnmarshalError,
|
|
|
|
|
PipelineValidationError,
|
|
|
|
|
)
|
|
|
|
|
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
|
|
|
|
|
from haystack.core.type_utils import _type_name, _types_are_compatible
|
|
|
|
|
from haystack.marshal import Marshaller, YamlMarshaller
|
|
|
|
|
from haystack.utils import is_in_jupyter
|
|
|
|
|
|
|
|
|
|
from .descriptions import find_pipeline_inputs, find_pipeline_outputs
|
|
|
|
|
from .draw import _to_mermaid_image
|
|
|
|
|
from .template import PipelineTemplate, PredefinedPipeline
|
|
|
|
|
from .utils import parse_connect_string
|
|
|
|
|
|
|
|
|
|
DEFAULT_MARSHALLER = YamlMarshaller()
|
|
|
|
|
|
|
|
|
|
# We use a generic type to annotate the return value of classmethods,
|
|
|
|
|
# so that static analyzers won't be confused when derived classes
|
|
|
|
|
# use those methods.
|
|
|
|
|
T = TypeVar("T", bound="PipelineBase")
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PipelineBase:
|
|
|
|
|
"""
|
|
|
|
|
Components orchestration engine.
|
|
|
|
|
|
|
|
|
|
Builds a graph of components and orchestrates their execution according to the execution graph.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
|
|
|
max_loops_allowed: int = 100,
|
|
|
|
|
debug_path: Union[Path, str] = Path(".haystack_debug/"),
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Creates the Pipeline.
|
|
|
|
|
|
|
|
|
|
:param metadata:
|
|
|
|
|
Arbitrary dictionary to store metadata about this pipeline. Make sure all the values contained in
|
|
|
|
|
this dictionary can be serialized and deserialized if you wish to save this pipeline to file with
|
|
|
|
|
`save_pipelines()/load_pipelines()`.
|
|
|
|
|
:param max_loops_allowed:
|
|
|
|
|
How many times the pipeline can run the same node before throwing an exception.
|
|
|
|
|
:param debug_path:
|
|
|
|
|
When debug is enabled in `run()`, where to save the debug data.
|
|
|
|
|
"""
|
|
|
|
|
self._telemetry_runs = 0
|
|
|
|
|
self._last_telemetry_sent: Optional[datetime] = None
|
|
|
|
|
self.metadata = metadata or {}
|
|
|
|
|
self.max_loops_allowed = max_loops_allowed
|
|
|
|
|
self.graph = networkx.MultiDiGraph()
|
|
|
|
|
self._debug: Dict[int, Dict[str, Any]] = {}
|
|
|
|
|
self._debug_path = Path(debug_path)
|
|
|
|
|
|
|
|
|
|
def __eq__(self, other) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
Pipeline equality is defined by their type and the equality of their serialized form.
|
|
|
|
|
|
|
|
|
|
Pipelines of the same type share every metadata, node and edge, but they're not required to use
|
|
|
|
|
the same node instances: this allows pipeline saved and then loaded back to be equal to themselves.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(self, type(other)):
|
|
|
|
|
return False
|
|
|
|
|
return self.to_dict() == other.to_dict()
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Returns a text representation of the Pipeline.
|
|
|
|
|
"""
|
|
|
|
|
res = f"{object.__repr__(self)}\n"
|
|
|
|
|
if self.metadata:
|
|
|
|
|
res += "🧱 Metadata\n"
|
|
|
|
|
for k, v in self.metadata.items():
|
|
|
|
|
res += f" - {k}: {v}\n"
|
|
|
|
|
|
|
|
|
|
res += "🚅 Components\n"
|
|
|
|
|
for name, instance in self.graph.nodes(data="instance"): # type: ignore # type wrongly defined in networkx
|
|
|
|
|
res += f" - {name}: {instance.__class__.__name__}\n"
|
|
|
|
|
|
|
|
|
|
res += "🛤️ Connections\n"
|
|
|
|
|
for sender, receiver, edge_data in self.graph.edges(data=True):
|
|
|
|
|
sender_socket = edge_data["from_socket"].name
|
|
|
|
|
receiver_socket = edge_data["to_socket"].name
|
|
|
|
|
res += f" - {sender}.{sender_socket} -> {receiver}.{receiver_socket} ({edge_data['conn_type']})\n"
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Serializes the pipeline to a dictionary.
|
|
|
|
|
|
|
|
|
|
This is meant to be an intermediate representation but it can be also used to save a pipeline to file.
|
|
|
|
|
|
|
|
|
|
:returns:
|
|
|
|
|
Dictionary with serialized data.
|
|
|
|
|
"""
|
|
|
|
|
components = {}
|
|
|
|
|
for name, instance in self.graph.nodes(data="instance"): # type:ignore
|
|
|
|
|
components[name] = component_to_dict(instance)
|
|
|
|
|
|
|
|
|
|
connections = []
|
|
|
|
|
for sender, receiver, edge_data in self.graph.edges.data():
|
|
|
|
|
sender_socket = edge_data["from_socket"].name
|
|
|
|
|
receiver_socket = edge_data["to_socket"].name
|
|
|
|
|
connections.append({"sender": f"{sender}.{sender_socket}", "receiver": f"{receiver}.{receiver_socket}"})
|
|
|
|
|
return {
|
|
|
|
|
"metadata": self.metadata,
|
|
|
|
|
"max_loops_allowed": self.max_loops_allowed,
|
|
|
|
|
"components": components,
|
|
|
|
|
"connections": connections,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_dict(
|
|
|
|
|
cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs
|
|
|
|
|
) -> T:
|
|
|
|
|
"""
|
|
|
|
|
Deserializes the pipeline from a dictionary.
|
|
|
|
|
|
|
|
|
|
:param data:
|
|
|
|
|
Dictionary to deserialize from.
|
|
|
|
|
:param callbacks:
|
|
|
|
|
Callbacks to invoke during deserialization.
|
|
|
|
|
:param kwargs:
|
|
|
|
|
`components`: a dictionary of {name: instance} to reuse instances of components instead of creating new ones.
|
|
|
|
|
:returns:
|
|
|
|
|
Deserialized component.
|
|
|
|
|
"""
|
|
|
|
|
metadata = data.get("metadata", {})
|
|
|
|
|
max_loops_allowed = data.get("max_loops_allowed", 100)
|
|
|
|
|
debug_path = Path(data.get("debug_path", ".haystack_debug/"))
|
|
|
|
|
pipe = cls(metadata=metadata, max_loops_allowed=max_loops_allowed, debug_path=debug_path)
|
|
|
|
|
components_to_reuse = kwargs.get("components", {})
|
|
|
|
|
for name, component_data in data.get("components", {}).items():
|
|
|
|
|
if name in components_to_reuse:
|
|
|
|
|
# Reuse an instance
|
|
|
|
|
instance = components_to_reuse[name]
|
|
|
|
|
else:
|
|
|
|
|
if "type" not in component_data:
|
|
|
|
|
raise PipelineError(f"Missing 'type' in component '{name}'")
|
|
|
|
|
|
|
|
|
|
if component_data["type"] not in component.registry:
|
|
|
|
|
try:
|
|
|
|
|
# Import the module first...
|
|
|
|
|
module, _ = component_data["type"].rsplit(".", 1)
|
|
|
|
|
logger.debug("Trying to import module {module_name}", module_name=module)
|
|
|
|
|
importlib.import_module(module)
|
|
|
|
|
# ...then try again
|
|
|
|
|
if component_data["type"] not in component.registry:
|
|
|
|
|
raise PipelineError(
|
|
|
|
|
f"Successfully imported module {module} but can't find it in the component registry."
|
|
|
|
|
"This is unexpected and most likely a bug."
|
|
|
|
|
)
|
|
|
|
|
except (ImportError, PipelineError) as e:
|
|
|
|
|
raise PipelineError(f"Component '{component_data['type']}' not imported.") from e
|
|
|
|
|
|
|
|
|
|
# Create a new one
|
|
|
|
|
component_class = component.registry[component_data["type"]]
|
|
|
|
|
instance = component_from_dict(component_class, component_data, name, callbacks)
|
|
|
|
|
pipe.add_component(name=name, instance=instance)
|
|
|
|
|
|
|
|
|
|
for connection in data.get("connections", []):
|
|
|
|
|
if "sender" not in connection:
|
|
|
|
|
raise PipelineError(f"Missing sender in connection: {connection}")
|
|
|
|
|
if "receiver" not in connection:
|
|
|
|
|
raise PipelineError(f"Missing receiver in connection: {connection}")
|
|
|
|
|
pipe.connect(sender=connection["sender"], receiver=connection["receiver"])
|
|
|
|
|
|
|
|
|
|
return pipe
|
|
|
|
|
|
|
|
|
|
def dumps(self, marshaller: Marshaller = DEFAULT_MARSHALLER) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Returns the string representation of this pipeline according to the format dictated by the `Marshaller` in use.
|
|
|
|
|
|
|
|
|
|
:param marshaller:
|
|
|
|
|
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
|
|
|
|
|
:returns:
|
|
|
|
|
A string representing the pipeline.
|
|
|
|
|
"""
|
|
|
|
|
return marshaller.marshal(self.to_dict())
|
|
|
|
|
|
|
|
|
|
def dump(self, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER):
|
|
|
|
|
"""
|
|
|
|
|
Writes the string representation of this pipeline to the file-like object passed in the `fp` argument.
|
|
|
|
|
|
|
|
|
|
:param fp:
|
|
|
|
|
A file-like object ready to be written to.
|
|
|
|
|
:param marshaller:
|
|
|
|
|
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
|
|
|
|
|
"""
|
|
|
|
|
fp.write(marshaller.marshal(self.to_dict()))
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def loads(
|
|
|
|
|
cls: Type[T],
|
|
|
|
|
data: Union[str, bytes, bytearray],
|
|
|
|
|
marshaller: Marshaller = DEFAULT_MARSHALLER,
|
|
|
|
|
callbacks: Optional[DeserializationCallbacks] = None,
|
|
|
|
|
) -> T:
|
|
|
|
|
"""
|
|
|
|
|
Creates a `Pipeline` object from the string representation passed in the `data` argument.
|
|
|
|
|
|
|
|
|
|
:param data:
|
|
|
|
|
The string representation of the pipeline, can be `str`, `bytes` or `bytearray`.
|
|
|
|
|
:param marshaller:
|
|
|
|
|
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
|
|
|
|
|
:param callbacks:
|
|
|
|
|
Callbacks to invoke during deserialization.
|
|
|
|
|
:returns:
|
|
|
|
|
A `Pipeline` object.
|
|
|
|
|
"""
|
|
|
|
|
return cls.from_dict(marshaller.unmarshal(data), callbacks)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def load(
|
|
|
|
|
cls: Type[T],
|
|
|
|
|
fp: TextIO,
|
|
|
|
|
marshaller: Marshaller = DEFAULT_MARSHALLER,
|
|
|
|
|
callbacks: Optional[DeserializationCallbacks] = None,
|
|
|
|
|
) -> T:
|
|
|
|
|
"""
|
|
|
|
|
Creates a `Pipeline` object a string representation.
|
|
|
|
|
|
|
|
|
|
The string representation is read from the file-like object passed in the `fp` argument.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param fp:
|
|
|
|
|
A file-like object ready to be read from.
|
|
|
|
|
:param marshaller:
|
|
|
|
|
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
|
|
|
|
|
:param callbacks:
|
|
|
|
|
Callbacks to invoke during deserialization.
|
|
|
|
|
:returns:
|
|
|
|
|
A `Pipeline` object.
|
|
|
|
|
"""
|
|
|
|
|
return cls.from_dict(marshaller.unmarshal(fp.read()), callbacks)
|
|
|
|
|
|
|
|
|
|
def add_component(self, name: str, instance: Component) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Add the given component to the pipeline.
|
|
|
|
|
|
|
|
|
|
Components are not connected to anything by default: use `Pipeline.connect()` to connect components together.
|
|
|
|
|
Component names must be unique, but component instances can be reused if needed.
|
|
|
|
|
|
|
|
|
|
:param name:
|
|
|
|
|
The name of the component to add.
|
|
|
|
|
:param instance:
|
|
|
|
|
The component instance to add.
|
|
|
|
|
|
|
|
|
|
:raises ValueError:
|
|
|
|
|
If a component with the same name already exists.
|
|
|
|
|
:raises PipelineValidationError:
|
|
|
|
|
If the given instance is not a Canals component.
|
|
|
|
|
"""
|
|
|
|
|
# Component names are unique
|
|
|
|
|
if name in self.graph.nodes:
|
|
|
|
|
raise ValueError(f"A component named '{name}' already exists in this pipeline: choose another name.")
|
|
|
|
|
|
|
|
|
|
# Components can't be named `_debug`
|
|
|
|
|
if name == "_debug":
|
|
|
|
|
raise ValueError("'_debug' is a reserved name for debug output. Choose another name.")
|
|
|
|
|
|
|
|
|
|
# Component instances must be components
|
|
|
|
|
if not isinstance(instance, Component):
|
|
|
|
|
raise PipelineValidationError(
|
|
|
|
|
f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if getattr(instance, "__haystack_added_to_pipeline__", None):
|
|
|
|
|
msg = (
|
|
|
|
|
"Component has already been added in another Pipeline. "
|
|
|
|
|
"Components can't be shared between Pipelines. Create a new instance instead."
|
|
|
|
|
)
|
|
|
|
|
raise PipelineError(msg)
|
|
|
|
|
|
|
|
|
|
setattr(instance, "__haystack_added_to_pipeline__", self)
|
|
|
|
|
|
|
|
|
|
# Add component to the graph, disconnected
|
|
|
|
|
logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance)
|
|
|
|
|
# We're completely sure the fields exist so we ignore the type error
|
|
|
|
|
self.graph.add_node(
|
|
|
|
|
name,
|
|
|
|
|
instance=instance,
|
|
|
|
|
input_sockets=instance.__haystack_input__._sockets_dict, # type: ignore[attr-defined]
|
|
|
|
|
output_sockets=instance.__haystack_output__._sockets_dict, # type: ignore[attr-defined]
|
|
|
|
|
visits=0,
|
|
|
|
|
)
|
|
|
|
|
|
2024-06-10 14:54:07 +02:00
|
|
|
def remove_component(self, name: str) -> Component:
|
|
|
|
|
"""
|
|
|
|
|
Remove and returns component from the pipeline.
|
|
|
|
|
|
|
|
|
|
Remove an existing component from the pipeline by providing its name.
|
|
|
|
|
All edges that connect to the component will also be deleted.
|
|
|
|
|
|
|
|
|
|
:param name:
|
|
|
|
|
The name of the component to remove.
|
|
|
|
|
:returns:
|
|
|
|
|
The removed Component instance.
|
|
|
|
|
|
|
|
|
|
:raises ValueError:
|
|
|
|
|
If there is no component with that name already in the Pipeline.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Check that a component with that name is in the Pipeline
|
|
|
|
|
try:
|
|
|
|
|
instance = self.get_component(name)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"There is no component named '{name}' in the pipeline. The valid component names are: ",
|
|
|
|
|
", ".join(n for n in self.graph.nodes),
|
|
|
|
|
) from exc
|
|
|
|
|
|
|
|
|
|
# Delete component from the graph, deleting all its connections
|
|
|
|
|
self.graph.remove_node(name)
|
|
|
|
|
|
|
|
|
|
# Reset the Component sockets' senders and receivers
|
|
|
|
|
input_sockets = instance.__haystack_input__._sockets_dict # type: ignore[attr-defined]
|
|
|
|
|
for socket in input_sockets.values():
|
|
|
|
|
socket.senders = []
|
|
|
|
|
|
|
|
|
|
output_sockets = instance.__haystack_output__._sockets_dict # type: ignore[attr-defined]
|
|
|
|
|
for socket in output_sockets.values():
|
|
|
|
|
socket.receivers = []
|
|
|
|
|
|
|
|
|
|
# Reset the Component's pipeline reference
|
|
|
|
|
setattr(instance, "__haystack_added_to_pipeline__", None)
|
|
|
|
|
|
|
|
|
|
return instance
|
|
|
|
|
|
2024-05-10 11:35:15 +02:00
|
|
|
def connect(self, sender: str, receiver: str) -> "PipelineBase":
|
|
|
|
|
"""
|
|
|
|
|
Connects two components together.
|
|
|
|
|
|
|
|
|
|
All components to connect must exist in the pipeline.
|
|
|
|
|
If connecting to a component that has several output connections, specify the inputs and output names as
|
|
|
|
|
'component_name.connections_name'.
|
|
|
|
|
|
|
|
|
|
:param sender:
|
|
|
|
|
The component that delivers the value. This can be either just a component name or can be
|
|
|
|
|
in the format `component_name.connection_name` if the component has multiple outputs.
|
|
|
|
|
:param receiver:
|
|
|
|
|
The component that receives the value. This can be either just a component name or can be
|
|
|
|
|
in the format `component_name.connection_name` if the component has multiple inputs.
|
|
|
|
|
:returns:
|
|
|
|
|
The Pipeline instance.
|
|
|
|
|
|
|
|
|
|
:raises PipelineConnectError:
|
|
|
|
|
If the two components cannot be connected (for example if one of the components is
|
|
|
|
|
not present in the pipeline, or the connections don't match by type, and so on).
|
|
|
|
|
"""
|
|
|
|
|
# Edges may be named explicitly by passing 'node_name.edge_name' to connect().
|
|
|
|
|
sender_component_name, sender_socket_name = parse_connect_string(sender)
|
|
|
|
|
receiver_component_name, receiver_socket_name = parse_connect_string(receiver)
|
|
|
|
|
|
|
|
|
|
# Get the nodes data.
|
|
|
|
|
try:
|
|
|
|
|
from_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
|
|
|
|
|
except KeyError as exc:
|
|
|
|
|
raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc
|
|
|
|
|
try:
|
|
|
|
|
to_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
|
|
|
|
|
except KeyError as exc:
|
|
|
|
|
raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc
|
|
|
|
|
|
|
|
|
|
# If the name of either socket is given, get the socket
|
|
|
|
|
sender_socket: Optional[OutputSocket] = None
|
|
|
|
|
if sender_socket_name:
|
|
|
|
|
sender_socket = from_sockets.get(sender_socket_name)
|
|
|
|
|
if not sender_socket:
|
|
|
|
|
raise PipelineConnectError(
|
|
|
|
|
f"'{sender} does not exist. "
|
|
|
|
|
f"Output connections of {sender_component_name} are: "
|
|
|
|
|
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
receiver_socket: Optional[InputSocket] = None
|
|
|
|
|
if receiver_socket_name:
|
|
|
|
|
receiver_socket = to_sockets.get(receiver_socket_name)
|
|
|
|
|
if not receiver_socket:
|
|
|
|
|
raise PipelineConnectError(
|
|
|
|
|
f"'{receiver} does not exist. "
|
|
|
|
|
f"Input connections of {receiver_component_name} are: "
|
|
|
|
|
+ ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Look for a matching connection among the possible ones.
|
|
|
|
|
# Note that if there is more than one possible connection but two sockets match by name, they're paired.
|
|
|
|
|
sender_socket_candidates: List[OutputSocket] = [sender_socket] if sender_socket else list(from_sockets.values())
|
|
|
|
|
receiver_socket_candidates: List[InputSocket] = (
|
|
|
|
|
[receiver_socket] if receiver_socket else list(to_sockets.values())
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Find all possible connections between these two components
|
|
|
|
|
possible_connections = [
|
|
|
|
|
(sender_sock, receiver_sock)
|
|
|
|
|
for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates)
|
|
|
|
|
if _types_are_compatible(sender_sock.type, receiver_sock.type)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# We need this status for error messages, since we might need it in multiple places we calculate it here
|
|
|
|
|
status = _connections_status(
|
|
|
|
|
sender_node=sender_component_name,
|
|
|
|
|
sender_sockets=sender_socket_candidates,
|
|
|
|
|
receiver_node=receiver_component_name,
|
|
|
|
|
receiver_sockets=receiver_socket_candidates,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not possible_connections:
|
|
|
|
|
# There's no possible connection between these two components
|
|
|
|
|
if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1:
|
|
|
|
|
msg = (
|
|
|
|
|
f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with '{receiver_component_name}.{receiver_socket_candidates[0].name}': "
|
|
|
|
|
f"their declared input and output types do not match.\n{status}"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
msg = (
|
|
|
|
|
f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': "
|
|
|
|
|
f"no matching connections available.\n{status}"
|
|
|
|
|
)
|
|
|
|
|
raise PipelineConnectError(msg)
|
|
|
|
|
|
|
|
|
|
if len(possible_connections) == 1:
|
|
|
|
|
# There's only one possible connection, use it
|
|
|
|
|
sender_socket = possible_connections[0][0]
|
|
|
|
|
receiver_socket = possible_connections[0][1]
|
|
|
|
|
|
|
|
|
|
if len(possible_connections) > 1:
|
|
|
|
|
# There are multiple possible connection, let's try to match them by name
|
|
|
|
|
name_matches = [
|
|
|
|
|
(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
|
|
|
|
|
]
|
|
|
|
|
if len(name_matches) != 1:
|
|
|
|
|
# There's are either no matches or more than one, we can't pick one reliably
|
|
|
|
|
msg = (
|
|
|
|
|
f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': more than one connection is possible "
|
|
|
|
|
"between these components. Please specify the connection name, like: "
|
|
|
|
|
f"pipeline.connect('{sender_component_name}.{possible_connections[0][0].name}', "
|
|
|
|
|
f"'{receiver_component_name}.{possible_connections[0][1].name}').\n{status}"
|
|
|
|
|
)
|
|
|
|
|
raise PipelineConnectError(msg)
|
|
|
|
|
|
|
|
|
|
# Get the only possible match
|
|
|
|
|
sender_socket = name_matches[0][0]
|
|
|
|
|
receiver_socket = name_matches[0][1]
|
|
|
|
|
|
|
|
|
|
# Connection must be valid on both sender/receiver sides
|
|
|
|
|
if not sender_socket or not receiver_socket or not sender_component_name or not receiver_component_name:
|
|
|
|
|
if sender_component_name and sender_socket:
|
|
|
|
|
sender_repr = f"{sender_component_name}.{sender_socket.name} ({_type_name(sender_socket.type)})"
|
|
|
|
|
else:
|
|
|
|
|
sender_repr = "input needed"
|
|
|
|
|
|
|
|
|
|
if receiver_component_name and receiver_socket:
|
|
|
|
|
receiver_repr = f"({_type_name(receiver_socket.type)}) {receiver_component_name}.{receiver_socket.name}"
|
|
|
|
|
else:
|
|
|
|
|
receiver_repr = "output"
|
|
|
|
|
msg = f"Connection must have both sender and receiver: {sender_repr} -> {receiver_repr}"
|
|
|
|
|
raise PipelineConnectError(msg)
|
|
|
|
|
|
|
|
|
|
logger.debug(
|
|
|
|
|
"Connecting '{sender_component}.{sender_socket_name}' to '{receiver_component}.{receiver_socket_name}'",
|
|
|
|
|
sender_component=sender_component_name,
|
|
|
|
|
sender_socket_name=sender_socket.name,
|
|
|
|
|
receiver_component=receiver_component_name,
|
|
|
|
|
receiver_socket_name=receiver_socket.name,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if receiver_component_name in sender_socket.receivers and sender_component_name in receiver_socket.senders:
|
|
|
|
|
# This is already connected, nothing to do
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
if receiver_socket.senders and not receiver_socket.is_variadic:
|
|
|
|
|
# Only variadic input sockets can receive from multiple senders
|
|
|
|
|
msg = (
|
|
|
|
|
f"Cannot connect '{sender_component_name}.{sender_socket.name}' with '{receiver_component_name}.{receiver_socket.name}': "
|
|
|
|
|
f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
|
|
|
|
|
)
|
|
|
|
|
raise PipelineConnectError(msg)
|
|
|
|
|
|
|
|
|
|
# Update the sockets with the new connection
|
|
|
|
|
sender_socket.receivers.append(receiver_component_name)
|
|
|
|
|
receiver_socket.senders.append(sender_component_name)
|
|
|
|
|
|
|
|
|
|
# Create the new connection
|
|
|
|
|
self.graph.add_edge(
|
|
|
|
|
sender_component_name,
|
|
|
|
|
receiver_component_name,
|
|
|
|
|
key=f"{sender_socket.name}/{receiver_socket.name}",
|
|
|
|
|
conn_type=_type_name(sender_socket.type),
|
|
|
|
|
from_socket=sender_socket,
|
|
|
|
|
to_socket=receiver_socket,
|
|
|
|
|
mandatory=receiver_socket.is_mandatory,
|
|
|
|
|
)
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def get_component(self, name: str) -> Component:
|
|
|
|
|
"""
|
|
|
|
|
Get the component with the specified name from the pipeline.
|
|
|
|
|
|
|
|
|
|
:param name:
|
|
|
|
|
The name of the component.
|
|
|
|
|
:returns:
|
|
|
|
|
The instance of that component.
|
|
|
|
|
|
|
|
|
|
:raises ValueError:
|
|
|
|
|
If a component with that name is not present in the pipeline.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
return self.graph.nodes[name]["instance"]
|
|
|
|
|
except KeyError as exc:
|
|
|
|
|
raise ValueError(f"Component named {name} not found in the pipeline.") from exc
|
|
|
|
|
|
|
|
|
|
def get_component_name(self, instance: Component) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Returns the name of the Component instance if it has been added to this Pipeline or an empty string otherwise.
|
|
|
|
|
|
|
|
|
|
:param instance:
|
|
|
|
|
The Component instance to look for.
|
|
|
|
|
:returns:
|
|
|
|
|
The name of the Component instance.
|
|
|
|
|
"""
|
|
|
|
|
for name, inst in self.graph.nodes(data="instance"): # type: ignore # type wrongly defined in networkx
|
|
|
|
|
if inst == instance:
|
|
|
|
|
return name
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
def inputs(self, include_components_with_connected_inputs: bool = False) -> Dict[str, Dict[str, Any]]:
|
|
|
|
|
"""
|
|
|
|
|
Returns a dictionary containing the inputs of a pipeline.
|
|
|
|
|
|
|
|
|
|
Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
|
|
|
|
|
the input sockets of that component, including their types and whether they are optional.
|
|
|
|
|
|
|
|
|
|
:param include_components_with_connected_inputs:
|
|
|
|
|
If `False`, only components that have disconnected input edges are
|
|
|
|
|
included in the output.
|
|
|
|
|
:returns:
|
|
|
|
|
A dictionary where each key is a pipeline component name and each value is a dictionary of
|
|
|
|
|
inputs sockets of that component.
|
|
|
|
|
"""
|
|
|
|
|
inputs: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
for component_name, data in find_pipeline_inputs(self.graph, include_components_with_connected_inputs).items():
|
|
|
|
|
sockets_description = {}
|
|
|
|
|
for socket in data:
|
|
|
|
|
sockets_description[socket.name] = {"type": socket.type, "is_mandatory": socket.is_mandatory}
|
|
|
|
|
if not socket.is_mandatory:
|
|
|
|
|
sockets_description[socket.name]["default_value"] = socket.default_value
|
|
|
|
|
|
|
|
|
|
if sockets_description:
|
|
|
|
|
inputs[component_name] = sockets_description
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
def outputs(self, include_components_with_connected_outputs: bool = False) -> Dict[str, Dict[str, Any]]:
|
|
|
|
|
"""
|
|
|
|
|
Returns a dictionary containing the outputs of a pipeline.
|
|
|
|
|
|
|
|
|
|
Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
|
|
|
|
|
the output sockets of that component.
|
|
|
|
|
|
|
|
|
|
:param include_components_with_connected_outputs:
|
|
|
|
|
If `False`, only components that have disconnected output edges are
|
|
|
|
|
included in the output.
|
|
|
|
|
:returns:
|
|
|
|
|
A dictionary where each key is a pipeline component name and each value is a dictionary of
|
|
|
|
|
output sockets of that component.
|
|
|
|
|
"""
|
|
|
|
|
outputs = {
|
|
|
|
|
comp: {socket.name: {"type": socket.type} for socket in data}
|
|
|
|
|
for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items()
|
|
|
|
|
if data
|
|
|
|
|
}
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def show(self) -> None:
|
|
|
|
|
"""
|
|
|
|
|
If running in a Jupyter notebook, display an image representing this `Pipeline`.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if is_in_jupyter():
|
|
|
|
|
from IPython.display import Image, display # type: ignore
|
|
|
|
|
|
|
|
|
|
image_data = _to_mermaid_image(self.graph)
|
|
|
|
|
|
|
|
|
|
display(Image(image_data))
|
|
|
|
|
else:
|
|
|
|
|
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
|
|
|
|
|
raise PipelineDrawingError(msg)
|
|
|
|
|
|
|
|
|
|
def draw(self, path: Path) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Save an image representing this `Pipeline` to `path`.
|
|
|
|
|
|
|
|
|
|
:param path:
|
|
|
|
|
The path to save the image to.
|
|
|
|
|
"""
|
|
|
|
|
# Before drawing we edit a bit the graph, to avoid modifying the original that is
|
|
|
|
|
# used for running the pipeline we copy it.
|
|
|
|
|
image_data = _to_mermaid_image(self.graph)
|
|
|
|
|
Path(path).write_bytes(image_data)
|
|
|
|
|
|
|
|
|
|
def walk(self) -> Iterator[Tuple[str, Component]]:
|
|
|
|
|
"""
|
|
|
|
|
Visits each component in the pipeline exactly once and yields its name and instance.
|
|
|
|
|
|
|
|
|
|
No guarantees are provided on the visiting order.
|
|
|
|
|
|
|
|
|
|
:returns:
|
|
|
|
|
An iterator of tuples of component name and component instance.
|
|
|
|
|
"""
|
|
|
|
|
for component_name, instance in self.graph.nodes(data="instance"): # type: ignore # type is wrong in networkx
|
|
|
|
|
yield component_name, instance
|
|
|
|
|
|
|
|
|
|
def warm_up(self):
|
|
|
|
|
"""
|
|
|
|
|
Make sure all nodes are warm.
|
|
|
|
|
|
|
|
|
|
It's the node's responsibility to make sure this method can be called at every `Pipeline.run()`
|
|
|
|
|
without re-initializing everything.
|
|
|
|
|
"""
|
|
|
|
|
for node in self.graph.nodes:
|
|
|
|
|
if hasattr(self.graph.nodes[node]["instance"], "warm_up"):
|
|
|
|
|
logger.info("Warming up component {node}...", node=node)
|
|
|
|
|
self.graph.nodes[node]["instance"].warm_up()
|
|
|
|
|
|
|
|
|
|
def _validate_input(self, data: Dict[str, Any]):
|
|
|
|
|
"""
|
|
|
|
|
Validates pipeline input data.
|
|
|
|
|
|
|
|
|
|
Validates that data:
|
|
|
|
|
* Each Component name actually exists in the Pipeline
|
|
|
|
|
* Each Component is not missing any input
|
|
|
|
|
* Each Component has only one input per input socket, if not variadic
|
|
|
|
|
* Each Component doesn't receive inputs that are already sent by another Component
|
|
|
|
|
|
|
|
|
|
:param data:
|
|
|
|
|
A dictionary of inputs for the pipeline's components. Each key is a component name.
|
|
|
|
|
|
|
|
|
|
:raises ValueError:
|
|
|
|
|
If inputs are invalid according to the above.
|
|
|
|
|
"""
|
|
|
|
|
for component_name, component_inputs in data.items():
|
|
|
|
|
if component_name not in self.graph.nodes:
|
|
|
|
|
raise ValueError(f"Component named {component_name} not found in the pipeline.")
|
|
|
|
|
instance = self.graph.nodes[component_name]["instance"]
|
|
|
|
|
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
|
|
|
|
|
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
|
|
|
|
|
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
|
|
|
|
|
for input_name in component_inputs.keys():
|
|
|
|
|
if input_name not in instance.__haystack_input__._sockets_dict:
|
|
|
|
|
raise ValueError(f"Input {input_name} not found in component {component_name}.")
|
|
|
|
|
|
|
|
|
|
for component_name in self.graph.nodes:
|
|
|
|
|
instance = self.graph.nodes[component_name]["instance"]
|
|
|
|
|
for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
|
|
|
|
|
component_inputs = data.get(component_name, {})
|
|
|
|
|
if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
|
|
|
|
|
raise ValueError(f"Missing input for component {component_name}: {socket_name}")
|
|
|
|
|
if socket.senders and socket_name in component_inputs and not socket.is_variadic:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Input {socket_name} for component {component_name} is already sent by {socket.senders}."
|
|
|
|
|
)
|
|
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
2024-05-10 11:35:15 +02:00
|
|
|
"""
|
|
|
|
|
Prepares input data for pipeline components.
|
|
|
|
|
|
|
|
|
|
Organizes input data for pipeline components and identifies any inputs that are not matched to any
|
2024-05-14 23:25:46 +02:00
|
|
|
component's input slots. Deep-copies data items to avoid sharing mutables across multiple components.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
This method processes a flat dictionary of input data, where each key-value pair represents an input name
|
|
|
|
|
and its corresponding value. It distributes these inputs to the appropriate pipeline components based on
|
|
|
|
|
their input requirements. Inputs that don't match any component's input slots are classified as unresolved.
|
|
|
|
|
|
|
|
|
|
:param data:
|
2024-05-14 23:25:46 +02:00
|
|
|
A dictionary potentially having input names as keys and input values as values.
|
2024-05-10 11:35:15 +02:00
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
:returns:
|
|
|
|
|
A dictionary mapping component names to their respective matched inputs.
|
|
|
|
|
"""
|
|
|
|
|
# check whether the data is a nested dictionary of component inputs where each key is a component name
|
|
|
|
|
# and each value is a dictionary of input parameters for that component
|
|
|
|
|
is_nested_component_input = all(isinstance(value, dict) for value in data.values())
|
|
|
|
|
if not is_nested_component_input:
|
|
|
|
|
# flat input, a dict where keys are input names and values are the corresponding values
|
|
|
|
|
# we need to convert it to a nested dictionary of component inputs and then run the pipeline
|
|
|
|
|
# just like in the previous case
|
|
|
|
|
pipeline_input_data: Dict[str, Dict[str, Any]] = defaultdict(dict)
|
|
|
|
|
unresolved_kwargs = {}
|
|
|
|
|
|
|
|
|
|
# Retrieve the input slots for each component in the pipeline
|
|
|
|
|
available_inputs: Dict[str, Dict[str, Any]] = self.inputs()
|
|
|
|
|
|
|
|
|
|
# Go through all provided to distribute them to the appropriate component inputs
|
|
|
|
|
for input_name, input_value in data.items():
|
|
|
|
|
resolved_at_least_once = False
|
|
|
|
|
|
|
|
|
|
# Check each component to see if it has a slot for the current kwarg
|
|
|
|
|
for component_name, component_inputs in available_inputs.items():
|
|
|
|
|
if input_name in component_inputs:
|
|
|
|
|
# If a match is found, add the kwarg to the component's input data
|
|
|
|
|
pipeline_input_data[component_name][input_name] = input_value
|
|
|
|
|
resolved_at_least_once = True
|
|
|
|
|
|
|
|
|
|
if not resolved_at_least_once:
|
|
|
|
|
unresolved_kwargs[input_name] = input_value
|
|
|
|
|
|
|
|
|
|
if unresolved_kwargs:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Inputs {input_keys} were not matched to any component inputs, please check your run parameters.",
|
|
|
|
|
input_keys=list(unresolved_kwargs.keys()),
|
|
|
|
|
)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
data = dict(pipeline_input_data)
|
2024-05-10 11:35:15 +02:00
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
# deepcopying the inputs prevents the Pipeline run logic from being altered unexpectedly
|
|
|
|
|
# when the same input reference is passed to multiple components.
|
|
|
|
|
for component_name, component_inputs in data.items():
|
|
|
|
|
data[component_name] = {k: deepcopy(v) for k, v in component_inputs.items()}
|
2024-05-10 11:35:15 +02:00
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
return data
|
2024-05-10 11:35:15 +02:00
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
def _init_inputs_state(self, data: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
|
|
|
|
for component_name, component_inputs in data.items():
|
|
|
|
|
if component_name not in self.graph.nodes:
|
|
|
|
|
# This is not a component name, it must be the name of one or more input sockets.
|
|
|
|
|
# Those are handled in a different way, so we skip them here.
|
|
|
|
|
continue
|
|
|
|
|
instance = self.graph.nodes[component_name]["instance"]
|
|
|
|
|
for component_input, input_value in component_inputs.items():
|
|
|
|
|
# Handle mutable input data
|
|
|
|
|
data[component_name][component_input] = copy(input_value)
|
|
|
|
|
if instance.__haystack_input__._sockets_dict[component_input].is_variadic:
|
|
|
|
|
# Components that have variadic inputs need to receive lists as input.
|
|
|
|
|
# We don't want to force the user to always pass lists, so we convert single values to lists here.
|
|
|
|
|
# If it's already a list we assume the component takes a variadic input of lists, so we
|
|
|
|
|
# convert it in any case.
|
|
|
|
|
data[component_name][component_input] = [input_value]
|
|
|
|
|
|
|
|
|
|
return {**data}
|
|
|
|
|
|
2024-06-06 15:19:07 +02:00
|
|
|
def _init_to_run(self, pipeline_inputs: Dict[str, Any]) -> List[Tuple[str, Component]]:
|
2024-05-14 23:25:46 +02:00
|
|
|
to_run: List[Tuple[str, Component]] = []
|
|
|
|
|
for node_name in self.graph.nodes:
|
|
|
|
|
component = self.graph.nodes[node_name]["instance"]
|
|
|
|
|
|
|
|
|
|
if len(component.__haystack_input__._sockets_dict) == 0:
|
|
|
|
|
# Component has no input, can run right away
|
|
|
|
|
to_run.append((node_name, component))
|
|
|
|
|
continue
|
|
|
|
|
|
2024-06-06 15:19:07 +02:00
|
|
|
if node_name in pipeline_inputs:
|
|
|
|
|
# This component is in the input data, if it has enough inputs it can run right away
|
|
|
|
|
to_run.append((node_name, component))
|
|
|
|
|
continue
|
|
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
for socket in component.__haystack_input__._sockets_dict.values():
|
|
|
|
|
if not socket.senders or socket.is_variadic:
|
|
|
|
|
# Component has at least one input not connected or is variadic, can run right away.
|
|
|
|
|
to_run.append((node_name, component))
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
return to_run
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_template(
|
|
|
|
|
cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None
|
|
|
|
|
) -> "PipelineBase":
|
|
|
|
|
"""
|
|
|
|
|
Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options.
|
|
|
|
|
|
|
|
|
|
:param predefined_pipeline:
|
|
|
|
|
The predefined pipeline to use.
|
|
|
|
|
:param template_params:
|
|
|
|
|
An optional dictionary of parameters to use when rendering the pipeline template.
|
|
|
|
|
:returns:
|
|
|
|
|
An instance of `Pipeline`.
|
|
|
|
|
"""
|
|
|
|
|
tpl = PipelineTemplate.from_predefined(predefined_pipeline)
|
|
|
|
|
# If tpl.render() fails, we let bubble up the original error
|
|
|
|
|
rendered = tpl.render(template_params)
|
|
|
|
|
|
|
|
|
|
# If there was a problem with the rendered version of the
|
|
|
|
|
# template, we add it to the error stack for debugging
|
|
|
|
|
try:
|
|
|
|
|
return cls.loads(rendered)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
msg = f"Error unmarshalling pipeline: {e}\n"
|
|
|
|
|
msg += f"Source:\n{rendered}"
|
|
|
|
|
raise PipelineUnmarshalError(msg)
|
|
|
|
|
|
2024-05-14 23:25:46 +02:00
|
|
|
def _init_graph(self):
|
|
|
|
|
"""Resets the visits count for each component"""
|
|
|
|
|
for node in self.graph.nodes:
|
|
|
|
|
self.graph.nodes[node]["visits"] = 0
|
|
|
|
|
|
2024-05-10 11:35:15 +02:00
|
|
|
|
|
|
|
|
def _connections_status(
|
|
|
|
|
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Lists the status of the sockets, for error messages.
|
|
|
|
|
"""
|
|
|
|
|
sender_sockets_entries = []
|
|
|
|
|
for sender_socket in sender_sockets:
|
|
|
|
|
sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}")
|
|
|
|
|
sender_sockets_list = "\n".join(sender_sockets_entries)
|
|
|
|
|
|
|
|
|
|
receiver_sockets_entries = []
|
|
|
|
|
for receiver_socket in receiver_sockets:
|
|
|
|
|
if receiver_socket.senders:
|
|
|
|
|
sender_status = f"sent by {','.join(receiver_socket.senders)}"
|
|
|
|
|
else:
|
|
|
|
|
sender_status = "available"
|
|
|
|
|
receiver_sockets_entries.append(
|
|
|
|
|
f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
|
|
|
|
|
)
|
|
|
|
|
receiver_sockets_list = "\n".join(receiver_sockets_entries)
|
|
|
|
|
|
|
|
|
|
return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"
|