mirror of
https://github.com/langgenius/dify.git
synced 2025-11-26 01:43:25 +00:00
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: Stream <Stream_2@qq.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do> Co-authored-by: Harry <xh001x@hotmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: yessenia <yessenia.contact@gmail.com> Co-authored-by: hjlarry <hjlarry@163.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WTW0313 <twwu@dify.ai> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
162 lines
5.3 KiB
Python
162 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Protocol
|
|
|
|
from core.workflow.enums import NodeExecutionType, NodeType
|
|
|
|
if TYPE_CHECKING:
|
|
from .graph import Graph
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class GraphValidationIssue:
|
|
"""Immutable value object describing a single validation issue."""
|
|
|
|
code: str
|
|
message: str
|
|
node_id: str | None = None
|
|
|
|
|
|
class GraphValidationError(ValueError):
|
|
"""Raised when graph validation fails."""
|
|
|
|
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
|
|
if not issues:
|
|
raise ValueError("GraphValidationError requires at least one issue.")
|
|
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
|
|
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
|
|
super().__init__(message)
|
|
|
|
|
|
class GraphValidationRule(Protocol):
|
|
"""Protocol that individual validation rules must satisfy."""
|
|
|
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
|
"""Validate the provided graph and return any discovered issues."""
|
|
...
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class _EdgeEndpointValidator:
|
|
"""Ensures all edges reference existing nodes."""
|
|
|
|
missing_node_code: str = "MISSING_NODE"
|
|
|
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
|
issues: list[GraphValidationIssue] = []
|
|
for edge in graph.edges.values():
|
|
if edge.tail not in graph.nodes:
|
|
issues.append(
|
|
GraphValidationIssue(
|
|
code=self.missing_node_code,
|
|
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
|
|
node_id=edge.tail,
|
|
)
|
|
)
|
|
if edge.head not in graph.nodes:
|
|
issues.append(
|
|
GraphValidationIssue(
|
|
code=self.missing_node_code,
|
|
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
|
|
node_id=edge.head,
|
|
)
|
|
)
|
|
return issues
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class _RootNodeValidator:
|
|
"""Validates root node invariants."""
|
|
|
|
invalid_root_code: str = "INVALID_ROOT"
|
|
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
|
|
|
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
|
root_node = graph.root_node
|
|
issues: list[GraphValidationIssue] = []
|
|
if root_node.id not in graph.nodes:
|
|
issues.append(
|
|
GraphValidationIssue(
|
|
code=self.invalid_root_code,
|
|
message=f"Root node '{root_node.id}' is missing from the node registry.",
|
|
node_id=root_node.id,
|
|
)
|
|
)
|
|
return issues
|
|
|
|
node_type = getattr(root_node, "node_type", None)
|
|
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
|
|
issues.append(
|
|
GraphValidationIssue(
|
|
code=self.invalid_root_code,
|
|
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
|
|
node_id=root_node.id,
|
|
)
|
|
)
|
|
return issues
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class GraphValidator:
|
|
"""Coordinates execution of graph validation rules."""
|
|
|
|
rules: tuple[GraphValidationRule, ...]
|
|
|
|
def validate(self, graph: Graph) -> None:
|
|
"""Validate the graph against all configured rules."""
|
|
issues: list[GraphValidationIssue] = []
|
|
for rule in self.rules:
|
|
issues.extend(rule.validate(graph))
|
|
|
|
if issues:
|
|
raise GraphValidationError(issues)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class _TriggerStartExclusivityValidator:
|
|
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
|
|
|
|
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
|
|
|
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
|
start_node_id: str | None = None
|
|
trigger_node_ids: list[str] = []
|
|
|
|
for node in graph.nodes.values():
|
|
node_type = getattr(node, "node_type", None)
|
|
if not isinstance(node_type, NodeType):
|
|
continue
|
|
|
|
if node_type == NodeType.START:
|
|
start_node_id = node.id
|
|
elif node_type.is_trigger_node:
|
|
trigger_node_ids.append(node.id)
|
|
|
|
if start_node_id and trigger_node_ids:
|
|
trigger_list = ", ".join(trigger_node_ids)
|
|
return [
|
|
GraphValidationIssue(
|
|
code=self.conflict_code,
|
|
message=(
|
|
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
|
|
),
|
|
node_id=start_node_id,
|
|
)
|
|
]
|
|
|
|
return []
|
|
|
|
|
|
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
|
_EdgeEndpointValidator(),
|
|
_RootNodeValidator(),
|
|
_TriggerStartExclusivityValidator(),
|
|
)
|
|
|
|
|
|
def get_graph_validator() -> GraphValidator:
|
|
"""Construct the validator composed of default rules."""
|
|
return GraphValidator(_DEFAULT_RULES)
|