mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-16 17:48:19 +00:00
feat: draw/show SuperComponents in detail, expand it and show it's internal components in the visualisation diagram (#9389)
* initial import * small fixes * adding tests * adding tests * refactoring merge graphs * updating tests * docstrings * adding release notes * adding SuperComponent name to extended components * adding colours and legend to different super components * adding missed docstring parameter * fixing tests and type checking * Update haystack/core/pipeline/base.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * forcing keyword arguments for draw() and show() * adding wrapper function and a deprecation warning * adding pylint disable - this will be removed soon * wip * adding a decorator function to test if another function is being called with positional arguments * adding a decorator function to test if another function is being called with positional arguments --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
This commit is contained in:
parent
ba41696bba
commit
3342f17f01
@ -32,7 +32,12 @@ from haystack.core.pipeline.component_checks import (
|
|||||||
is_any_greedy_socket_ready,
|
is_any_greedy_socket_ready,
|
||||||
is_socket_lazy_variadic,
|
is_socket_lazy_variadic,
|
||||||
)
|
)
|
||||||
from haystack.core.pipeline.utils import FIFOPriorityQueue, _deepcopy_with_exceptions, parse_connect_string
|
from haystack.core.pipeline.utils import (
|
||||||
|
FIFOPriorityQueue,
|
||||||
|
_deepcopy_with_exceptions,
|
||||||
|
args_deprecated,
|
||||||
|
parse_connect_string,
|
||||||
|
)
|
||||||
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
|
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
|
||||||
from haystack.core.type_utils import _type_name, _types_are_compatible
|
from haystack.core.type_utils import _type_name, _types_are_compatible
|
||||||
from haystack.marshal import Marshaller, YamlMarshaller
|
from haystack.marshal import Marshaller, YamlMarshaller
|
||||||
@ -669,7 +674,14 @@ class PipelineBase:
|
|||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30) -> None:
|
@args_deprecated
|
||||||
|
def show(
|
||||||
|
self,
|
||||||
|
server_url: str = "https://mermaid.ink",
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
super_component_expansion: bool = False,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Display an image representing this `Pipeline` in a Jupyter notebook.
|
Display an image representing this `Pipeline` in a Jupyter notebook.
|
||||||
|
|
||||||
@ -698,20 +710,62 @@ class PipelineBase:
|
|||||||
:param timeout:
|
:param timeout:
|
||||||
Timeout in seconds for the request to the Mermaid server.
|
Timeout in seconds for the request to the Mermaid server.
|
||||||
|
|
||||||
|
:param super_component_expansion:
|
||||||
|
If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
|
||||||
|
super-components as if they were components part of the pipeline instead of a "black-box".
|
||||||
|
Otherwise, only the super-component itself will be displayed.
|
||||||
|
|
||||||
:raises PipelineDrawingError:
|
:raises PipelineDrawingError:
|
||||||
If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
|
If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Call the internal implementation with keyword arguments
|
||||||
|
self._show_internal(
|
||||||
|
server_url=server_url, params=params, timeout=timeout, super_component_expansion=super_component_expansion
|
||||||
|
)
|
||||||
|
|
||||||
|
def _show_internal(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
server_url: str = "https://mermaid.ink",
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
super_component_expansion: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Internal implementation of show() that uses keyword-only arguments.
|
||||||
|
|
||||||
|
ToDo: after 2.14.0 release make this the main function and remove the old one.
|
||||||
|
"""
|
||||||
if is_in_jupyter():
|
if is_in_jupyter():
|
||||||
from IPython.display import Image, display # type: ignore
|
from IPython.display import Image, display # type: ignore
|
||||||
|
|
||||||
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params, timeout=timeout)
|
if super_component_expansion:
|
||||||
|
graph, super_component_mapping = self._merge_super_component_pipelines()
|
||||||
|
else:
|
||||||
|
graph = self.graph
|
||||||
|
super_component_mapping = None
|
||||||
|
|
||||||
|
image_data = _to_mermaid_image(
|
||||||
|
graph,
|
||||||
|
server_url=server_url,
|
||||||
|
params=params,
|
||||||
|
timeout=timeout,
|
||||||
|
super_component_mapping=super_component_mapping,
|
||||||
|
)
|
||||||
display(Image(image_data))
|
display(Image(image_data))
|
||||||
else:
|
else:
|
||||||
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
|
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
|
||||||
raise PipelineDrawingError(msg)
|
raise PipelineDrawingError(msg)
|
||||||
|
|
||||||
def draw(
|
@args_deprecated
|
||||||
self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30
|
def draw( # pylint: disable=too-many-positional-arguments
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
server_url: str = "https://mermaid.ink",
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
super_component_expansion: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Save an image representing this `Pipeline` to the specified file path.
|
Save an image representing this `Pipeline` to the specified file path.
|
||||||
@ -720,10 +774,12 @@ class PipelineBase:
|
|||||||
|
|
||||||
:param path:
|
:param path:
|
||||||
The file path where the generated image will be saved.
|
The file path where the generated image will be saved.
|
||||||
|
|
||||||
:param server_url:
|
:param server_url:
|
||||||
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
|
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
|
||||||
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
|
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
|
||||||
info on how to set up your own Mermaid server.
|
info on how to set up your own Mermaid server.
|
||||||
|
|
||||||
:param params:
|
:param params:
|
||||||
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
|
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
|
||||||
Supported keys:
|
Supported keys:
|
||||||
@ -741,12 +797,53 @@ class PipelineBase:
|
|||||||
:param timeout:
|
:param timeout:
|
||||||
Timeout in seconds for the request to the Mermaid server.
|
Timeout in seconds for the request to the Mermaid server.
|
||||||
|
|
||||||
|
:param super_component_expansion:
|
||||||
|
If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
|
||||||
|
super-components as if they were components part of the pipeline instead of a "black-box".
|
||||||
|
Otherwise, only the super-component itself will be displayed.
|
||||||
|
|
||||||
:raises PipelineDrawingError:
|
:raises PipelineDrawingError:
|
||||||
If there is an issue with rendering or saving the image.
|
If there is an issue with rendering or saving the image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Call the internal implementation with keyword arguments
|
||||||
|
self._draw_internal(
|
||||||
|
path=path,
|
||||||
|
server_url=server_url,
|
||||||
|
params=params,
|
||||||
|
timeout=timeout,
|
||||||
|
super_component_expansion=super_component_expansion,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _draw_internal(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
path: Path,
|
||||||
|
server_url: str = "https://mermaid.ink",
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
super_component_expansion: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Internal implementation of draw() that uses keyword-only arguments.
|
||||||
|
|
||||||
|
ToDo: after 2.14.0 release make this the main function and remove the old one.
|
||||||
|
"""
|
||||||
# Before drawing we edit a bit the graph, to avoid modifying the original that is
|
# Before drawing we edit a bit the graph, to avoid modifying the original that is
|
||||||
# used for running the pipeline we copy it.
|
# used for running the pipeline we copy it.
|
||||||
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params, timeout=timeout)
|
if super_component_expansion:
|
||||||
|
graph, super_component_mapping = self._merge_super_component_pipelines()
|
||||||
|
else:
|
||||||
|
graph = self.graph
|
||||||
|
super_component_mapping = None
|
||||||
|
|
||||||
|
image_data = _to_mermaid_image(
|
||||||
|
graph,
|
||||||
|
server_url=server_url,
|
||||||
|
params=params,
|
||||||
|
timeout=timeout,
|
||||||
|
super_component_mapping=super_component_mapping,
|
||||||
|
)
|
||||||
Path(path).write_bytes(image_data)
|
Path(path).write_bytes(image_data)
|
||||||
|
|
||||||
def walk(self) -> Iterator[Tuple[str, Component]]:
|
def walk(self) -> Iterator[Tuple[str, Component]]:
|
||||||
@ -1175,7 +1272,7 @@ class PipelineBase:
|
|||||||
for receiver_name, sender_socket, receiver_socket in receivers:
|
for receiver_name, sender_socket, receiver_socket in receivers:
|
||||||
# We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
|
# We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
|
||||||
# that the sender did not produce an output for this socket.
|
# that the sender did not produce an output for this socket.
|
||||||
# This allows us to track if a pre-decessor already ran but did not produce an output.
|
# This allows us to track if a predecessor already ran but did not produce an output.
|
||||||
value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
|
value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
|
||||||
|
|
||||||
if receiver_name not in inputs:
|
if receiver_name not in inputs:
|
||||||
@ -1239,6 +1336,99 @@ class PipelineBase:
|
|||||||
if candidate is not None and candidate[0] == ComponentPriority.BLOCKED:
|
if candidate is not None and candidate[0] == ComponentPriority.BLOCKED:
|
||||||
raise PipelineComponentsBlockedError()
|
raise PipelineComponentsBlockedError()
|
||||||
|
|
||||||
|
def _find_super_components(self) -> list[tuple[str, Component]]:
|
||||||
|
"""
|
||||||
|
Find all SuperComponents in the pipeline.
|
||||||
|
|
||||||
|
:returns:
|
||||||
|
List of tuples containing (component_name, component_instance) representing a SuperComponent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super_components = []
|
||||||
|
for comp_name, comp in self.walk():
|
||||||
|
# a SuperComponent has a "pipeline" attribute which itself a Pipeline instance
|
||||||
|
# we don't test against SuperComponent because doing so always lead to circular imports
|
||||||
|
if hasattr(comp, "pipeline") and isinstance(comp.pipeline, self.__class__):
|
||||||
|
super_components.append((comp_name, comp))
|
||||||
|
return super_components
|
||||||
|
|
||||||
|
def _merge_super_component_pipelines(self) -> Tuple["networkx.MultiDiGraph", Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Merge the internal pipelines of SuperComponents into the main pipeline graph structure.
|
||||||
|
|
||||||
|
This creates a new networkx.MultiDiGraph containing all the components from both the main pipeline
|
||||||
|
and all the internal SuperComponents' pipelines. The SuperComponents are removed and their internal
|
||||||
|
components are connected to corresponding input and output sockets of the main pipeline.
|
||||||
|
|
||||||
|
:returns:
|
||||||
|
A tuple containing:
|
||||||
|
- A networkx.MultiDiGraph with the expanded structure of the main pipeline and all it's SuperComponents
|
||||||
|
- A dictionary mapping component names to boolean indicating that this component was part of a
|
||||||
|
SuperComponent
|
||||||
|
- A dictionary mapping component names to their SuperComponent name
|
||||||
|
"""
|
||||||
|
merged_graph = self.graph.copy()
|
||||||
|
super_component_mapping: Dict[str, str] = {}
|
||||||
|
|
||||||
|
for super_name, super_component in self._find_super_components():
|
||||||
|
internal_pipeline = super_component.pipeline # type: ignore
|
||||||
|
internal_graph = internal_pipeline.graph.copy()
|
||||||
|
|
||||||
|
# Mark all components in the internal pipeline as being part of a SuperComponent
|
||||||
|
for node in internal_graph.nodes():
|
||||||
|
super_component_mapping[node] = super_name
|
||||||
|
|
||||||
|
# edges connected to the super component
|
||||||
|
incoming_edges = list(merged_graph.in_edges(super_name, data=True))
|
||||||
|
outgoing_edges = list(merged_graph.out_edges(super_name, data=True))
|
||||||
|
|
||||||
|
# merge the SuperComponent graph into the main graph and remove the super component node
|
||||||
|
# since its components are now part of the main graph
|
||||||
|
merged_graph = networkx.compose(merged_graph, internal_graph)
|
||||||
|
merged_graph.remove_node(super_name)
|
||||||
|
|
||||||
|
# get the entry and exit points of the SuperComponent internal pipeline
|
||||||
|
entry_points = [n for n in internal_graph.nodes() if internal_graph.in_degree(n) == 0]
|
||||||
|
exit_points = [n for n in internal_graph.nodes() if internal_graph.out_degree(n) == 0]
|
||||||
|
|
||||||
|
# connect the incoming edges to entry points
|
||||||
|
for sender, _, edge_data in incoming_edges:
|
||||||
|
sender_socket = edge_data["from_socket"]
|
||||||
|
for entry_point in entry_points:
|
||||||
|
# find a matching input socket in the entry point
|
||||||
|
entry_point_sockets = internal_graph.nodes[entry_point]["input_sockets"]
|
||||||
|
for socket_name, socket in entry_point_sockets.items():
|
||||||
|
if _types_are_compatible(sender_socket.type, socket.type, self._connection_type_validation):
|
||||||
|
merged_graph.add_edge(
|
||||||
|
sender,
|
||||||
|
entry_point,
|
||||||
|
key=f"{sender_socket.name}/{socket_name}",
|
||||||
|
conn_type=_type_name(sender_socket.type),
|
||||||
|
from_socket=sender_socket,
|
||||||
|
to_socket=socket,
|
||||||
|
mandatory=socket.is_mandatory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# connect outgoing edges from exit points
|
||||||
|
for _, receiver, edge_data in outgoing_edges:
|
||||||
|
receiver_socket = edge_data["to_socket"]
|
||||||
|
for exit_point in exit_points:
|
||||||
|
# find a matching output socket in the exit point
|
||||||
|
exit_point_sockets = internal_graph.nodes[exit_point]["output_sockets"]
|
||||||
|
for socket_name, socket in exit_point_sockets.items():
|
||||||
|
if _types_are_compatible(socket.type, receiver_socket.type, self._connection_type_validation):
|
||||||
|
merged_graph.add_edge(
|
||||||
|
exit_point,
|
||||||
|
receiver,
|
||||||
|
key=f"{socket_name}/{receiver_socket.name}",
|
||||||
|
conn_type=_type_name(socket.type),
|
||||||
|
from_socket=socket,
|
||||||
|
to_socket=receiver_socket,
|
||||||
|
mandatory=receiver_socket.is_mandatory,
|
||||||
|
)
|
||||||
|
|
||||||
|
return merged_graph, super_component_mapping
|
||||||
|
|
||||||
|
|
||||||
def _connections_status(
|
def _connections_status(
|
||||||
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
|
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
|
||||||
|
|||||||
@ -3,9 +3,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import colorsys
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
import zlib
|
import zlib
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import networkx # type:ignore
|
import networkx # type:ignore
|
||||||
import requests
|
import requests
|
||||||
@ -18,6 +20,44 @@ from haystack.core.type_utils import _type_name
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_color_variations(n: int, base_color: Optional[str] = "#3498DB", variation_range=0.4) -> List[str]:
|
||||||
|
"""
|
||||||
|
Generate n different variations of a base color.
|
||||||
|
|
||||||
|
:param n: Number of variations to generate
|
||||||
|
:param base_color: Hex color code, default is a shade of blue (#3498DB)
|
||||||
|
:param variation_range: Range for varying brightness and saturation (0-1)
|
||||||
|
|
||||||
|
:returns:
|
||||||
|
list: List of hex color codes representing variations of the base color
|
||||||
|
"""
|
||||||
|
# convert hex to RGB
|
||||||
|
base_color = base_color.lstrip("#") # type:ignore
|
||||||
|
r = int(base_color[0:2], 16) / 255.0
|
||||||
|
g = int(base_color[2:4], 16) / 255.0
|
||||||
|
b = int(base_color[4:6], 16) / 255.0
|
||||||
|
|
||||||
|
# convert RGB to HSV (Hue, Saturation, Value)
|
||||||
|
h, s, v = colorsys.rgb_to_hsv(r, g, b)
|
||||||
|
|
||||||
|
variations = []
|
||||||
|
for _ in range(n):
|
||||||
|
# vary saturation and brightness within the specified range
|
||||||
|
new_s = max(0, min(1, s + random.uniform(-variation_range, variation_range)))
|
||||||
|
new_v = max(0, min(1, v + random.uniform(-variation_range, variation_range)))
|
||||||
|
|
||||||
|
# keep hue the same for color consistency
|
||||||
|
new_h = h
|
||||||
|
|
||||||
|
# Convert back to RGB and then to hex
|
||||||
|
new_r, new_g, new_b = colorsys.hsv_to_rgb(new_h, new_s, new_v)
|
||||||
|
hex_color = "#{:02x}{:02x}{:02x}".format(int(new_r * 255), int(new_g * 255), int(new_b * 255))
|
||||||
|
|
||||||
|
variations.append(hex_color)
|
||||||
|
|
||||||
|
return variations
|
||||||
|
|
||||||
|
|
||||||
def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
|
def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
|
||||||
"""
|
"""
|
||||||
Add some extra nodes to show the inputs and outputs of the pipeline.
|
Add some extra nodes to show the inputs and outputs of the pipeline.
|
||||||
@ -62,6 +102,7 @@ graph TD;
|
|||||||
{connections}
|
{connections}
|
||||||
|
|
||||||
classDef component text-align:center;
|
classDef component text-align:center;
|
||||||
|
{style_definitions}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -133,6 +174,7 @@ def _to_mermaid_image(
|
|||||||
server_url: str = "https://mermaid.ink",
|
server_url: str = "https://mermaid.ink",
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
timeout: int = 30,
|
timeout: int = 30,
|
||||||
|
super_component_mapping: Optional[Dict[str, str]] = None,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""
|
"""
|
||||||
Renders a pipeline using a Mermaid server.
|
Renders a pipeline using a Mermaid server.
|
||||||
@ -162,7 +204,7 @@ def _to_mermaid_image(
|
|||||||
init_params = json.dumps({"theme": theme})
|
init_params = json.dumps({"theme": theme})
|
||||||
|
|
||||||
# Copy the graph to avoid modifying the original
|
# Copy the graph to avoid modifying the original
|
||||||
graph_styled = _to_mermaid_text(graph.copy(), init_params)
|
graph_styled = _to_mermaid_text(graph.copy(), init_params, super_component_mapping)
|
||||||
json_string = json.dumps({"code": graph_styled})
|
json_string = json.dumps({"code": graph_styled})
|
||||||
|
|
||||||
# Compress the JSON string with zlib (RFC 1950)
|
# Compress the JSON string with zlib (RFC 1950)
|
||||||
@ -214,12 +256,18 @@ def _to_mermaid_image(
|
|||||||
return resp.content
|
return resp.content
|
||||||
|
|
||||||
|
|
||||||
def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
|
def _to_mermaid_text(
|
||||||
|
graph: networkx.MultiDiGraph, init_params: str, super_component_mapping: Optional[Dict[str, str]] = None
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Converts a Networkx graph into Mermaid syntax.
|
Converts a Networkx graph into Mermaid syntax.
|
||||||
|
|
||||||
The output of this function can be used in the documentation with `mermaid` codeblocks and will be
|
The output of this function can be used in the documentation with `mermaid` codeblocks and will be
|
||||||
automatically rendered.
|
automatically rendered.
|
||||||
|
|
||||||
|
:param graph: The graph to convert to Mermaid syntax
|
||||||
|
:param init_params: Initialization parameters for Mermaid
|
||||||
|
:param super_component_mapping: Mapping of component names to super component names
|
||||||
"""
|
"""
|
||||||
# Copy the graph to avoid modifying the original
|
# Copy the graph to avoid modifying the original
|
||||||
graph = _prepare_for_drawing(graph.copy())
|
graph = _prepare_for_drawing(graph.copy())
|
||||||
@ -238,11 +286,34 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
|
|||||||
for comp, sockets in sockets.items()
|
for comp, sockets in sockets.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
states = {
|
# Create node definitions
|
||||||
comp: f'{comp}["<b>{comp}</b><br><small><i>{type(data["instance"]).__name__}{optional_inputs[comp]}</i></small>"]:::component' # noqa
|
states = {}
|
||||||
for comp, data in graph.nodes(data=True)
|
super_component_components = super_component_mapping.keys() if super_component_mapping else {}
|
||||||
if comp not in ["input", "output"]
|
|
||||||
}
|
# color variations for super components
|
||||||
|
super_component_colors = {}
|
||||||
|
if super_component_components:
|
||||||
|
unique_super_components = set(super_component_mapping.values()) # type:ignore
|
||||||
|
color_variations = generate_color_variations(n=len(unique_super_components))
|
||||||
|
super_component_colors = dict(zip(unique_super_components, color_variations))
|
||||||
|
|
||||||
|
# Generate style definitions for each super component
|
||||||
|
style_definitions = []
|
||||||
|
for super_comp, color in super_component_colors.items():
|
||||||
|
style_definitions.append(f"classDef {super_comp} fill:{color},color:white;")
|
||||||
|
|
||||||
|
for comp, data in graph.nodes(data=True):
|
||||||
|
if comp in ["input", "output"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# styling based on whether the component is a SuperComponent
|
||||||
|
if comp in super_component_components:
|
||||||
|
super_component_name = super_component_mapping[comp] # type:ignore
|
||||||
|
style = super_component_name
|
||||||
|
else:
|
||||||
|
style = "component"
|
||||||
|
node_def = f'{comp}["<b>{comp}</b><br><small><i>{type(data["instance"]).__name__}{optional_inputs[comp]}</i></small>"]:::{style}' # noqa: E501
|
||||||
|
states[comp] = node_def
|
||||||
|
|
||||||
connections_list = []
|
connections_list = []
|
||||||
for from_comp, to_comp, conn_data in graph.edges(data=True):
|
for from_comp, to_comp, conn_data in graph.edges(data=True):
|
||||||
@ -263,7 +334,20 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
|
|||||||
]
|
]
|
||||||
connections = "\n".join(connections_list + input_connections + output_connections)
|
connections = "\n".join(connections_list + input_connections + output_connections)
|
||||||
|
|
||||||
graph_styled = MERMAID_STYLED_TEMPLATE.format(params=init_params, connections=connections)
|
# Create legend
|
||||||
|
legend_nodes = []
|
||||||
|
if super_component_colors:
|
||||||
|
legend_nodes.append("subgraph Legend")
|
||||||
|
for super_comp, color in super_component_colors.items():
|
||||||
|
legend_id = f"legend_{super_comp}"
|
||||||
|
legend_nodes.append(f'{legend_id}["{super_comp}"]:::{super_comp}')
|
||||||
|
legend_nodes.append("end")
|
||||||
|
connections += "\n" + "\n".join(legend_nodes)
|
||||||
|
|
||||||
|
# Add style definitions to the template
|
||||||
|
graph_styled = MERMAID_STYLED_TEMPLATE.format(
|
||||||
|
params=init_params, connections=connections, style_definitions="\n".join(style_definitions)
|
||||||
|
)
|
||||||
logger.debug("Mermaid diagram:\n{diagram}", diagram=graph_styled)
|
logger.debug("Mermaid diagram:\n{diagram}", diagram=graph_styled)
|
||||||
|
|
||||||
return graph_styled
|
return graph_styled
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
import heapq
|
import heapq
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from functools import wraps
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
@ -163,3 +164,41 @@ class FIFOPriorityQueue:
|
|||||||
True if the queue contains items, False otherwise.
|
True if the queue contains items, False otherwise.
|
||||||
"""
|
"""
|
||||||
return bool(self._queue)
|
return bool(self._queue)
|
||||||
|
|
||||||
|
|
||||||
|
def args_deprecated(func):
|
||||||
|
"""
|
||||||
|
Decorator to warn about the use of positional arguments in a function.
|
||||||
|
|
||||||
|
Adapted from https://stackoverflow.com/questions/68432070/
|
||||||
|
:param func:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _positional_arg_warning() -> None:
|
||||||
|
"""
|
||||||
|
Triggers a warning message if positional arguments are used in a function
|
||||||
|
"""
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
msg = (
|
||||||
|
"Warning: In an upcoming release, this method will require keyword arguments for all parameters. "
|
||||||
|
"Please update your code to use keyword arguments to ensure future compatibility. "
|
||||||
|
"Example: pipeline.draw(path='output.png', server_url='https://custom-server.com')"
|
||||||
|
)
|
||||||
|
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# call the function first, to make sure the signature matches
|
||||||
|
ret_value = func(*args, **kwargs)
|
||||||
|
|
||||||
|
# A Pipeline instance is always the first argument - remove it from the args to check for positional arguments
|
||||||
|
# We check the class name as strings to avoid circular imports
|
||||||
|
if args and isinstance(args, tuple) and args[0].__class__.__name__ in ["Pipeline", "PipelineBase"]:
|
||||||
|
args = args[1:]
|
||||||
|
|
||||||
|
if args:
|
||||||
|
_positional_arg_warning()
|
||||||
|
return ret_value
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|||||||
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
The `draw()` and `show()` methods from `Pipeline` now have an extra boolean parameter, `super_component_expansion`, which, when set to `True` and if the pipeline contains SuperComponents, the visualisation diagram
|
||||||
|
will show the internal structure of super-components as if they were components part of the pipeline instead of a "black-box" with the name of the SuperComponent.
|
||||||
@ -94,6 +94,7 @@ comp1["<b>comp1</b><br><small><i>AddFixedValue<br><br>Optional inputs:<ul style=
|
|||||||
comp2["<b>comp2</b><br><small><i>Double</i></small>"]:::component -- "value -> value<br><small><i>int</i></small>" --> comp1["<b>comp1</b><br><small><i>AddFixedValue<br><br>Optional inputs:<ul style='text-align:left;'><li>add (Optional[int])</li></ul></i></small>"]:::component
|
comp2["<b>comp2</b><br><small><i>Double</i></small>"]:::component -- "value -> value<br><small><i>int</i></small>" --> comp1["<b>comp1</b><br><small><i>AddFixedValue<br><br>Optional inputs:<ul style='text-align:left;'><li>add (Optional[int])</li></ul></i></small>"]:::component
|
||||||
|
|
||||||
classDef component text-align:center;
|
classDef component text-align:center;
|
||||||
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -175,6 +175,110 @@ class TestPipelineBase:
|
|||||||
pipe.draw(path=image_path)
|
pipe.draw(path=image_path)
|
||||||
assert image_path.read_bytes() == mock_to_mermaid_image.return_value
|
assert image_path.read_bytes() == mock_to_mermaid_image.return_value
|
||||||
|
|
||||||
|
def test_find_super_components(self):
|
||||||
|
"""
|
||||||
|
Test that the pipeline can find super components in it's pipeline.
|
||||||
|
"""
|
||||||
|
from haystack import Pipeline
|
||||||
|
from haystack.components.converters import MultiFileConverter
|
||||||
|
from haystack.components.preprocessors import DocumentPreprocessor
|
||||||
|
from haystack.components.writers import DocumentWriter
|
||||||
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||||
|
|
||||||
|
multi_file_converter = MultiFileConverter()
|
||||||
|
doc_processor = DocumentPreprocessor()
|
||||||
|
|
||||||
|
pipeline = Pipeline()
|
||||||
|
pipeline.add_component("converter", multi_file_converter)
|
||||||
|
pipeline.add_component("preprocessor", doc_processor)
|
||||||
|
pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore()))
|
||||||
|
pipeline.connect("converter", "preprocessor")
|
||||||
|
pipeline.connect("preprocessor", "writer")
|
||||||
|
|
||||||
|
result = pipeline._find_super_components()
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert [("converter", multi_file_converter), ("preprocessor", doc_processor)] == result
|
||||||
|
|
||||||
|
def test_merge_super_component_pipelines(self):
|
||||||
|
from haystack import Pipeline
|
||||||
|
from haystack.components.converters import MultiFileConverter
|
||||||
|
from haystack.components.preprocessors import DocumentPreprocessor
|
||||||
|
from haystack.components.writers import DocumentWriter
|
||||||
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||||
|
|
||||||
|
multi_file_converter = MultiFileConverter()
|
||||||
|
doc_processor = DocumentPreprocessor()
|
||||||
|
|
||||||
|
pipeline = Pipeline()
|
||||||
|
pipeline.add_component("converter", multi_file_converter)
|
||||||
|
pipeline.add_component("preprocessor", doc_processor)
|
||||||
|
pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore()))
|
||||||
|
pipeline.connect("converter", "preprocessor")
|
||||||
|
pipeline.connect("preprocessor", "writer")
|
||||||
|
|
||||||
|
merged_graph, super_component_components = pipeline._merge_super_component_pipelines()
|
||||||
|
|
||||||
|
assert super_component_components == {
|
||||||
|
"router": "converter",
|
||||||
|
"docx": "converter",
|
||||||
|
"html": "converter",
|
||||||
|
"json": "converter",
|
||||||
|
"md": "converter",
|
||||||
|
"text": "converter",
|
||||||
|
"pdf": "converter",
|
||||||
|
"pptx": "converter",
|
||||||
|
"xlsx": "converter",
|
||||||
|
"joiner": "converter",
|
||||||
|
"csv": "converter",
|
||||||
|
"splitter": "preprocessor",
|
||||||
|
"cleaner": "preprocessor",
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_nodes = [
|
||||||
|
"cleaner",
|
||||||
|
"csv",
|
||||||
|
"docx",
|
||||||
|
"html",
|
||||||
|
"joiner",
|
||||||
|
"json",
|
||||||
|
"md",
|
||||||
|
"pdf",
|
||||||
|
"pptx",
|
||||||
|
"router",
|
||||||
|
"splitter",
|
||||||
|
"text",
|
||||||
|
"writer",
|
||||||
|
"xlsx",
|
||||||
|
]
|
||||||
|
assert sorted(merged_graph.nodes) == expected_nodes
|
||||||
|
|
||||||
|
expected_edges = [
|
||||||
|
("cleaner", "writer"),
|
||||||
|
("csv", "joiner"),
|
||||||
|
("docx", "joiner"),
|
||||||
|
("html", "joiner"),
|
||||||
|
("joiner", "splitter"),
|
||||||
|
("json", "joiner"),
|
||||||
|
("md", "joiner"),
|
||||||
|
("pdf", "joiner"),
|
||||||
|
("pptx", "joiner"),
|
||||||
|
("router", "csv"),
|
||||||
|
("router", "docx"),
|
||||||
|
("router", "html"),
|
||||||
|
("router", "json"),
|
||||||
|
("router", "md"),
|
||||||
|
("router", "pdf"),
|
||||||
|
("router", "pptx"),
|
||||||
|
("router", "text"),
|
||||||
|
("router", "xlsx"),
|
||||||
|
("splitter", "cleaner"),
|
||||||
|
("text", "joiner"),
|
||||||
|
("xlsx", "joiner"),
|
||||||
|
]
|
||||||
|
actual_edges = [(u, v) for u, v, _ in merged_graph.edges]
|
||||||
|
assert sorted(actual_edges) == expected_edges
|
||||||
|
|
||||||
# UNIT
|
# UNIT
|
||||||
def test_add_invalid_component_name(self):
|
def test_add_invalid_component_name(self):
|
||||||
pipe = PipelineBase()
|
pipe = PipelineBase()
|
||||||
@ -1681,3 +1785,84 @@ class TestPipelineBase:
|
|||||||
consumed = PipelineBase._consume_component_inputs("test_component", component, inputs)
|
consumed = PipelineBase._consume_component_inputs("test_component", component, inputs)
|
||||||
|
|
||||||
assert consumed["input1"].equals(DataFrame({"a": [1, 2], "b": [1, 2]}))
|
assert consumed["input1"].equals(DataFrame({"a": [1, 2], "b": [1, 2]}))
|
||||||
|
|
||||||
|
@patch("haystack.core.pipeline.draw.requests")
|
||||||
|
def test_pipeline_draw_called_with_positional_args_triggers_a_warning(self, mock_requests):
|
||||||
|
"""
|
||||||
|
Test that calling the pipeline draw method with positional arguments raises a warning.
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
pipeline = PipelineBase()
|
||||||
|
mock_response = mock_requests.get.return_value
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.content = b"image_data"
|
||||||
|
out_file = Path("original_pipeline.png")
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
pipeline.draw(out_file, server_url="http://localhost:3000")
|
||||||
|
assert len(w) == 1
|
||||||
|
assert issubclass(w[0].category, DeprecationWarning)
|
||||||
|
assert (
|
||||||
|
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
|
||||||
|
in str(w[0].message)
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("haystack.core.pipeline.draw.requests")
|
||||||
|
@patch("haystack.core.pipeline.base.is_in_jupyter")
|
||||||
|
def test_pipeline_show_called_with_positional_args_triggers_a_warning(self, mock_is_in_jupyter, mock_requests):
|
||||||
|
"""
|
||||||
|
Test that calling the pipeline show method with positional arguments raises a warning.
|
||||||
|
"""
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
pipeline = PipelineBase()
|
||||||
|
mock_response = mock_requests.get.return_value
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.content = b"image_data"
|
||||||
|
mock_is_in_jupyter.return_value = True
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
pipeline.show("http://localhost:3000")
|
||||||
|
assert len(w) == 1
|
||||||
|
assert issubclass(w[0].category, DeprecationWarning)
|
||||||
|
assert (
|
||||||
|
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
|
||||||
|
in str(w[0].message)
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("haystack.core.pipeline.draw.requests")
|
||||||
|
def test_pipeline_draw_called_with_keyword_args_triggers_no_warning(self, mock_requests):
|
||||||
|
"""
|
||||||
|
Test that calling the pipeline draw method with keyword arguments does not raise a warning.
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
pipeline = PipelineBase()
|
||||||
|
mock_response = mock_requests.get.return_value
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.content = b"image_data"
|
||||||
|
out_file = Path("original_pipeline.png")
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
pipeline.draw(path=out_file, server_url="http://localhost:3000")
|
||||||
|
assert len(w) == 0, "No warning should be triggered when using keyword arguments"
|
||||||
|
|
||||||
|
@patch("haystack.core.pipeline.draw.requests")
|
||||||
|
@patch("haystack.core.pipeline.base.is_in_jupyter")
|
||||||
|
def test_pipeline_show_called_with_keyword_args_triggers_no_warning(self, mock_is_in_jupyter, mock_requests):
|
||||||
|
"""
|
||||||
|
Test that calling the pipeline show method with keyword arguments does not raise a warning.
|
||||||
|
"""
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
pipeline = PipelineBase()
|
||||||
|
mock_response = mock_requests.get.return_value
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.content = b"image_data"
|
||||||
|
mock_is_in_jupyter.return_value = True
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
pipeline.show(server_url="http://localhost:3000")
|
||||||
|
assert len(w) == 0, "No warning should be triggered when using keyword arguments"
|
||||||
|
|||||||
@ -4,10 +4,16 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import pytest
|
import pytest
|
||||||
|
import warnings
|
||||||
|
|
||||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||||
from haystack.core.pipeline.utils import parse_connect_string, FIFOPriorityQueue, _deepcopy_with_exceptions
|
from haystack.core.pipeline.utils import (
|
||||||
|
parse_connect_string,
|
||||||
|
FIFOPriorityQueue,
|
||||||
|
_deepcopy_with_exceptions,
|
||||||
|
args_deprecated,
|
||||||
|
)
|
||||||
from haystack.tools import ComponentTool, Tool
|
from haystack.tools import ComponentTool, Tool
|
||||||
|
|
||||||
|
|
||||||
@ -247,3 +253,51 @@ class TestDeepcopyWithFallback:
|
|||||||
original = {"component": comp}
|
original = {"component": comp}
|
||||||
res = _deepcopy_with_exceptions(original)
|
res = _deepcopy_with_exceptions(original)
|
||||||
assert res["component"] is original["component"]
|
assert res["component"] is original["component"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestArgsDeprecated:
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_function(self):
|
||||||
|
@args_deprecated
|
||||||
|
def sample_func(param1: str = "default1", param2: int = 42):
|
||||||
|
return f"{param1}-{param2}"
|
||||||
|
|
||||||
|
return sample_func
|
||||||
|
|
||||||
|
def test_warning_with_positional_args(self, sample_function):
|
||||||
|
# using positional arguments only
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
result = sample_function("test", 123)
|
||||||
|
assert result == "test-123"
|
||||||
|
assert len(w) == 1
|
||||||
|
assert issubclass(w[0].category, DeprecationWarning)
|
||||||
|
assert (
|
||||||
|
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
|
||||||
|
in str(w[0].message)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_warning_with_mixed_args(self, sample_function):
|
||||||
|
# mixing positional and keyword arguments
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
result = sample_function("test", 123)
|
||||||
|
assert result == "test-123"
|
||||||
|
assert len(w) == 1
|
||||||
|
assert issubclass(w[0].category, DeprecationWarning)
|
||||||
|
assert (
|
||||||
|
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
|
||||||
|
in str(w[0].message)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_warning_with_default_args(self, sample_function):
|
||||||
|
# using default arguments
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
result = sample_function()
|
||||||
|
assert result == "default1-42"
|
||||||
|
assert len(w) == 0
|
||||||
|
|
||||||
|
def test_no_warning_with_keyword_args(self, sample_function):
|
||||||
|
# using keyword arguments
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
result = sample_function(param1="test", param2=123)
|
||||||
|
assert result == "test-123"
|
||||||
|
assert len(w) == 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user