This commit is contained in:
jyong 2025-06-16 13:48:43 +08:00
parent b2b95412b9
commit 41fef8a21f
6 changed files with 25 additions and 78 deletions

View File

@ -8,6 +8,7 @@ from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
import services import services
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import api
@ -453,7 +454,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource):
raise ValueError("missing datasource_type") raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.run_datasource_workflow_node( return helper.compact_generate_response(rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, pipeline=pipeline,
node_id=node_id, node_id=node_id,
user_inputs=inputs, user_inputs=inputs,
@ -461,8 +462,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource):
datasource_type=datasource_type, datasource_type=datasource_type,
is_published=False is_published=False
) )
)
return result
class RagPipelinePublishedNodeRunApi(Resource): class RagPipelinePublishedNodeRunApi(Resource):

View File

@ -1,10 +1,11 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any, Generator
from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceEntity, DatasourceEntity,
DatasourceInvokeMessage,
DatasourceProviderType, DatasourceProviderType,
GetWebsiteCrawlResponse, GetWebsiteCrawlResponse,
) )
@ -36,7 +37,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
user_id: str, user_id: str,
datasource_parameters: Mapping[str, Any], datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> GetWebsiteCrawlResponse: ) -> Generator[DatasourceInvokeMessage, None, None]:
manager = PluginDatasourceManager() manager = PluginDatasourceManager()
return manager.get_website_crawl( return manager.get_website_crawl(

View File

@ -188,6 +188,8 @@ class ToolInvokeMessage(BaseModel):
FILE = "file" FILE = "file"
LOG = "log" LOG = "log"
BLOB_CHUNK = "blob_chunk" BLOB_CHUNK = "blob_chunk"
WEBSITE_CRAWL = "website_crawl"
ONLINE_DOCUMENT = "online_document"
type: MessageType = MessageType.TEXT type: MessageType = MessageType.TEXT
""" """

View File

@ -273,3 +273,8 @@ class AgentLogEvent(BaseAgentEvent):
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent
class DatasourceRunEvent(BaseModel):
status: str = Field(..., description="status")
result: dict[str, Any] = Field(..., description="result")

View File

@ -127,7 +127,7 @@ class ToolNode(BaseNode[ToolNodeData]):
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to transform tool message: {str(e)}", error=f"Failed to transform tool message: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__, PipelineGenerator.convert_to_event_strea
) )
) )

View File

@ -15,6 +15,7 @@ import contexts
from configs import dify_config from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceInvokeMessage,
DatasourceProviderType, DatasourceProviderType,
GetOnlineDocumentPagesResponse, GetOnlineDocumentPagesResponse,
GetWebsiteCrawlResponse, GetWebsiteCrawlResponse,
@ -31,7 +32,7 @@ from core.workflow.entities.workflow_node_execution import (
) )
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.graph_engine.entities.event import DatasourceRunEvent, InNodeEvent
from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.event.event import RunCompletedEvent
@ -423,69 +424,11 @@ class RagPipelineService:
return workflow_node_execution return workflow_node_execution
def run_datasource_workflow_node_status(
self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool
) -> dict:
"""
Run published workflow datasource
"""
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
datasource_node_data = None
start_at = time.perf_counter()
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get('provider_name'),
plugin_id=datasource_node_data.get('plugin_id'),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
user_id=account.id,
datasource_parameters={"job_id": job_id},
provider_type=datasource_runtime.datasource_provider_type(),
)
return {
"result": [result for result in website_crawl_result.result],
"job_id": website_crawl_result.result.job_id,
"status": website_crawl_result.result.status,
"provider_type": datasource_node_data.get("provider_type"),
}
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
def run_datasource_workflow_node( def run_datasource_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str,
is_published: bool is_published: bool
) -> dict: ) -> Generator[DatasourceRunEvent, None, None]:
""" """
Run published workflow datasource Run published workflow datasource
""" """
@ -532,29 +475,25 @@ class RagPipelineService:
match datasource_type: match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT: case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( online_document_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_online_document_pages(
user_id=account.id, user_id=account.id,
datasource_parameters=user_inputs, datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) )
return { for message in online_document_result:
"result": [page.model_dump() for page in online_document_result.result], yield DatasourceRunEvent(
"provider_type": datasource_node_data.get("provider_type"), status="success",
} result=message.model_dump(),
)
case DatasourceProviderType.WEBSITE_CRAWL: case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( website_crawl_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_website_crawl(
user_id=account.id, user_id=account.id,
datasource_parameters=user_inputs, datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) )
return { yield from website_crawl_result
"result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [],
"job_id": website_crawl_result.result.job_id,
"status": website_crawl_result.result.status,
"provider_type": datasource_node_data.get("provider_type"),
}
case _: case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")