From 5fc2bc58a9fba17c812837e208b948f8bd322e2d Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 27 May 2025 00:01:23 +0800 Subject: [PATCH] r2 --- api/controllers/console/auth/oauth.py | 58 ------- .../datasets/rag_pipeline/datasource_auth.py | 140 ++++++++++++++++ .../datasets/rag_pipeline/datasource_oauth.py | 47 ------ api/core/plugin/impl/datasource.py | 33 +--- .../index_processor/index_processor_base.py | 4 +- .../processor/paragraph_index_processor.py | 4 +- .../processor/parent_child_index_processor.py | 4 +- .../knowledge_index/knowledge_index_node.py | 6 - api/models/oauth.py | 2 - api/services/datasource_provider_service.py | 150 ++++++++++++++++++ 10 files changed, 301 insertions(+), 147 deletions(-) create mode 100644 api/controllers/console/datasets/rag_pipeline/datasource_auth.py delete mode 100644 api/controllers/console/datasets/rag_pipeline/datasource_oauth.py create mode 100644 api/services/datasource_provider_service.py diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index d5e13525d6..ed595f5d3d 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -186,64 +186,6 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): return account -class PluginOauthApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, provider, plugin_id): - # Check user role first - if not current_user.is_editor: - raise Forbidden() - # get all plugin oauth configs - plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( - provider=provider, - plugin_id=plugin_id - ).first() - if not plugin_oauth_config: - raise NotFound() - oauth_handler = OAuthHandler() - response = oauth_handler.get_authorization_url( - current_user.current_tenant.id, - current_user.id, - plugin_id, - provider, - system_credentials=plugin_oauth_config.system_credentials - ) - return response.model_dump() - -class PluginOauthCallback(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, provider, plugin_id): - oauth_handler = OAuthHandler() - plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( - provider=provider, - plugin_id=plugin_id - ).first() - if not plugin_oauth_config: - raise NotFound() - credentials = oauth_handler.get_credentials( - current_user.current_tenant.id, - current_user.id, - plugin_id, - provider, - system_credentials=plugin_oauth_config.system_credentials, - request=request - ) - datasource_provider = DatasourceProvider( - datasource_name=plugin_oauth_config.datasource_name, - plugin_id=plugin_id, - provider=provider, - auth_type="oauth", - encrypted_credentials=credentials - ) - db.session.add(datasource_provider) - db.session.commit() - return redirect(f"{dify_config.CONSOLE_WEB_URL}") - api.add_resource(OAuthLogin, "/oauth/login/") api.add_resource(OAuthCallback, "/oauth/authorize/") -api.add_resource(PluginOauthApi, "/oauth/plugin/provider//plugin/") -api.add_resource(PluginOauthCallback, "/oauth/plugin/callback//plugin/") diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py new file mode 100644 index 0000000000..8894babcf7 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -0,0 +1,140 @@ + +from flask import redirect, request +from flask_login import current_user # type: ignore +from flask_restful import ( # type: ignore + Resource, # type: ignore + marshal_with, + reqparse, +) +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden, NotFound + +from configs import dify_config +from controllers.console import api +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.plugin.impl.datasource import PluginDatasourceManager +from core.plugin.impl.oauth import OAuthHandler +from extensions.ext_database import db +from libs.login import login_required +from models.oauth import DatasourceOauthParamConfig, DatasourceProvider +from services.datasource_provider_service import DatasourceProviderService + + +class DatasourcePluginOauthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider, plugin_id): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + # get all plugin oauth configs + plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( + provider=provider, + plugin_id=plugin_id + ).first() + if not plugin_oauth_config: + raise NotFound() + oauth_handler = OAuthHandler() + redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/provider/{provider}/plugin/{plugin_id}/callback" + system_credentials = plugin_oauth_config.system_credentials + if system_credentials: + system_credentials["redirect_url"] = redirect_url + response = oauth_handler.get_authorization_url( + current_user.current_tenant.id, + current_user.id, + plugin_id, + provider, + system_credentials=system_credentials + ) + return response.model_dump() + +class DatasourceOauthCallback(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider, plugin_id): + oauth_handler = OAuthHandler() + plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( + provider=provider, + plugin_id=plugin_id + ).first() + if not plugin_oauth_config: + raise NotFound() + credentials = oauth_handler.get_credentials( + current_user.current_tenant.id, + current_user.id, + plugin_id, + provider, + system_credentials=plugin_oauth_config.system_credentials, + request=request + ) + datasource_provider = DatasourceProvider( + plugin_id=plugin_id, + provider=provider, + auth_type="oauth", + encrypted_credentials=credentials + ) + db.session.add(datasource_provider) + db.session.commit() + return redirect(f"{dify_config.CONSOLE_WEB_URL}") + +class DatasourceAuth(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider, plugin_id): + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + datasource_provider_service = DatasourceProviderService() + + try: + datasource_provider_service.datasource_provider_credentials_validate( + tenant_id=current_user.current_tenant_id, + provider=provider, + plugin_id=plugin_id, + credentials=args["credentials"] + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + +class DatasourceAuthDeleteApi(Resource): + @setup_required + @login_required + @account_initialization_required + def delete(self, provider, plugin_id): + if not current_user.is_editor: + raise Forbidden() + datasource_provider_service = DatasourceProviderService() + datasource_provider_service.remove_datasource_credentials( + tenant_id=current_user.current_tenant_id, + provider=provider, + plugin_id=plugin_id + ) + return {"result": "success"}, 200 + +# Import Rag Pipeline +api.add_resource( + DatasourcePluginOauthApi, + "/oauth/datasource/provider//plugin/", +) +api.add_resource( + DatasourceOauthCallback, + "/oauth/datasource/provider//plugin//callback", +) +api.add_resource( + DatasourceAuth, + "/auth/datasource/provider//plugin/", +) + diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py b/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py deleted file mode 100644 index f4164dea7b..0000000000 --- a/api/controllers/console/datasets/rag_pipeline/datasource_oauth.py +++ /dev/null @@ -1,47 +0,0 @@ - -from flask_login import current_user # type: ignore -from flask_restful import Resource # type: ignore -from werkzeug.exceptions import Forbidden - -from controllers.console import api -from controllers.console.wraps import ( - account_initialization_required, - setup_required, -) -from core.plugin.impl.datasource import PluginDatasourceManager -from libs.login import login_required - - -class DatasourcePluginOauthApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, datasource_type, datasource_name): - # Check user role first - if not current_user.is_editor: - raise Forbidden() - # get all builtin providers - manager = PluginDatasourceManager() - providers = manager.get_provider_oauth_url() - return providers - - - - -# Import Rag Pipeline -api.add_resource( - DatasourcePluginOauthApi, - "/datasource///oauth", -) -api.add_resource( - RagPipelineImportConfirmApi, - "/rag/pipelines/imports//confirm", -) -api.add_resource( - RagPipelineImportCheckDependenciesApi, - "/rag/pipelines/imports//check-dependencies", -) -api.add_resource( - RagPipelineExportApi, - "/rag/pipelines//exports", -) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 004cf7f9c3..b5212eb719 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -203,7 +203,7 @@ class PluginDatasourceManager(BasePluginClient): """ validate the credentials of the provider """ - tool_provider_id = GenericProviderID(provider) + datasource_provider_id = GenericProviderID(provider) response = self._request_with_plugin_daemon_response_stream( "POST", @@ -212,12 +212,12 @@ class PluginDatasourceManager(BasePluginClient): data={ "user_id": user_id, "data": { - "provider": tool_provider_id.provider_name, + "provider": datasource_provider_id.provider_name, "credentials": credentials, }, }, headers={ - "X-Plugin-ID": tool_provider_id.plugin_id, + "X-Plugin-ID": datasource_provider_id.plugin_id, "Content-Type": "application/json", }, @@ -227,34 +227,11 @@ class PluginDatasourceManager(BasePluginClient): return resp.result return False - - def get_provider_oauth_url(self, datasource_type: str, datasource_name: str, provider: str) -> str: - """ - get the oauth url of the provider - """ - tool_provider_id = GenericProviderID(provider) - response = self._request_with_plugin_daemon_response_stream( - "GET", - "plugin/datasource/oauth", - PluginBasicBooleanResponse, - params={"page": 1, "page_size": 256}, - headers={ - "X-Plugin-ID": tool_provider_id.plugin_id, - "Content-Type": "application/json", - }, - - ) - - for resp in response: - return resp.result - - return False - def _get_local_file_datasource_provider(self) -> dict[str, Any]: return { "id": "langgenius/file/file", - "plugin_id": "langgenius/file", + "plugin_id": "langgenius/file/file", "provider": "langgenius", "plugin_unique_identifier": "langgenius/file:0.0.1@dify", "declaration": { @@ -280,7 +257,7 @@ class PluginDatasourceManager(BasePluginClient): "datasources": [{ "identity": { "author": "langgenius", - "name": "local_file", + "name": "upload-file", "provider": "langgenius", "label": { "zh_Hans": "File", diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index d796c9fd24..50511de16f 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -13,7 +13,7 @@ from core.rag.splitter.fixed_text_splitter import ( FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter -from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule class BaseIndexProcessor(ABC): @@ -35,7 +35,7 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): raise NotImplementedError @abstractmethod diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 155aae61d4..5eab77d4f8 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -15,7 +15,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document, GeneralStructureChunk from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule from services.entities.knowledge_entities.knowledge_entities import Rule @@ -128,7 +128,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): docs.append(doc) return docs - def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): paragraph = GeneralStructureChunk(**chunks) documents = [] for content in paragraph.general_chunk: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 5279864441..6300d05707 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -16,7 +16,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk from extensions.ext_database import db from libs import helper -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, Document as DatasetDocument, DocumentSegment from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -205,7 +205,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_nodes.append(child_document) return child_nodes - def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): parent_childs = ParentChildStructureChunk(**chunks) documents = [] for parent_child in parent_childs.parent_child_chunks: diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index d883200c94..25a4112998 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -42,12 +42,6 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) is_preview = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) == InvokeFrom.DEBUGGER - if not isinstance(variable, ObjectSegment): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - error="Index chunk variable is not object type.", - ) chunks = variable.value variables = {"chunks": chunks} if not chunks: diff --git a/api/models/oauth.py b/api/models/oauth.py index aee45d7c41..fefe743195 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -17,7 +17,6 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - datasource_name: Mapped[str] = db.Column(db.String(255), nullable=False) plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) @@ -29,7 +28,6 @@ class DatasourceProvider(Base): db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - datasource_name: Mapped[str] = db.Column(db.String(255), nullable=False) plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py new file mode 100644 index 0000000000..fbb9b25a75 --- /dev/null +++ b/api/services/datasource_provider_service.py @@ -0,0 +1,150 @@ +import logging +from typing import Optional + +from flask_login import current_user + +from constants import HIDDEN_VALUE +from core import datasource +from core.datasource.__base import datasource_provider +from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.helper import encrypter +from core.model_runtime.entities.model_entities import ModelType, ParameterRule +from core.model_runtime.entities.provider_entities import FormType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.datasource import PluginDatasourceManager +from core.provider_manager import ProviderManager +from models.oauth import DatasourceProvider +from models.provider import ProviderType +from services.entities.model_provider_entities import ( + CustomConfigurationResponse, + CustomConfigurationStatus, + DefaultModelResponse, + ModelWithProviderEntityResponse, + ProviderResponse, + ProviderWithModelsResponse, + SimpleProviderEntityResponse, + SystemConfigurationResponse, +) +from extensions.database import db + +logger = logging.getLogger(__name__) + + +class DatasourceProviderService: + """ + Model Provider Service + """ + + def __init__(self) -> None: + self.provider_manager = PluginDatasourceManager() + + def datasource_provider_credentials_validate(self, + tenant_id: str, + provider: str, + plugin_id: str, + credentials: dict) -> None: + """ + validate datasource provider credentials. + + :param tenant_id: + :param provider: + :param credentials: + """ + credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id, + user_id=current_user.id, + provider=provider, + credentials=credentials) + if credential_valid: + # Get all provider configurations of the current workspace + datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id).first() + + provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, + provider=provider + ) + if not datasource_provider: + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + credentials[key] = encrypter.encrypt_token(tenant_id, value) + datasource_provider = DatasourceProvider(tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + auth_type="api_key", + encrypted_credentials=credentials) + db.session.add(datasource_provider) + db.session.commit() + else: + original_credentials = datasource_provider.encrypted_credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + original_value = encrypter.encrypt_token(tenant_id, original_credentials[key]) + credentials[key] = encrypter.encrypt_token(tenant_id, original_value) + else: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + datasource_provider.encrypted_credentials = credentials + db.session.commit() + else: + raise CredentialsValidateFailedError() + + def extract_secret_variables(self, tenant_id: str, provider: str) -> list[str]: + """ + Extract secret input form variables. + + :param credential_form_schemas: + :return: + """ + datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id, provider=provider) + credential_form_schemas = datasource_provider.declaration.credentials_schema + secret_input_form_variables = [] + for credential_form_schema in credential_form_schemas: + if credential_form_schema.type == FormType.SECRET_INPUT: + secret_input_form_variables.append(credential_form_schema.name) + + return secret_input_form_variables + + + def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> Optional[dict]: + """ + get datasource credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param datasource_name: datasource name + :param plugin_id: plugin id + :return: + """ + # Get all provider configurations of the current workspace + datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id).first() + + + + + def remove_datasource_credentials(self, + tenant_id: str, + provider: str, + plugin_id: str) -> None: + """ + remove datasource credentials. + + :param tenant_id: workspace id + :param provider: provider name + :param datasource_name: datasource name + :param plugin_id: plugin id + :return: + """ + # Get all provider configurations of the current workspace + datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id).first() + if datasource_provider: + db.session.delete(datasource_provider) + db.session.commit() +