This commit is contained in:
jyong 2025-05-23 15:55:41 +08:00
parent a49942b949
commit 64d997fdb0
16 changed files with 176 additions and 198 deletions

View File

@ -8,6 +8,7 @@ from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from models.model import EndUser
import services
from configs import dify_config
from controllers.console import api
@ -44,7 +45,6 @@ from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
logger = logging.getLogger(__name__)
@ -243,6 +243,7 @@ class DraftRagPipelineRunApi(Resource):
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
args = parser.parse_args()
try:
@ -313,13 +314,20 @@ class RagPipelineDatasourceNodeRunApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=args.get("datasource_type"),
)
return result
@ -648,40 +656,6 @@ class RagPipelineByIdApi(Resource):
return workflow
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def delete(self, pipeline: Pipeline, workflow_id: str):
"""
Delete workflow
"""
# Check permission
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
rag_pipeline_service = RagPipelineService()
# Create a session and manage the transaction
with Session(db.engine) as session:
try:
rag_pipeline_service.delete_workflow(
session=session, workflow_id=workflow_id, tenant_id=pipeline.tenant_id
)
# Commit the transaction in the controller
session.commit()
except WorkflowInUseError as e:
abort(400, description=str(e))
except DraftWorkflowDeletionError as e:
abort(400, description=str(e))
except ValueError as e:
raise NotFound(str(e))
return None, 204
class PublishedRagPipelineSecondStepApi(Resource):
@setup_required
@ -695,8 +669,12 @@ class PublishedRagPipelineSecondStepApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
node_id = request.args.get("node_id", required=True, type=str)
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id)
return {
@ -716,7 +694,12 @@ class DraftRagPipelineSecondStepApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
node_id = request.args.get("node_id", required=True, type=str)
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id)
@ -777,9 +760,11 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
run_id = str(run_id)
rag_pipeline_service = RagPipelineService()
user = cast("Account | EndUser", current_user)
node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
pipeline=pipeline,
run_id=run_id,
user=user,
)
return {"data": node_executions}
@ -875,9 +860,9 @@ api.add_resource(
)
api.add_resource(
PublishedRagPipelineSecondStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/paramters",
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters",
)
api.add_resource(
DraftRagPipelineSecondStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/paramters",
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters",
)

View File

@ -99,6 +99,7 @@ class PipelineGenerator(BaseAppGenerator):
)
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
@ -118,7 +119,7 @@ class PipelineGenerator(BaseAppGenerator):
position=position,
account=user,
batch=batch,
document_form=pipeline.dataset.doc_form,
document_form=pipeline.dataset.chunk_structure,
)
db.session.add(document)
db.session.commit()
@ -231,7 +232,7 @@ class PipelineGenerator(BaseAppGenerator):
def single_iteration_generate(
self,
app_model: App,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
@ -255,7 +256,7 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(

View File

@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.entities.provider_entities import ProviderConfig
from core.plugin.impl.tool import PluginToolManager
@ -11,9 +10,11 @@ from core.tools.errors import ToolProviderCredentialValidationError
class DatasourcePluginProviderController(ABC):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None:
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
self.entity = entity
self.tenant_id = tenant_id
@property
def need_credentials(self) -> bool:
@ -51,21 +52,6 @@ class DatasourcePluginProviderController(ABC):
"""
pass
def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore
"""
get all datasources
"""
return [
DatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
for datasource_entity in self.entity.datasources
]
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
validate the format of the credentials of the provider and set the default value if needed

View File

@ -6,7 +6,11 @@ import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.common_entities import I18nObject
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.plugin.impl.datasource import PluginDatasourceManager
logger = logging.getLogger(__name__)
@ -19,7 +23,9 @@ class DatasourceManager:
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_datasource_plugin_provider(cls, provider: str, tenant_id: str) -> DatasourcePluginProviderController:
def get_datasource_plugin_provider(
cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType
) -> DatasourcePluginProviderController:
"""
get the datasource plugin provider
"""
@ -40,12 +46,30 @@ class DatasourceManager:
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
controller = DatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
match (datasource_type):
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.WEBSITE_CRAWL:
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.LOCAL_FILE:
controller = LocalFileDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
datasource_plugin_providers[provider] = controller
@ -57,6 +81,7 @@ class DatasourceManager:
provider_id: str,
datasource_name: str,
tenant_id: str,
datasource_type: DatasourceProviderType,
) -> DatasourcePlugin:
"""
get the datasource runtime
@ -68,21 +93,10 @@ class DatasourceManager:
:return: the datasource plugin
"""
return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
return cls.get_datasource_plugin_provider(
provider_id,
tenant_id,
datasource_type,
).get_datasource(datasource_name)
@classmethod
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
"""
list all the datasource providers
"""
manager = PluginDatasourceManager()
provider_entities = manager.fetch_datasource_providers(tenant_id)
return [
DatasourcePluginProviderController(
entity=provider.declaration,
plugin_id=provider.plugin_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
tenant_id=tenant_id,
)
for provider in provider_entities
]

View File

@ -251,7 +251,7 @@ class GetOnlineDocumentPageContentRequest(BaseModel):
Get online document page content request
"""
online_document_info_list: list[OnlineDocumentInfo]
online_document_info: OnlineDocumentInfo
class OnlineDocumentPageContent(BaseModel):
@ -259,6 +259,7 @@ class OnlineDocumentPageContent(BaseModel):
Online document page content
"""
workspace_id: str = Field(..., description="The workspace id")
page_id: str = Field(..., description="The page id")
content: str = Field(..., description="The content of the page")
@ -268,7 +269,7 @@ class GetOnlineDocumentPageContentResponse(BaseModel):
Get online document page content response
"""
result: list[OnlineDocumentPageContent]
result: OnlineDocumentPageContent
class GetWebsiteCrawlRequest(BaseModel):
@ -286,7 +287,7 @@ class WebSiteInfo(BaseModel):
"""
source_url: str = Field(..., description="The url of the website")
markdown: str = Field(..., description="The markdown of the website")
content: str = Field(..., description="The content of the website")
title: str = Field(..., description="The title of the website")
description: str = Field(..., description="The description of the website")
@ -296,4 +297,4 @@ class GetWebsiteCrawlResponse(BaseModel):
Get website crawl response
"""
result: list[WebSiteInfo]
result: WebSiteInfo

View File

@ -26,12 +26,3 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
def datasource_provider_type(self) -> DatasourceProviderType:
return DatasourceProviderType.LOCAL_FILE
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourcePlugin(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -8,15 +8,13 @@ from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlug
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity)
self.tenant_id = tenant_id
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier

View File

@ -69,12 +69,3 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
def datasource_provider_type(self) -> DatasourceProviderType:
return DatasourceProviderType.ONLINE_DOCUMENT
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourcePlugin(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -1,20 +1,18 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity)
self.tenant_id = tenant_id
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@ -25,7 +23,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
"""
return DatasourceProviderType.ONLINE_DOCUMENT
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
@ -41,7 +39,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return DatasourcePlugin(
return OnlineDocumentDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,

View File

@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import (
GetWebsiteCrawlResponse,
)
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
@ -38,9 +37,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
) -> GetWebsiteCrawlResponse:
manager = PluginDatasourceManager()
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
return manager.invoke_first_step(
return manager.get_website_crawl(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
@ -52,12 +49,3 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
def datasource_provider_type(self) -> DatasourceProviderType:
return DatasourceProviderType.WEBSITE_CRAWL
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourcePlugin(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@ -1,20 +1,18 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity)
self.tenant_id = tenant_id
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@ -25,7 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
"""
return DatasourceProviderType.WEBSITE_CRAWL
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
@ -41,7 +39,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return DatasourcePlugin(
return WebsiteCrawlDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,

View File

@ -7,7 +7,6 @@ from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.oauth import OAuthSchema
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,
@ -350,7 +349,6 @@ class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = Field(default=None, description="The oauth schema of the tool provider")
class ToolProviderEntityWithPlugin(ToolProviderEntity):

View File

@ -4,6 +4,9 @@ from typing import Any, cast
from core.datasource.entities.datasource_entities import (
DatasourceParameter,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse,
GetWebsiteCrawlRequest,
GetWebsiteCrawlResponse,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
@ -54,6 +57,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
provider_id=node_data.provider_id,
datasource_name=node_data.datasource_name,
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType(node_data.provider_type),
)
except DatasourceNodeError as e:
yield RunCompletedEvent(
@ -82,38 +86,43 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
)
try:
# TODO: handle result
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
result = datasource_runtime._get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=parameters,
provider_type=node_data.provider_type,
online_document_result: GetOnlineDocumentPageContentResponse = (
datasource_runtime._get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
provider_type=datasource_runtime.datasource_provider_type(),
)
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"result": result.result.model_dump(),
"datasource_type": datasource_runtime.datasource_provider_type,
},
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"online_document": online_document_result.result.model_dump(),
"datasource_type": datasource_runtime.datasource_provider_type,
},
)
)
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
user_id=self.user_id,
datasource_parameters=parameters,
provider_type=node_data.provider_type,
datasource_parameters=GetWebsiteCrawlRequest(**parameters),
provider_type=datasource_runtime.datasource_provider_type(),
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"result": result.result.model_dump(),
"datasource_type": datasource_runtime.datasource_provider_type,
},
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"website": website_crawl_result.result.model_dump(),
"datasource_type": datasource_runtime.datasource_provider_type,
},
)
)
else:
raise DatasourceNodeError(

View File

@ -360,7 +360,7 @@ class Workflow(Base):
)
@property
def rag_pipeline_variables(self) -> Sequence[Variable]:
def rag_pipeline_variables(self) -> list[dict]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._rag_pipeline_variables is None:
self._rag_pipeline_variables = "{}"

View File

@ -2,12 +2,11 @@ from collections.abc import Mapping
from typing import Any, Union
from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from models.dataset import Pipeline
from models.model import Account, App, AppMode, EndUser
from models.model import Account, App, EndUser
from models.workflow import Workflow
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -57,23 +56,15 @@ class PipelineGenerateService:
return max_active_requests
@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
def generate_single_iteration(
cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True
):
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
return PipelineGenerator.convert_to_event_stream(
PipelineGenerator().single_iteration_generate(
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
elif app_model.mode == AppMode.WORKFLOW.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
)
@classmethod
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):

View File

@ -3,7 +3,7 @@ import threading
import time
from collections.abc import Callable, Generator, Sequence
from datetime import UTC, datetime
from typing import Any, Optional
from typing import Any, Optional, cast
from uuid import uuid4
from flask_login import current_user
@ -12,6 +12,9 @@ from sqlalchemy.orm import Session
import contexts
from configs import dify_config
from core.datasource.entities.datasource_entities import DatasourceProviderType, GetOnlineDocumentPagesRequest, GetOnlineDocumentPagesResponse, GetWebsiteCrawlRequest, GetWebsiteCrawlResponse
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
@ -30,6 +33,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
Workflow,
WorkflowNodeExecution,
@ -394,8 +398,8 @@ class RagPipelineService:
return workflow_node_execution
def run_datasource_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
) -> WorkflowNodeExecution:
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str
) -> dict:
"""
Run published workflow datasource
"""
@ -416,17 +420,36 @@ class RagPipelineService:
provider_id=datasource_node_data.get("provider_id"),
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
result = datasource_runtime._invoke_first_step(
inputs=user_inputs,
provider_type=datasource_node_data.get("provider_type"),
user_id=account.id,
)
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPagesResponse = (
datasource_runtime._get_online_document_pages(
user_id=account.id,
datasource_parameters=GetOnlineDocumentPagesRequest(tenant_id=pipeline.tenant_id),
provider_type=datasource_runtime.datasource_provider_type(),
)
)
return {
"result": [page.model_dump() for page in online_document_result.result],
"provider_type": datasource_node_data.get("provider_type"),
}
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
user_id=account.id,
datasource_parameters=GetWebsiteCrawlRequest(**user_inputs),
provider_type=datasource_runtime.datasource_provider_type(),
)
return {
"result": website_crawl_result.result.model_dump(),
"provider_type": datasource_node_data.get("provider_type"),
}
else:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
return {
"result": result,
"provider_type": datasource_node_data.get("provider_type"),
}
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
@ -587,7 +610,7 @@ class RagPipelineService:
return workflow
def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
"""
Get second step parameters of rag pipeline
"""
@ -599,7 +622,7 @@ class RagPipelineService:
# get second step node
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
return {}
return []
# get datasource provider
datasource_provider_variables = [
@ -609,7 +632,7 @@ class RagPipelineService:
]
return datasource_provider_variables
def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
"""
Get second step parameters of rag pipeline
"""
@ -621,7 +644,7 @@ class RagPipelineService:
# get second step node
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
return {}
return []
# get datasource provider
datasource_provider_variables = [
@ -702,6 +725,7 @@ class RagPipelineService:
self,
pipeline: Pipeline,
run_id: str,
user: Account | EndUser,
) -> list[WorkflowNodeExecution]:
"""
Get workflow run node execution list
@ -716,11 +740,16 @@ class RagPipelineService:
# Use the repository to get the node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine, tenant_id=pipeline.tenant_id, app_id=pipeline.id
session_factory=db.engine,
app_id=pipeline.id,
user=user,
triggered_from=None
)
# Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
# Convert domain models to database models
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
return list(node_executions)
return workflow_node_executions