This commit is contained in:
jyong 2025-07-09 14:27:49 +08:00
parent 4a8061d14c
commit b5e4ce6c68
9 changed files with 151 additions and 35 deletions

View File

@ -178,6 +178,7 @@ class DatasourceAuthUpdateDeleteApi(Resource):
return {"result": "success"}, 201 return {"result": "success"}, 201
class DatasourceAuthListApi(Resource): class DatasourceAuthListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -189,6 +190,7 @@ class DatasourceAuthListApi(Resource):
) )
return {"result": datasources}, 200 return {"result": datasources}, 200
# Import Rag Pipeline # Import Rag Pipeline
api.add_resource( api.add_resource(
DatasourcePluginOauthApi, DatasourcePluginOauthApi,

View File

@ -26,7 +26,9 @@ class PipelineConfigManager(BaseAppConfigManager):
app_id=pipeline.id, app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE, app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id, workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow, start_node_id=start_node_id), rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
workflow=workflow, start_node_id=start_node_id
),
) )
return pipeline_config return pipeline_config

View File

@ -7,7 +7,7 @@ import threading
import time import time
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload from typing import Any, Literal, Optional, Union, cast, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
@ -24,6 +24,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
)
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.built_in_field import BuiltInField
@ -39,6 +44,7 @@ from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode from models.model import AppMode
from services.dataset_service import DocumentService from services.dataset_service import DocumentService
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -105,13 +111,13 @@ class PipelineGenerator(BaseAppGenerator):
inputs: Mapping[str, Any] = args["inputs"] inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"] start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"] datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000) batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
# convert to app config # convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config( pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
workflow=workflow,
start_node_id=start_node_id
) )
documents = [] documents = []
if invoke_from == InvokeFrom.PUBLISHED: if invoke_from == InvokeFrom.PUBLISHED:
@ -353,9 +359,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required") raise ValueError("inputs is required")
# convert to app config # convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, pipeline_config = PipelineConfigManager.get_pipeline_config(
workflow=workflow, pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
start_node_id=args.get("start_node_id","shared")) )
dataset = pipeline.dataset dataset = pipeline.dataset
if not dataset: if not dataset:
@ -440,9 +446,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("Pipeline dataset is required") raise ValueError("Pipeline dataset is required")
# convert to app config # convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, pipeline_config = PipelineConfigManager.get_pipeline_config(
workflow=workflow, pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
start_node_id=args.get("start_node_id","shared")) )
# init application generate entity # init application generate entity
application_generate_entity = RagPipelineGenerateEntity( application_generate_entity = RagPipelineGenerateEntity(
@ -633,3 +639,107 @@ class PipelineGenerator(BaseAppGenerator):
if doc_metadata: if doc_metadata:
document.doc_metadata = doc_metadata document.doc_metadata = doc_metadata
return document return document
def _format_datasource_info_list(
self,
datasource_type: str,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Union[Account, EndUser],
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.
"""
if datasource_type == "online_drive":
all_files = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
for datasource_info in datasource_info_list:
if datasource_info.get("key") and datasource_info.get("key", "").endswith("/"):
# get all files in the folder
self._get_files_in_folder(
datasource_runtime,
datasource_info.get("key", ""),
None,
datasource_info.get("bucket", None),
user.id,
all_files,
datasource_info,
)
return all_files
else:
return datasource_info_list
def _get_files_in_folder(
self,
datasource_runtime: OnlineDriveDatasourcePlugin,
prefix: str,
start_after: Optional[str],
bucket: Optional[str],
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
):
"""
Get files in a folder.
"""
result_generator = datasource_runtime.online_drive_browse_files(
user_id=user_id,
request=OnlineDriveBrowseFilesRequest(
bucket=bucket,
prefix=prefix,
max_keys=20,
start_after=start_after,
),
provider_type=datasource_runtime.datasource_provider_type(),
)
is_truncated = False
last_file_key = None
for result in result_generator:
for files in result.result:
for file in files.files:
if file.key.endswith("/"):
self._get_files_in_folder(
datasource_runtime, file.key, None, bucket, user_id, all_files, datasource_info
)
else:
all_files.append(
{
"key": file.key,
"bucket": bucket,
}
)
last_file_key = file.key
is_truncated = files.is_truncated
if is_truncated:
self._get_files_in_folder(
datasource_runtime, prefix, last_file_key, bucket, user_id, all_files, datasource_info
)

View File

@ -1,10 +1,10 @@
import base64 import base64
from datetime import datetime
import hashlib import hashlib
import hmac import hmac
import logging import logging
import os import os
import time import time
from datetime import datetime
from mimetypes import guess_extension, guess_type from mimetypes import guess_extension, guess_type
from typing import Optional, Union from typing import Optional, Union
from uuid import uuid4 from uuid import uuid4

View File

@ -72,9 +72,11 @@ def to_prompt_message_content(
def download(f: File, /): def download(f: File, /):
if f.transfer_method in (FileTransferMethod.TOOL_FILE, if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
FileTransferMethod.LOCAL_FILE, FileTransferMethod.LOCAL_FILE,
FileTransferMethod.DATASOURCE_FILE): FileTransferMethod.DATASOURCE_FILE,
):
return _download_file_content(f._storage_key) return _download_file_content(f._storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL: elif f.transfer_method == FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response = ssrf_proxy.get(f.remote_url, follow_redirects=True)

View File

@ -69,7 +69,6 @@ class QAChunk(BaseModel):
answer: str answer: str
class QAStructureChunk(BaseModel): class QAStructureChunk(BaseModel):
""" """
QAStructureChunk. QAStructureChunk.

View File

@ -132,7 +132,6 @@ class DatasourceProviderService:
return copy_credentials_list return copy_credentials_list
def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]: def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
""" """
get datasource credentials. get datasource credentials.
@ -144,10 +143,11 @@ class DatasourceProviderService:
datasources = manager.fetch_installed_datasource_providers(tenant_id) datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = [] datasource_credentials = []
for datasource in datasources: for datasource in datasources:
credentials = self.get_datasource_credentials(tenant_id=tenant_id, credentials = self.get_datasource_credentials(
provider=datasource.provider, tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
plugin_id=datasource.plugin_id) )
datasource_credentials.append({ datasource_credentials.append(
{
"provider": datasource.provider, "provider": datasource.provider,
"plugin_id": datasource.plugin_id, "plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier, "plugin_unique_identifier": datasource.plugin_unique_identifier,
@ -156,7 +156,8 @@ class DatasourceProviderService:
"description": datasource.declaration.identity.description.model_dump(), "description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author, "author": datasource.declaration.identity.author,
"credentials": credentials, "credentials": credentials,
}) }
)
return datasource_credentials return datasource_credentials
def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: