diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 36fefb3a5c..49729d5499 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -177,7 +177,8 @@ class DatasourceAuthUpdateDeleteApi(Resource): raise ValueError(str(ex)) return {"result": "success"}, 201 - + + class DatasourceAuthListApi(Resource): @setup_required @login_required @@ -189,6 +190,7 @@ class DatasourceAuthListApi(Resource): ) return {"result": datasources}, 200 + # Import Rag Pipeline api.add_resource( DatasourcePluginOauthApi, @@ -211,4 +213,4 @@ api.add_resource( api.add_resource( DatasourceAuthListApi, "/auth/plugin/datasource/list", -) \ No newline at end of file +) diff --git a/api/core/app/apps/pipeline/pipeline_config_manager.py b/api/core/app/apps/pipeline/pipeline_config_manager.py index a86cad78dc..72b7f4bef6 100644 --- a/api/core/app/apps/pipeline/pipeline_config_manager.py +++ b/api/core/app/apps/pipeline/pipeline_config_manager.py @@ -26,7 +26,9 @@ class PipelineConfigManager(BaseAppConfigManager): app_id=pipeline.id, app_mode=AppMode.RAG_PIPELINE, workflow_id=workflow.id, - rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow, start_node_id=start_node_id), + rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable( + workflow=workflow, start_node_id=start_node_id + ), ) return pipeline_config diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index e2341ea391..8e98c67f12 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -7,7 +7,7 @@ import threading import time import uuid from collections.abc import Generator, Mapping -from typing import Any, Literal, Optional, Union, overload +from typing import Any, Literal, Optional, Union, cast, overload from flask import Flask, current_app from pydantic import ValidationError @@ -24,6 +24,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.datasource.entities.datasource_entities import ( + DatasourceProviderType, + OnlineDriveBrowseFilesRequest, +) +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.entities.knowledge_entities import PipelineDataset, PipelineDocument from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.index_processor.constant.built_in_field import BuiltInField @@ -39,6 +44,7 @@ from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from services.dataset_service import DocumentService +from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) @@ -105,13 +111,13 @@ class PipelineGenerator(BaseAppGenerator): inputs: Mapping[str, Any] = args["inputs"] start_node_id: str = args["start_node_id"] datasource_type: str = args["datasource_type"] - datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] + datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list( + datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user + ) batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000) # convert to app config pipeline_config = PipelineConfigManager.get_pipeline_config( - pipeline=pipeline, - workflow=workflow, - start_node_id=start_node_id + pipeline=pipeline, workflow=workflow, start_node_id=start_node_id ) documents = [] if invoke_from == InvokeFrom.PUBLISHED: @@ -353,9 +359,9 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("inputs is required") # convert to app config - pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, - workflow=workflow, - start_node_id=args.get("start_node_id","shared")) + pipeline_config = PipelineConfigManager.get_pipeline_config( + pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared") + ) dataset = pipeline.dataset if not dataset: @@ -440,9 +446,9 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("Pipeline dataset is required") # convert to app config - pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, - workflow=workflow, - start_node_id=args.get("start_node_id","shared")) + pipeline_config = PipelineConfigManager.get_pipeline_config( + pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared") + ) # init application generate entity application_generate_entity = RagPipelineGenerateEntity( @@ -633,3 +639,107 @@ class PipelineGenerator(BaseAppGenerator): if doc_metadata: document.doc_metadata = doc_metadata return document + + def _format_datasource_info_list( + self, + datasource_type: str, + datasource_info_list: list[Mapping[str, Any]], + pipeline: Pipeline, + workflow: Workflow, + start_node_id: str, + user: Union[Account, EndUser], + ) -> list[Mapping[str, Any]]: + """ + Format datasource info list. + """ + if datasource_type == "online_drive": + all_files = [] + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == start_node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") + + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) + + for datasource_info in datasource_info_list: + if datasource_info.get("key") and datasource_info.get("key", "").endswith("/"): + # get all files in the folder + self._get_files_in_folder( + datasource_runtime, + datasource_info.get("key", ""), + None, + datasource_info.get("bucket", None), + user.id, + all_files, + datasource_info, + ) + return all_files + else: + return datasource_info_list + + def _get_files_in_folder( + self, + datasource_runtime: OnlineDriveDatasourcePlugin, + prefix: str, + start_after: Optional[str], + bucket: Optional[str], + user_id: str, + all_files: list, + datasource_info: Mapping[str, Any], + ): + """ + Get files in a folder. + """ + result_generator = datasource_runtime.online_drive_browse_files( + user_id=user_id, + request=OnlineDriveBrowseFilesRequest( + bucket=bucket, + prefix=prefix, + max_keys=20, + start_after=start_after, + ), + provider_type=datasource_runtime.datasource_provider_type(), + ) + is_truncated = False + last_file_key = None + for result in result_generator: + for files in result.result: + for file in files.files: + if file.key.endswith("/"): + self._get_files_in_folder( + datasource_runtime, file.key, None, bucket, user_id, all_files, datasource_info + ) + else: + all_files.append( + { + "key": file.key, + "bucket": bucket, + } + ) + last_file_key = file.key + is_truncated = files.is_truncated + + if is_truncated: + self._get_files_in_folder( + datasource_runtime, prefix, last_file_key, bucket, user_id, all_files, datasource_info + ) diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 858bb79a9b..af50a58212 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -1,10 +1,10 @@ import base64 -from datetime import datetime import hashlib import hmac import logging import os import time +from datetime import datetime from mimetypes import guess_extension, guess_type from typing import Optional, Union from uuid import uuid4 diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 12d2a71d69..6c93865264 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -63,7 +63,7 @@ class DatasourceFileMessageTransformer: mimetype = meta.get("mime_type") if not mimetype: mimetype = guess_type(filename)[0] or "application/octet-stream" - + # if message is str, encode it to bytes if not isinstance(message.message, DatasourceMessage.BlobMessage): diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index ee58840c84..c2bc1ffbe3 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -72,9 +72,11 @@ def to_prompt_message_content( def download(f: File, /): - if f.transfer_method in (FileTransferMethod.TOOL_FILE, - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.DATASOURCE_FILE): + if f.transfer_method in ( + FileTransferMethod.TOOL_FILE, + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.DATASOURCE_FILE, + ): return _download_file_content(f._storage_key) elif f.transfer_method == FileTransferMethod.REMOTE_URL: response = ssrf_proxy.get(f.remote_url, follow_redirects=True) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 3549735483..e1c14df4e8 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -56,7 +56,7 @@ class PluginDatasourceManager(BasePluginClient): tool.identity.provider = provider.declaration.identity.name return all_response - + def fetch_installed_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: """ Fetch datasource providers for the given tenant. diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 7f44c6a211..e382ff6b54 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -69,7 +69,6 @@ class QAChunk(BaseModel): answer: str - class QAStructureChunk(BaseModel): """ QAStructureChunk. diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index cbdc188caf..4b36b84c6a 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -131,7 +131,6 @@ class DatasourceProviderService: ) return copy_credentials_list - def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]: """ @@ -144,19 +143,21 @@ class DatasourceProviderService: datasources = manager.fetch_installed_datasource_providers(tenant_id) datasource_credentials = [] for datasource in datasources: - credentials = self.get_datasource_credentials(tenant_id=tenant_id, - provider=datasource.provider, - plugin_id=datasource.plugin_id) - datasource_credentials.append({ - "provider": datasource.provider, - "plugin_id": datasource.plugin_id, - "plugin_unique_identifier": datasource.plugin_unique_identifier, - "icon": datasource.declaration.identity.icon, - "name": datasource.declaration.identity.name, - "description": datasource.declaration.identity.description.model_dump(), - "author": datasource.declaration.identity.author, - "credentials": credentials, - }) + credentials = self.get_datasource_credentials( + tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id + ) + datasource_credentials.append( + { + "provider": datasource.provider, + "plugin_id": datasource.plugin_id, + "plugin_unique_identifier": datasource.plugin_unique_identifier, + "icon": datasource.declaration.identity.icon, + "name": datasource.declaration.identity.name, + "description": datasource.declaration.identity.description.model_dump(), + "author": datasource.declaration.identity.author, + "credentials": credentials, + } + ) return datasource_credentials def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: