mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-13 07:47:26 +00:00
feat: draw/show SuperComponents in detail, expand it and show it's internal components in the visualisation diagram (#9389)
* initial import * small fixes * adding tests * adding tests * refactoring merge graphs * updating tests * docstrings * adding release notes * adding SuperComponent name to extended components * adding colours and legend to different super components * adding missed docstring parameter * fixing tests and type checking * Update haystack/core/pipeline/base.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * forcing keyword arguments for draw() and show() * adding wrapper function and a deprecation warning * adding pylint disable - this will be removed soon * wip * adding a decorator function to test if another function is being called with positional arguments * adding a decorator function to test if another function is being called with positional arguments --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
This commit is contained in:
parent
ba41696bba
commit
3342f17f01
@ -32,7 +32,12 @@ from haystack.core.pipeline.component_checks import (
|
||||
is_any_greedy_socket_ready,
|
||||
is_socket_lazy_variadic,
|
||||
)
|
||||
from haystack.core.pipeline.utils import FIFOPriorityQueue, _deepcopy_with_exceptions, parse_connect_string
|
||||
from haystack.core.pipeline.utils import (
|
||||
FIFOPriorityQueue,
|
||||
_deepcopy_with_exceptions,
|
||||
args_deprecated,
|
||||
parse_connect_string,
|
||||
)
|
||||
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
|
||||
from haystack.core.type_utils import _type_name, _types_are_compatible
|
||||
from haystack.marshal import Marshaller, YamlMarshaller
|
||||
@ -669,7 +674,14 @@ class PipelineBase:
|
||||
}
|
||||
return outputs
|
||||
|
||||
def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30) -> None:
|
||||
@args_deprecated
|
||||
def show(
|
||||
self,
|
||||
server_url: str = "https://mermaid.ink",
|
||||
params: Optional[dict] = None,
|
||||
timeout: int = 30,
|
||||
super_component_expansion: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Display an image representing this `Pipeline` in a Jupyter notebook.
|
||||
|
||||
@ -698,20 +710,62 @@ class PipelineBase:
|
||||
:param timeout:
|
||||
Timeout in seconds for the request to the Mermaid server.
|
||||
|
||||
:param super_component_expansion:
|
||||
If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
|
||||
super-components as if they were components part of the pipeline instead of a "black-box".
|
||||
Otherwise, only the super-component itself will be displayed.
|
||||
|
||||
:raises PipelineDrawingError:
|
||||
If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
|
||||
"""
|
||||
|
||||
# Call the internal implementation with keyword arguments
|
||||
self._show_internal(
|
||||
server_url=server_url, params=params, timeout=timeout, super_component_expansion=super_component_expansion
|
||||
)
|
||||
|
||||
def _show_internal(
|
||||
self,
|
||||
*,
|
||||
server_url: str = "https://mermaid.ink",
|
||||
params: Optional[dict] = None,
|
||||
timeout: int = 30,
|
||||
super_component_expansion: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Internal implementation of show() that uses keyword-only arguments.
|
||||
|
||||
ToDo: after 2.14.0 release make this the main function and remove the old one.
|
||||
"""
|
||||
if is_in_jupyter():
|
||||
from IPython.display import Image, display # type: ignore
|
||||
|
||||
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params, timeout=timeout)
|
||||
if super_component_expansion:
|
||||
graph, super_component_mapping = self._merge_super_component_pipelines()
|
||||
else:
|
||||
graph = self.graph
|
||||
super_component_mapping = None
|
||||
|
||||
image_data = _to_mermaid_image(
|
||||
graph,
|
||||
server_url=server_url,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
super_component_mapping=super_component_mapping,
|
||||
)
|
||||
display(Image(image_data))
|
||||
else:
|
||||
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
|
||||
raise PipelineDrawingError(msg)
|
||||
|
||||
def draw(
|
||||
self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30
|
||||
@args_deprecated
|
||||
def draw( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
path: Path,
|
||||
server_url: str = "https://mermaid.ink",
|
||||
params: Optional[dict] = None,
|
||||
timeout: int = 30,
|
||||
super_component_expansion: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save an image representing this `Pipeline` to the specified file path.
|
||||
@ -720,10 +774,12 @@ class PipelineBase:
|
||||
|
||||
:param path:
|
||||
The file path where the generated image will be saved.
|
||||
|
||||
:param server_url:
|
||||
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
|
||||
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
|
||||
info on how to set up your own Mermaid server.
|
||||
|
||||
:param params:
|
||||
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
|
||||
Supported keys:
|
||||
@ -741,12 +797,53 @@ class PipelineBase:
|
||||
:param timeout:
|
||||
Timeout in seconds for the request to the Mermaid server.
|
||||
|
||||
:param super_component_expansion:
|
||||
If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
|
||||
super-components as if they were components part of the pipeline instead of a "black-box".
|
||||
Otherwise, only the super-component itself will be displayed.
|
||||
|
||||
:raises PipelineDrawingError:
|
||||
If there is an issue with rendering or saving the image.
|
||||
"""
|
||||
|
||||
# Call the internal implementation with keyword arguments
|
||||
self._draw_internal(
|
||||
path=path,
|
||||
server_url=server_url,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
super_component_expansion=super_component_expansion,
|
||||
)
|
||||
|
||||
def _draw_internal(
|
||||
self,
|
||||
*,
|
||||
path: Path,
|
||||
server_url: str = "https://mermaid.ink",
|
||||
params: Optional[dict] = None,
|
||||
timeout: int = 30,
|
||||
super_component_expansion: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Internal implementation of draw() that uses keyword-only arguments.
|
||||
|
||||
ToDo: after 2.14.0 release make this the main function and remove the old one.
|
||||
"""
|
||||
# Before drawing we edit a bit the graph, to avoid modifying the original that is
|
||||
# used for running the pipeline we copy it.
|
||||
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params, timeout=timeout)
|
||||
if super_component_expansion:
|
||||
graph, super_component_mapping = self._merge_super_component_pipelines()
|
||||
else:
|
||||
graph = self.graph
|
||||
super_component_mapping = None
|
||||
|
||||
image_data = _to_mermaid_image(
|
||||
graph,
|
||||
server_url=server_url,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
super_component_mapping=super_component_mapping,
|
||||
)
|
||||
Path(path).write_bytes(image_data)
|
||||
|
||||
def walk(self) -> Iterator[Tuple[str, Component]]:
|
||||
@ -1175,7 +1272,7 @@ class PipelineBase:
|
||||
for receiver_name, sender_socket, receiver_socket in receivers:
|
||||
# We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
|
||||
# that the sender did not produce an output for this socket.
|
||||
# This allows us to track if a pre-decessor already ran but did not produce an output.
|
||||
# This allows us to track if a predecessor already ran but did not produce an output.
|
||||
value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
|
||||
|
||||
if receiver_name not in inputs:
|
||||
@ -1239,6 +1336,99 @@ class PipelineBase:
|
||||
if candidate is not None and candidate[0] == ComponentPriority.BLOCKED:
|
||||
raise PipelineComponentsBlockedError()
|
||||
|
||||
def _find_super_components(self) -> list[tuple[str, Component]]:
|
||||
"""
|
||||
Find all SuperComponents in the pipeline.
|
||||
|
||||
:returns:
|
||||
List of tuples containing (component_name, component_instance) representing a SuperComponent.
|
||||
"""
|
||||
|
||||
super_components = []
|
||||
for comp_name, comp in self.walk():
|
||||
# a SuperComponent has a "pipeline" attribute which itself a Pipeline instance
|
||||
# we don't test against SuperComponent because doing so always lead to circular imports
|
||||
if hasattr(comp, "pipeline") and isinstance(comp.pipeline, self.__class__):
|
||||
super_components.append((comp_name, comp))
|
||||
return super_components
|
||||
|
||||
def _merge_super_component_pipelines(self) -> Tuple["networkx.MultiDiGraph", Dict[str, str]]:
|
||||
"""
|
||||
Merge the internal pipelines of SuperComponents into the main pipeline graph structure.
|
||||
|
||||
This creates a new networkx.MultiDiGraph containing all the components from both the main pipeline
|
||||
and all the internal SuperComponents' pipelines. The SuperComponents are removed and their internal
|
||||
components are connected to corresponding input and output sockets of the main pipeline.
|
||||
|
||||
:returns:
|
||||
A tuple containing:
|
||||
- A networkx.MultiDiGraph with the expanded structure of the main pipeline and all it's SuperComponents
|
||||
- A dictionary mapping component names to boolean indicating that this component was part of a
|
||||
SuperComponent
|
||||
- A dictionary mapping component names to their SuperComponent name
|
||||
"""
|
||||
merged_graph = self.graph.copy()
|
||||
super_component_mapping: Dict[str, str] = {}
|
||||
|
||||
for super_name, super_component in self._find_super_components():
|
||||
internal_pipeline = super_component.pipeline # type: ignore
|
||||
internal_graph = internal_pipeline.graph.copy()
|
||||
|
||||
# Mark all components in the internal pipeline as being part of a SuperComponent
|
||||
for node in internal_graph.nodes():
|
||||
super_component_mapping[node] = super_name
|
||||
|
||||
# edges connected to the super component
|
||||
incoming_edges = list(merged_graph.in_edges(super_name, data=True))
|
||||
outgoing_edges = list(merged_graph.out_edges(super_name, data=True))
|
||||
|
||||
# merge the SuperComponent graph into the main graph and remove the super component node
|
||||
# since its components are now part of the main graph
|
||||
merged_graph = networkx.compose(merged_graph, internal_graph)
|
||||
merged_graph.remove_node(super_name)
|
||||
|
||||
# get the entry and exit points of the SuperComponent internal pipeline
|
||||
entry_points = [n for n in internal_graph.nodes() if internal_graph.in_degree(n) == 0]
|
||||
exit_points = [n for n in internal_graph.nodes() if internal_graph.out_degree(n) == 0]
|
||||
|
||||
# connect the incoming edges to entry points
|
||||
for sender, _, edge_data in incoming_edges:
|
||||
sender_socket = edge_data["from_socket"]
|
||||
for entry_point in entry_points:
|
||||
# find a matching input socket in the entry point
|
||||
entry_point_sockets = internal_graph.nodes[entry_point]["input_sockets"]
|
||||
for socket_name, socket in entry_point_sockets.items():
|
||||
if _types_are_compatible(sender_socket.type, socket.type, self._connection_type_validation):
|
||||
merged_graph.add_edge(
|
||||
sender,
|
||||
entry_point,
|
||||
key=f"{sender_socket.name}/{socket_name}",
|
||||
conn_type=_type_name(sender_socket.type),
|
||||
from_socket=sender_socket,
|
||||
to_socket=socket,
|
||||
mandatory=socket.is_mandatory,
|
||||
)
|
||||
|
||||
# connect outgoing edges from exit points
|
||||
for _, receiver, edge_data in outgoing_edges:
|
||||
receiver_socket = edge_data["to_socket"]
|
||||
for exit_point in exit_points:
|
||||
# find a matching output socket in the exit point
|
||||
exit_point_sockets = internal_graph.nodes[exit_point]["output_sockets"]
|
||||
for socket_name, socket in exit_point_sockets.items():
|
||||
if _types_are_compatible(socket.type, receiver_socket.type, self._connection_type_validation):
|
||||
merged_graph.add_edge(
|
||||
exit_point,
|
||||
receiver,
|
||||
key=f"{socket_name}/{receiver_socket.name}",
|
||||
conn_type=_type_name(socket.type),
|
||||
from_socket=socket,
|
||||
to_socket=receiver_socket,
|
||||
mandatory=receiver_socket.is_mandatory,
|
||||
)
|
||||
|
||||
return merged_graph, super_component_mapping
|
||||
|
||||
|
||||
def _connections_status(
|
||||
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
|
||||
|
||||
@ -3,9 +3,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import base64
|
||||
import colorsys
|
||||
import json
|
||||
import random
|
||||
import zlib
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import networkx # type:ignore
|
||||
import requests
|
||||
@ -18,6 +20,44 @@ from haystack.core.type_utils import _type_name
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_color_variations(n: int, base_color: Optional[str] = "#3498DB", variation_range=0.4) -> List[str]:
|
||||
"""
|
||||
Generate n different variations of a base color.
|
||||
|
||||
:param n: Number of variations to generate
|
||||
:param base_color: Hex color code, default is a shade of blue (#3498DB)
|
||||
:param variation_range: Range for varying brightness and saturation (0-1)
|
||||
|
||||
:returns:
|
||||
list: List of hex color codes representing variations of the base color
|
||||
"""
|
||||
# convert hex to RGB
|
||||
base_color = base_color.lstrip("#") # type:ignore
|
||||
r = int(base_color[0:2], 16) / 255.0
|
||||
g = int(base_color[2:4], 16) / 255.0
|
||||
b = int(base_color[4:6], 16) / 255.0
|
||||
|
||||
# convert RGB to HSV (Hue, Saturation, Value)
|
||||
h, s, v = colorsys.rgb_to_hsv(r, g, b)
|
||||
|
||||
variations = []
|
||||
for _ in range(n):
|
||||
# vary saturation and brightness within the specified range
|
||||
new_s = max(0, min(1, s + random.uniform(-variation_range, variation_range)))
|
||||
new_v = max(0, min(1, v + random.uniform(-variation_range, variation_range)))
|
||||
|
||||
# keep hue the same for color consistency
|
||||
new_h = h
|
||||
|
||||
# Convert back to RGB and then to hex
|
||||
new_r, new_g, new_b = colorsys.hsv_to_rgb(new_h, new_s, new_v)
|
||||
hex_color = "#{:02x}{:02x}{:02x}".format(int(new_r * 255), int(new_g * 255), int(new_b * 255))
|
||||
|
||||
variations.append(hex_color)
|
||||
|
||||
return variations
|
||||
|
||||
|
||||
def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
|
||||
"""
|
||||
Add some extra nodes to show the inputs and outputs of the pipeline.
|
||||
@ -62,6 +102,7 @@ graph TD;
|
||||
{connections}
|
||||
|
||||
classDef component text-align:center;
|
||||
{style_definitions}
|
||||
"""
|
||||
|
||||
|
||||
@ -133,6 +174,7 @@ def _to_mermaid_image(
|
||||
server_url: str = "https://mermaid.ink",
|
||||
params: Optional[dict] = None,
|
||||
timeout: int = 30,
|
||||
super_component_mapping: Optional[Dict[str, str]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Renders a pipeline using a Mermaid server.
|
||||
@ -162,7 +204,7 @@ def _to_mermaid_image(
|
||||
init_params = json.dumps({"theme": theme})
|
||||
|
||||
# Copy the graph to avoid modifying the original
|
||||
graph_styled = _to_mermaid_text(graph.copy(), init_params)
|
||||
graph_styled = _to_mermaid_text(graph.copy(), init_params, super_component_mapping)
|
||||
json_string = json.dumps({"code": graph_styled})
|
||||
|
||||
# Compress the JSON string with zlib (RFC 1950)
|
||||
@ -214,12 +256,18 @@ def _to_mermaid_image(
|
||||
return resp.content
|
||||
|
||||
|
||||
def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
|
||||
def _to_mermaid_text(
|
||||
graph: networkx.MultiDiGraph, init_params: str, super_component_mapping: Optional[Dict[str, str]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Converts a Networkx graph into Mermaid syntax.
|
||||
|
||||
The output of this function can be used in the documentation with `mermaid` codeblocks and will be
|
||||
automatically rendered.
|
||||
|
||||
:param graph: The graph to convert to Mermaid syntax
|
||||
:param init_params: Initialization parameters for Mermaid
|
||||
:param super_component_mapping: Mapping of component names to super component names
|
||||
"""
|
||||
# Copy the graph to avoid modifying the original
|
||||
graph = _prepare_for_drawing(graph.copy())
|
||||
@ -238,11 +286,34 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
|
||||
for comp, sockets in sockets.items()
|
||||
}
|
||||
|
||||
states = {
|
||||
comp: f'{comp}["<b>{comp}</b><br><small><i>{type(data["instance"]).__name__}{optional_inputs[comp]}</i></small>"]:::component' # noqa
|
||||
for comp, data in graph.nodes(data=True)
|
||||
if comp not in ["input", "output"]
|
||||
}
|
||||
# Create node definitions
|
||||
states = {}
|
||||
super_component_components = super_component_mapping.keys() if super_component_mapping else {}
|
||||
|
||||
# color variations for super components
|
||||
super_component_colors = {}
|
||||
if super_component_components:
|
||||
unique_super_components = set(super_component_mapping.values()) # type:ignore
|
||||
color_variations = generate_color_variations(n=len(unique_super_components))
|
||||
super_component_colors = dict(zip(unique_super_components, color_variations))
|
||||
|
||||
# Generate style definitions for each super component
|
||||
style_definitions = []
|
||||
for super_comp, color in super_component_colors.items():
|
||||
style_definitions.append(f"classDef {super_comp} fill:{color},color:white;")
|
||||
|
||||
for comp, data in graph.nodes(data=True):
|
||||
if comp in ["input", "output"]:
|
||||
continue
|
||||
|
||||
# styling based on whether the component is a SuperComponent
|
||||
if comp in super_component_components:
|
||||
super_component_name = super_component_mapping[comp] # type:ignore
|
||||
style = super_component_name
|
||||
else:
|
||||
style = "component"
|
||||
node_def = f'{comp}["<b>{comp}</b><br><small><i>{type(data["instance"]).__name__}{optional_inputs[comp]}</i></small>"]:::{style}' # noqa: E501
|
||||
states[comp] = node_def
|
||||
|
||||
connections_list = []
|
||||
for from_comp, to_comp, conn_data in graph.edges(data=True):
|
||||
@ -263,7 +334,20 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
|
||||
]
|
||||
connections = "\n".join(connections_list + input_connections + output_connections)
|
||||
|
||||
graph_styled = MERMAID_STYLED_TEMPLATE.format(params=init_params, connections=connections)
|
||||
# Create legend
|
||||
legend_nodes = []
|
||||
if super_component_colors:
|
||||
legend_nodes.append("subgraph Legend")
|
||||
for super_comp, color in super_component_colors.items():
|
||||
legend_id = f"legend_{super_comp}"
|
||||
legend_nodes.append(f'{legend_id}["{super_comp}"]:::{super_comp}')
|
||||
legend_nodes.append("end")
|
||||
connections += "\n" + "\n".join(legend_nodes)
|
||||
|
||||
# Add style definitions to the template
|
||||
graph_styled = MERMAID_STYLED_TEMPLATE.format(
|
||||
params=init_params, connections=connections, style_definitions="\n".join(style_definitions)
|
||||
)
|
||||
logger.debug("Mermaid diagram:\n{diagram}", diagram=graph_styled)
|
||||
|
||||
return graph_styled
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
import heapq
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from itertools import count
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
@ -163,3 +164,41 @@ class FIFOPriorityQueue:
|
||||
True if the queue contains items, False otherwise.
|
||||
"""
|
||||
return bool(self._queue)
|
||||
|
||||
|
||||
def args_deprecated(func):
|
||||
"""
|
||||
Decorator to warn about the use of positional arguments in a function.
|
||||
|
||||
Adapted from https://stackoverflow.com/questions/68432070/
|
||||
:param func:
|
||||
"""
|
||||
|
||||
def _positional_arg_warning() -> None:
|
||||
"""
|
||||
Triggers a warning message if positional arguments are used in a function
|
||||
"""
|
||||
import warnings
|
||||
|
||||
msg = (
|
||||
"Warning: In an upcoming release, this method will require keyword arguments for all parameters. "
|
||||
"Please update your code to use keyword arguments to ensure future compatibility. "
|
||||
"Example: pipeline.draw(path='output.png', server_url='https://custom-server.com')"
|
||||
)
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# call the function first, to make sure the signature matches
|
||||
ret_value = func(*args, **kwargs)
|
||||
|
||||
# A Pipeline instance is always the first argument - remove it from the args to check for positional arguments
|
||||
# We check the class name as strings to avoid circular imports
|
||||
if args and isinstance(args, tuple) and args[0].__class__.__name__ in ["Pipeline", "PipelineBase"]:
|
||||
args = args[1:]
|
||||
|
||||
if args:
|
||||
_positional_arg_warning()
|
||||
return ret_value
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
The `draw()` and `show()` methods from `Pipeline` now have an extra boolean parameter, `super_component_expansion`, which, when set to `True` and if the pipeline contains SuperComponents, the visualisation diagram
|
||||
will show the internal structure of super-components as if they were components part of the pipeline instead of a "black-box" with the name of the SuperComponent.
|
||||
@ -94,6 +94,7 @@ comp1["<b>comp1</b><br><small><i>AddFixedValue<br><br>Optional inputs:<ul style=
|
||||
comp2["<b>comp2</b><br><small><i>Double</i></small>"]:::component -- "value -> value<br><small><i>int</i></small>" --> comp1["<b>comp1</b><br><small><i>AddFixedValue<br><br>Optional inputs:<ul style='text-align:left;'><li>add (Optional[int])</li></ul></i></small>"]:::component
|
||||
|
||||
classDef component text-align:center;
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@ -175,6 +175,110 @@ class TestPipelineBase:
|
||||
pipe.draw(path=image_path)
|
||||
assert image_path.read_bytes() == mock_to_mermaid_image.return_value
|
||||
|
||||
def test_find_super_components(self):
|
||||
"""
|
||||
Test that the pipeline can find super components in it's pipeline.
|
||||
"""
|
||||
from haystack import Pipeline
|
||||
from haystack.components.converters import MultiFileConverter
|
||||
from haystack.components.preprocessors import DocumentPreprocessor
|
||||
from haystack.components.writers import DocumentWriter
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
|
||||
multi_file_converter = MultiFileConverter()
|
||||
doc_processor = DocumentPreprocessor()
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("converter", multi_file_converter)
|
||||
pipeline.add_component("preprocessor", doc_processor)
|
||||
pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore()))
|
||||
pipeline.connect("converter", "preprocessor")
|
||||
pipeline.connect("preprocessor", "writer")
|
||||
|
||||
result = pipeline._find_super_components()
|
||||
|
||||
assert len(result) == 2
|
||||
assert [("converter", multi_file_converter), ("preprocessor", doc_processor)] == result
|
||||
|
||||
def test_merge_super_component_pipelines(self):
|
||||
from haystack import Pipeline
|
||||
from haystack.components.converters import MultiFileConverter
|
||||
from haystack.components.preprocessors import DocumentPreprocessor
|
||||
from haystack.components.writers import DocumentWriter
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
|
||||
multi_file_converter = MultiFileConverter()
|
||||
doc_processor = DocumentPreprocessor()
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("converter", multi_file_converter)
|
||||
pipeline.add_component("preprocessor", doc_processor)
|
||||
pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore()))
|
||||
pipeline.connect("converter", "preprocessor")
|
||||
pipeline.connect("preprocessor", "writer")
|
||||
|
||||
merged_graph, super_component_components = pipeline._merge_super_component_pipelines()
|
||||
|
||||
assert super_component_components == {
|
||||
"router": "converter",
|
||||
"docx": "converter",
|
||||
"html": "converter",
|
||||
"json": "converter",
|
||||
"md": "converter",
|
||||
"text": "converter",
|
||||
"pdf": "converter",
|
||||
"pptx": "converter",
|
||||
"xlsx": "converter",
|
||||
"joiner": "converter",
|
||||
"csv": "converter",
|
||||
"splitter": "preprocessor",
|
||||
"cleaner": "preprocessor",
|
||||
}
|
||||
|
||||
expected_nodes = [
|
||||
"cleaner",
|
||||
"csv",
|
||||
"docx",
|
||||
"html",
|
||||
"joiner",
|
||||
"json",
|
||||
"md",
|
||||
"pdf",
|
||||
"pptx",
|
||||
"router",
|
||||
"splitter",
|
||||
"text",
|
||||
"writer",
|
||||
"xlsx",
|
||||
]
|
||||
assert sorted(merged_graph.nodes) == expected_nodes
|
||||
|
||||
expected_edges = [
|
||||
("cleaner", "writer"),
|
||||
("csv", "joiner"),
|
||||
("docx", "joiner"),
|
||||
("html", "joiner"),
|
||||
("joiner", "splitter"),
|
||||
("json", "joiner"),
|
||||
("md", "joiner"),
|
||||
("pdf", "joiner"),
|
||||
("pptx", "joiner"),
|
||||
("router", "csv"),
|
||||
("router", "docx"),
|
||||
("router", "html"),
|
||||
("router", "json"),
|
||||
("router", "md"),
|
||||
("router", "pdf"),
|
||||
("router", "pptx"),
|
||||
("router", "text"),
|
||||
("router", "xlsx"),
|
||||
("splitter", "cleaner"),
|
||||
("text", "joiner"),
|
||||
("xlsx", "joiner"),
|
||||
]
|
||||
actual_edges = [(u, v) for u, v, _ in merged_graph.edges]
|
||||
assert sorted(actual_edges) == expected_edges
|
||||
|
||||
# UNIT
|
||||
def test_add_invalid_component_name(self):
|
||||
pipe = PipelineBase()
|
||||
@ -1681,3 +1785,84 @@ class TestPipelineBase:
|
||||
consumed = PipelineBase._consume_component_inputs("test_component", component, inputs)
|
||||
|
||||
assert consumed["input1"].equals(DataFrame({"a": [1, 2], "b": [1, 2]}))
|
||||
|
||||
@patch("haystack.core.pipeline.draw.requests")
|
||||
def test_pipeline_draw_called_with_positional_args_triggers_a_warning(self, mock_requests):
|
||||
"""
|
||||
Test that calling the pipeline draw method with positional arguments raises a warning.
|
||||
"""
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
|
||||
pipeline = PipelineBase()
|
||||
mock_response = mock_requests.get.return_value
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"image_data"
|
||||
out_file = Path("original_pipeline.png")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
pipeline.draw(out_file, server_url="http://localhost:3000")
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert (
|
||||
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
|
||||
in str(w[0].message)
|
||||
)
|
||||
|
||||
@patch("haystack.core.pipeline.draw.requests")
|
||||
@patch("haystack.core.pipeline.base.is_in_jupyter")
|
||||
def test_pipeline_show_called_with_positional_args_triggers_a_warning(self, mock_is_in_jupyter, mock_requests):
|
||||
"""
|
||||
Test that calling the pipeline show method with positional arguments raises a warning.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
pipeline = PipelineBase()
|
||||
mock_response = mock_requests.get.return_value
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"image_data"
|
||||
mock_is_in_jupyter.return_value = True
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
pipeline.show("http://localhost:3000")
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert (
|
||||
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
|
||||
in str(w[0].message)
|
||||
)
|
||||
|
||||
@patch("haystack.core.pipeline.draw.requests")
|
||||
def test_pipeline_draw_called_with_keyword_args_triggers_no_warning(self, mock_requests):
|
||||
"""
|
||||
Test that calling the pipeline draw method with keyword arguments does not raise a warning.
|
||||
"""
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
|
||||
pipeline = PipelineBase()
|
||||
mock_response = mock_requests.get.return_value
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"image_data"
|
||||
out_file = Path("original_pipeline.png")
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
pipeline.draw(path=out_file, server_url="http://localhost:3000")
|
||||
assert len(w) == 0, "No warning should be triggered when using keyword arguments"
|
||||
|
||||
@patch("haystack.core.pipeline.draw.requests")
|
||||
@patch("haystack.core.pipeline.base.is_in_jupyter")
|
||||
def test_pipeline_show_called_with_keyword_args_triggers_no_warning(self, mock_is_in_jupyter, mock_requests):
|
||||
"""
|
||||
Test that calling the pipeline show method with keyword arguments does not raise a warning.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
pipeline = PipelineBase()
|
||||
mock_response = mock_requests.get.return_value
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"image_data"
|
||||
mock_is_in_jupyter.return_value = True
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
pipeline.show(server_url="http://localhost:3000")
|
||||
assert len(w) == 0, "No warning should be triggered when using keyword arguments"
|
||||
|
||||
@ -4,10 +4,16 @@
|
||||
|
||||
import logging
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||
from haystack.core.pipeline.utils import parse_connect_string, FIFOPriorityQueue, _deepcopy_with_exceptions
|
||||
from haystack.core.pipeline.utils import (
|
||||
parse_connect_string,
|
||||
FIFOPriorityQueue,
|
||||
_deepcopy_with_exceptions,
|
||||
args_deprecated,
|
||||
)
|
||||
from haystack.tools import ComponentTool, Tool
|
||||
|
||||
|
||||
@ -247,3 +253,51 @@ class TestDeepcopyWithFallback:
|
||||
original = {"component": comp}
|
||||
res = _deepcopy_with_exceptions(original)
|
||||
assert res["component"] is original["component"]
|
||||
|
||||
|
||||
class TestArgsDeprecated:
|
||||
@pytest.fixture
|
||||
def sample_function(self):
|
||||
@args_deprecated
|
||||
def sample_func(param1: str = "default1", param2: int = 42):
|
||||
return f"{param1}-{param2}"
|
||||
|
||||
return sample_func
|
||||
|
||||
def test_warning_with_positional_args(self, sample_function):
|
||||
# using positional arguments only
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = sample_function("test", 123)
|
||||
assert result == "test-123"
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert (
|
||||
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
|
||||
in str(w[0].message)
|
||||
)
|
||||
|
||||
def test_warning_with_mixed_args(self, sample_function):
|
||||
# mixing positional and keyword arguments
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = sample_function("test", 123)
|
||||
assert result == "test-123"
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert (
|
||||
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
|
||||
in str(w[0].message)
|
||||
)
|
||||
|
||||
def test_no_warning_with_default_args(self, sample_function):
|
||||
# using default arguments
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = sample_function()
|
||||
assert result == "default1-42"
|
||||
assert len(w) == 0
|
||||
|
||||
def test_no_warning_with_keyword_args(self, sample_function):
|
||||
# using keyword arguments
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
result = sample_function(param1="test", param2=123)
|
||||
assert result == "test-123"
|
||||
assert len(w) == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user