mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
feat: Expand Pipeline.inputs and Pipeline.outputs to include connected sockets (#7586)
This commit is contained in:
parent
19a46af9da
commit
ec0e22265a
@ -12,24 +12,35 @@ from haystack.core.type_utils import _type_name
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_pipeline_inputs(graph: networkx.MultiDiGraph) -> Dict[str, List[InputSocket]]:
|
||||
def find_pipeline_inputs(
|
||||
graph: networkx.MultiDiGraph, include_connected_sockets: bool = False
|
||||
) -> Dict[str, List[InputSocket]]:
|
||||
"""
|
||||
Collect components that have disconnected input sockets.
|
||||
|
||||
Note that this method returns *ALL* disconnected input sockets, including all such sockets with default values.
|
||||
Collect components that have disconnected/connected input sockets. Note that this method returns *ALL*
|
||||
disconnected input sockets, including all such sockets with default values.
|
||||
"""
|
||||
return {
|
||||
name: [socket for socket in data.get("input_sockets", {}).values() if not socket.senders or socket.is_variadic]
|
||||
name: [
|
||||
socket
|
||||
for socket in data.get("input_sockets", {}).values()
|
||||
if socket.is_variadic or (include_connected_sockets or not socket.senders)
|
||||
]
|
||||
for name, data in graph.nodes(data=True)
|
||||
}
|
||||
|
||||
|
||||
def find_pipeline_outputs(graph: networkx.MultiDiGraph) -> Dict[str, List[OutputSocket]]:
|
||||
def find_pipeline_outputs(
|
||||
graph: networkx.MultiDiGraph, include_connected_sockets: bool = False
|
||||
) -> Dict[str, List[OutputSocket]]:
|
||||
"""
|
||||
Collect components that have disconnected output sockets. They define the pipeline output.
|
||||
Collect components that have disconnected/connected output sockets. They define the pipeline output.
|
||||
"""
|
||||
return {
|
||||
name: [socket for socket in data.get("output_sockets", {}).values() if not socket.receivers]
|
||||
name: [
|
||||
socket
|
||||
for socket in data.get("output_sockets", {}).values()
|
||||
if (include_connected_sockets or not socket.receivers)
|
||||
]
|
||||
for name, data in graph.nodes(data=True)
|
||||
}
|
||||
|
||||
|
||||
@ -505,19 +505,22 @@ class Pipeline:
|
||||
return name
|
||||
return ""
|
||||
|
||||
def inputs(self) -> Dict[str, Dict[str, Any]]:
|
||||
def inputs(self, include_components_with_connected_inputs: bool = False) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Returns a dictionary containing the inputs of a pipeline.
|
||||
|
||||
Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
|
||||
the input sockets of that component, including their types and whether they are optional.
|
||||
|
||||
:param include_components_with_connected_inputs:
|
||||
If `False`, only components that have disconnected input edges are
|
||||
included in the output.
|
||||
:returns:
|
||||
A dictionary where each key is a pipeline component name and each value is a dictionary of
|
||||
inputs sockets of that component.
|
||||
"""
|
||||
inputs: Dict[str, Dict[str, Any]] = {}
|
||||
for component_name, data in find_pipeline_inputs(self.graph).items():
|
||||
for component_name, data in find_pipeline_inputs(self.graph, include_components_with_connected_inputs).items():
|
||||
sockets_description = {}
|
||||
for socket in data:
|
||||
sockets_description[socket.name] = {"type": socket.type, "is_mandatory": socket.is_mandatory}
|
||||
@ -528,20 +531,23 @@ class Pipeline:
|
||||
inputs[component_name] = sockets_description
|
||||
return inputs
|
||||
|
||||
def outputs(self) -> Dict[str, Dict[str, Any]]:
|
||||
def outputs(self, include_components_with_connected_outputs: bool = False) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Returns a dictionary containing the outputs of a pipeline.
|
||||
|
||||
Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
|
||||
the output sockets of that component.
|
||||
|
||||
:param include_components_with_connected_outputs:
|
||||
If `False`, only components that have disconnected output edges are
|
||||
included in the output.
|
||||
:returns:
|
||||
A dictionary where each key is a pipeline component name and each value is a dictionary of
|
||||
output sockets of that component.
|
||||
"""
|
||||
outputs = {
|
||||
comp: {socket.name: {"type": socket.type} for socket in data}
|
||||
for comp, data in find_pipeline_outputs(self.graph).items()
|
||||
for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items()
|
||||
if data
|
||||
}
|
||||
return outputs
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
`Pipeline.inputs` and `Pipeline.outputs` can optionally include components input/output sockets that are connected.
|
||||
@ -673,6 +673,9 @@ def test_describe_input_only_no_inputs_components():
|
||||
p.connect("a.x", "c.x")
|
||||
p.connect("b.y", "c.y")
|
||||
assert p.inputs() == {}
|
||||
assert p.inputs(include_components_with_connected_inputs=True) == {
|
||||
"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}
|
||||
}
|
||||
|
||||
|
||||
def test_describe_input_some_components_with_no_inputs():
|
||||
@ -686,6 +689,10 @@ def test_describe_input_some_components_with_no_inputs():
|
||||
p.connect("a.x", "c.x")
|
||||
p.connect("b.y", "c.y")
|
||||
assert p.inputs() == {"b": {"y": {"type": int, "is_mandatory": True}}}
|
||||
assert p.inputs(include_components_with_connected_inputs=True) == {
|
||||
"b": {"y": {"type": int, "is_mandatory": True}},
|
||||
"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}},
|
||||
}
|
||||
|
||||
|
||||
def test_describe_input_all_components_have_inputs():
|
||||
@ -702,6 +709,11 @@ def test_describe_input_all_components_have_inputs():
|
||||
"a": {"x": {"type": Optional[int], "is_mandatory": True}},
|
||||
"b": {"y": {"type": int, "is_mandatory": True}},
|
||||
}
|
||||
assert p.inputs(include_components_with_connected_inputs=True) == {
|
||||
"a": {"x": {"type": Optional[int], "is_mandatory": True}},
|
||||
"b": {"y": {"type": int, "is_mandatory": True}},
|
||||
"c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}},
|
||||
}
|
||||
|
||||
|
||||
def test_describe_output_multiple_possible():
|
||||
@ -718,6 +730,10 @@ def test_describe_output_multiple_possible():
|
||||
pipe.connect("a.output_b", "b.input_b")
|
||||
|
||||
assert pipe.outputs() == {"b": {"output_b": {"type": str}}, "a": {"output_a": {"type": str}}}
|
||||
assert pipe.outputs(include_components_with_connected_outputs=True) == {
|
||||
"a": {"output_a": {"type": str}, "output_b": {"type": str}},
|
||||
"b": {"output_b": {"type": str}},
|
||||
}
|
||||
|
||||
|
||||
def test_describe_output_single():
|
||||
@ -736,6 +752,11 @@ def test_describe_output_single():
|
||||
p.connect("b.y", "c.y")
|
||||
|
||||
assert p.outputs() == {"c": {"z": {"type": int}}}
|
||||
assert p.outputs(include_components_with_connected_outputs=True) == {
|
||||
"a": {"x": {"type": int}},
|
||||
"b": {"y": {"type": int}},
|
||||
"c": {"z": {"type": int}},
|
||||
}
|
||||
|
||||
|
||||
def test_describe_no_outputs():
|
||||
@ -753,6 +774,10 @@ def test_describe_no_outputs():
|
||||
p.connect("a.x", "c.x")
|
||||
p.connect("b.y", "c.y")
|
||||
assert p.outputs() == {}
|
||||
assert p.outputs(include_components_with_connected_outputs=True) == {
|
||||
"a": {"x": {"type": int}},
|
||||
"b": {"y": {"type": int}},
|
||||
}
|
||||
|
||||
|
||||
def test_from_template(monkeypatch):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user