From ec0e22265a330e2f9b3b5f2f7f0a161bf2739d8e Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Wed, 24 Apr 2024 12:27:18 +0200 Subject: [PATCH] feat: Expand `Pipeline.inputs` and `Pipeline.outputs` to include connected sockets (#7586) --- haystack/core/pipeline/descriptions.py | 27 +++++++++++++------ haystack/core/pipeline/pipeline.py | 14 +++++++--- ...io-connected-sockets-db862d045944f788.yaml | 4 +++ test/core/pipeline/test_pipeline.py | 25 +++++++++++++++++ 4 files changed, 58 insertions(+), 12 deletions(-) create mode 100644 releasenotes/notes/pipeline-io-connected-sockets-db862d045944f788.yaml diff --git a/haystack/core/pipeline/descriptions.py b/haystack/core/pipeline/descriptions.py index 00a6bc25d..f74136f83 100644 --- a/haystack/core/pipeline/descriptions.py +++ b/haystack/core/pipeline/descriptions.py @@ -12,24 +12,35 @@ from haystack.core.type_utils import _type_name logger = logging.getLogger(__name__) -def find_pipeline_inputs(graph: networkx.MultiDiGraph) -> Dict[str, List[InputSocket]]: +def find_pipeline_inputs( + graph: networkx.MultiDiGraph, include_connected_sockets: bool = False +) -> Dict[str, List[InputSocket]]: """ - Collect components that have disconnected input sockets. - - Note that this method returns *ALL* disconnected input sockets, including all such sockets with default values. + Collect components that have disconnected/connected input sockets. Note that this method returns *ALL* + disconnected input sockets, including all such sockets with default values. """ return { - name: [socket for socket in data.get("input_sockets", {}).values() if not socket.senders or socket.is_variadic] + name: [ + socket + for socket in data.get("input_sockets", {}).values() + if socket.is_variadic or (include_connected_sockets or not socket.senders) + ] for name, data in graph.nodes(data=True) } -def find_pipeline_outputs(graph: networkx.MultiDiGraph) -> Dict[str, List[OutputSocket]]: +def find_pipeline_outputs( + graph: networkx.MultiDiGraph, include_connected_sockets: bool = False +) -> Dict[str, List[OutputSocket]]: """ - Collect components that have disconnected output sockets. They define the pipeline output. + Collect components that have disconnected/connected output sockets. They define the pipeline output. """ return { - name: [socket for socket in data.get("output_sockets", {}).values() if not socket.receivers] + name: [ + socket + for socket in data.get("output_sockets", {}).values() + if (include_connected_sockets or not socket.receivers) + ] for name, data in graph.nodes(data=True) } diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index b72914818..ae09add8c 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -505,19 +505,22 @@ class Pipeline: return name return "" - def inputs(self) -> Dict[str, Dict[str, Any]]: + def inputs(self, include_components_with_connected_inputs: bool = False) -> Dict[str, Dict[str, Any]]: """ Returns a dictionary containing the inputs of a pipeline. Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes the input sockets of that component, including their types and whether they are optional. + :param include_components_with_connected_inputs: + If `False`, only components that have disconnected input edges are + included in the output. :returns: A dictionary where each key is a pipeline component name and each value is a dictionary of inputs sockets of that component. """ inputs: Dict[str, Dict[str, Any]] = {} - for component_name, data in find_pipeline_inputs(self.graph).items(): + for component_name, data in find_pipeline_inputs(self.graph, include_components_with_connected_inputs).items(): sockets_description = {} for socket in data: sockets_description[socket.name] = {"type": socket.type, "is_mandatory": socket.is_mandatory} @@ -528,20 +531,23 @@ class Pipeline: inputs[component_name] = sockets_description return inputs - def outputs(self) -> Dict[str, Dict[str, Any]]: + def outputs(self, include_components_with_connected_outputs: bool = False) -> Dict[str, Dict[str, Any]]: """ Returns a dictionary containing the outputs of a pipeline. Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes the output sockets of that component. + :param include_components_with_connected_outputs: + If `False`, only components that have disconnected output edges are + included in the output. :returns: A dictionary where each key is a pipeline component name and each value is a dictionary of output sockets of that component. """ outputs = { comp: {socket.name: {"type": socket.type} for socket in data} - for comp, data in find_pipeline_outputs(self.graph).items() + for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items() if data } return outputs diff --git a/releasenotes/notes/pipeline-io-connected-sockets-db862d045944f788.yaml b/releasenotes/notes/pipeline-io-connected-sockets-db862d045944f788.yaml new file mode 100644 index 000000000..1ccfa6332 --- /dev/null +++ b/releasenotes/notes/pipeline-io-connected-sockets-db862d045944f788.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + `Pipeline.inputs` and `Pipeline.outputs` can optionally include components input/output sockets that are connected. diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 1e34a9292..3e00b2faf 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -673,6 +673,9 @@ def test_describe_input_only_no_inputs_components(): p.connect("a.x", "c.x") p.connect("b.y", "c.y") assert p.inputs() == {} + assert p.inputs(include_components_with_connected_inputs=True) == { + "c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}} + } def test_describe_input_some_components_with_no_inputs(): @@ -686,6 +689,10 @@ def test_describe_input_some_components_with_no_inputs(): p.connect("a.x", "c.x") p.connect("b.y", "c.y") assert p.inputs() == {"b": {"y": {"type": int, "is_mandatory": True}}} + assert p.inputs(include_components_with_connected_inputs=True) == { + "b": {"y": {"type": int, "is_mandatory": True}}, + "c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}, + } def test_describe_input_all_components_have_inputs(): @@ -702,6 +709,11 @@ def test_describe_input_all_components_have_inputs(): "a": {"x": {"type": Optional[int], "is_mandatory": True}}, "b": {"y": {"type": int, "is_mandatory": True}}, } + assert p.inputs(include_components_with_connected_inputs=True) == { + "a": {"x": {"type": Optional[int], "is_mandatory": True}}, + "b": {"y": {"type": int, "is_mandatory": True}}, + "c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}, + } def test_describe_output_multiple_possible(): @@ -718,6 +730,10 @@ def test_describe_output_multiple_possible(): pipe.connect("a.output_b", "b.input_b") assert pipe.outputs() == {"b": {"output_b": {"type": str}}, "a": {"output_a": {"type": str}}} + assert pipe.outputs(include_components_with_connected_outputs=True) == { + "a": {"output_a": {"type": str}, "output_b": {"type": str}}, + "b": {"output_b": {"type": str}}, + } def test_describe_output_single(): @@ -736,6 +752,11 @@ def test_describe_output_single(): p.connect("b.y", "c.y") assert p.outputs() == {"c": {"z": {"type": int}}} + assert p.outputs(include_components_with_connected_outputs=True) == { + "a": {"x": {"type": int}}, + "b": {"y": {"type": int}}, + "c": {"z": {"type": int}}, + } def test_describe_no_outputs(): @@ -753,6 +774,10 @@ def test_describe_no_outputs(): p.connect("a.x", "c.x") p.connect("b.y", "c.y") assert p.outputs() == {} + assert p.outputs(include_components_with_connected_outputs=True) == { + "a": {"x": {"type": int}}, + "b": {"y": {"type": int}}, + } def test_from_template(monkeypatch):