feat: Highlight optional connections in Pipeline.draw() (#6724)

* highlight optional connections in Pipeline.draw()

* reno
This commit is contained in:
ZanSara 2024-01-15 12:18:51 +01:00 committed by GitHub
parent a5189dd035
commit 24afc2a7fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 7 deletions

View File

@ -76,7 +76,9 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph, style_map: Dict[str, str]
# Label the edges
for inp, outp, key, data in graph.edges(keys=True, data=True):
data["label"] = f"{data['from_socket'].name} -> {data['to_socket'].name}"
data[
"label"
] = f"{data['from_socket'].name} -> {data['to_socket'].name}{' (opt.)' if not data['mandatory'] else ''}"
graph.add_edge(inp, outp, key=key, **data)
# Draw the inputs

View File

@ -12,7 +12,10 @@ from haystack.core.type_utils import _type_name
logger = logging.getLogger(__name__)
ARROWTAIL_MANDATORY = "--"
ARROWTAIL_OPTIONAL = "-."
ARROWHEAD_MANDATORY = "-->"
ARROWHEAD_OPTIONAL = ".->"
MERMAID_STYLED_TEMPLATE = """
%%{{ init: {{'theme': 'neutral' }} }}%%
@ -81,11 +84,15 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
if comp not in ["input", "output"]
}
connections_list = [
f"{states[from_comp]} -- \"{conn_data['label']}<br><small><i>{conn_data['conn_type']}</i></small>\" --> {states[to_comp]}"
for from_comp, to_comp, conn_data in graph.edges(data=True)
if from_comp != "input" and to_comp != "output"
]
connections_list = []
for from_comp, to_comp, conn_data in graph.edges(data=True):
if from_comp != "input" and to_comp != "output":
arrowtail = ARROWTAIL_MANDATORY if conn_data["mandatory"] else ARROWTAIL_OPTIONAL
arrowhead = ARROWHEAD_MANDATORY if conn_data["mandatory"] else ARROWHEAD_OPTIONAL
label = f'"{conn_data["label"]}<br><small><i>{conn_data["conn_type"]}</i></small>"'
conn_string = f"{states[from_comp]} {arrowtail} {label} {arrowhead} {states[to_comp]}"
connections_list.append(conn_string)
input_connections = [
f"i{{*}} -- \"{conn_data['label']}<br><small><i>{conn_data['conn_type']}</i></small>\" --> {states[to_comp]}"
for _, to_comp, conn_data in graph.out_edges("input", data=True)

View File

@ -327,6 +327,7 @@ class Pipeline:
conn_type=_type_name(connection.sender_socket.type),
from_socket=connection.sender_socket,
to_socket=connection.receiver_socket,
mandatory=connection.is_mandatory,
)
self._connections.append(connection)

View File

@ -0,0 +1,3 @@
---
enhancements:
- Highlight optional connections in the `Pipeline.draw()` output.