This commit is contained in:
jyong 2025-07-09 14:27:49 +08:00
parent 4a8061d14c
commit b5e4ce6c68
9 changed files with 151 additions and 35 deletions

View File

@ -177,7 +177,8 @@ class DatasourceAuthUpdateDeleteApi(Resource):
raise ValueError(str(ex))
return {"result": "success"}, 201
class DatasourceAuthListApi(Resource):
@setup_required
@login_required
@ -189,6 +190,7 @@ class DatasourceAuthListApi(Resource):
)
return {"result": datasources}, 200
# Import Rag Pipeline
api.add_resource(
DatasourcePluginOauthApi,
@ -211,4 +213,4 @@ api.add_resource(
api.add_resource(
DatasourceAuthListApi,
"/auth/plugin/datasource/list",
)
)

View File

@ -26,7 +26,9 @@ class PipelineConfigManager(BaseAppConfigManager):
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow, start_node_id=start_node_id),
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
workflow=workflow, start_node_id=start_node_id
),
)
return pipeline_config

View File

@ -7,7 +7,7 @@ import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload
from typing import Any, Literal, Optional, Union, cast, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -24,6 +24,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
)
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
@ -39,6 +44,7 @@ from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.dataset_service import DocumentService
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
@ -105,13 +111,13 @@ class PipelineGenerator(BaseAppGenerator):
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
workflow=workflow,
start_node_id=start_node_id
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents = []
if invoke_from == InvokeFrom.PUBLISHED:
@ -353,9 +359,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
workflow=workflow,
start_node_id=args.get("start_node_id","shared"))
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
dataset = pipeline.dataset
if not dataset:
@ -440,9 +446,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("Pipeline dataset is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
workflow=workflow,
start_node_id=args.get("start_node_id","shared"))
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
@ -633,3 +639,107 @@ class PipelineGenerator(BaseAppGenerator):
if doc_metadata:
document.doc_metadata = doc_metadata
return document
def _format_datasource_info_list(
self,
datasource_type: str,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Union[Account, EndUser],
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.
"""
if datasource_type == "online_drive":
all_files = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
for datasource_info in datasource_info_list:
if datasource_info.get("key") and datasource_info.get("key", "").endswith("/"):
# get all files in the folder
self._get_files_in_folder(
datasource_runtime,
datasource_info.get("key", ""),
None,
datasource_info.get("bucket", None),
user.id,
all_files,
datasource_info,
)
return all_files
else:
return datasource_info_list
def _get_files_in_folder(
self,
datasource_runtime: OnlineDriveDatasourcePlugin,
prefix: str,
start_after: Optional[str],
bucket: Optional[str],
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
):
"""
Get files in a folder.
"""
result_generator = datasource_runtime.online_drive_browse_files(
user_id=user_id,
request=OnlineDriveBrowseFilesRequest(
bucket=bucket,
prefix=prefix,
max_keys=20,
start_after=start_after,
),
provider_type=datasource_runtime.datasource_provider_type(),
)
is_truncated = False
last_file_key = None
for result in result_generator:
for files in result.result:
for file in files.files:
if file.key.endswith("/"):
self._get_files_in_folder(
datasource_runtime, file.key, None, bucket, user_id, all_files, datasource_info
)
else:
all_files.append(
{
"key": file.key,
"bucket": bucket,
}
)
last_file_key = file.key
is_truncated = files.is_truncated
if is_truncated:
self._get_files_in_folder(
datasource_runtime, prefix, last_file_key, bucket, user_id, all_files, datasource_info
)

View File

@ -1,10 +1,10 @@
import base64
from datetime import datetime
import hashlib
import hmac
import logging
import os
import time
from datetime import datetime
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4

View File

@ -63,7 +63,7 @@ class DatasourceFileMessageTransformer:
mimetype = meta.get("mime_type")
if not mimetype:
mimetype = guess_type(filename)[0] or "application/octet-stream"
# if message is str, encode it to bytes
if not isinstance(message.message, DatasourceMessage.BlobMessage):

View File

@ -72,9 +72,11 @@ def to_prompt_message_content(
def download(f: File, /):
if f.transfer_method in (FileTransferMethod.TOOL_FILE,
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.DATASOURCE_FILE):
if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.DATASOURCE_FILE,
):
return _download_file_content(f._storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)

View File

@ -56,7 +56,7 @@ class PluginDatasourceManager(BasePluginClient):
tool.identity.provider = provider.declaration.identity.name
return all_response
def fetch_installed_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
"""
Fetch datasource providers for the given tenant.

View File

@ -69,7 +69,6 @@ class QAChunk(BaseModel):
answer: str
class QAStructureChunk(BaseModel):
"""
QAStructureChunk.

View File

@ -131,7 +131,6 @@ class DatasourceProviderService:
)
return copy_credentials_list
def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
"""
@ -144,19 +143,21 @@ class DatasourceProviderService:
datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = []
for datasource in datasources:
credentials = self.get_datasource_credentials(tenant_id=tenant_id,
provider=datasource.provider,
plugin_id=datasource.plugin_id)
datasource_credentials.append({
"provider": datasource.provider,
"plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier,
"icon": datasource.declaration.identity.icon,
"name": datasource.declaration.identity.name,
"description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author,
"credentials": credentials,
})
credentials = self.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
datasource_credentials.append(
{
"provider": datasource.provider,
"plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier,
"icon": datasource.declaration.identity.icon,
"name": datasource.declaration.identity.name,
"description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author,
"credentials": credentials,
}
)
return datasource_credentials
def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: