mirror of
https://github.com/langgenius/dify.git
synced 2025-11-01 11:23:29 +00:00
r2
This commit is contained in:
parent
4a8061d14c
commit
b5e4ce6c68
@ -177,7 +177,8 @@ class DatasourceAuthUpdateDeleteApi(Resource):
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
|
||||
class DatasourceAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -189,6 +190,7 @@ class DatasourceAuthListApi(Resource):
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
DatasourcePluginOauthApi,
|
||||
@ -211,4 +213,4 @@ api.add_resource(
|
||||
api.add_resource(
|
||||
DatasourceAuthListApi,
|
||||
"/auth/plugin/datasource/list",
|
||||
)
|
||||
)
|
||||
|
||||
@ -26,7 +26,9 @@ class PipelineConfigManager(BaseAppConfigManager):
|
||||
app_id=pipeline.id,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
workflow_id=workflow.id,
|
||||
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow, start_node_id=start_node_id),
|
||||
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
|
||||
workflow=workflow, start_node_id=start_node_id
|
||||
),
|
||||
)
|
||||
|
||||
return pipeline_config
|
||||
|
||||
@ -7,7 +7,7 @@ import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -24,6 +24,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
)
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
@ -39,6 +44,7 @@ from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -105,13 +111,13 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
start_node_id: str = args["start_node_id"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
||||
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
|
||||
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
|
||||
)
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
start_node_id=start_node_id
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
|
||||
)
|
||||
documents = []
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
@ -353,9 +359,9 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
start_node_id=args.get("start_node_id","shared"))
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
|
||||
)
|
||||
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
@ -440,9 +446,9 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
start_node_id=args.get("start_node_id","shared"))
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
@ -633,3 +639,107 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
if doc_metadata:
|
||||
document.doc_metadata = doc_metadata
|
||||
return document
|
||||
|
||||
def _format_datasource_info_list(
|
||||
self,
|
||||
datasource_type: str,
|
||||
datasource_info_list: list[Mapping[str, Any]],
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
start_node_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
) -> list[Mapping[str, Any]]:
|
||||
"""
|
||||
Format datasource info list.
|
||||
"""
|
||||
if datasource_type == "online_drive":
|
||||
all_files = []
|
||||
datasource_node_data = None
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
for datasource_node in datasource_nodes:
|
||||
if datasource_node.get("id") == start_node_id:
|
||||
datasource_node_data = datasource_node.get("data", {})
|
||||
break
|
||||
if not datasource_node_data:
|
||||
raise ValueError("Datasource node data not found")
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
|
||||
datasource_name=datasource_node_data.get("datasource_name"),
|
||||
tenant_id=pipeline.tenant_id,
|
||||
datasource_type=DatasourceProviderType(datasource_type),
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_real_datasource_credentials(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
provider=datasource_node_data.get("provider_name"),
|
||||
plugin_id=datasource_node_data.get("plugin_id"),
|
||||
)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
||||
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||
|
||||
for datasource_info in datasource_info_list:
|
||||
if datasource_info.get("key") and datasource_info.get("key", "").endswith("/"):
|
||||
# get all files in the folder
|
||||
self._get_files_in_folder(
|
||||
datasource_runtime,
|
||||
datasource_info.get("key", ""),
|
||||
None,
|
||||
datasource_info.get("bucket", None),
|
||||
user.id,
|
||||
all_files,
|
||||
datasource_info,
|
||||
)
|
||||
return all_files
|
||||
else:
|
||||
return datasource_info_list
|
||||
|
||||
def _get_files_in_folder(
|
||||
self,
|
||||
datasource_runtime: OnlineDriveDatasourcePlugin,
|
||||
prefix: str,
|
||||
start_after: Optional[str],
|
||||
bucket: Optional[str],
|
||||
user_id: str,
|
||||
all_files: list,
|
||||
datasource_info: Mapping[str, Any],
|
||||
):
|
||||
"""
|
||||
Get files in a folder.
|
||||
"""
|
||||
result_generator = datasource_runtime.online_drive_browse_files(
|
||||
user_id=user_id,
|
||||
request=OnlineDriveBrowseFilesRequest(
|
||||
bucket=bucket,
|
||||
prefix=prefix,
|
||||
max_keys=20,
|
||||
start_after=start_after,
|
||||
),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
is_truncated = False
|
||||
last_file_key = None
|
||||
for result in result_generator:
|
||||
for files in result.result:
|
||||
for file in files.files:
|
||||
if file.key.endswith("/"):
|
||||
self._get_files_in_folder(
|
||||
datasource_runtime, file.key, None, bucket, user_id, all_files, datasource_info
|
||||
)
|
||||
else:
|
||||
all_files.append(
|
||||
{
|
||||
"key": file.key,
|
||||
"bucket": bucket,
|
||||
}
|
||||
)
|
||||
last_file_key = file.key
|
||||
is_truncated = files.is_truncated
|
||||
|
||||
if is_truncated:
|
||||
self._get_files_in_folder(
|
||||
datasource_runtime, prefix, last_file_key, bucket, user_id, all_files, datasource_info
|
||||
)
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import base64
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
@ -63,7 +63,7 @@ class DatasourceFileMessageTransformer:
|
||||
mimetype = meta.get("mime_type")
|
||||
if not mimetype:
|
||||
mimetype = guess_type(filename)[0] or "application/octet-stream"
|
||||
|
||||
|
||||
# if message is str, encode it to bytes
|
||||
|
||||
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
||||
|
||||
@ -72,9 +72,11 @@ def to_prompt_message_content(
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method in (FileTransferMethod.TOOL_FILE,
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
FileTransferMethod.DATASOURCE_FILE):
|
||||
if f.transfer_method in (
|
||||
FileTransferMethod.TOOL_FILE,
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
FileTransferMethod.DATASOURCE_FILE,
|
||||
):
|
||||
return _download_file_content(f._storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
|
||||
@ -56,7 +56,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return all_response
|
||||
|
||||
|
||||
def fetch_installed_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
Fetch datasource providers for the given tenant.
|
||||
|
||||
@ -69,7 +69,6 @@ class QAChunk(BaseModel):
|
||||
answer: str
|
||||
|
||||
|
||||
|
||||
class QAStructureChunk(BaseModel):
|
||||
"""
|
||||
QAStructureChunk.
|
||||
|
||||
@ -131,7 +131,6 @@ class DatasourceProviderService:
|
||||
)
|
||||
|
||||
return copy_credentials_list
|
||||
|
||||
|
||||
def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
|
||||
"""
|
||||
@ -144,19 +143,21 @@ class DatasourceProviderService:
|
||||
datasources = manager.fetch_installed_datasource_providers(tenant_id)
|
||||
datasource_credentials = []
|
||||
for datasource in datasources:
|
||||
credentials = self.get_datasource_credentials(tenant_id=tenant_id,
|
||||
provider=datasource.provider,
|
||||
plugin_id=datasource.plugin_id)
|
||||
datasource_credentials.append({
|
||||
"provider": datasource.provider,
|
||||
"plugin_id": datasource.plugin_id,
|
||||
"plugin_unique_identifier": datasource.plugin_unique_identifier,
|
||||
"icon": datasource.declaration.identity.icon,
|
||||
"name": datasource.declaration.identity.name,
|
||||
"description": datasource.declaration.identity.description.model_dump(),
|
||||
"author": datasource.declaration.identity.author,
|
||||
"credentials": credentials,
|
||||
})
|
||||
credentials = self.get_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
datasource_credentials.append(
|
||||
{
|
||||
"provider": datasource.provider,
|
||||
"plugin_id": datasource.plugin_id,
|
||||
"plugin_unique_identifier": datasource.plugin_unique_identifier,
|
||||
"icon": datasource.declaration.identity.icon,
|
||||
"name": datasource.declaration.identity.name,
|
||||
"description": datasource.declaration.identity.description.model_dump(),
|
||||
"author": datasource.declaration.identity.author,
|
||||
"credentials": credentials,
|
||||
}
|
||||
)
|
||||
return datasource_credentials
|
||||
|
||||
def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user