This commit is contained in:
jyong 2025-06-03 19:02:57 +08:00
parent 309fffd1e4
commit 9cdd2cbb27
35 changed files with 229 additions and 300 deletions

View File

@ -1,4 +1,3 @@
import os
import sys
@ -18,19 +17,19 @@ else:
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
#from gevent import monkey
#
# # gevent
# monkey.patch_all()
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
# from gevent import monkey
#
# # gevent
# monkey.patch_all()
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()
from app_factory import create_app

View File

@ -109,8 +109,6 @@ class OAuthDataSourceSync(Resource):
return {"result": "success"}, 200
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")

View File

@ -4,24 +4,19 @@ from typing import Optional
import requests
from flask import current_app, redirect, request
from flask_login import current_user
from flask_restful import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from werkzeug.exceptions import Unauthorized
from configs import dify_config
from constants.languages import languages
from controllers.console.wraps import account_initialization_required, setup_required
from core.plugin.impl.oauth import OAuthHandler
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
@ -186,6 +181,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
return account
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")

View File

@ -1,12 +1,9 @@
from flask import redirect, request
from flask_login import current_user # type: ignore
from flask_restful import ( # type: ignore
Resource, # type: ignore
marshal_with,
reqparse,
)
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
@ -16,7 +13,6 @@ from controllers.console.wraps import (
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.impl.oauth import OAuthHandler
from extensions.ext_database import db
from libs.login import login_required
@ -33,10 +29,9 @@ class DatasourcePluginOauthApi(Resource):
if not current_user.is_editor:
raise Forbidden()
# get all plugin oauth configs
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by(
provider=provider,
plugin_id=plugin_id
).first()
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if not plugin_oauth_config:
raise NotFound()
oauth_handler = OAuthHandler()
@ -45,24 +40,20 @@ class DatasourcePluginOauthApi(Resource):
if system_credentials:
system_credentials["redirect_url"] = redirect_url
response = oauth_handler.get_authorization_url(
current_user.current_tenant.id,
current_user.id,
plugin_id,
provider,
system_credentials=system_credentials
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
)
return response.model_dump()
class DatasourceOauthCallback(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
oauth_handler = OAuthHandler()
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by(
provider=provider,
plugin_id=plugin_id
).first()
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if not plugin_oauth_config:
raise NotFound()
credentials = oauth_handler.get_credentials(
@ -71,18 +62,16 @@ class DatasourceOauthCallback(Resource):
plugin_id,
provider,
system_credentials=plugin_oauth_config.system_credentials,
request=request
request=request,
)
datasource_provider = DatasourceProvider(
plugin_id=plugin_id,
provider=provider,
auth_type="oauth",
encrypted_credentials=credentials
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
)
db.session.add(datasource_provider)
db.session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
class DatasourceAuth(Resource):
@setup_required
@login_required
@ -99,28 +88,27 @@ class DatasourceAuth(Resource):
try:
datasource_provider_service.datasource_provider_credentials_validate(
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id,
credentials=args["credentials"]
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id,
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
)
return {"result": datasources}, 200
class DatasourceAuthDeleteApi(Resource):
@setup_required
@login_required
@ -130,12 +118,11 @@ class DatasourceAuthDeleteApi(Resource):
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
)
return {"result": "success"}, 200
# Import Rag Pipeline
api.add_resource(
DatasourcePluginOauthApi,
@ -149,4 +136,3 @@ api.add_resource(
DatasourceAuth,
"/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>",
)

View File

@ -110,6 +110,7 @@ class CustomizedPipelineTemplateApi(Resource):
dsl = yaml.safe_load(template.yaml_content)
return {"data": dsl}, 200
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@ -142,6 +143,7 @@ class CustomizedPipelineTemplateApi(Resource):
RagPipelineService.publish_customized_pipeline_template(pipeline_id, args)
return 200
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",

View File

@ -540,7 +540,6 @@ class RagPipelineConfigApi(Resource):
@login_required
@account_initialization_required
def get(self, pipeline_id):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}

View File

@ -32,7 +32,6 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from fields.document_fields import dataset_and_document_fields
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, Pipeline
from models.enums import WorkflowRunTriggeredFrom
@ -55,8 +54,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None:
...
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
@overload
def generate(
@ -70,8 +68,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any]:
...
) -> Mapping[str, Any]: ...
@overload
def generate(
@ -85,8 +82,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
...
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
self,
@ -233,17 +229,19 @@ class PipelineGenerator(BaseAppGenerator):
description=dataset.description,
chunk_structure=dataset.chunk_structure,
).model_dump(),
"documents": [PipelineDocument(
id=document.id,
position=document.position,
data_source_type=document.data_source_type,
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump() for document in documents
]
"documents": [
PipelineDocument(
id=document.id,
position=document.position,
data_source_type=document.data_source_type,
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump()
for document in documents
],
}
def _generate(
@ -316,9 +314,7 @@ class PipelineGenerator(BaseAppGenerator):
)
# new thread
worker_thread = threading.Thread(
target=worker_with_context
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start()

View File

@ -111,7 +111,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs:
if (
rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id
and rag_pipeline_variable.variable in inputs
):
rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable]
variable_pool = VariablePool(
@ -195,7 +198,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
continue
real_run_nodes.append(node)
for edge in edges:
if edge.get("source") in exclude_node_ids :
if edge.get("source") in exclude_node_ids:
continue
real_edges.append(edge)
graph_config = dict(graph_config)

View File

@ -232,6 +232,7 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
"""
RAG Pipeline Application Generate Entity.
"""
# pipeline config
pipeline_config: WorkflowUIBasedAppConfig
datasource_type: str

View File

@ -5,7 +5,6 @@ from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom
class DatasourceRuntime(BaseModel):

View File

@ -46,7 +46,7 @@ class DatasourceManager:
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
match (datasource_type):
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration,
@ -98,5 +98,3 @@ class DatasourceManager:
tenant_id,
datasource_type,
).get_datasource(datasource_name)

View File

@ -215,7 +215,6 @@ class PluginDatasourceManager(BasePluginClient):
"X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
@ -233,41 +232,23 @@ class PluginDatasourceManager(BasePluginClient):
"identity": {
"author": "langgenius",
"name": "langgenius/file/file",
"label": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
},
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg",
"description": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
}
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
},
"credentials_schema": [],
"provider_type": "local_file",
"datasources": [{
"identity": {
"author": "langgenius",
"name": "upload-file",
"provider": "langgenius",
"label": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
}
},
"parameters": [],
"description": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
"datasources": [
{
"identity": {
"author": "langgenius",
"name": "upload-file",
"provider": "langgenius",
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
},
"parameters": [],
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
}
}]
}
],
},
}

View File

@ -28,12 +28,12 @@ class Jieba(BaseKeyword):
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
for text in texts:
keywords = keyword_table_handler.extract_keywords(
text.page_content, keyword_number
)
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
@ -51,19 +51,17 @@ class Jieba(BaseKeyword):
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get("keywords_list")
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
if not keywords:
keywords = keyword_table_handler.extract_keywords(
text.page_content, keyword_number
)
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
else:
keywords = keyword_table_handler.extract_keywords(
text.page_content, keyword_number
)
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
@ -242,7 +240,9 @@ class Jieba(BaseKeyword):
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
segment.keywords = list(keywords)

View File

@ -38,7 +38,7 @@ class BaseIndexProcessor(ABC):
@abstractmethod
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]):
raise NotImplementedError
@abstractmethod
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
raise NotImplementedError

View File

@ -15,7 +15,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, GeneralStructureChunk
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
@ -152,13 +153,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
paragraph = GeneralStructureChunk(**chunks)
preview = []
for content in paragraph.general_chunks:
preview.append({"content": content})
return {
"preview": preview,
"total_segments": len(paragraph.general_chunks)
}
return {"preview": preview, "total_segments": len(paragraph.general_chunks)}

View File

@ -16,7 +16,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, Document as DatasetDocument, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@ -239,14 +240,5 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
parent_childs = ParentChildStructureChunk(**chunks)
preview = []
for parent_child in parent_childs.parent_child_chunks:
preview.append(
{
"content": parent_child.parent_content,
"child_chunks": parent_child.child_contents
}
)
return {
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks)
}
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)}

View File

@ -4,7 +4,8 @@ import logging
import re
import threading
import uuid
from typing import Any, Mapping, Optional
from collections.abc import Mapping
from typing import Any, Optional
import pandas as pd
from flask import Flask, current_app
@ -20,7 +21,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset, Document as DatasetDocument
from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import Rule
@ -160,10 +161,10 @@ class QAIndexProcessor(BaseIndexProcessor):
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]):
pass
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
return {"preview": chunks}

View File

@ -94,19 +94,26 @@ class FileVariable(FileSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
class RAGPipelineVariable(BaseModel):
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
label: str = Field(description="label")
description: str | None = Field(description="description", default="")
variable: str = Field(description="variable key", default="")
max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0)
max_length: int | None = Field(
description="max length, applicable to text-input, paragraph, and file-list", default=0
)
default_value: str | None = Field(description="default value", default="")
placeholder: str | None = Field(description="placeholder", default="")
unit: str | None = Field(description="unit, applicable to Number", default="")
tooltips: str | None = Field(description="helpful text", default="")
allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list)
allowed_file_types: list[str] | None = Field(
description="image, document, audio, video, custom.", default_factory=list
)
allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list)
allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list)
allowed_file_upload_methods: list[str] | None = Field(
description="remote_url, local_file, tool_file.", default_factory=list
)
required: bool = Field(description="optional, default false", default=False)
options: list[str] | None = Field(default_factory=list)

View File

@ -49,7 +49,7 @@ class VariablePool(BaseModel):
)
rag_pipeline_variables: Mapping[str, Any] = Field(
description="RAG pipeline variables.",
default_factory=dict,
default_factory=dict,
)
def __init__(

View File

@ -28,6 +28,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
DATASOURCE_INFO = "datasource_info"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"

View File

@ -122,7 +122,6 @@ class Graph(BaseModel):
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id:
@ -142,7 +141,7 @@ class Graph(BaseModel):
(
node_config.get("id")
for node_config in root_node_configs
if node_config.get("data", {}).get("type", "") == NodeType.START.value
if node_config.get("data", {}).get("type", "") == NodeType.START.value
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
),
None,

View File

@ -317,10 +317,10 @@ class GraphEngine:
raise e
# It may not be necessary, but it is necessary. :)
if (
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower()
in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value]
):
if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [
NodeType.END.value,
NodeType.KNOWLEDGE_INDEX.value,
]:
break
previous_route_node_state = route_node_state

View File

@ -11,18 +11,19 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen
from core.file import File
from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment, FileSegment
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import UploadFile
from models.workflow import WorkflowNodeExecutionStatus
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from .entities import DatasourceNodeData
from .exc import DatasourceNodeError, DatasourceParameterError
@ -54,7 +55,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
try:
from core.datasource.datasource_manager import DatasourceManager
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
@ -66,13 +66,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
)
except DatasourceNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__,
)
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__,
)
# get parameters
datasource_parameters = datasource_runtime.entity.parameters
@ -102,7 +101,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"online_document": online_document_result.result.model_dump(),
"datasource_type": datasource_type,
@ -112,18 +111,16 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"website": datasource_info,
"datasource_type": datasource_type,
"website": datasource_info,
"datasource_type": datasource_type,
},
)
case DatasourceProviderType.LOCAL_FILE:
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError(
"File is not exist"
)
raise DatasourceNodeError("File is not exist")
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
if not upload_file:
raise ValueError("Invalid upload file Info")
@ -146,26 +143,27 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
# construct new key list
new_key_list = ["file", key]
self._append_variables_recursively(
variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value
variable_pool=variable_pool,
node_id=self.node_id,
variable_key_list=new_key_list,
variable_value=value,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file_info": datasource_info,
"datasource_type": datasource_type,
},
)
case _:
raise DatasourceNodeError(
f"Unsupported datasource provider: {datasource_type}"
"file_info": datasource_info,
"datasource_type": datasource_type,
},
)
case _:
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
except PluginDaemonClientSideError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
@ -173,7 +171,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__,
)
@ -227,8 +225,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
def _append_variables_recursively(
self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
):
"""
Append variables recursively
:param node_id: node id

View File

@ -6,7 +6,6 @@ from typing import Any, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables.segments import ObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
@ -72,8 +71,9 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
process_data=None,
outputs=outputs,
)
results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks,
variable_pool=variable_pool)
results = self._invoke_knowledge_index(
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
)
@ -96,8 +96,11 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
)
def _invoke_knowledge_index(
self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any],
variable_pool: VariablePool
self,
dataset: Dataset,
node_data: KnowledgeIndexNodeData,
chunks: Mapping[str, Any],
variable_pool: VariablePool,
) -> Any:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if not document_id:
@ -116,7 +119,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(document)
#update document segment status
# update document segment status
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,

View File

@ -208,6 +208,7 @@ class Dataset(Base):
"external_knowledge_api_name": external_knowledge_api.name,
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
}
@property
def is_published(self):
if self.pipeline_id:
@ -1177,7 +1178,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_customized_templates"
__table_args__ = (

View File

@ -1,4 +1,3 @@
from datetime import datetime
from sqlalchemy.dialects.postgresql import JSONB
@ -21,6 +20,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
class DatasourceProvider(Base):
__tablename__ = "datasource_providers"
__table_args__ = (

View File

@ -1,4 +1,3 @@
from calendar import day_abbr
import copy
import datetime
import json
@ -7,7 +6,7 @@ import random
import time
import uuid
from collections import Counter
from typing import Any, Optional, cast
from typing import Any, Optional
from flask_login import current_user
from sqlalchemy import func, select
@ -282,7 +281,6 @@ class DatasetService:
db.session.commit()
return dataset
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
@ -494,10 +492,9 @@ class DatasetService:
return dataset
@staticmethod
def update_rag_pipeline_dataset_settings(session: Session,
dataset: Dataset,
knowledge_configuration: KnowledgeConfiguration,
has_published: bool = False):
def update_rag_pipeline_dataset_settings(
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
):
dataset = session.merge(dataset)
if not has_published:
dataset.chunk_structure = knowledge_configuration.chunk_structure
@ -616,7 +613,6 @@ class DatasetService:
if action:
deal_dataset_index_update_task.delay(dataset.id, action)
@staticmethod
def delete_dataset(dataset_id, user):
dataset = DatasetService.get_dataset(dataset_id)

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from flask_login import current_user
@ -22,11 +21,9 @@ class DatasourceProviderService:
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()
def datasource_provider_credentials_validate(self,
tenant_id: str,
provider: str,
plugin_id: str,
credentials: dict) -> None:
def datasource_provider_credentials_validate(
self, tenant_id: str, provider: str, plugin_id: str, credentials: dict
) -> None:
"""
validate datasource provider credentials.
@ -34,29 +31,30 @@ class DatasourceProviderService:
:param provider:
:param credentials:
"""
credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
credentials=credentials)
credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials
)
if credential_valid:
# Get all provider configurations of the current workspace
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id).first()
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.first()
)
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id,
provider=provider
)
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
if not datasource_provider:
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
auth_type="api_key",
encrypted_credentials=credentials)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
auth_type="api_key",
encrypted_credentials=credentials,
)
db.session.add(datasource_provider)
db.session.commit()
else:
@ -101,11 +99,15 @@ class DatasourceProviderService:
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id
).all()
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.filter(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
@ -128,10 +130,7 @@ class DatasourceProviderService:
return copy_credentials_list
def remove_datasource_credentials(self,
tenant_id: str,
provider: str,
plugin_id: str) -> None:
def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None:
"""
remove datasource credentials.
@ -140,9 +139,11 @@ class DatasourceProviderService:
:param plugin_id: plugin id
:return:
"""
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id).first()
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.first()
)
if datasource_provider:
db.session.delete(datasource_provider)
db.session.commit()

View File

@ -107,6 +107,7 @@ class KnowledgeConfiguration(BaseModel):
"""
Knowledge Base Configuration.
"""
chunk_structure: str
indexing_technique: Literal["high_quality", "economy"]
embedding_model_provider: Optional[str] = ""

View File

@ -3,7 +3,6 @@ from typing import Any, Union
from configs import dify_config
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from models.dataset import Pipeline
from models.model import Account, App, EndUser

View File

@ -1,13 +1,12 @@
from typing import Optional
from flask_login import current_user
import yaml
from flask_login import current_user
from extensions.ext_database import db
from models.dataset import PipelineCustomizedTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
@ -43,7 +42,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
)
recommended_pipelines_results = []
for pipeline_customized_template in pipeline_customized_templates:
recommended_pipeline_result = {
"id": pipeline_customized_template.id,
"name": pipeline_customized_template.name,
@ -56,7 +54,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
"""

View File

@ -38,7 +38,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates:
recommended_pipeline_result = {
"id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name,

View File

@ -35,7 +35,7 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
@ -57,9 +57,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi
class RagPipelineService:
@classmethod
def get_pipeline_templates(
cls, type: str = "built-in", language: str = "en-US"
) -> dict:
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
@ -308,7 +306,7 @@ class RagPipelineService:
session=session,
dataset=dataset,
knowledge_configuration=knowledge_configuration,
has_published=pipeline.is_published
has_published=pipeline.is_published,
)
# return new workflow
return workflow
@ -444,12 +442,10 @@ class RagPipelineService:
)
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPagesResponse = (
datasource_runtime._get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
return {
"result": [page.model_dump() for page in online_document_result.result],
@ -470,7 +466,6 @@ class RagPipelineService:
else:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
@ -689,8 +684,8 @@ class RagPipelineService:
WorkflowRun.app_id == pipeline.id,
or_(
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value
)
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
),
)
if args.get("last_id"):
@ -763,18 +758,17 @@ class RagPipelineService:
# Use the repository to get the node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine,
app_id=pipeline.id,
user=user,
triggered_from=None
session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
)
# Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id,
order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN)
# Convert domain models to database models
node_executions = repository.get_by_workflow_run(
workflow_run_id=run_id,
order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
# Convert domain models to database models
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
return workflow_node_executions

View File

@ -279,7 +279,11 @@ class RagPipelineDslService:
if node.get("data", {}).get("type") == "knowledge_index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure:
if (
dataset
and pipeline.is_published
and dataset.chunk_structure != knowledge_configuration.chunk_structure
):
raise ValueError("Chunk structure is not compatible with the published pipeline")
else:
dataset = Dataset(
@ -304,8 +308,7 @@ class RagPipelineDslService:
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name
== knowledge_configuration.embedding_model,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
@ -323,12 +326,8 @@ class RagPipelineDslService:
db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
knowledge_configuration.embedding_model
)
dataset.embedding_model_provider = (
knowledge_configuration.embedding_model_provider
)
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
@ -443,8 +442,7 @@ class RagPipelineDslService:
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name
== knowledge_configuration.embedding_model,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
@ -462,12 +460,8 @@ class RagPipelineDslService:
db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
knowledge_configuration.embedding_model
)
dataset.embedding_model_provider = (
knowledge_configuration.embedding_model_provider
)
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
@ -538,7 +532,6 @@ class RagPipelineDslService:
icon_type = "emoji"
icon = str(pipeline_data.get("icon", ""))
# Initialize pipeline based on mode
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
@ -554,7 +547,6 @@ class RagPipelineDslService:
]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
@ -576,7 +568,6 @@ class RagPipelineDslService:
pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.updated_by = account.id
else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")
@ -636,7 +627,6 @@ class RagPipelineDslService:
# commit db session changes
db.session.commit()
return pipeline
@classmethod
@ -874,7 +864,6 @@ class RagPipelineDslService:
except Exception:
return None
@staticmethod
def create_rag_pipeline_dataset(
tenant_id: str,
@ -886,9 +875,7 @@ class RagPipelineDslService:
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise ValueError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)

View File

@ -12,12 +12,12 @@ class RagPipelineManageService:
# get all builtin providers
manager = PluginDatasourceManager()
datasources = manager.fetch_datasource_providers(tenant_id)
datasources = manager.fetch_datasource_providers(tenant_id)
for datasource in datasources:
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id,
provider=datasource.provider,
plugin_id=datasource.plugin_id)
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
if credentials:
datasource.is_authorized = True
return datasources