From 41fef8a21fe9463e98f7b7ad29cad2cd7b5c45fd Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 16 Jun 2025 13:48:43 +0800 Subject: [PATCH] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 6 +- .../website_crawl/website_crawl_plugin.py | 5 +- api/core/tools/entities/tool_entities.py | 2 + .../workflow/graph_engine/entities/event.py | 5 ++ api/core/workflow/nodes/tool/tool_node.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 83 +++---------------- 6 files changed, 25 insertions(+), 78 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 7b8adfe560..c97b3b1d92 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -8,6 +8,7 @@ from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator import services from configs import dify_config from controllers.console import api @@ -453,7 +454,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource): raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() - result = rag_pipeline_service.run_datasource_workflow_node( + return helper.compact_generate_response(rag_pipeline_service.run_datasource_workflow_node( pipeline=pipeline, node_id=node_id, user_inputs=inputs, @@ -461,8 +462,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource): datasource_type=datasource_type, is_published=False ) - - return result + ) class RagPipelinePublishedNodeRunApi(Resource): diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index e8256b3282..87612fff44 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -1,10 +1,11 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Generator from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, + DatasourceInvokeMessage, DatasourceProviderType, GetWebsiteCrawlResponse, ) @@ -36,7 +37,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetWebsiteCrawlResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: manager = PluginDatasourceManager() return manager.get_website_crawl( diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 03047c0545..34a86555f7 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -188,6 +188,8 @@ class ToolInvokeMessage(BaseModel): FILE = "file" LOG = "log" BLOB_CHUNK = "blob_chunk" + WEBSITE_CRAWL = "website_crawl" + ONLINE_DOCUMENT = "online_document" type: MessageType = MessageType.TEXT """ diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 9a4939502e..0d8a4ee821 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -273,3 +273,8 @@ class AgentLogEvent(BaseAgentEvent): InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent + + +class DatasourceRunEvent(BaseModel): + status: str = Field(..., description="status") + result: dict[str, Any] = Field(..., description="result") diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index aaecc7b989..9a37f0e51c 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -127,7 +127,7 @@ class ToolNode(BaseNode[ToolNodeData]): inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, error=f"Failed to transform tool message: {str(e)}", - error_type=type(e).__name__, + error_type=type(e).__name__, PipelineGenerator.convert_to_event_strea ) ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index df9fea805c..a3978c9a5a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -15,6 +15,7 @@ import contexts from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( + DatasourceInvokeMessage, DatasourceProviderType, GetOnlineDocumentPagesResponse, GetWebsiteCrawlResponse, @@ -31,7 +32,7 @@ from core.workflow.entities.workflow_node_execution import ( ) from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.graph_engine.entities.event import DatasourceRunEvent, InNodeEvent from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event.event import RunCompletedEvent @@ -423,69 +424,11 @@ class RagPipelineService: return workflow_node_execution - def run_datasource_workflow_node_status( - self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool - ) -> dict: - """ - Run published workflow datasource - """ - if is_published: - # fetch published workflow by app_model - workflow = self.get_published_workflow(pipeline=pipeline) - else: - workflow = self.get_draft_workflow(pipeline=pipeline) - if not workflow: - raise ValueError("Workflow not initialized") - - # run draft workflow node - datasource_node_data = None - start_at = time.perf_counter() - datasource_nodes = workflow.graph_dict.get("nodes", []) - for datasource_node in datasource_nodes: - if datasource_node.get("id") == node_id: - datasource_node_data = datasource_node.get("data", {}) - break - if not datasource_node_data: - raise ValueError("Datasource node data not found") - - from core.datasource.datasource_manager import DatasourceManager - - datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", - datasource_name=datasource_node_data.get("datasource_name"), - tenant_id=pipeline.tenant_id, - datasource_type=DatasourceProviderType(datasource_type), - ) - datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_real_datasource_credentials( - tenant_id=pipeline.tenant_id, - provider=datasource_node_data.get('provider_name'), - plugin_id=datasource_node_data.get('plugin_id'), - ) - if credentials: - datasource_runtime.runtime.credentials = credentials[0].get("credentials") - match datasource_type: - - case DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( - user_id=account.id, - datasource_parameters={"job_id": job_id}, - provider_type=datasource_runtime.datasource_provider_type(), - ) - return { - "result": [result for result in website_crawl_result.result], - "job_id": website_crawl_result.result.job_id, - "status": website_crawl_result.result.status, - "provider_type": datasource_node_data.get("provider_type"), - } - case _: - raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") def run_datasource_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool - ) -> dict: + ) -> Generator[DatasourceRunEvent, None, None]: """ Run published workflow datasource """ @@ -532,29 +475,25 @@ class RagPipelineService: match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( + online_document_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_online_document_pages( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) - return { - "result": [page.model_dump() for page in online_document_result.result], - "provider_type": datasource_node_data.get("provider_type"), - } + for message in online_document_result: + yield DatasourceRunEvent( + status="success", + result=message.model_dump(), + ) case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + website_crawl_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_website_crawl( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) - return { - "result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [], - "job_id": website_crawl_result.result.job_id, - "status": website_crawl_result.result.status, - "provider_type": datasource_node_data.get("provider_type"), - } + yield from website_crawl_result case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")