diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index d072b8541b..af67eaf761 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -213,10 +213,11 @@ class OnlineDocumentPage(BaseModel): """ page_id: str = Field(..., description="The page id") - page_title: str = Field(..., description="The page title") + page_name: str = Field(..., description="The page title") page_icon: Optional[dict] = Field(None, description="The page icon") type: str = Field(..., description="The type of the page") last_edited_time: str = Field(..., description="The last edited time") + parent_id: Optional[str] = Field(None, description="The parent page id") class OnlineDocumentInfo(BaseModel): diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 66469b43b4..f2539de8f5 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -135,7 +135,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_provider_id = GenericProviderID(datasource_provider) - response = self._request_with_plugin_daemon_response_stream( + return self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", OnlineDocumentPagesMessage, @@ -153,7 +153,6 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - yield from response def get_online_document_page_content( self, diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index 59a470c35c..4acb558531 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -11,11 +11,16 @@ class DatasourceStreamEvent(Enum): """ PROCESSING = "datasource_processing" COMPLETED = "datasource_completed" + ERROR = "datasource_error" class BaseDatasourceEvent(BaseModel): pass +class DatasourceErrorEvent(BaseDatasourceEvent): + event: str = DatasourceStreamEvent.ERROR.value + error: str = Field(..., description="error message") + class DatasourceCompletedEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.COMPLETED.value data: Mapping[str,Any] | list = Field(..., description="result") diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 909df456d4..1f9337665b 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1,4 +1,5 @@ import json +import logging import re import threading import time @@ -21,7 +22,12 @@ from core.datasource.entities.datasource_entities import ( ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin -from core.rag.entities.event import BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceProcessingEvent +from core.rag.entities.event import ( + BaseDatasourceEvent, + DatasourceCompletedEvent, + DatasourceErrorEvent, + DatasourceProcessingEvent, +) from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult @@ -61,6 +67,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import ( from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory +logger = logging.getLogger(__name__) class RagPipelineService: @classmethod @@ -430,93 +437,112 @@ class RagPipelineService: """ Run published workflow datasource """ - if is_published: - # fetch published workflow by app_model - workflow = self.get_published_workflow(pipeline=pipeline) - else: - workflow = self.get_draft_workflow(pipeline=pipeline) - if not workflow: - raise ValueError("Workflow not initialized") + try: + if is_published: + # fetch published workflow by app_model + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") - # run draft workflow node - datasource_node_data = None - start_at = time.perf_counter() - datasource_nodes = workflow.graph_dict.get("nodes", []) - for datasource_node in datasource_nodes: - if datasource_node.get("id") == node_id: - datasource_node_data = datasource_node.get("data", {}) - break - if not datasource_node_data: - raise ValueError("Datasource node data not found") + # run draft workflow node + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") - datasource_parameters = datasource_node_data.get("datasource_parameters", {}) - for key, value in datasource_parameters.items(): - if not user_inputs.get(key): - user_inputs[key] = value["value"] + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + for key, value in datasource_parameters.items(): + if not user_inputs.get(key): + user_inputs[key] = value["value"] - from core.datasource.datasource_manager import DatasourceManager + from core.datasource.datasource_manager import DatasourceManager - datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", - datasource_name=datasource_node_data.get("datasource_name"), - tenant_id=pipeline.tenant_id, - datasource_type=DatasourceProviderType(datasource_type), - ) - datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_real_datasource_credentials( - tenant_id=pipeline.tenant_id, - provider=datasource_node_data.get("provider_name"), - plugin_id=datasource_node_data.get("plugin_id"), - ) - if credentials: - datasource_runtime.runtime.credentials = credentials[0].get("credentials") - match datasource_type: - case DatasourceProviderType.ONLINE_DOCUMENT: - datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( - datasource_runtime.get_online_document_pages( + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( + datasource_runtime.get_online_document_pages( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + ) + start_time = time.time() + start_event = DatasourceProcessingEvent( + total=0, + completed=0, + ) + yield start_event.model_dump() + try: + for message in online_document_result: + end_time = time.time() + online_document_event = DatasourceCompletedEvent( + data=message.result, + time_consuming=round(end_time - start_time, 2) + ) + yield online_document_event.model_dump() + except Exception as e: + logger.exception("Error during online document.") + yield DatasourceErrorEvent( + error=str(e) + ).model_dump() + case DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = ( + datasource_runtime.get_website_crawl( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), - ) - ) - start_time = time.time() - for message in online_document_result: - end_time = time.time() - online_document_event = DatasourceCompletedEvent( - data=message.result, - time_consuming=round(end_time - start_time, 2) - ) - yield online_document_event.model_dump() - - case DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - start_time = time.time() - try: - for message in website_crawl_result: - end_time = time.time() - if message.result.status == "completed": - crawl_event = DatasourceCompletedEvent( - data=message.result.web_info_list, - total=message.result.total, - completed=message.result.completed, - time_consuming=round(end_time - start_time, 2) - ) - else: - crawl_event = DatasourceProcessingEvent( - total=message.result.total, - completed=message.result.completed, - ) - yield crawl_event.model_dump() - except Exception as e: - print(str(e)) - case _: - raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + )) + start_time = time.time() + try: + for message in website_crawl_result: + end_time = time.time() + if message.result.status == "completed": + crawl_event = DatasourceCompletedEvent( + data=message.result.web_info_list, + total=message.result.total, + completed=message.result.completed, + time_consuming=round(end_time - start_time, 2) + ) + else: + crawl_event = DatasourceProcessingEvent( + total=message.result.total, + completed=message.result.completed, + ) + yield crawl_event.model_dump() + except Exception as e: + logger.exception("Error during website crawl.") + yield DatasourceErrorEvent( + error=str(e) + ).model_dump() + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + except Exception as e: + logger.exception("Error in run_datasource_workflow_node.") + yield DatasourceErrorEvent( + error=str(e) + ).model_dump() def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]