mirror of
https://github.com/langgenius/dify.git
synced 2025-11-24 17:05:14 +00:00
126 lines
4.1 KiB
Python
126 lines
4.1 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Protocol
|
|
|
|
from core.workflow.enums import NodeExecutionType, NodeType
|
|
|
|
if TYPE_CHECKING:
|
|
from .graph import Graph
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class GraphValidationIssue:
|
|
"""Immutable value object describing a single validation issue."""
|
|
|
|
code: str
|
|
message: str
|
|
node_id: str | None = None
|
|
|
|
|
|
class GraphValidationError(ValueError):
|
|
"""Raised when graph validation fails."""
|
|
|
|
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
|
|
if not issues:
|
|
raise ValueError("GraphValidationError requires at least one issue.")
|
|
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
|
|
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
|
|
super().__init__(message)
|
|
|
|
|
|
class GraphValidationRule(Protocol):
|
|
"""Protocol that individual validation rules must satisfy."""
|
|
|
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
|
"""Validate the provided graph and return any discovered issues."""
|
|
...
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class _EdgeEndpointValidator:
|
|
"""Ensures all edges reference existing nodes."""
|
|
|
|
missing_node_code: str = "MISSING_NODE"
|
|
|
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
|
issues: list[GraphValidationIssue] = []
|
|
for edge in graph.edges.values():
|
|
if edge.tail not in graph.nodes:
|
|
issues.append(
|
|
GraphValidationIssue(
|
|
code=self.missing_node_code,
|
|
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
|
|
node_id=edge.tail,
|
|
)
|
|
)
|
|
if edge.head not in graph.nodes:
|
|
issues.append(
|
|
GraphValidationIssue(
|
|
code=self.missing_node_code,
|
|
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
|
|
node_id=edge.head,
|
|
)
|
|
)
|
|
return issues
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class _RootNodeValidator:
|
|
"""Validates root node invariants."""
|
|
|
|
invalid_root_code: str = "INVALID_ROOT"
|
|
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
|
|
|
|
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
|
root_node = graph.root_node
|
|
issues: list[GraphValidationIssue] = []
|
|
if root_node.id not in graph.nodes:
|
|
issues.append(
|
|
GraphValidationIssue(
|
|
code=self.invalid_root_code,
|
|
message=f"Root node '{root_node.id}' is missing from the node registry.",
|
|
node_id=root_node.id,
|
|
)
|
|
)
|
|
return issues
|
|
|
|
node_type = getattr(root_node, "node_type", None)
|
|
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
|
|
issues.append(
|
|
GraphValidationIssue(
|
|
code=self.invalid_root_code,
|
|
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
|
|
node_id=root_node.id,
|
|
)
|
|
)
|
|
return issues
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class GraphValidator:
|
|
"""Coordinates execution of graph validation rules."""
|
|
|
|
rules: tuple[GraphValidationRule, ...]
|
|
|
|
def validate(self, graph: Graph) -> None:
|
|
"""Validate the graph against all configured rules."""
|
|
issues: list[GraphValidationIssue] = []
|
|
for rule in self.rules:
|
|
issues.extend(rule.validate(graph))
|
|
|
|
if issues:
|
|
raise GraphValidationError(issues)
|
|
|
|
|
|
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
|
_EdgeEndpointValidator(),
|
|
_RootNodeValidator(),
|
|
)
|
|
|
|
|
|
def get_graph_validator() -> GraphValidator:
|
|
"""Construct the validator composed of default rules."""
|
|
return GraphValidator(_DEFAULT_RULES)
|