diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index c0406940a7..bdd40fcabe 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -8,6 +8,7 @@ from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from models.model import EndUser import services from configs import dify_config from controllers.console import api @@ -44,7 +45,6 @@ from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService -from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError logger = logging.getLogger(__name__) @@ -243,6 +243,7 @@ class DraftRagPipelineRunApi(Resource): parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json") parser.add_argument("datasource_info", type=list, required=True, location="json") + parser.add_argument("start_node_id", type=str, required=True, location="json") args = parser.parse_args() try: @@ -313,13 +314,20 @@ class RagPipelineDatasourceNodeRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") args = parser.parse_args() inputs = args.get("inputs") + if inputs == None: + raise ValueError("missing inputs") rag_pipeline_service = RagPipelineService() result = rag_pipeline_service.run_datasource_workflow_node( - pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=args.get("datasource_type"), ) return result @@ -648,40 +656,6 @@ class RagPipelineByIdApi(Resource): return workflow - @setup_required - @login_required - @account_initialization_required - @get_rag_pipeline - def delete(self, pipeline: Pipeline, workflow_id: str): - """ - Delete workflow - """ - # Check permission - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): - raise Forbidden() - - rag_pipeline_service = RagPipelineService() - - # Create a session and manage the transaction - with Session(db.engine) as session: - try: - rag_pipeline_service.delete_workflow( - session=session, workflow_id=workflow_id, tenant_id=pipeline.tenant_id - ) - # Commit the transaction in the controller - session.commit() - except WorkflowInUseError as e: - abort(400, description=str(e)) - except DraftWorkflowDeletionError as e: - abort(400, description=str(e)) - except ValueError as e: - raise NotFound(str(e)) - - return None, 204 - class PublishedRagPipelineSecondStepApi(Resource): @setup_required @@ -695,8 +669,12 @@ class PublishedRagPipelineSecondStepApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - node_id = request.args.get("node_id", required=True, type=str) - + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id) return { @@ -716,7 +694,12 @@ class DraftRagPipelineSecondStepApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - node_id = request.args.get("node_id", required=True, type=str) + parser = reqparse.RequestParser() + parser.add_argument("node_id", type=str, required=True, location="args") + args = parser.parse_args() + node_id = args.get("node_id") + if not node_id: + raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id) @@ -777,9 +760,11 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource): run_id = str(run_id) rag_pipeline_service = RagPipelineService() + user = cast("Account | EndUser", current_user) node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions( pipeline=pipeline, run_id=run_id, + user=user, ) return {"data": node_executions} @@ -875,9 +860,9 @@ api.add_resource( ) api.add_resource( PublishedRagPipelineSecondStepApi, - "/rag/pipelines//workflows/published/processing/paramters", + "/rag/pipelines//workflows/published/processing/parameters", ) api.add_resource( DraftRagPipelineSecondStepApi, - "/rag/pipelines//workflows/draft/processing/paramters", + "/rag/pipelines//workflows/draft/processing/parameters", ) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 1e880c700c..c1aa9747d2 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -99,6 +99,7 @@ class PipelineGenerator(BaseAppGenerator): ) inputs: Mapping[str, Any] = args["inputs"] + start_node_id: str = args["start_node_id"] datasource_type: str = args["datasource_type"] datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) @@ -118,7 +119,7 @@ class PipelineGenerator(BaseAppGenerator): position=position, account=user, batch=batch, - document_form=pipeline.dataset.doc_form, + document_form=pipeline.dataset.chunk_structure, ) db.session.add(document) db.session.commit() @@ -231,7 +232,7 @@ class PipelineGenerator(BaseAppGenerator): def single_iteration_generate( self, - app_model: App, + pipeline: Pipeline, workflow: Workflow, node_id: str, user: Account | EndUser, @@ -255,7 +256,7 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("inputs is required") # convert to app config - app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index 1544270d7a..bae39dc8c7 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin -from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType from core.entities.provider_entities import ProviderConfig from core.plugin.impl.tool import PluginToolManager @@ -11,9 +10,11 @@ from core.tools.errors import ToolProviderCredentialValidationError class DatasourcePluginProviderController(ABC): entity: DatasourceProviderEntityWithPlugin + tenant_id: str - def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None: + def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None: self.entity = entity + self.tenant_id = tenant_id @property def need_credentials(self) -> bool: @@ -51,21 +52,6 @@ class DatasourcePluginProviderController(ABC): """ pass - def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore - """ - get all datasources - """ - return [ - DatasourcePlugin( - entity=datasource_entity, - runtime=DatasourceRuntime(tenant_id=self.tenant_id), - tenant_id=self.tenant_id, - icon=self.entity.identity.icon, - plugin_unique_identifier=self.plugin_unique_identifier, - ) - for datasource_entity in self.entity.datasources - ] - def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ validate the format of the credentials of the provider and set the default value if needed diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index c865b557f9..8c74aeb320 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -6,7 +6,11 @@ import contexts from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.entities.common_entities import I18nObject +from core.datasource.entities.datasource_entities import DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError +from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController +from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController +from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController from core.plugin.impl.datasource import PluginDatasourceManager logger = logging.getLogger(__name__) @@ -19,7 +23,9 @@ class DatasourceManager: _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} @classmethod - def get_datasource_plugin_provider(cls, provider: str, tenant_id: str) -> DatasourcePluginProviderController: + def get_datasource_plugin_provider( + cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType + ) -> DatasourcePluginProviderController: """ get the datasource plugin provider """ @@ -40,12 +46,30 @@ class DatasourceManager: if not provider_entity: raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") - controller = DatasourcePluginProviderController( - entity=provider_entity.declaration, - plugin_id=provider_entity.plugin_id, - plugin_unique_identifier=provider_entity.plugin_unique_identifier, - tenant_id=tenant_id, - ) + match (datasource_type): + case DatasourceProviderType.ONLINE_DOCUMENT: + controller = OnlineDocumentDatasourcePluginProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + case DatasourceProviderType.WEBSITE_CRAWL: + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + case DatasourceProviderType.LOCAL_FILE: + controller = LocalFileDatasourcePluginProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + case _: + raise ValueError(f"Unsupported datasource type: {datasource_type}") datasource_plugin_providers[provider] = controller @@ -57,6 +81,7 @@ class DatasourceManager: provider_id: str, datasource_name: str, tenant_id: str, + datasource_type: DatasourceProviderType, ) -> DatasourcePlugin: """ get the datasource runtime @@ -68,21 +93,10 @@ class DatasourceManager: :return: the datasource plugin """ - return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name) + return cls.get_datasource_plugin_provider( + provider_id, + tenant_id, + datasource_type, + ).get_datasource(datasource_name) + - @classmethod - def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: - """ - list all the datasource providers - """ - manager = PluginDatasourceManager() - provider_entities = manager.fetch_datasource_providers(tenant_id) - return [ - DatasourcePluginProviderController( - entity=provider.declaration, - plugin_id=provider.plugin_id, - plugin_unique_identifier=provider.plugin_unique_identifier, - tenant_id=tenant_id, - ) - for provider in provider_entities - ] diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 7b3fadfee8..e9f73d3c18 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -251,7 +251,7 @@ class GetOnlineDocumentPageContentRequest(BaseModel): Get online document page content request """ - online_document_info_list: list[OnlineDocumentInfo] + online_document_info: OnlineDocumentInfo class OnlineDocumentPageContent(BaseModel): @@ -259,6 +259,7 @@ class OnlineDocumentPageContent(BaseModel): Online document page content """ + workspace_id: str = Field(..., description="The workspace id") page_id: str = Field(..., description="The page id") content: str = Field(..., description="The content of the page") @@ -268,7 +269,7 @@ class GetOnlineDocumentPageContentResponse(BaseModel): Get online document page content response """ - result: list[OnlineDocumentPageContent] + result: OnlineDocumentPageContent class GetWebsiteCrawlRequest(BaseModel): @@ -286,7 +287,7 @@ class WebSiteInfo(BaseModel): """ source_url: str = Field(..., description="The url of the website") - markdown: str = Field(..., description="The markdown of the website") + content: str = Field(..., description="The content of the website") title: str = Field(..., description="The title of the website") description: str = Field(..., description="The description of the website") @@ -296,4 +297,4 @@ class GetWebsiteCrawlResponse(BaseModel): Get website crawl response """ - result: list[WebSiteInfo] + result: WebSiteInfo diff --git a/api/core/datasource/local_file/local_file_plugin.py b/api/core/datasource/local_file/local_file_plugin.py index a9dced1186..45f4777f44 100644 --- a/api/core/datasource/local_file/local_file_plugin.py +++ b/api/core/datasource/local_file/local_file_plugin.py @@ -26,12 +26,3 @@ class LocalFileDatasourcePlugin(DatasourcePlugin): def datasource_provider_type(self) -> DatasourceProviderType: return DatasourceProviderType.LOCAL_FILE - - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": - return DatasourcePlugin( - entity=self.entity, - runtime=runtime, - tenant_id=self.tenant_id, - icon=self.icon, - plugin_unique_identifier=self.plugin_unique_identifier, - ) diff --git a/api/core/datasource/local_file/local_file_provider.py b/api/core/datasource/local_file/local_file_provider.py index 79f885dda5..b2b6f51dd3 100644 --- a/api/core/datasource/local_file/local_file_provider.py +++ b/api/core/datasource/local_file/local_file_provider.py @@ -8,15 +8,13 @@ from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlug class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController): entity: DatasourceProviderEntityWithPlugin - tenant_id: str plugin_id: str plugin_unique_identifier: str def __init__( self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str ) -> None: - super().__init__(entity) - self.tenant_id = tenant_id + super().__init__(entity, tenant_id) self.plugin_id = plugin_id self.plugin_unique_identifier = plugin_unique_identifier diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index 197d85ef59..07d7a25160 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -69,12 +69,3 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): def datasource_provider_type(self) -> DatasourceProviderType: return DatasourceProviderType.ONLINE_DOCUMENT - - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": - return DatasourcePlugin( - entity=self.entity, - runtime=runtime, - tenant_id=self.tenant_id, - icon=self.icon, - plugin_unique_identifier=self.plugin_unique_identifier, - ) diff --git a/api/core/datasource/online_document/online_document_provider.py b/api/core/datasource/online_document/online_document_provider.py index 06572880b8..a128b479f4 100644 --- a/api/core/datasource/online_document/online_document_provider.py +++ b/api/core/datasource/online_document/online_document_provider.py @@ -1,20 +1,18 @@ -from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController): entity: DatasourceProviderEntityWithPlugin - tenant_id: str plugin_id: str plugin_unique_identifier: str def __init__( self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str ) -> None: - super().__init__(entity) - self.tenant_id = tenant_id + super().__init__(entity, tenant_id) self.plugin_id = plugin_id self.plugin_unique_identifier = plugin_unique_identifier @@ -25,7 +23,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC """ return DatasourceProviderType.ONLINE_DOCUMENT - def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore + def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore """ return datasource with given name """ @@ -41,7 +39,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") - return DatasourcePlugin( + return OnlineDocumentDatasourcePlugin( entity=datasource_entity, runtime=DatasourceRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index 8454d1636e..5f92551198 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import ( GetWebsiteCrawlResponse, ) from core.plugin.impl.datasource import PluginDatasourceManager -from core.plugin.utils.converter import convert_parameters_to_plugin_format class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): @@ -38,9 +37,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): ) -> GetWebsiteCrawlResponse: manager = PluginDatasourceManager() - datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) - - return manager.invoke_first_step( + return manager.get_website_crawl( tenant_id=self.tenant_id, user_id=user_id, datasource_provider=self.entity.identity.provider, @@ -52,12 +49,3 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): def datasource_provider_type(self) -> DatasourceProviderType: return DatasourceProviderType.WEBSITE_CRAWL - - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": - return DatasourcePlugin( - entity=self.entity, - runtime=runtime, - tenant_id=self.tenant_id, - icon=self.icon, - plugin_unique_identifier=self.plugin_unique_identifier, - ) diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 9c6bcdb7c2..95f05fcee0 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -1,20 +1,18 @@ -from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController): entity: DatasourceProviderEntityWithPlugin - tenant_id: str plugin_id: str plugin_unique_identifier: str def __init__( self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str ) -> None: - super().__init__(entity) - self.tenant_id = tenant_id + super().__init__(entity, tenant_id) self.plugin_id = plugin_id self.plugin_unique_identifier = plugin_unique_identifier @@ -25,7 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon """ return DatasourceProviderType.WEBSITE_CRAWL - def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore + def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore """ return datasource with given name """ @@ -41,7 +39,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") - return DatasourcePlugin( + return WebsiteCrawlDatasourcePlugin( entity=datasource_entity, runtime=DatasourceRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 9884d93e9d..37375f4a71 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -7,7 +7,6 @@ from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator from core.entities.provider_entities import ProviderConfig -from core.plugin.entities.oauth import OAuthSchema from core.plugin.entities.parameters import ( PluginParameter, PluginParameterOption, @@ -350,7 +349,6 @@ class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity plugin_id: Optional[str] = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) - oauth_schema: Optional[OAuthSchema] = Field(default=None, description="The oauth schema of the tool provider") class ToolProviderEntityWithPlugin(ToolProviderEntity): diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index d25784b781..612c5a5a74 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -4,6 +4,9 @@ from typing import Any, cast from core.datasource.entities.datasource_entities import ( DatasourceParameter, DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + GetOnlineDocumentPageContentResponse, + GetWebsiteCrawlRequest, GetWebsiteCrawlResponse, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin @@ -54,6 +57,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): provider_id=node_data.provider_id, datasource_name=node_data.datasource_name, tenant_id=self.tenant_id, + datasource_type=DatasourceProviderType(node_data.provider_type), ) except DatasourceNodeError as e: yield RunCompletedEvent( @@ -82,38 +86,43 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) try: - # TODO: handle result if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - result = datasource_runtime._get_online_document_page_content( - user_id=self.user_id, - datasource_parameters=parameters, - provider_type=node_data.provider_type, + online_document_result: GetOnlineDocumentPageContentResponse = ( + datasource_runtime._get_online_document_page_content( + user_id=self.user_id, + datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), + provider_type=datasource_runtime.datasource_provider_type(), + ) ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "result": result.result.model_dump(), - "datasource_type": datasource_runtime.datasource_provider_type, - }, + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "online_document": online_document_result.result.model_dump(), + "datasource_type": datasource_runtime.datasource_provider_type, + }, + ) ) elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( user_id=self.user_id, - datasource_parameters=parameters, - provider_type=node_data.provider_type, + datasource_parameters=GetWebsiteCrawlRequest(**parameters), + provider_type=datasource_runtime.datasource_provider_type(), ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "result": result.result.model_dump(), - "datasource_type": datasource_runtime.datasource_provider_type, - }, + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "website": website_crawl_result.result.model_dump(), + "datasource_type": datasource_runtime.datasource_provider_type, + }, + ) ) else: raise DatasourceNodeError( diff --git a/api/models/workflow.py b/api/models/workflow.py index 13ef16442c..b428b1e5db 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -360,7 +360,7 @@ class Workflow(Base): ) @property - def rag_pipeline_variables(self) -> Sequence[Variable]: + def rag_pipeline_variables(self) -> list[dict]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._rag_pipeline_variables is None: self._rag_pipeline_variables = "{}" diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 089519dd0d..14594be351 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -2,12 +2,11 @@ from collections.abc import Mapping from typing import Any, Union from configs import dify_config -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from models.dataset import Pipeline -from models.model import Account, App, AppMode, EndUser +from models.model import Account, App, EndUser from models.workflow import Workflow from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -57,23 +56,15 @@ class PipelineGenerateService: return max_active_requests @classmethod - def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT.value: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - AdvancedChatAppGenerator().single_iteration_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming - ) + def generate_single_iteration( + cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True + ): + workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER) + return PipelineGenerator.convert_to_event_stream( + PipelineGenerator().single_iteration_generate( + pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) - elif app_model.mode == AppMode.WORKFLOW.value: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - WorkflowAppGenerator().single_iteration_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming - ) - ) - else: - raise ValueError(f"Invalid app mode {app_model.mode}") + ) @classmethod def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True): diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a0a890aee7..bf582b9d27 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -3,7 +3,7 @@ import threading import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime -from typing import Any, Optional +from typing import Any, Optional, cast from uuid import uuid4 from flask_login import current_user @@ -12,6 +12,9 @@ from sqlalchemy.orm import Session import contexts from configs import dify_config +from core.datasource.entities.datasource_entities import DatasourceProviderType, GetOnlineDocumentPagesRequest, GetOnlineDocumentPagesResponse, GetWebsiteCrawlRequest, GetWebsiteCrawlResponse +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable @@ -30,6 +33,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.model import EndUser from models.workflow import ( Workflow, WorkflowNodeExecution, @@ -394,8 +398,8 @@ class RagPipelineService: return workflow_node_execution def run_datasource_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account - ) -> WorkflowNodeExecution: + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str + ) -> dict: """ Run published workflow datasource """ @@ -416,17 +420,36 @@ class RagPipelineService: provider_id=datasource_node_data.get("provider_id"), datasource_name=datasource_node_data.get("datasource_name"), tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), ) - result = datasource_runtime._invoke_first_step( - inputs=user_inputs, - provider_type=datasource_node_data.get("provider_type"), - user_id=account.id, - ) + if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: GetOnlineDocumentPagesResponse = ( + datasource_runtime._get_online_document_pages( + user_id=account.id, + datasource_parameters=GetOnlineDocumentPagesRequest(tenant_id=pipeline.tenant_id), + provider_type=datasource_runtime.datasource_provider_type(), + ) + ) + return { + "result": [page.model_dump() for page in online_document_result.result], + "provider_type": datasource_node_data.get("provider_type"), + } + + elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + user_id=account.id, + datasource_parameters=GetWebsiteCrawlRequest(**user_inputs), + provider_type=datasource_runtime.datasource_provider_type(), + ) + return { + "result": website_crawl_result.result.model_dump(), + "provider_type": datasource_node_data.get("provider_type"), + } + else: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") - return { - "result": result, - "provider_type": datasource_node_data.get("provider_type"), - } def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] @@ -587,7 +610,7 @@ class RagPipelineService: return workflow - def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: + def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """ Get second step parameters of rag pipeline """ @@ -599,7 +622,7 @@ class RagPipelineService: # get second step node rag_pipeline_variables = workflow.rag_pipeline_variables if not rag_pipeline_variables: - return {} + return [] # get datasource provider datasource_provider_variables = [ @@ -609,7 +632,7 @@ class RagPipelineService: ] return datasource_provider_variables - def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: + def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """ Get second step parameters of rag pipeline """ @@ -621,7 +644,7 @@ class RagPipelineService: # get second step node rag_pipeline_variables = workflow.rag_pipeline_variables if not rag_pipeline_variables: - return {} + return [] # get datasource provider datasource_provider_variables = [ @@ -702,6 +725,7 @@ class RagPipelineService: self, pipeline: Pipeline, run_id: str, + user: Account | EndUser, ) -> list[WorkflowNodeExecution]: """ Get workflow run node execution list @@ -716,11 +740,16 @@ class RagPipelineService: # Use the repository to get the node execution repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, tenant_id=pipeline.tenant_id, app_id=pipeline.id + session_factory=db.engine, + app_id=pipeline.id, + user=user, + triggered_from=None ) # Use the repository to get the node executions with ordering order_config = OrderConfig(order_by=["index"], order_direction="desc") node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + # Convert domain models to database models + workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] - return list(node_executions) + return workflow_node_executions