mirror of
https://github.com/langgenius/dify.git
synced 2025-11-09 07:53:55 +00:00
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: jyong <718720800@qq.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: quicksand <quicksandzn@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Hanqing Zhao <sherry9277@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry <xh001x@hotmail.com>
347 lines
12 KiB
Python
347 lines
12 KiB
Python
import logging
|
|
from collections import defaultdict
|
|
from collections.abc import Mapping, Sequence
|
|
from typing import Protocol, cast, final
|
|
|
|
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
|
from core.workflow.nodes.base.node import Node
|
|
from libs.typing import is_str, is_str_dict
|
|
|
|
from .edge import Edge
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NodeFactory(Protocol):
|
|
"""
|
|
Protocol for creating Node instances from node data dictionaries.
|
|
|
|
This protocol decouples the Graph class from specific node mapping implementations,
|
|
allowing for different node creation strategies while maintaining type safety.
|
|
"""
|
|
|
|
def create_node(self, node_config: dict[str, object]) -> Node:
|
|
"""
|
|
Create a Node instance from node configuration data.
|
|
|
|
:param node_config: node configuration dictionary containing type and other data
|
|
:return: initialized Node instance
|
|
:raises ValueError: if node type is unknown or configuration is invalid
|
|
"""
|
|
...
|
|
|
|
|
|
@final
|
|
class Graph:
|
|
"""Graph representation with nodes and edges for workflow execution."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
nodes: dict[str, Node] | None = None,
|
|
edges: dict[str, Edge] | None = None,
|
|
in_edges: dict[str, list[str]] | None = None,
|
|
out_edges: dict[str, list[str]] | None = None,
|
|
root_node: Node,
|
|
):
|
|
"""
|
|
Initialize Graph instance.
|
|
|
|
:param nodes: graph nodes mapping (node id: node object)
|
|
:param edges: graph edges mapping (edge id: edge object)
|
|
:param in_edges: incoming edges mapping (node id: list of edge ids)
|
|
:param out_edges: outgoing edges mapping (node id: list of edge ids)
|
|
:param root_node: root node object
|
|
"""
|
|
self.nodes = nodes or {}
|
|
self.edges = edges or {}
|
|
self.in_edges = in_edges or {}
|
|
self.out_edges = out_edges or {}
|
|
self.root_node = root_node
|
|
|
|
@classmethod
|
|
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
|
|
"""
|
|
Parse node configurations and build a mapping of node IDs to configs.
|
|
|
|
:param node_configs: list of node configuration dictionaries
|
|
:return: mapping of node ID to node config
|
|
"""
|
|
node_configs_map: dict[str, dict[str, object]] = {}
|
|
|
|
for node_config in node_configs:
|
|
node_id = node_config.get("id")
|
|
if not node_id or not isinstance(node_id, str):
|
|
continue
|
|
|
|
node_configs_map[node_id] = node_config
|
|
|
|
return node_configs_map
|
|
|
|
@classmethod
|
|
def _find_root_node_id(
|
|
cls,
|
|
node_configs_map: Mapping[str, Mapping[str, object]],
|
|
edge_configs: Sequence[Mapping[str, object]],
|
|
root_node_id: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Find the root node ID if not specified.
|
|
|
|
:param node_configs_map: mapping of node ID to node config
|
|
:param edge_configs: list of edge configurations
|
|
:param root_node_id: explicitly specified root node ID
|
|
:return: determined root node ID
|
|
"""
|
|
if root_node_id:
|
|
if root_node_id not in node_configs_map:
|
|
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
|
return root_node_id
|
|
|
|
# Find nodes with no incoming edges
|
|
nodes_with_incoming: set[str] = set()
|
|
for edge_config in edge_configs:
|
|
target = edge_config.get("target")
|
|
if isinstance(target, str):
|
|
nodes_with_incoming.add(target)
|
|
|
|
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
|
|
|
|
# Prefer START node if available
|
|
start_node_id = None
|
|
for nid in root_candidates:
|
|
node_data = node_configs_map[nid].get("data")
|
|
if not is_str_dict(node_data):
|
|
continue
|
|
node_type = node_data.get("type")
|
|
if not isinstance(node_type, str):
|
|
continue
|
|
if node_type in [NodeType.START, NodeType.DATASOURCE]:
|
|
start_node_id = nid
|
|
break
|
|
|
|
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
|
|
|
|
if not root_node_id:
|
|
raise ValueError("Unable to determine root node ID")
|
|
|
|
return root_node_id
|
|
|
|
@classmethod
|
|
def _build_edges(
|
|
cls, edge_configs: list[dict[str, object]]
|
|
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
|
|
"""
|
|
Build edge objects and mappings from edge configurations.
|
|
|
|
:param edge_configs: list of edge configurations
|
|
:return: tuple of (edges dict, in_edges dict, out_edges dict)
|
|
"""
|
|
edges: dict[str, Edge] = {}
|
|
in_edges: dict[str, list[str]] = defaultdict(list)
|
|
out_edges: dict[str, list[str]] = defaultdict(list)
|
|
|
|
edge_counter = 0
|
|
for edge_config in edge_configs:
|
|
source = edge_config.get("source")
|
|
target = edge_config.get("target")
|
|
|
|
if not is_str(source) or not is_str(target):
|
|
continue
|
|
|
|
# Create edge
|
|
edge_id = f"edge_{edge_counter}"
|
|
edge_counter += 1
|
|
|
|
source_handle = edge_config.get("sourceHandle", "source")
|
|
if not is_str(source_handle):
|
|
continue
|
|
|
|
edge = Edge(
|
|
id=edge_id,
|
|
tail=source,
|
|
head=target,
|
|
source_handle=source_handle,
|
|
)
|
|
|
|
edges[edge_id] = edge
|
|
out_edges[source].append(edge_id)
|
|
in_edges[target].append(edge_id)
|
|
|
|
return edges, dict(in_edges), dict(out_edges)
|
|
|
|
@classmethod
|
|
def _create_node_instances(
|
|
cls,
|
|
node_configs_map: dict[str, dict[str, object]],
|
|
node_factory: "NodeFactory",
|
|
) -> dict[str, Node]:
|
|
"""
|
|
Create node instances from configurations using the node factory.
|
|
|
|
:param node_configs_map: mapping of node ID to node config
|
|
:param node_factory: factory for creating node instances
|
|
:return: mapping of node ID to node instance
|
|
"""
|
|
nodes: dict[str, Node] = {}
|
|
|
|
for node_id, node_config in node_configs_map.items():
|
|
try:
|
|
node_instance = node_factory.create_node(node_config)
|
|
except Exception:
|
|
logger.exception("Failed to create node instance for node_id %s", node_id)
|
|
raise
|
|
nodes[node_id] = node_instance
|
|
|
|
return nodes
|
|
|
|
@classmethod
|
|
def _mark_inactive_root_branches(
|
|
cls,
|
|
nodes: dict[str, Node],
|
|
edges: dict[str, Edge],
|
|
in_edges: dict[str, list[str]],
|
|
out_edges: dict[str, list[str]],
|
|
active_root_id: str,
|
|
) -> None:
|
|
"""
|
|
Mark nodes and edges from inactive root branches as skipped.
|
|
|
|
Algorithm:
|
|
1. Mark inactive root nodes as skipped
|
|
2. For skipped nodes, mark all their outgoing edges as skipped
|
|
3. For each edge marked as skipped, check its target node:
|
|
- If ALL incoming edges are skipped, mark the node as skipped
|
|
- Otherwise, leave the node state unchanged
|
|
|
|
:param nodes: mapping of node ID to node instance
|
|
:param edges: mapping of edge ID to edge instance
|
|
:param in_edges: mapping of node ID to incoming edge IDs
|
|
:param out_edges: mapping of node ID to outgoing edge IDs
|
|
:param active_root_id: ID of the active root node
|
|
"""
|
|
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
|
|
top_level_roots: list[str] = [
|
|
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
|
|
]
|
|
|
|
# If there's only one root or the active root is not a top-level root, no marking needed
|
|
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
|
|
return
|
|
|
|
# Mark inactive root nodes as skipped
|
|
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
|
|
for root_id in inactive_roots:
|
|
if root_id in nodes:
|
|
nodes[root_id].state = NodeState.SKIPPED
|
|
|
|
# Recursively mark downstream nodes and edges
|
|
def mark_downstream(node_id: str) -> None:
|
|
"""Recursively mark downstream nodes and edges as skipped."""
|
|
if nodes[node_id].state != NodeState.SKIPPED:
|
|
return
|
|
# If this node is skipped, mark all its outgoing edges as skipped
|
|
out_edge_ids = out_edges.get(node_id, [])
|
|
for edge_id in out_edge_ids:
|
|
edge = edges[edge_id]
|
|
edge.state = NodeState.SKIPPED
|
|
|
|
# Check the target node of this edge
|
|
target_node = nodes[edge.head]
|
|
in_edge_ids = in_edges.get(target_node.id, [])
|
|
in_edge_states = [edges[eid].state for eid in in_edge_ids]
|
|
|
|
# If all incoming edges are skipped, mark the node as skipped
|
|
if all(state == NodeState.SKIPPED for state in in_edge_states):
|
|
target_node.state = NodeState.SKIPPED
|
|
# Recursively process downstream nodes
|
|
mark_downstream(target_node.id)
|
|
|
|
# Process each inactive root and its downstream nodes
|
|
for root_id in inactive_roots:
|
|
mark_downstream(root_id)
|
|
|
|
@classmethod
|
|
def init(
|
|
cls,
|
|
*,
|
|
graph_config: Mapping[str, object],
|
|
node_factory: "NodeFactory",
|
|
root_node_id: str | None = None,
|
|
) -> "Graph":
|
|
"""
|
|
Initialize graph
|
|
|
|
:param graph_config: graph config containing nodes and edges
|
|
:param node_factory: factory for creating node instances from config data
|
|
:param root_node_id: root node id
|
|
:return: graph instance
|
|
"""
|
|
# Parse configs
|
|
edge_configs = graph_config.get("edges", [])
|
|
node_configs = graph_config.get("nodes", [])
|
|
|
|
edge_configs = cast(list[dict[str, object]], edge_configs)
|
|
node_configs = cast(list[dict[str, object]], node_configs)
|
|
|
|
if not node_configs:
|
|
raise ValueError("Graph must have at least one node")
|
|
|
|
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
|
|
|
# Parse node configurations
|
|
node_configs_map = cls._parse_node_configs(node_configs)
|
|
|
|
# Find root node
|
|
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
|
|
|
|
# Build edges
|
|
edges, in_edges, out_edges = cls._build_edges(edge_configs)
|
|
|
|
# Create node instances
|
|
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
|
|
|
# Get root node instance
|
|
root_node = nodes[root_node_id]
|
|
|
|
# Mark inactive root branches as skipped
|
|
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
|
|
|
# Create and return the graph
|
|
return cls(
|
|
nodes=nodes,
|
|
edges=edges,
|
|
in_edges=in_edges,
|
|
out_edges=out_edges,
|
|
root_node=root_node,
|
|
)
|
|
|
|
@property
|
|
def node_ids(self) -> list[str]:
|
|
"""
|
|
Get list of node IDs (compatibility property for existing code)
|
|
|
|
:return: list of node IDs
|
|
"""
|
|
return list(self.nodes.keys())
|
|
|
|
def get_outgoing_edges(self, node_id: str) -> list[Edge]:
|
|
"""
|
|
Get all outgoing edges from a node (V2 method)
|
|
|
|
:param node_id: node id
|
|
:return: list of outgoing edges
|
|
"""
|
|
edge_ids = self.out_edges.get(node_id, [])
|
|
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
|
|
|
def get_incoming_edges(self, node_id: str) -> list[Edge]:
|
|
"""
|
|
Get all incoming edges to a node (V2 method)
|
|
|
|
:param node_id: node id
|
|
:return: list of incoming edges
|
|
"""
|
|
edge_ids = self.in_edges.get(node_id, [])
|
|
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|