mirror of
https://github.com/langgenius/dify.git
synced 2025-11-14 10:20:05 +00:00
r2
This commit is contained in:
parent
b2b95412b9
commit
41fef8a21f
@ -8,6 +8,7 @@ from flask_restful.inputs import int_range # type: ignore
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
|
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
@ -453,7 +454,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource):
|
|||||||
raise ValueError("missing datasource_type")
|
raise ValueError("missing datasource_type")
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
result = rag_pipeline_service.run_datasource_workflow_node(
|
return helper.compact_generate_response(rag_pipeline_service.run_datasource_workflow_node(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
user_inputs=inputs,
|
user_inputs=inputs,
|
||||||
@ -461,8 +462,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource):
|
|||||||
datasource_type=datasource_type,
|
datasource_type=datasource_type,
|
||||||
is_published=False
|
is_published=False
|
||||||
)
|
)
|
||||||
|
)
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class RagPipelinePublishedNodeRunApi(Resource):
|
class RagPipelinePublishedNodeRunApi(Resource):
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
from typing import Any, Generator
|
||||||
|
|
||||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
from core.datasource.entities.datasource_entities import (
|
from core.datasource.entities.datasource_entities import (
|
||||||
DatasourceEntity,
|
DatasourceEntity,
|
||||||
|
DatasourceInvokeMessage,
|
||||||
DatasourceProviderType,
|
DatasourceProviderType,
|
||||||
GetWebsiteCrawlResponse,
|
GetWebsiteCrawlResponse,
|
||||||
)
|
)
|
||||||
@ -36,7 +37,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
datasource_parameters: Mapping[str, Any],
|
datasource_parameters: Mapping[str, Any],
|
||||||
provider_type: str,
|
provider_type: str,
|
||||||
) -> GetWebsiteCrawlResponse:
|
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||||
manager = PluginDatasourceManager()
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
return manager.get_website_crawl(
|
return manager.get_website_crawl(
|
||||||
|
|||||||
@ -188,6 +188,8 @@ class ToolInvokeMessage(BaseModel):
|
|||||||
FILE = "file"
|
FILE = "file"
|
||||||
LOG = "log"
|
LOG = "log"
|
||||||
BLOB_CHUNK = "blob_chunk"
|
BLOB_CHUNK = "blob_chunk"
|
||||||
|
WEBSITE_CRAWL = "website_crawl"
|
||||||
|
ONLINE_DOCUMENT = "online_document"
|
||||||
|
|
||||||
type: MessageType = MessageType.TEXT
|
type: MessageType = MessageType.TEXT
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -273,3 +273,8 @@ class AgentLogEvent(BaseAgentEvent):
|
|||||||
|
|
||||||
|
|
||||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent
|
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceRunEvent(BaseModel):
|
||||||
|
status: str = Field(..., description="status")
|
||||||
|
result: dict[str, Any] = Field(..., description="result")
|
||||||
|
|||||||
@ -127,7 +127,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||||
error=f"Failed to transform tool message: {str(e)}",
|
error=f"Failed to transform tool message: {str(e)}",
|
||||||
error_type=type(e).__name__,
|
error_type=type(e).__name__, PipelineGenerator.convert_to_event_strea
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import contexts
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.datasource.entities.datasource_entities import (
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceInvokeMessage,
|
||||||
DatasourceProviderType,
|
DatasourceProviderType,
|
||||||
GetOnlineDocumentPagesResponse,
|
GetOnlineDocumentPagesResponse,
|
||||||
GetWebsiteCrawlResponse,
|
GetWebsiteCrawlResponse,
|
||||||
@ -31,7 +32,7 @@ from core.workflow.entities.workflow_node_execution import (
|
|||||||
)
|
)
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
from core.workflow.graph_engine.entities.event import DatasourceRunEvent, InNodeEvent
|
||||||
from core.workflow.nodes.base.node import BaseNode
|
from core.workflow.nodes.base.node import BaseNode
|
||||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||||
@ -423,69 +424,11 @@ class RagPipelineService:
|
|||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def run_datasource_workflow_node_status(
|
|
||||||
self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Run published workflow datasource
|
|
||||||
"""
|
|
||||||
if is_published:
|
|
||||||
# fetch published workflow by app_model
|
|
||||||
workflow = self.get_published_workflow(pipeline=pipeline)
|
|
||||||
else:
|
|
||||||
workflow = self.get_draft_workflow(pipeline=pipeline)
|
|
||||||
if not workflow:
|
|
||||||
raise ValueError("Workflow not initialized")
|
|
||||||
|
|
||||||
# run draft workflow node
|
|
||||||
datasource_node_data = None
|
|
||||||
start_at = time.perf_counter()
|
|
||||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
|
||||||
for datasource_node in datasource_nodes:
|
|
||||||
if datasource_node.get("id") == node_id:
|
|
||||||
datasource_node_data = datasource_node.get("data", {})
|
|
||||||
break
|
|
||||||
if not datasource_node_data:
|
|
||||||
raise ValueError("Datasource node data not found")
|
|
||||||
|
|
||||||
from core.datasource.datasource_manager import DatasourceManager
|
|
||||||
|
|
||||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
|
||||||
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
|
|
||||||
datasource_name=datasource_node_data.get("datasource_name"),
|
|
||||||
tenant_id=pipeline.tenant_id,
|
|
||||||
datasource_type=DatasourceProviderType(datasource_type),
|
|
||||||
)
|
|
||||||
datasource_provider_service = DatasourceProviderService()
|
|
||||||
credentials = datasource_provider_service.get_real_datasource_credentials(
|
|
||||||
tenant_id=pipeline.tenant_id,
|
|
||||||
provider=datasource_node_data.get('provider_name'),
|
|
||||||
plugin_id=datasource_node_data.get('plugin_id'),
|
|
||||||
)
|
|
||||||
if credentials:
|
|
||||||
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
|
||||||
match datasource_type:
|
|
||||||
|
|
||||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
|
||||||
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
|
||||||
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
|
||||||
user_id=account.id,
|
|
||||||
datasource_parameters={"job_id": job_id},
|
|
||||||
provider_type=datasource_runtime.datasource_provider_type(),
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"result": [result for result in website_crawl_result.result],
|
|
||||||
"job_id": website_crawl_result.result.job_id,
|
|
||||||
"status": website_crawl_result.result.status,
|
|
||||||
"provider_type": datasource_node_data.get("provider_type"),
|
|
||||||
}
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
|
|
||||||
|
|
||||||
def run_datasource_workflow_node(
|
def run_datasource_workflow_node(
|
||||||
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str,
|
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str,
|
||||||
is_published: bool
|
is_published: bool
|
||||||
) -> dict:
|
) -> Generator[DatasourceRunEvent, None, None]:
|
||||||
"""
|
"""
|
||||||
Run published workflow datasource
|
Run published workflow datasource
|
||||||
"""
|
"""
|
||||||
@ -532,29 +475,25 @@ class RagPipelineService:
|
|||||||
match datasource_type:
|
match datasource_type:
|
||||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||||
online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
|
online_document_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_online_document_pages(
|
||||||
user_id=account.id,
|
user_id=account.id,
|
||||||
datasource_parameters=user_inputs,
|
datasource_parameters=user_inputs,
|
||||||
provider_type=datasource_runtime.datasource_provider_type(),
|
provider_type=datasource_runtime.datasource_provider_type(),
|
||||||
)
|
)
|
||||||
return {
|
for message in online_document_result:
|
||||||
"result": [page.model_dump() for page in online_document_result.result],
|
yield DatasourceRunEvent(
|
||||||
"provider_type": datasource_node_data.get("provider_type"),
|
status="success",
|
||||||
}
|
result=message.model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||||
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
||||||
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
website_crawl_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_website_crawl(
|
||||||
user_id=account.id,
|
user_id=account.id,
|
||||||
datasource_parameters=user_inputs,
|
datasource_parameters=user_inputs,
|
||||||
provider_type=datasource_runtime.datasource_provider_type(),
|
provider_type=datasource_runtime.datasource_provider_type(),
|
||||||
)
|
)
|
||||||
return {
|
yield from website_crawl_result
|
||||||
"result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [],
|
|
||||||
"job_id": website_crawl_result.result.job_id,
|
|
||||||
"status": website_crawl_result.result.status,
|
|
||||||
"provider_type": datasource_node_data.get("provider_type"),
|
|
||||||
}
|
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
|
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user