diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 23d402f914..93976bd6f5 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -1,6 +1,5 @@ import logging -import yaml from flask import request from flask_restful import Resource, reqparse from sqlalchemy.orm import Session diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 8d68c80c81..b9a0c1f150 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -15,7 +15,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolLabelEnum, ToolInvokeMessage +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum class DatasourceProviderType(enum.StrEnum): @@ -207,12 +207,6 @@ class DatasourceInvokeFrom(Enum): RAG_PIPELINE = "rag_pipeline" -class GetOnlineDocumentPagesRequest(BaseModel): - """ - Get online document pages request - """ - - class OnlineDocumentPage(BaseModel): """ Online document page @@ -237,7 +231,7 @@ class OnlineDocumentInfo(BaseModel): pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") -class GetOnlineDocumentPagesResponse(BaseModel): +class OnlineDocumentPagesMessage(BaseModel): """ Get online document pages response """ @@ -290,17 +284,19 @@ class WebSiteInfo(BaseModel): """ Website info """ - job_id: str = Field(..., description="The job id") - status: str = Field(..., description="The status of the job") + status: Optional[str] = Field(..., description="crawl job status") web_info_list: Optional[list[WebSiteInfoDetail]] = [] + total: Optional[int] = Field(default=0, description="The total number of websites") + completed: Optional[int] = Field(default=0, description="The number of completed websites") - -class GetWebsiteCrawlResponse(BaseModel): +class WebsiteCrawlMessage(BaseModel): """ Get website crawl response """ + result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0) - result: WebSiteInfo = WebSiteInfo(job_id="", status="", web_info_list=[]) +class DatasourceMessage(ToolInvokeMessage): + pass class DatasourceInvokeMessage(ToolInvokeMessage): @@ -326,4 +322,4 @@ class DatasourceInvokeMessage(ToolInvokeMessage): workspace_name: str = Field(..., description="The workspace name") workspace_icon: str = Field(..., description="The workspace icon") total: int = Field(..., description="The total number of documents") - pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") \ No newline at end of file + pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index 2ab60cae1e..db73d9a64b 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,5 +1,5 @@ -from collections.abc import Mapping -from typing import Any, Generator +from collections.abc import Generator, Mapping +from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime @@ -8,8 +8,7 @@ from core.datasource.entities.datasource_entities import ( DatasourceInvokeMessage, DatasourceProviderType, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, - GetOnlineDocumentPagesResponse, + OnlineDocumentPagesMessage, ) from core.plugin.impl.datasource import PluginDatasourceManager @@ -39,7 +38,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[OnlineDocumentPagesMessage, None, None]: manager = PluginDatasourceManager() return manager.get_online_document_pages( diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index 87612fff44..1625670165 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -1,5 +1,5 @@ -from collections.abc import Mapping -from typing import Any, Generator +from collections.abc import Generator, Mapping +from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime @@ -7,7 +7,7 @@ from core.datasource.entities.datasource_entities import ( DatasourceEntity, DatasourceInvokeMessage, DatasourceProviderType, - GetWebsiteCrawlResponse, + WebsiteCrawlMessage, ) from core.plugin.impl.datasource import PluginDatasourceManager @@ -32,12 +32,12 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - def _get_website_crawl( + def get_website_crawl( self, user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[WebsiteCrawlMessage, None, None]: manager = PluginDatasourceManager() return manager.get_website_crawl( diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index a65efb750e..0567f1a480 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -1,4 +1,3 @@ -from core.datasource.__base import datasource_provider from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 83b1a5760b..06ee00c688 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,12 +1,11 @@ -from collections.abc import Mapping -from typing import Any, Generator +from collections.abc import Generator, Mapping +from typing import Any from core.datasource.entities.datasource_entities import ( DatasourceInvokeMessage, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, - GetOnlineDocumentPagesResponse, - GetWebsiteCrawlResponse, + OnlineDocumentPagesMessage, + WebsiteCrawlMessage, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( @@ -94,17 +93,17 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[WebsiteCrawlMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ datasource_provider_id = GenericProviderID(datasource_provider) - response = self._request_with_plugin_daemon_response_stream( + return self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl", - DatasourceInvokeMessage, + WebsiteCrawlMessage, data={ "user_id": user_id, "data": { @@ -119,7 +118,6 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - yield from response def get_online_document_pages( self, @@ -130,7 +128,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[OnlineDocumentPagesMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -140,7 +138,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", - DatasourceInvokeMessage, + OnlineDocumentPagesMessage, data={ "user_id": user_id, "data": { diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index bd4a6e3a56..240eeeb725 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,5 +1,5 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Generator, cast +from collections.abc import Generator, Mapping, Sequence +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -9,7 +9,6 @@ from core.datasource.entities.datasource_entities import ( DatasourceParameter, DatasourceProviderType, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index c63d837106..49c8ec1e69 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,7 +1,7 @@ import datetime import logging -from collections.abc import Mapping import time +from collections.abc import Mapping from typing import Any, cast from sqlalchemy import func diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a3978c9a5a..1d61677bea 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -17,12 +17,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( DatasourceInvokeMessage, DatasourceProviderType, - GetOnlineDocumentPagesResponse, - GetWebsiteCrawlResponse, + OnlineDocumentPagesMessage, + WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin -from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult @@ -44,14 +43,14 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.enums import WorkflowRunTriggeredFrom from models.model import EndUser -from models.oauth import DatasourceProvider from models.workflow import ( Workflow, + WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, - WorkflowType, WorkflowNodeExecutionModel, + WorkflowType, ) from services.dataset_service import DatasetService from services.datasource_provider_service import DatasourceProviderService @@ -424,6 +423,65 @@ class RagPipelineService: return workflow_node_execution + def run_datasource_workflow_node_status( + self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool + ) -> dict: + """ + Run published workflow datasource + """ + if is_published: + # fetch published workflow by app_model + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + datasource_node_data = None + start_at = time.perf_counter() + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") + + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get('provider_name'), + plugin_id=datasource_node_data.get('plugin_id'), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + + case DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_results: list[WebsiteCrawlMessage] = [] + for website_message in datasource_runtime.get_website_crawl( + user_id=account.id, + datasource_parameters={"job_id": job_id}, + provider_type=datasource_runtime.datasource_provider_type(), + ): + website_crawl_results.append(website_message) + return { + "result": [result for result in website_crawl_results.result], + "status": website_crawl_results.result.status, + "provider_type": datasource_node_data.get("provider_type"), + } + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") def run_datasource_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str,