feat: draw/show SuperComponents in detail, expand it and show it's internal components in the visualisation diagram (#9389)

* initial import

* small fixes

* adding tests

* adding tests

* refactoring merge graphs

* updating tests

* docstrings

* adding release notes

* adding SuperComponent name to extended components

* adding colours and legend to different super components

* adding missed docstring parameter

* fixing tests and type checking

* Update haystack/core/pipeline/base.py

Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>

* forcing keyword arguments for draw() and show()

* adding wrapper function and a deprecation warning

* adding pylint disable - this will be removed soon

* wip

* adding a decorator function to test if another function is being called with positional arguments

* adding a decorator function to test if another function is being called with positional arguments

---------

Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
This commit is contained in:
David S. Batista 2025-05-23 09:21:44 +01:00 committed by GitHub
parent ba41696bba
commit 3342f17f01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 575 additions and 17 deletions

View File

@ -32,7 +32,12 @@ from haystack.core.pipeline.component_checks import (
is_any_greedy_socket_ready,
is_socket_lazy_variadic,
)
from haystack.core.pipeline.utils import FIFOPriorityQueue, _deepcopy_with_exceptions, parse_connect_string
from haystack.core.pipeline.utils import (
FIFOPriorityQueue,
_deepcopy_with_exceptions,
args_deprecated,
parse_connect_string,
)
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
from haystack.core.type_utils import _type_name, _types_are_compatible
from haystack.marshal import Marshaller, YamlMarshaller
@ -669,7 +674,14 @@ class PipelineBase:
}
return outputs
def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30) -> None:
@args_deprecated
def show(
self,
server_url: str = "https://mermaid.ink",
params: Optional[dict] = None,
timeout: int = 30,
super_component_expansion: bool = False,
) -> None:
"""
Display an image representing this `Pipeline` in a Jupyter notebook.
@ -698,20 +710,62 @@ class PipelineBase:
:param timeout:
Timeout in seconds for the request to the Mermaid server.
:param super_component_expansion:
If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
super-components as if they were components part of the pipeline instead of a "black-box".
Otherwise, only the super-component itself will be displayed.
:raises PipelineDrawingError:
If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
"""
# Call the internal implementation with keyword arguments
self._show_internal(
server_url=server_url, params=params, timeout=timeout, super_component_expansion=super_component_expansion
)
def _show_internal(
self,
*,
server_url: str = "https://mermaid.ink",
params: Optional[dict] = None,
timeout: int = 30,
super_component_expansion: bool = False,
) -> None:
"""
Internal implementation of show() that uses keyword-only arguments.
ToDo: after 2.14.0 release make this the main function and remove the old one.
"""
if is_in_jupyter():
from IPython.display import Image, display # type: ignore
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params, timeout=timeout)
if super_component_expansion:
graph, super_component_mapping = self._merge_super_component_pipelines()
else:
graph = self.graph
super_component_mapping = None
image_data = _to_mermaid_image(
graph,
server_url=server_url,
params=params,
timeout=timeout,
super_component_mapping=super_component_mapping,
)
display(Image(image_data))
else:
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
raise PipelineDrawingError(msg)
def draw(
self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30
@args_deprecated
def draw( # pylint: disable=too-many-positional-arguments
self,
path: Path,
server_url: str = "https://mermaid.ink",
params: Optional[dict] = None,
timeout: int = 30,
super_component_expansion: bool = False,
) -> None:
"""
Save an image representing this `Pipeline` to the specified file path.
@ -720,10 +774,12 @@ class PipelineBase:
:param path:
The file path where the generated image will be saved.
:param server_url:
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
info on how to set up your own Mermaid server.
:param params:
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
Supported keys:
@ -741,12 +797,53 @@ class PipelineBase:
:param timeout:
Timeout in seconds for the request to the Mermaid server.
:param super_component_expansion:
If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
super-components as if they were components part of the pipeline instead of a "black-box".
Otherwise, only the super-component itself will be displayed.
:raises PipelineDrawingError:
If there is an issue with rendering or saving the image.
"""
# Call the internal implementation with keyword arguments
self._draw_internal(
path=path,
server_url=server_url,
params=params,
timeout=timeout,
super_component_expansion=super_component_expansion,
)
def _draw_internal(
self,
*,
path: Path,
server_url: str = "https://mermaid.ink",
params: Optional[dict] = None,
timeout: int = 30,
super_component_expansion: bool = False,
) -> None:
"""
Internal implementation of draw() that uses keyword-only arguments.
ToDo: after 2.14.0 release make this the main function and remove the old one.
"""
# Before drawing we edit a bit the graph, to avoid modifying the original that is
# used for running the pipeline we copy it.
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params, timeout=timeout)
if super_component_expansion:
graph, super_component_mapping = self._merge_super_component_pipelines()
else:
graph = self.graph
super_component_mapping = None
image_data = _to_mermaid_image(
graph,
server_url=server_url,
params=params,
timeout=timeout,
super_component_mapping=super_component_mapping,
)
Path(path).write_bytes(image_data)
def walk(self) -> Iterator[Tuple[str, Component]]:
@ -1175,7 +1272,7 @@ class PipelineBase:
for receiver_name, sender_socket, receiver_socket in receivers:
# We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
# that the sender did not produce an output for this socket.
# This allows us to track if a pre-decessor already ran but did not produce an output.
# This allows us to track if a predecessor already ran but did not produce an output.
value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
if receiver_name not in inputs:
@ -1239,6 +1336,99 @@ class PipelineBase:
if candidate is not None and candidate[0] == ComponentPriority.BLOCKED:
raise PipelineComponentsBlockedError()
def _find_super_components(self) -> list[tuple[str, Component]]:
"""
Find all SuperComponents in the pipeline.
:returns:
List of tuples containing (component_name, component_instance) representing a SuperComponent.
"""
super_components = []
for comp_name, comp in self.walk():
# a SuperComponent has a "pipeline" attribute which itself a Pipeline instance
# we don't test against SuperComponent because doing so always lead to circular imports
if hasattr(comp, "pipeline") and isinstance(comp.pipeline, self.__class__):
super_components.append((comp_name, comp))
return super_components
def _merge_super_component_pipelines(self) -> Tuple["networkx.MultiDiGraph", Dict[str, str]]:
"""
Merge the internal pipelines of SuperComponents into the main pipeline graph structure.
This creates a new networkx.MultiDiGraph containing all the components from both the main pipeline
and all the internal SuperComponents' pipelines. The SuperComponents are removed and their internal
components are connected to corresponding input and output sockets of the main pipeline.
:returns:
A tuple containing:
- A networkx.MultiDiGraph with the expanded structure of the main pipeline and all it's SuperComponents
- A dictionary mapping component names to boolean indicating that this component was part of a
SuperComponent
- A dictionary mapping component names to their SuperComponent name
"""
merged_graph = self.graph.copy()
super_component_mapping: Dict[str, str] = {}
for super_name, super_component in self._find_super_components():
internal_pipeline = super_component.pipeline # type: ignore
internal_graph = internal_pipeline.graph.copy()
# Mark all components in the internal pipeline as being part of a SuperComponent
for node in internal_graph.nodes():
super_component_mapping[node] = super_name
# edges connected to the super component
incoming_edges = list(merged_graph.in_edges(super_name, data=True))
outgoing_edges = list(merged_graph.out_edges(super_name, data=True))
# merge the SuperComponent graph into the main graph and remove the super component node
# since its components are now part of the main graph
merged_graph = networkx.compose(merged_graph, internal_graph)
merged_graph.remove_node(super_name)
# get the entry and exit points of the SuperComponent internal pipeline
entry_points = [n for n in internal_graph.nodes() if internal_graph.in_degree(n) == 0]
exit_points = [n for n in internal_graph.nodes() if internal_graph.out_degree(n) == 0]
# connect the incoming edges to entry points
for sender, _, edge_data in incoming_edges:
sender_socket = edge_data["from_socket"]
for entry_point in entry_points:
# find a matching input socket in the entry point
entry_point_sockets = internal_graph.nodes[entry_point]["input_sockets"]
for socket_name, socket in entry_point_sockets.items():
if _types_are_compatible(sender_socket.type, socket.type, self._connection_type_validation):
merged_graph.add_edge(
sender,
entry_point,
key=f"{sender_socket.name}/{socket_name}",
conn_type=_type_name(sender_socket.type),
from_socket=sender_socket,
to_socket=socket,
mandatory=socket.is_mandatory,
)
# connect outgoing edges from exit points
for _, receiver, edge_data in outgoing_edges:
receiver_socket = edge_data["to_socket"]
for exit_point in exit_points:
# find a matching output socket in the exit point
exit_point_sockets = internal_graph.nodes[exit_point]["output_sockets"]
for socket_name, socket in exit_point_sockets.items():
if _types_are_compatible(socket.type, receiver_socket.type, self._connection_type_validation):
merged_graph.add_edge(
exit_point,
receiver,
key=f"{socket_name}/{receiver_socket.name}",
conn_type=_type_name(socket.type),
from_socket=socket,
to_socket=receiver_socket,
mandatory=receiver_socket.is_mandatory,
)
return merged_graph, super_component_mapping
def _connections_status(
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]

View File

@ -3,9 +3,11 @@
# SPDX-License-Identifier: Apache-2.0
import base64
import colorsys
import json
import random
import zlib
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import networkx # type:ignore
import requests
@ -18,6 +20,44 @@ from haystack.core.type_utils import _type_name
logger = logging.getLogger(__name__)
def generate_color_variations(n: int, base_color: Optional[str] = "#3498DB", variation_range=0.4) -> List[str]:
"""
Generate n different variations of a base color.
:param n: Number of variations to generate
:param base_color: Hex color code, default is a shade of blue (#3498DB)
:param variation_range: Range for varying brightness and saturation (0-1)
:returns:
list: List of hex color codes representing variations of the base color
"""
# convert hex to RGB
base_color = base_color.lstrip("#") # type:ignore
r = int(base_color[0:2], 16) / 255.0
g = int(base_color[2:4], 16) / 255.0
b = int(base_color[4:6], 16) / 255.0
# convert RGB to HSV (Hue, Saturation, Value)
h, s, v = colorsys.rgb_to_hsv(r, g, b)
variations = []
for _ in range(n):
# vary saturation and brightness within the specified range
new_s = max(0, min(1, s + random.uniform(-variation_range, variation_range)))
new_v = max(0, min(1, v + random.uniform(-variation_range, variation_range)))
# keep hue the same for color consistency
new_h = h
# Convert back to RGB and then to hex
new_r, new_g, new_b = colorsys.hsv_to_rgb(new_h, new_s, new_v)
hex_color = "#{:02x}{:02x}{:02x}".format(int(new_r * 255), int(new_g * 255), int(new_b * 255))
variations.append(hex_color)
return variations
def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
"""
Add some extra nodes to show the inputs and outputs of the pipeline.
@ -62,6 +102,7 @@ graph TD;
{connections}
classDef component text-align:center;
{style_definitions}
"""
@ -133,6 +174,7 @@ def _to_mermaid_image(
server_url: str = "https://mermaid.ink",
params: Optional[dict] = None,
timeout: int = 30,
super_component_mapping: Optional[Dict[str, str]] = None,
) -> bytes:
"""
Renders a pipeline using a Mermaid server.
@ -162,7 +204,7 @@ def _to_mermaid_image(
init_params = json.dumps({"theme": theme})
# Copy the graph to avoid modifying the original
graph_styled = _to_mermaid_text(graph.copy(), init_params)
graph_styled = _to_mermaid_text(graph.copy(), init_params, super_component_mapping)
json_string = json.dumps({"code": graph_styled})
# Compress the JSON string with zlib (RFC 1950)
@ -214,12 +256,18 @@ def _to_mermaid_image(
return resp.content
def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
def _to_mermaid_text(
graph: networkx.MultiDiGraph, init_params: str, super_component_mapping: Optional[Dict[str, str]] = None
) -> str:
"""
Converts a Networkx graph into Mermaid syntax.
The output of this function can be used in the documentation with `mermaid` codeblocks and will be
automatically rendered.
:param graph: The graph to convert to Mermaid syntax
:param init_params: Initialization parameters for Mermaid
:param super_component_mapping: Mapping of component names to super component names
"""
# Copy the graph to avoid modifying the original
graph = _prepare_for_drawing(graph.copy())
@ -238,11 +286,34 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
for comp, sockets in sockets.items()
}
states = {
comp: f'{comp}["<b>{comp}</b><br><small><i>{type(data["instance"]).__name__}{optional_inputs[comp]}</i></small>"]:::component' # noqa
for comp, data in graph.nodes(data=True)
if comp not in ["input", "output"]
}
# Create node definitions
states = {}
super_component_components = super_component_mapping.keys() if super_component_mapping else {}
# color variations for super components
super_component_colors = {}
if super_component_components:
unique_super_components = set(super_component_mapping.values()) # type:ignore
color_variations = generate_color_variations(n=len(unique_super_components))
super_component_colors = dict(zip(unique_super_components, color_variations))
# Generate style definitions for each super component
style_definitions = []
for super_comp, color in super_component_colors.items():
style_definitions.append(f"classDef {super_comp} fill:{color},color:white;")
for comp, data in graph.nodes(data=True):
if comp in ["input", "output"]:
continue
# styling based on whether the component is a SuperComponent
if comp in super_component_components:
super_component_name = super_component_mapping[comp] # type:ignore
style = super_component_name
else:
style = "component"
node_def = f'{comp}["<b>{comp}</b><br><small><i>{type(data["instance"]).__name__}{optional_inputs[comp]}</i></small>"]:::{style}' # noqa: E501
states[comp] = node_def
connections_list = []
for from_comp, to_comp, conn_data in graph.edges(data=True):
@ -263,7 +334,20 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
]
connections = "\n".join(connections_list + input_connections + output_connections)
graph_styled = MERMAID_STYLED_TEMPLATE.format(params=init_params, connections=connections)
# Create legend
legend_nodes = []
if super_component_colors:
legend_nodes.append("subgraph Legend")
for super_comp, color in super_component_colors.items():
legend_id = f"legend_{super_comp}"
legend_nodes.append(f'{legend_id}["{super_comp}"]:::{super_comp}')
legend_nodes.append("end")
connections += "\n" + "\n".join(legend_nodes)
# Add style definitions to the template
graph_styled = MERMAID_STYLED_TEMPLATE.format(
params=init_params, connections=connections, style_definitions="\n".join(style_definitions)
)
logger.debug("Mermaid diagram:\n{diagram}", diagram=graph_styled)
return graph_styled

View File

@ -4,6 +4,7 @@
import heapq
from copy import deepcopy
from functools import wraps
from itertools import count
from typing import Any, List, Optional, Tuple
@ -163,3 +164,41 @@ class FIFOPriorityQueue:
True if the queue contains items, False otherwise.
"""
return bool(self._queue)
def args_deprecated(func):
"""
Decorator to warn about the use of positional arguments in a function.
Adapted from https://stackoverflow.com/questions/68432070/
:param func:
"""
def _positional_arg_warning() -> None:
"""
Triggers a warning message if positional arguments are used in a function
"""
import warnings
msg = (
"Warning: In an upcoming release, this method will require keyword arguments for all parameters. "
"Please update your code to use keyword arguments to ensure future compatibility. "
"Example: pipeline.draw(path='output.png', server_url='https://custom-server.com')"
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
@wraps(func)
def wrapper(*args, **kwargs):
# call the function first, to make sure the signature matches
ret_value = func(*args, **kwargs)
# A Pipeline instance is always the first argument - remove it from the args to check for positional arguments
# We check the class name as strings to avoid circular imports
if args and isinstance(args, tuple) and args[0].__class__.__name__ in ["Pipeline", "PipelineBase"]:
args = args[1:]
if args:
_positional_arg_warning()
return ret_value
return wrapper

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
The `draw()` and `show()` methods from `Pipeline` now have an extra boolean parameter, `super_component_expansion`, which, when set to `True` and if the pipeline contains SuperComponents, the visualisation diagram
will show the internal structure of super-components as if they were components part of the pipeline instead of a "black-box" with the name of the SuperComponent.

View File

@ -94,6 +94,7 @@ comp1["<b>comp1</b><br><small><i>AddFixedValue<br><br>Optional inputs:<ul style=
comp2["<b>comp2</b><br><small><i>Double</i></small>"]:::component -- "value -> value<br><small><i>int</i></small>" --> comp1["<b>comp1</b><br><small><i>AddFixedValue<br><br>Optional inputs:<ul style='text-align:left;'><li>add (Optional[int])</li></ul></i></small>"]:::component
classDef component text-align:center;
"""
)

View File

@ -175,6 +175,110 @@ class TestPipelineBase:
pipe.draw(path=image_path)
assert image_path.read_bytes() == mock_to_mermaid_image.return_value
def test_find_super_components(self):
"""
Test that the pipeline can find super components in it's pipeline.
"""
from haystack import Pipeline
from haystack.components.converters import MultiFileConverter
from haystack.components.preprocessors import DocumentPreprocessor
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
multi_file_converter = MultiFileConverter()
doc_processor = DocumentPreprocessor()
pipeline = Pipeline()
pipeline.add_component("converter", multi_file_converter)
pipeline.add_component("preprocessor", doc_processor)
pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore()))
pipeline.connect("converter", "preprocessor")
pipeline.connect("preprocessor", "writer")
result = pipeline._find_super_components()
assert len(result) == 2
assert [("converter", multi_file_converter), ("preprocessor", doc_processor)] == result
def test_merge_super_component_pipelines(self):
from haystack import Pipeline
from haystack.components.converters import MultiFileConverter
from haystack.components.preprocessors import DocumentPreprocessor
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
multi_file_converter = MultiFileConverter()
doc_processor = DocumentPreprocessor()
pipeline = Pipeline()
pipeline.add_component("converter", multi_file_converter)
pipeline.add_component("preprocessor", doc_processor)
pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore()))
pipeline.connect("converter", "preprocessor")
pipeline.connect("preprocessor", "writer")
merged_graph, super_component_components = pipeline._merge_super_component_pipelines()
assert super_component_components == {
"router": "converter",
"docx": "converter",
"html": "converter",
"json": "converter",
"md": "converter",
"text": "converter",
"pdf": "converter",
"pptx": "converter",
"xlsx": "converter",
"joiner": "converter",
"csv": "converter",
"splitter": "preprocessor",
"cleaner": "preprocessor",
}
expected_nodes = [
"cleaner",
"csv",
"docx",
"html",
"joiner",
"json",
"md",
"pdf",
"pptx",
"router",
"splitter",
"text",
"writer",
"xlsx",
]
assert sorted(merged_graph.nodes) == expected_nodes
expected_edges = [
("cleaner", "writer"),
("csv", "joiner"),
("docx", "joiner"),
("html", "joiner"),
("joiner", "splitter"),
("json", "joiner"),
("md", "joiner"),
("pdf", "joiner"),
("pptx", "joiner"),
("router", "csv"),
("router", "docx"),
("router", "html"),
("router", "json"),
("router", "md"),
("router", "pdf"),
("router", "pptx"),
("router", "text"),
("router", "xlsx"),
("splitter", "cleaner"),
("text", "joiner"),
("xlsx", "joiner"),
]
actual_edges = [(u, v) for u, v, _ in merged_graph.edges]
assert sorted(actual_edges) == expected_edges
# UNIT
def test_add_invalid_component_name(self):
pipe = PipelineBase()
@ -1681,3 +1785,84 @@ class TestPipelineBase:
consumed = PipelineBase._consume_component_inputs("test_component", component, inputs)
assert consumed["input1"].equals(DataFrame({"a": [1, 2], "b": [1, 2]}))
@patch("haystack.core.pipeline.draw.requests")
def test_pipeline_draw_called_with_positional_args_triggers_a_warning(self, mock_requests):
"""
Test that calling the pipeline draw method with positional arguments raises a warning.
"""
from pathlib import Path
import warnings
pipeline = PipelineBase()
mock_response = mock_requests.get.return_value
mock_response.status_code = 200
mock_response.content = b"image_data"
out_file = Path("original_pipeline.png")
with warnings.catch_warnings(record=True) as w:
pipeline.draw(out_file, server_url="http://localhost:3000")
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert (
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
in str(w[0].message)
)
@patch("haystack.core.pipeline.draw.requests")
@patch("haystack.core.pipeline.base.is_in_jupyter")
def test_pipeline_show_called_with_positional_args_triggers_a_warning(self, mock_is_in_jupyter, mock_requests):
"""
Test that calling the pipeline show method with positional arguments raises a warning.
"""
import warnings
pipeline = PipelineBase()
mock_response = mock_requests.get.return_value
mock_response.status_code = 200
mock_response.content = b"image_data"
mock_is_in_jupyter.return_value = True
with warnings.catch_warnings(record=True) as w:
pipeline.show("http://localhost:3000")
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert (
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
in str(w[0].message)
)
@patch("haystack.core.pipeline.draw.requests")
def test_pipeline_draw_called_with_keyword_args_triggers_no_warning(self, mock_requests):
"""
Test that calling the pipeline draw method with keyword arguments does not raise a warning.
"""
from pathlib import Path
import warnings
pipeline = PipelineBase()
mock_response = mock_requests.get.return_value
mock_response.status_code = 200
mock_response.content = b"image_data"
out_file = Path("original_pipeline.png")
with warnings.catch_warnings(record=True) as w:
pipeline.draw(path=out_file, server_url="http://localhost:3000")
assert len(w) == 0, "No warning should be triggered when using keyword arguments"
@patch("haystack.core.pipeline.draw.requests")
@patch("haystack.core.pipeline.base.is_in_jupyter")
def test_pipeline_show_called_with_keyword_args_triggers_no_warning(self, mock_is_in_jupyter, mock_requests):
"""
Test that calling the pipeline show method with keyword arguments does not raise a warning.
"""
import warnings
pipeline = PipelineBase()
mock_response = mock_requests.get.return_value
mock_response.status_code = 200
mock_response.content = b"image_data"
mock_is_in_jupyter.return_value = True
with warnings.catch_warnings(record=True) as w:
pipeline.show(server_url="http://localhost:3000")
assert len(w) == 0, "No warning should be triggered when using keyword arguments"

View File

@ -4,10 +4,16 @@
import logging
import pytest
import warnings
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.core.pipeline.utils import parse_connect_string, FIFOPriorityQueue, _deepcopy_with_exceptions
from haystack.core.pipeline.utils import (
parse_connect_string,
FIFOPriorityQueue,
_deepcopy_with_exceptions,
args_deprecated,
)
from haystack.tools import ComponentTool, Tool
@ -247,3 +253,51 @@ class TestDeepcopyWithFallback:
original = {"component": comp}
res = _deepcopy_with_exceptions(original)
assert res["component"] is original["component"]
class TestArgsDeprecated:
@pytest.fixture
def sample_function(self):
@args_deprecated
def sample_func(param1: str = "default1", param2: int = 42):
return f"{param1}-{param2}"
return sample_func
def test_warning_with_positional_args(self, sample_function):
# using positional arguments only
with warnings.catch_warnings(record=True) as w:
result = sample_function("test", 123)
assert result == "test-123"
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert (
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
in str(w[0].message)
)
def test_warning_with_mixed_args(self, sample_function):
# mixing positional and keyword arguments
with warnings.catch_warnings(record=True) as w:
result = sample_function("test", 123)
assert result == "test-123"
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert (
"Warning: In an upcoming release, this method will require keyword arguments for all parameters"
in str(w[0].message)
)
def test_no_warning_with_default_args(self, sample_function):
# using default arguments
with warnings.catch_warnings(record=True) as w:
result = sample_function()
assert result == "default1-42"
assert len(w) == 0
def test_no_warning_with_keyword_args(self, sample_function):
# using keyword arguments
with warnings.catch_warnings(record=True) as w:
result = sample_function(param1="test", param2=123)
assert result == "test-123"
assert len(w) == 0