mirror of
https://github.com/langgenius/dify.git
synced 2025-11-02 03:43:12 +00:00
r2
This commit is contained in:
parent
a49942b949
commit
64d997fdb0
@ -8,6 +8,7 @@ from flask_restful.inputs import int_range # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
from models.model import EndUser
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
@ -44,7 +45,6 @@ from services.errors.llm import InvokeRateLimitError
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -243,6 +243,7 @@ class DraftRagPipelineRunApi(Resource):
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("datasource_info", type=list, required=True, location="json")
|
||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -313,13 +314,20 @@ class RagPipelineDatasourceNodeRunApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
result = rag_pipeline_service.run_datasource_workflow_node(
|
||||
pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
account=current_user,
|
||||
datasource_type=args.get("datasource_type"),
|
||||
)
|
||||
|
||||
return result
|
||||
@ -648,40 +656,6 @@ class RagPipelineByIdApi(Resource):
|
||||
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def delete(self, pipeline: Pipeline, workflow_id: str):
|
||||
"""
|
||||
Delete workflow
|
||||
"""
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
rag_pipeline_service.delete_workflow(
|
||||
session=session, workflow_id=workflow_id, tenant_id=pipeline.tenant_id
|
||||
)
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
except WorkflowInUseError as e:
|
||||
abort(400, description=str(e))
|
||||
except DraftWorkflowDeletionError as e:
|
||||
abort(400, description=str(e))
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
|
||||
return None, 204
|
||||
|
||||
|
||||
class PublishedRagPipelineSecondStepApi(Resource):
|
||||
@setup_required
|
||||
@ -695,8 +669,12 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
node_id = request.args.get("node_id", required=True, type=str)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id)
|
||||
return {
|
||||
@ -716,7 +694,12 @@ class DraftRagPipelineSecondStepApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
node_id = request.args.get("node_id", required=True, type=str)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id)
|
||||
@ -777,9 +760,11 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
run_id = str(run_id)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
user = cast("Account | EndUser", current_user)
|
||||
node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
|
||||
pipeline=pipeline,
|
||||
run_id=run_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return {"data": node_executions}
|
||||
@ -875,9 +860,9 @@ api.add_resource(
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedRagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/paramters",
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftRagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/paramters",
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters",
|
||||
)
|
||||
|
||||
@ -99,6 +99,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
start_node_id: str = args["start_node_id"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||
@ -118,7 +119,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
position=position,
|
||||
account=user,
|
||||
batch=batch,
|
||||
document_form=pipeline.dataset.doc_form,
|
||||
document_form=pipeline.dataset.chunk_structure,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
@ -231,7 +232,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
|
||||
def single_iteration_generate(
|
||||
self,
|
||||
app_model: App,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
@ -255,7 +256,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
|
||||
@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
@ -11,9 +10,11 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
class DatasourcePluginProviderController(ABC):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None:
|
||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
@ -51,21 +52,6 @@ class DatasourcePluginProviderController(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore
|
||||
"""
|
||||
get all datasources
|
||||
"""
|
||||
return [
|
||||
DatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
for datasource_entity in self.entity.datasources
|
||||
]
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
@ -6,7 +6,11 @@ import contexts
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.common_entities import I18nObject
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -19,7 +23,9 @@ class DatasourceManager:
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_datasource_plugin_provider(cls, provider: str, tenant_id: str) -> DatasourcePluginProviderController:
|
||||
def get_datasource_plugin_provider(
|
||||
cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType
|
||||
) -> DatasourcePluginProviderController:
|
||||
"""
|
||||
get the datasource plugin provider
|
||||
"""
|
||||
@ -40,12 +46,30 @@ class DatasourceManager:
|
||||
if not provider_entity:
|
||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
|
||||
|
||||
controller = DatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
match (datasource_type):
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
datasource_plugin_providers[provider] = controller
|
||||
|
||||
@ -57,6 +81,7 @@ class DatasourceManager:
|
||||
provider_id: str,
|
||||
datasource_name: str,
|
||||
tenant_id: str,
|
||||
datasource_type: DatasourceProviderType,
|
||||
) -> DatasourcePlugin:
|
||||
"""
|
||||
get the datasource runtime
|
||||
@ -68,21 +93,10 @@ class DatasourceManager:
|
||||
|
||||
:return: the datasource plugin
|
||||
"""
|
||||
return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
|
||||
return cls.get_datasource_plugin_provider(
|
||||
provider_id,
|
||||
tenant_id,
|
||||
datasource_type,
|
||||
).get_datasource(datasource_name)
|
||||
|
||||
|
||||
@classmethod
|
||||
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
|
||||
"""
|
||||
list all the datasource providers
|
||||
"""
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entities = manager.fetch_datasource_providers(tenant_id)
|
||||
return [
|
||||
DatasourcePluginProviderController(
|
||||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for provider in provider_entities
|
||||
]
|
||||
|
||||
@ -251,7 +251,7 @@ class GetOnlineDocumentPageContentRequest(BaseModel):
|
||||
Get online document page content request
|
||||
"""
|
||||
|
||||
online_document_info_list: list[OnlineDocumentInfo]
|
||||
online_document_info: OnlineDocumentInfo
|
||||
|
||||
|
||||
class OnlineDocumentPageContent(BaseModel):
|
||||
@ -259,6 +259,7 @@ class OnlineDocumentPageContent(BaseModel):
|
||||
Online document page content
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
page_id: str = Field(..., description="The page id")
|
||||
content: str = Field(..., description="The content of the page")
|
||||
|
||||
@ -268,7 +269,7 @@ class GetOnlineDocumentPageContentResponse(BaseModel):
|
||||
Get online document page content response
|
||||
"""
|
||||
|
||||
result: list[OnlineDocumentPageContent]
|
||||
result: OnlineDocumentPageContent
|
||||
|
||||
|
||||
class GetWebsiteCrawlRequest(BaseModel):
|
||||
@ -286,7 +287,7 @@ class WebSiteInfo(BaseModel):
|
||||
"""
|
||||
|
||||
source_url: str = Field(..., description="The url of the website")
|
||||
markdown: str = Field(..., description="The markdown of the website")
|
||||
content: str = Field(..., description="The content of the website")
|
||||
title: str = Field(..., description="The title of the website")
|
||||
description: str = Field(..., description="The description of the website")
|
||||
|
||||
@ -296,4 +297,4 @@ class GetWebsiteCrawlResponse(BaseModel):
|
||||
Get website crawl response
|
||||
"""
|
||||
|
||||
result: list[WebSiteInfo]
|
||||
result: WebSiteInfo
|
||||
|
||||
@ -26,12 +26,3 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
@ -8,15 +8,13 @@ from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlug
|
||||
|
||||
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity)
|
||||
self.tenant_id = tenant_id
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
|
||||
@ -69,12 +69,3 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
@ -1,20 +1,18 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
|
||||
|
||||
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity)
|
||||
self.tenant_id = tenant_id
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@ -25,7 +23,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
|
||||
"""
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
||||
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
@ -41,7 +39,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return DatasourcePlugin(
|
||||
return OnlineDocumentDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import (
|
||||
GetWebsiteCrawlResponse,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
@ -38,9 +37,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
) -> GetWebsiteCrawlResponse:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
return manager.invoke_first_step(
|
||||
return manager.get_website_crawl(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
@ -52,12 +49,3 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return DatasourcePlugin(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
@ -1,20 +1,18 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity)
|
||||
self.tenant_id = tenant_id
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@ -25,7 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
||||
"""
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
||||
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
@ -41,7 +39,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return DatasourcePlugin(
|
||||
return WebsiteCrawlDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import Any, Optional, Union
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.oauth import OAuthSchema
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
@ -350,7 +349,6 @@ class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: Optional[str] = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
oauth_schema: Optional[OAuthSchema] = Field(default=None, description="The oauth schema of the tool provider")
|
||||
|
||||
|
||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||
|
||||
@ -4,6 +4,9 @@ from typing import Any, cast
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
GetOnlineDocumentPageContentResponse,
|
||||
GetWebsiteCrawlRequest,
|
||||
GetWebsiteCrawlResponse,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
@ -54,6 +57,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
provider_id=node_data.provider_id,
|
||||
datasource_name=node_data.datasource_name,
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType(node_data.provider_type),
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
@ -82,38 +86,43 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
)
|
||||
|
||||
try:
|
||||
# TODO: handle result
|
||||
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
result = datasource_runtime._get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=parameters,
|
||||
provider_type=node_data.provider_type,
|
||||
online_document_result: GetOnlineDocumentPageContentResponse = (
|
||||
datasource_runtime._get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"result": result.result.model_dump(),
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"online_document": online_document_result.result.model_dump(),
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
|
||||
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
||||
result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
||||
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=parameters,
|
||||
provider_type=node_data.provider_type,
|
||||
datasource_parameters=GetWebsiteCrawlRequest(**parameters),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"result": result.result.model_dump(),
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"website": website_crawl_result.result.model_dump(),
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise DatasourceNodeError(
|
||||
|
||||
@ -360,7 +360,7 @@ class Workflow(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def rag_pipeline_variables(self) -> Sequence[Variable]:
|
||||
def rag_pipeline_variables(self) -> list[dict]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._rag_pipeline_variables is None:
|
||||
self._rag_pipeline_variables = "{}"
|
||||
|
||||
@ -2,12 +2,11 @@ from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.dataset import Pipeline
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.model import Account, App, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
@ -57,23 +56,15 @@ class PipelineGenerateService:
|
||||
return max_active_requests
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
def generate_single_iteration(
|
||||
cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True
|
||||
):
|
||||
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
|
||||
return PipelineGenerator.convert_to_event_stream(
|
||||
PipelineGenerator().single_iteration_generate(
|
||||
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
|
||||
@ -3,7 +3,7 @@ import threading
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from flask_login import current_user
|
||||
@ -12,6 +12,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, GetOnlineDocumentPagesRequest, GetOnlineDocumentPagesResponse, GetWebsiteCrawlRequest, GetWebsiteCrawlResponse
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables.variables import Variable
|
||||
@ -30,6 +33,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
@ -394,8 +398,8 @@ class RagPipelineService:
|
||||
return workflow_node_execution
|
||||
|
||||
def run_datasource_workflow_node(
|
||||
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
|
||||
) -> WorkflowNodeExecution:
|
||||
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str
|
||||
) -> dict:
|
||||
"""
|
||||
Run published workflow datasource
|
||||
"""
|
||||
@ -416,17 +420,36 @@ class RagPipelineService:
|
||||
provider_id=datasource_node_data.get("provider_id"),
|
||||
datasource_name=datasource_node_data.get("datasource_name"),
|
||||
tenant_id=pipeline.tenant_id,
|
||||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
)
|
||||
result = datasource_runtime._invoke_first_step(
|
||||
inputs=user_inputs,
|
||||
provider_type=datasource_node_data.get("provider_type"),
|
||||
user_id=account.id,
|
||||
)
|
||||
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: GetOnlineDocumentPagesResponse = (
|
||||
datasource_runtime._get_online_document_pages(
|
||||
user_id=account.id,
|
||||
datasource_parameters=GetOnlineDocumentPagesRequest(tenant_id=pipeline.tenant_id),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
)
|
||||
return {
|
||||
"result": [page.model_dump() for page in online_document_result.result],
|
||||
"provider_type": datasource_node_data.get("provider_type"),
|
||||
}
|
||||
|
||||
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
|
||||
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
|
||||
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
|
||||
user_id=account.id,
|
||||
datasource_parameters=GetWebsiteCrawlRequest(**user_inputs),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
return {
|
||||
"result": website_crawl_result.result.model_dump(),
|
||||
"provider_type": datasource_node_data.get("provider_type"),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"provider_type": datasource_node_data.get("provider_type"),
|
||||
}
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
@ -587,7 +610,7 @@ class RagPipelineService:
|
||||
|
||||
return workflow
|
||||
|
||||
def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
|
||||
def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
|
||||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
@ -599,7 +622,7 @@ class RagPipelineService:
|
||||
# get second step node
|
||||
rag_pipeline_variables = workflow.rag_pipeline_variables
|
||||
if not rag_pipeline_variables:
|
||||
return {}
|
||||
return []
|
||||
|
||||
# get datasource provider
|
||||
datasource_provider_variables = [
|
||||
@ -609,7 +632,7 @@ class RagPipelineService:
|
||||
]
|
||||
return datasource_provider_variables
|
||||
|
||||
def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict:
|
||||
def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
|
||||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
@ -621,7 +644,7 @@ class RagPipelineService:
|
||||
# get second step node
|
||||
rag_pipeline_variables = workflow.rag_pipeline_variables
|
||||
if not rag_pipeline_variables:
|
||||
return {}
|
||||
return []
|
||||
|
||||
# get datasource provider
|
||||
datasource_provider_variables = [
|
||||
@ -702,6 +725,7 @@ class RagPipelineService:
|
||||
self,
|
||||
pipeline: Pipeline,
|
||||
run_id: str,
|
||||
user: Account | EndUser,
|
||||
) -> list[WorkflowNodeExecution]:
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
@ -716,11 +740,16 @@ class RagPipelineService:
|
||||
|
||||
# Use the repository to get the node execution
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=db.engine, tenant_id=pipeline.tenant_id, app_id=pipeline.id
|
||||
session_factory=db.engine,
|
||||
app_id=pipeline.id,
|
||||
user=user,
|
||||
triggered_from=None
|
||||
)
|
||||
|
||||
# Use the repository to get the node executions with ordering
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||
# Convert domain models to database models
|
||||
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
|
||||
|
||||
return list(node_executions)
|
||||
return workflow_node_executions
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user