feat: Expand Pipeline.inputs and Pipeline.outputs to include connected sockets (#7586)

This commit is contained in:
Madeesh Kannan 2024-04-24 12:27:18 +02:00 committed by GitHub
parent 19a46af9da
commit ec0e22265a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 58 additions and 12 deletions

View File

@ -12,24 +12,35 @@ from haystack.core.type_utils import _type_name
logger = logging.getLogger(__name__)
def find_pipeline_inputs(graph: networkx.MultiDiGraph) -> Dict[str, List[InputSocket]]:
def find_pipeline_inputs(
graph: networkx.MultiDiGraph, include_connected_sockets: bool = False
) -> Dict[str, List[InputSocket]]:
"""
Collect components that have disconnected input sockets.
Note that this method returns *ALL* disconnected input sockets, including all such sockets with default values.
Collect components that have disconnected/connected input sockets. Note that this method returns *ALL*
disconnected input sockets, including all such sockets with default values.
"""
return {
name: [socket for socket in data.get("input_sockets", {}).values() if not socket.senders or socket.is_variadic]
name: [
socket
for socket in data.get("input_sockets", {}).values()
if socket.is_variadic or (include_connected_sockets or not socket.senders)
]
for name, data in graph.nodes(data=True)
}
def find_pipeline_outputs(graph: networkx.MultiDiGraph) -> Dict[str, List[OutputSocket]]:
def find_pipeline_outputs(
graph: networkx.MultiDiGraph, include_connected_sockets: bool = False
) -> Dict[str, List[OutputSocket]]:
"""
Collect components that have disconnected output sockets. They define the pipeline output.
Collect components that have disconnected/connected output sockets. They define the pipeline output.
"""
return {
name: [socket for socket in data.get("output_sockets", {}).values() if not socket.receivers]
name: [
socket
for socket in data.get("output_sockets", {}).values()
if (include_connected_sockets or not socket.receivers)
]
for name, data in graph.nodes(data=True)
}

View File

@ -505,19 +505,22 @@ class Pipeline:
return name
return ""
def inputs(self) -> Dict[str, Dict[str, Any]]:
def inputs(self, include_components_with_connected_inputs: bool = False) -> Dict[str, Dict[str, Any]]:
"""
Returns a dictionary containing the inputs of a pipeline.
Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
the input sockets of that component, including their types and whether they are optional.
:param include_components_with_connected_inputs:
If `False`, only components that have disconnected input edges are
included in the output.
:returns:
A dictionary where each key is a pipeline component name and each value is a dictionary of
inputs sockets of that component.
"""
inputs: Dict[str, Dict[str, Any]] = {}
for component_name, data in find_pipeline_inputs(self.graph).items():
for component_name, data in find_pipeline_inputs(self.graph, include_components_with_connected_inputs).items():
sockets_description = {}
for socket in data:
sockets_description[socket.name] = {"type": socket.type, "is_mandatory": socket.is_mandatory}
@ -528,20 +531,23 @@ class Pipeline:
inputs[component_name] = sockets_description
return inputs
def outputs(self) -> Dict[str, Dict[str, Any]]:
def outputs(self, include_components_with_connected_outputs: bool = False) -> Dict[str, Dict[str, Any]]:
"""
Returns a dictionary containing the outputs of a pipeline.
Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
the output sockets of that component.
:param include_components_with_connected_outputs:
If `False`, only components that have disconnected output edges are
included in the output.
:returns:
A dictionary where each key is a pipeline component name and each value is a dictionary of
output sockets of that component.
"""
outputs = {
comp: {socket.name: {"type": socket.type} for socket in data}
for comp, data in find_pipeline_outputs(self.graph).items()
for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items()
if data
}
return outputs

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
`Pipeline.inputs` and `Pipeline.outputs` can optionally include components input/output sockets that are connected.

View File

@ -673,6 +673,9 @@ def test_describe_input_only_no_inputs_components():
p.connect("a.x", "c.x")
p.connect("b.y", "c.y")
assert p.inputs() == {}
assert p.inputs(include_components_with_connected_inputs=True) == {
"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}
}
def test_describe_input_some_components_with_no_inputs():
@ -686,6 +689,10 @@ def test_describe_input_some_components_with_no_inputs():
p.connect("a.x", "c.x")
p.connect("b.y", "c.y")
assert p.inputs() == {"b": {"y": {"type": int, "is_mandatory": True}}}
assert p.inputs(include_components_with_connected_inputs=True) == {
"b": {"y": {"type": int, "is_mandatory": True}},
"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}},
}
def test_describe_input_all_components_have_inputs():
@ -702,6 +709,11 @@ def test_describe_input_all_components_have_inputs():
"a": {"x": {"type": Optional[int], "is_mandatory": True}},
"b": {"y": {"type": int, "is_mandatory": True}},
}
assert p.inputs(include_components_with_connected_inputs=True) == {
"a": {"x": {"type": Optional[int], "is_mandatory": True}},
"b": {"y": {"type": int, "is_mandatory": True}},
"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}},
}
def test_describe_output_multiple_possible():
@ -718,6 +730,10 @@ def test_describe_output_multiple_possible():
pipe.connect("a.output_b", "b.input_b")
assert pipe.outputs() == {"b": {"output_b": {"type": str}}, "a": {"output_a": {"type": str}}}
assert pipe.outputs(include_components_with_connected_outputs=True) == {
"a": {"output_a": {"type": str}, "output_b": {"type": str}},
"b": {"output_b": {"type": str}},
}
def test_describe_output_single():
@ -736,6 +752,11 @@ def test_describe_output_single():
p.connect("b.y", "c.y")
assert p.outputs() == {"c": {"z": {"type": int}}}
assert p.outputs(include_components_with_connected_outputs=True) == {
"a": {"x": {"type": int}},
"b": {"y": {"type": int}},
"c": {"z": {"type": int}},
}
def test_describe_no_outputs():
@ -753,6 +774,10 @@ def test_describe_no_outputs():
p.connect("a.x", "c.x")
p.connect("b.y", "c.y")
assert p.outputs() == {}
assert p.outputs(include_components_with_connected_outputs=True) == {
"a": {"x": {"type": int}},
"b": {"y": {"type": int}},
}
def test_from_template(monkeypatch):