diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index b6d64d95da..e8088e17c1 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -10,20 +10,17 @@ from sqlalchemy.orm import Session, sessionmaker from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable from core.variables.variables import VariableUnion -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool, WorkflowNodeExecution +from core.workflow.entities import VariablePool, WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph.graph import Graph from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent from core.workflow.node_events import NodeRunResult from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData from core.workflow.system_variable import SystemVariable @@ -34,7 +31,6 @@ from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from models import Account -from models.enums import UserFrom from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType @@ -215,7 +211,7 @@ class WorkflowService: self.validate_features_structure(app_model=app_model, features=features) # validate graph structure - self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=graph) + self.validate_graph_structure(graph=graph) # create draft workflow if not found if not workflow: @@ -274,7 +270,7 @@ class WorkflowService: self._validate_workflow_credentials(draft_workflow) # validate graph structure - self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=draft_workflow.graph_dict) + self.validate_graph_structure(graph=draft_workflow.graph_dict) # create new workflow workflow = Workflow.new( @@ -905,42 +901,30 @@ class WorkflowService: return new_app - def validate_graph_structure(self, user_id: str, app_model: App, graph: Mapping[str, Any]): + def validate_graph_structure(self, graph: Mapping[str, Any]): """ - Validate workflow graph structure by instantiating the Graph object. + Validate workflow graph structure. - This leverages the built-in graph validators (including trigger/UserInput exclusivity) - and raises any structural errors before persisting the workflow. + This performs a lightweight validation on the graph, checking for structural + inconsistencies such as the coexistence of start and trigger nodes. """ node_configs = graph.get("nodes", []) - node_configs = cast(list[dict[str, object]], node_configs) + node_configs = cast(list[dict[str, Any]], node_configs) # is empty graph if not node_configs: return - workflow_id = app_model.workflow_id or "UNKNOWN" - Graph.init( - graph_config=graph, - # TODO(Mairuis): Add root node id - root_node_id=None, - node_factory=DifyNodeFactory( - graph_init_params=GraphInitParams( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=workflow_id, - graph_config=graph, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.VALIDATION, - call_depth=0, - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=VariablePool(), - start_at=time.perf_counter(), - ), - ), - ) + node_types: set[NodeType] = set() + for node in node_configs: + node_type = node.get("data", {}).get("type") + if node_type: + node_types.add(NodeType(node_type)) + + # start node and trigger node cannot coexist + if NodeType.START in node_types: + if any(nt.is_trigger_node for nt in node_types): + raise ValueError("Start node and trigger nodes cannot coexist in the same workflow") def validate_features_structure(self, app_model: App, features: dict): if app_model.mode == AppMode.ADVANCED_CHAT: