This commit is contained in:
jyong 2025-06-03 19:02:57 +08:00
parent 309fffd1e4
commit 9cdd2cbb27
35 changed files with 229 additions and 300 deletions

View File

@ -1,4 +1,3 @@
import os
import sys import sys
@ -18,7 +17,7 @@ else:
# so we need to disable gevent in debug mode. # so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: # if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
#from gevent import monkey # from gevent import monkey
# #
# # gevent # # gevent
# monkey.patch_all() # monkey.patch_all()

View File

@ -109,8 +109,6 @@ class OAuthDataSourceSync(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>") api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>") api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>") api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")

View File

@ -4,24 +4,19 @@ from typing import Optional
import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_login import current_user
from flask_restful import Resource from flask_restful import Resource
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from werkzeug.exceptions import Unauthorized
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from controllers.console.wraps import account_initialization_required, setup_required
from core.plugin.impl.oauth import OAuthHandler
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account from models import Account
from models.account import AccountStatus from models.account import AccountStatus
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
@ -186,6 +181,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
return account return account
api.add_resource(OAuthLogin, "/oauth/login/<provider>") api.add_resource(OAuthLogin, "/oauth/login/<provider>")
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>") api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")

View File

@ -1,12 +1,9 @@
from flask import redirect, request from flask import redirect, request
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import ( # type: ignore from flask_restful import ( # type: ignore
Resource, # type: ignore Resource, # type: ignore
marshal_with,
reqparse, reqparse,
) )
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
@ -16,7 +13,6 @@ from controllers.console.wraps import (
setup_required, setup_required,
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.login import login_required
@ -33,10 +29,9 @@ class DatasourcePluginOauthApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
# get all plugin oauth configs # get all plugin oauth configs
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( plugin_oauth_config = (
provider=provider, db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
plugin_id=plugin_id )
).first()
if not plugin_oauth_config: if not plugin_oauth_config:
raise NotFound() raise NotFound()
oauth_handler = OAuthHandler() oauth_handler = OAuthHandler()
@ -45,24 +40,20 @@ class DatasourcePluginOauthApi(Resource):
if system_credentials: if system_credentials:
system_credentials["redirect_url"] = redirect_url system_credentials["redirect_url"] = redirect_url
response = oauth_handler.get_authorization_url( response = oauth_handler.get_authorization_url(
current_user.current_tenant.id, current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
current_user.id,
plugin_id,
provider,
system_credentials=system_credentials
) )
return response.model_dump() return response.model_dump()
class DatasourceOauthCallback(Resource): class DatasourceOauthCallback(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider, plugin_id): def get(self, provider, plugin_id):
oauth_handler = OAuthHandler() oauth_handler = OAuthHandler()
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( plugin_oauth_config = (
provider=provider, db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
plugin_id=plugin_id )
).first()
if not plugin_oauth_config: if not plugin_oauth_config:
raise NotFound() raise NotFound()
credentials = oauth_handler.get_credentials( credentials = oauth_handler.get_credentials(
@ -71,18 +62,16 @@ class DatasourceOauthCallback(Resource):
plugin_id, plugin_id,
provider, provider,
system_credentials=plugin_oauth_config.system_credentials, system_credentials=plugin_oauth_config.system_credentials,
request=request request=request,
) )
datasource_provider = DatasourceProvider( datasource_provider = DatasourceProvider(
plugin_id=plugin_id, plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
provider=provider,
auth_type="oauth",
encrypted_credentials=credentials
) )
db.session.add(datasource_provider) db.session.add(datasource_provider)
db.session.commit() db.session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}") return redirect(f"{dify_config.CONSOLE_WEB_URL}")
class DatasourceAuth(Resource): class DatasourceAuth(Resource):
@setup_required @setup_required
@login_required @login_required
@ -102,7 +91,7 @@ class DatasourceAuth(Resource):
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=provider, provider=provider,
plugin_id=plugin_id, plugin_id=plugin_id,
credentials=args["credentials"] credentials=args["credentials"],
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ValueError(str(ex)) raise ValueError(str(ex))
@ -115,12 +104,11 @@ class DatasourceAuth(Resource):
def get(self, provider, plugin_id): def get(self, provider, plugin_id):
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials( datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
provider=provider,
plugin_id=plugin_id
) )
return {"result": datasources}, 200 return {"result": datasources}, 200
class DatasourceAuthDeleteApi(Resource): class DatasourceAuthDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -130,12 +118,11 @@ class DatasourceAuthDeleteApi(Resource):
raise Forbidden() raise Forbidden()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials( datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
provider=provider,
plugin_id=plugin_id
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
# Import Rag Pipeline # Import Rag Pipeline
api.add_resource( api.add_resource(
DatasourcePluginOauthApi, DatasourcePluginOauthApi,
@ -149,4 +136,3 @@ api.add_resource(
DatasourceAuth, DatasourceAuth,
"/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>", "/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>",
) )

View File

@ -110,6 +110,7 @@ class CustomizedPipelineTemplateApi(Resource):
dsl = yaml.safe_load(template.yaml_content) dsl = yaml.safe_load(template.yaml_content)
return {"data": dsl}, 200 return {"data": dsl}, 200
class CustomizedPipelineTemplateApi(Resource): class CustomizedPipelineTemplateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -142,6 +143,7 @@ class CustomizedPipelineTemplateApi(Resource):
RagPipelineService.publish_customized_pipeline_template(pipeline_id, args) RagPipelineService.publish_customized_pipeline_template(pipeline_id, args)
return 200 return 200
api.add_resource( api.add_resource(
PipelineTemplateListApi, PipelineTemplateListApi,
"/rag/pipeline/templates", "/rag/pipeline/templates",

View File

@ -540,7 +540,6 @@ class RagPipelineConfigApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, pipeline_id): def get(self, pipeline_id):
return { return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
} }

View File

@ -32,7 +32,6 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import dataset_and_document_fields
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, Pipeline from models.dataset import Document, Pipeline
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
@ -55,8 +54,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: Literal[True], streaming: Literal[True],
call_depth: int, call_depth: int,
workflow_thread_pool_id: Optional[str], workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
...
@overload @overload
def generate( def generate(
@ -70,8 +68,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: Literal[False], streaming: Literal[False],
call_depth: int, call_depth: int,
workflow_thread_pool_id: Optional[str], workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any]: ) -> Mapping[str, Any]: ...
...
@overload @overload
def generate( def generate(
@ -85,8 +82,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool, streaming: bool,
call_depth: int, call_depth: int,
workflow_thread_pool_id: Optional[str], workflow_thread_pool_id: Optional[str],
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
...
def generate( def generate(
self, self,
@ -233,7 +229,8 @@ class PipelineGenerator(BaseAppGenerator):
description=dataset.description, description=dataset.description,
chunk_structure=dataset.chunk_structure, chunk_structure=dataset.chunk_structure,
).model_dump(), ).model_dump(),
"documents": [PipelineDocument( "documents": [
PipelineDocument(
id=document.id, id=document.id,
position=document.position, position=document.position,
data_source_type=document.data_source_type, data_source_type=document.data_source_type,
@ -242,8 +239,9 @@ class PipelineGenerator(BaseAppGenerator):
indexing_status=document.indexing_status, indexing_status=document.indexing_status,
error=document.error, error=document.error,
enabled=document.enabled, enabled=document.enabled,
).model_dump() for document in documents ).model_dump()
] for document in documents
],
} }
def _generate( def _generate(
@ -316,9 +314,7 @@ class PipelineGenerator(BaseAppGenerator):
) )
# new thread # new thread
worker_thread = threading.Thread( worker_thread = threading.Thread(target=worker_with_context)
target=worker_with_context
)
worker_thread.start() worker_thread.start()

View File

@ -111,7 +111,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
if workflow.rag_pipeline_variables: if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v) rag_pipeline_variable = RAGPipelineVariable(**v)
if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs: if (
rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id
and rag_pipeline_variable.variable in inputs
):
rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable] rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable]
variable_pool = VariablePool( variable_pool = VariablePool(
@ -195,7 +198,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
continue continue
real_run_nodes.append(node) real_run_nodes.append(node)
for edge in edges: for edge in edges:
if edge.get("source") in exclude_node_ids : if edge.get("source") in exclude_node_ids:
continue continue
real_edges.append(edge) real_edges.append(edge)
graph_config = dict(graph_config) graph_config = dict(graph_config)

View File

@ -232,6 +232,7 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
""" """
RAG Pipeline Application Generate Entity. RAG Pipeline Application Generate Entity.
""" """
# pipeline config # pipeline config
pipeline_config: WorkflowUIBasedAppConfig pipeline_config: WorkflowUIBasedAppConfig
datasource_type: str datasource_type: str

View File

@ -5,7 +5,6 @@ from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom
class DatasourceRuntime(BaseModel): class DatasourceRuntime(BaseModel):

View File

@ -46,7 +46,7 @@ class DatasourceManager:
if not provider_entity: if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
match (datasource_type): match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT: case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController( controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration, entity=provider_entity.declaration,
@ -98,5 +98,3 @@ class DatasourceManager:
tenant_id, tenant_id,
datasource_type, datasource_type,
).get_datasource(datasource_name) ).get_datasource(datasource_name)

View File

@ -215,7 +215,6 @@ class PluginDatasourceManager(BasePluginClient):
"X-Plugin-ID": datasource_provider_id.plugin_id, "X-Plugin-ID": datasource_provider_id.plugin_id,
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
for resp in response: for resp in response:
@ -233,41 +232,23 @@ class PluginDatasourceManager(BasePluginClient):
"identity": { "identity": {
"author": "langgenius", "author": "langgenius",
"name": "langgenius/file/file", "name": "langgenius/file/file",
"label": { "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
},
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg",
"description": { "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
}
}, },
"credentials_schema": [], "credentials_schema": [],
"provider_type": "local_file", "provider_type": "local_file",
"datasources": [{ "datasources": [
{
"identity": { "identity": {
"author": "langgenius", "author": "langgenius",
"name": "upload-file", "name": "upload-file",
"provider": "langgenius", "provider": "langgenius",
"label": { "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
}
}, },
"parameters": [], "parameters": [],
"description": { "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
}
}]
} }
],
},
} }

View File

@ -28,12 +28,12 @@ class Jieba(BaseKeyword):
with redis_client.lock(lock_name, timeout=600): with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler() keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
for text in texts: for text in texts:
keywords = keyword_table_handler.extract_keywords( keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
text.page_content, keyword_number
)
if text.metadata is not None: if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table( keyword_table = self._add_text_to_keyword_table(
@ -51,19 +51,17 @@ class Jieba(BaseKeyword):
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get("keywords_list") keywords_list = kwargs.get("keywords_list")
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
for i in range(len(texts)): for i in range(len(texts)):
text = texts[i] text = texts[i]
if keywords_list: if keywords_list:
keywords = keywords_list[i] keywords = keywords_list[i]
if not keywords: if not keywords:
keywords = keyword_table_handler.extract_keywords( keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
text.page_content, keyword_number
)
else: else:
keywords = keyword_table_handler.extract_keywords( keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
text.page_content, keyword_number
)
if text.metadata is not None: if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table( keyword_table = self._add_text_to_keyword_table(
@ -242,7 +240,9 @@ class Jieba(BaseKeyword):
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
) )
else: else:
keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk keyword_number = (
self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk
)
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number) keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
segment.keywords = list(keywords) segment.keywords = list(keywords)

View File

@ -15,7 +15,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, GeneralStructureChunk from core.rag.models.document import Document, GeneralStructureChunk
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule from services.entities.knowledge_entities.knowledge_entities import Rule
@ -152,13 +153,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword = Keyword(dataset) keyword = Keyword(dataset)
keyword.add_texts(documents) keyword.add_texts(documents)
def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
paragraph = GeneralStructureChunk(**chunks) paragraph = GeneralStructureChunk(**chunks)
preview = [] preview = []
for content in paragraph.general_chunks: for content in paragraph.general_chunks:
preview.append({"content": content}) preview.append({"content": content})
return { return {"preview": preview, "total_segments": len(paragraph.general_chunks)}
"preview": preview,
"total_segments": len(paragraph.general_chunks)
}

View File

@ -16,7 +16,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from models.dataset import ChildChunk, Dataset, Document as DatasetDocument, DocumentSegment from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@ -239,14 +240,5 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
parent_childs = ParentChildStructureChunk(**chunks) parent_childs = ParentChildStructureChunk(**chunks)
preview = [] preview = []
for parent_child in parent_childs.parent_child_chunks: for parent_child in parent_childs.parent_child_chunks:
preview.append( preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
{ return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)}
"content": parent_child.parent_content,
"child_chunks": parent_child.child_contents
}
)
return {
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks)
}

View File

@ -4,7 +4,8 @@ import logging
import re import re
import threading import threading
import uuid import uuid
from typing import Any, Mapping, Optional from collections.abc import Mapping
from typing import Any, Optional
import pandas as pd import pandas as pd
from flask import Flask, current_app from flask import Flask, current_app
@ -20,7 +21,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset, Document as DatasetDocument from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import Rule from services.entities.knowledge_entities.knowledge_entities import Rule

View File

@ -94,19 +94,26 @@ class FileVariable(FileSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass pass
class RAGPipelineVariable(BaseModel): class RAGPipelineVariable(BaseModel):
belong_to_node_id: str = Field(description="belong to which node id, shared means public") belong_to_node_id: str = Field(description="belong to which node id, shared means public")
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
label: str = Field(description="label") label: str = Field(description="label")
description: str | None = Field(description="description", default="") description: str | None = Field(description="description", default="")
variable: str = Field(description="variable key", default="") variable: str = Field(description="variable key", default="")
max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0) max_length: int | None = Field(
description="max length, applicable to text-input, paragraph, and file-list", default=0
)
default_value: str | None = Field(description="default value", default="") default_value: str | None = Field(description="default value", default="")
placeholder: str | None = Field(description="placeholder", default="") placeholder: str | None = Field(description="placeholder", default="")
unit: str | None = Field(description="unit, applicable to Number", default="") unit: str | None = Field(description="unit, applicable to Number", default="")
tooltips: str | None = Field(description="helpful text", default="") tooltips: str | None = Field(description="helpful text", default="")
allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list) allowed_file_types: list[str] | None = Field(
description="image, document, audio, video, custom.", default_factory=list
)
allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list)
allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list) allowed_file_upload_methods: list[str] | None = Field(
description="remote_url, local_file, tool_file.", default_factory=list
)
required: bool = Field(description="optional, default false", default=False) required: bool = Field(description="optional, default false", default=False)
options: list[str] | None = Field(default_factory=list) options: list[str] | None = Field(default_factory=list)

View File

@ -28,6 +28,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
AGENT_LOG = "agent_log" AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id" ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index" ITERATION_INDEX = "iteration_index"
DATASOURCE_INFO = "datasource_info"
LOOP_ID = "loop_id" LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index" LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id" PARALLEL_ID = "parallel_id"

View File

@ -122,7 +122,6 @@ class Graph(BaseModel):
root_node_configs = [] root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {} all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs: for node_config in node_configs:
node_id = node_config.get("id") node_id = node_config.get("id")
if not node_id: if not node_id:

View File

@ -317,10 +317,10 @@ class GraphEngine:
raise e raise e
# It may not be necessary, but it is necessary. :) # It may not be necessary, but it is necessary. :)
if ( if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [
self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() NodeType.END.value,
in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value] NodeType.KNOWLEDGE_INDEX.value,
): ]:
break break
previous_route_node_state = route_node_state previous_route_node_state = route_node_state

View File

@ -11,18 +11,19 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen
from core.file import File from core.file import File
from core.file.enums import FileTransferMethod, FileType from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment, FileSegment from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from models.model import UploadFile from models.model import UploadFile
from models.workflow import WorkflowNodeExecutionStatus
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from .entities import DatasourceNodeData from .entities import DatasourceNodeData
from .exc import DatasourceNodeError, DatasourceParameterError from .exc import DatasourceNodeError, DatasourceParameterError
@ -54,7 +55,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
try: try:
from core.datasource.datasource_manager import DatasourceManager from core.datasource.datasource_manager import DatasourceManager
if datasource_type is None: if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set") raise DatasourceNodeError("Datasource type is not set")
@ -68,12 +68,11 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs={}, inputs={},
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}", error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) )
# get parameters # get parameters
datasource_parameters = datasource_runtime.entity.parameters datasource_parameters = datasource_runtime.entity.parameters
parameters = self._generate_parameters( parameters = self._generate_parameters(
@ -102,7 +101,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={ outputs={
"online_document": online_document_result.result.model_dump(), "online_document": online_document_result.result.model_dump(),
"datasource_type": datasource_type, "datasource_type": datasource_type,
@ -112,7 +111,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={ outputs={
"website": datasource_info, "website": datasource_info,
"datasource_type": datasource_type, "datasource_type": datasource_type,
@ -121,9 +120,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
case DatasourceProviderType.LOCAL_FILE: case DatasourceProviderType.LOCAL_FILE:
related_id = datasource_info.get("related_id") related_id = datasource_info.get("related_id")
if not related_id: if not related_id:
raise DatasourceNodeError( raise DatasourceNodeError("File is not exist")
"File is not exist"
)
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
if not upload_file: if not upload_file:
raise ValueError("Invalid upload file Info") raise ValueError("Invalid upload file Info")
@ -146,26 +143,27 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
# construct new key list # construct new key list
new_key_list = ["file", key] new_key_list = ["file", key]
self._append_variables_recursively( self._append_variables_recursively(
variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value variable_pool=variable_pool,
node_id=self.node_id,
variable_key_list=new_key_list,
variable_value=value,
) )
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={ outputs={
"file_info": datasource_info, "file_info": datasource_info,
"datasource_type": datasource_type, "datasource_type": datasource_type,
}, },
) )
case _: case _:
raise DatasourceNodeError( raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
f"Unsupported datasource provider: {datasource_type}"
)
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}", error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) )
@ -173,7 +171,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}", error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) )
@ -227,8 +225,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else [] return list(variable.value) if variable else []
def _append_variables_recursively(
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue): self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
):
""" """
Append variables recursively Append variables recursively
:param node_id: node id :param node_id: node id

View File

@ -6,7 +6,6 @@ from typing import Any, cast
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables.segments import ObjectSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
@ -72,8 +71,9 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
process_data=None, process_data=None,
outputs=outputs, outputs=outputs,
) )
results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks, results = self._invoke_knowledge_index(
variable_pool=variable_pool) dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
)
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
) )
@ -96,8 +96,11 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
) )
def _invoke_knowledge_index( def _invoke_knowledge_index(
self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], self,
variable_pool: VariablePool dataset: Dataset,
node_data: KnowledgeIndexNodeData,
chunks: Mapping[str, Any],
variable_pool: VariablePool,
) -> Any: ) -> Any:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if not document_id: if not document_id:
@ -116,7 +119,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
document.indexing_status = "completed" document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
#update document segment status # update document segment status
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id, DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,

View File

@ -208,6 +208,7 @@ class Dataset(Base):
"external_knowledge_api_name": external_knowledge_api.name, "external_knowledge_api_name": external_knowledge_api.name,
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
} }
@property @property
def is_published(self): def is_published(self):
if self.pipeline_id: if self.pipeline_id:
@ -1177,7 +1178,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_customized_templates" __tablename__ = "pipeline_customized_templates"
__table_args__ = ( __table_args__ = (

View File

@ -1,4 +1,3 @@
from datetime import datetime from datetime import datetime
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
@ -21,6 +20,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
provider: Mapped[str] = db.Column(db.String(255), nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False)
system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
class DatasourceProvider(Base): class DatasourceProvider(Base):
__tablename__ = "datasource_providers" __tablename__ = "datasource_providers"
__table_args__ = ( __table_args__ = (

View File

@ -1,4 +1,3 @@
from calendar import day_abbr
import copy import copy
import datetime import datetime
import json import json
@ -7,7 +6,7 @@ import random
import time import time
import uuid import uuid
from collections import Counter from collections import Counter
from typing import Any, Optional, cast from typing import Any, Optional
from flask_login import current_user from flask_login import current_user
from sqlalchemy import func, select from sqlalchemy import func, select
@ -282,7 +281,6 @@ class DatasetService:
db.session.commit() db.session.commit()
return dataset return dataset
@staticmethod @staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]: def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
@ -494,10 +492,9 @@ class DatasetService:
return dataset return dataset
@staticmethod @staticmethod
def update_rag_pipeline_dataset_settings(session: Session, def update_rag_pipeline_dataset_settings(
dataset: Dataset, session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
knowledge_configuration: KnowledgeConfiguration, ):
has_published: bool = False):
dataset = session.merge(dataset) dataset = session.merge(dataset)
if not has_published: if not has_published:
dataset.chunk_structure = knowledge_configuration.chunk_structure dataset.chunk_structure = knowledge_configuration.chunk_structure
@ -616,7 +613,6 @@ class DatasetService:
if action: if action:
deal_dataset_index_update_task.delay(dataset.id, action) deal_dataset_index_update_task.delay(dataset.id, action)
@staticmethod @staticmethod
def delete_dataset(dataset_id, user): def delete_dataset(dataset_id, user):
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)

View File

@ -1,5 +1,4 @@
import logging import logging
from typing import Optional
from flask_login import current_user from flask_login import current_user
@ -22,11 +21,9 @@ class DatasourceProviderService:
def __init__(self) -> None: def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager() self.provider_manager = PluginDatasourceManager()
def datasource_provider_credentials_validate(self, def datasource_provider_credentials_validate(
tenant_id: str, self, tenant_id: str, provider: str, plugin_id: str, credentials: dict
provider: str, ) -> None:
plugin_id: str,
credentials: dict) -> None:
""" """
validate datasource provider credentials. validate datasource provider credentials.
@ -34,29 +31,30 @@ class DatasourceProviderService:
:param provider: :param provider:
:param credentials: :param credentials:
""" """
credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id, credential_valid = self.provider_manager.validate_provider_credentials(
user_id=current_user.id, tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials
provider=provider, )
credentials=credentials)
if credential_valid: if credential_valid:
# Get all provider configurations of the current workspace # Get all provider configurations of the current workspace
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, datasource_provider = (
provider=provider, db.session.query(DatasourceProvider)
plugin_id=plugin_id).first() .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.first()
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id,
provider=provider
) )
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
if not datasource_provider: if not datasource_provider:
for key, value in credentials.items(): for key, value in credentials.items():
if key in provider_credential_secret_variables: if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value # if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value) credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(tenant_id=tenant_id, datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
provider=provider, provider=provider,
plugin_id=plugin_id, plugin_id=plugin_id,
auth_type="api_key", auth_type="api_key",
encrypted_credentials=credentials) encrypted_credentials=credentials,
)
db.session.add(datasource_provider) db.session.add(datasource_provider)
db.session.commit() db.session.commit()
else: else:
@ -101,11 +99,15 @@ class DatasourceProviderService:
:return: :return:
""" """
# Get all provider configurations of the current workspace # Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter( datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.filter(
DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider, DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id DatasourceProvider.plugin_id == plugin_id,
).all() )
.all()
)
if not datasource_providers: if not datasource_providers:
return [] return []
copy_credentials_list = [] copy_credentials_list = []
@ -128,10 +130,7 @@ class DatasourceProviderService:
return copy_credentials_list return copy_credentials_list
def remove_datasource_credentials(self, def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None:
tenant_id: str,
provider: str,
plugin_id: str) -> None:
""" """
remove datasource credentials. remove datasource credentials.
@ -140,9 +139,11 @@ class DatasourceProviderService:
:param plugin_id: plugin id :param plugin_id: plugin id
:return: :return:
""" """
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, datasource_provider = (
provider=provider, db.session.query(DatasourceProvider)
plugin_id=plugin_id).first() .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.first()
)
if datasource_provider: if datasource_provider:
db.session.delete(datasource_provider) db.session.delete(datasource_provider)
db.session.commit() db.session.commit()

View File

@ -107,6 +107,7 @@ class KnowledgeConfiguration(BaseModel):
""" """
Knowledge Base Configuration. Knowledge Base Configuration.
""" """
chunk_structure: str chunk_structure: str
indexing_technique: Literal["high_quality", "economy"] indexing_technique: Literal["high_quality", "economy"]
embedding_model_provider: Optional[str] = "" embedding_model_provider: Optional[str] = ""

View File

@ -3,7 +3,6 @@ from typing import Any, Union
from configs import dify_config from configs import dify_config
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from models.dataset import Pipeline from models.dataset import Pipeline
from models.model import Account, App, EndUser from models.model import Account, App, EndUser

View File

@ -1,13 +1,12 @@
from typing import Optional from typing import Optional
from flask_login import current_user
import yaml import yaml
from flask_login import current_user
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import PipelineCustomizedTemplate from models.dataset import PipelineCustomizedTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
@ -43,7 +42,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
) )
recommended_pipelines_results = [] recommended_pipelines_results = []
for pipeline_customized_template in pipeline_customized_templates: for pipeline_customized_template in pipeline_customized_templates:
recommended_pipeline_result = { recommended_pipeline_result = {
"id": pipeline_customized_template.id, "id": pipeline_customized_template.id,
"name": pipeline_customized_template.name, "name": pipeline_customized_template.name,
@ -56,7 +54,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return {"pipeline_templates": recommended_pipelines_results} return {"pipeline_templates": recommended_pipelines_results}
@classmethod @classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
""" """

View File

@ -38,7 +38,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
recommended_pipelines_results = [] recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates: for pipeline_built_in_template in pipeline_built_in_templates:
recommended_pipeline_result = { recommended_pipeline_result = {
"id": pipeline_built_in_template.id, "id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name, "name": pipeline_built_in_template.name,

View File

@ -35,7 +35,7 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account from models.account import Account
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
from models.workflow import ( from models.workflow import (
@ -57,9 +57,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi
class RagPipelineService: class RagPipelineService:
@classmethod @classmethod
def get_pipeline_templates( def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
cls, type: str = "built-in", language: str = "en-US"
) -> dict:
if type == "built-in": if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
@ -308,7 +306,7 @@ class RagPipelineService:
session=session, session=session,
dataset=dataset, dataset=dataset,
knowledge_configuration=knowledge_configuration, knowledge_configuration=knowledge_configuration,
has_published=pipeline.is_published has_published=pipeline.is_published,
) )
# return new workflow # return new workflow
return workflow return workflow
@ -444,13 +442,11 @@ class RagPipelineService:
) )
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPagesResponse = ( online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
datasource_runtime._get_online_document_pages(
user_id=account.id, user_id=account.id,
datasource_parameters=user_inputs, datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) )
)
return { return {
"result": [page.model_dump() for page in online_document_result.result], "result": [page.model_dump() for page in online_document_result.result],
"provider_type": datasource_node_data.get("provider_type"), "provider_type": datasource_node_data.get("provider_type"),
@ -470,7 +466,6 @@ class RagPipelineService:
else: else:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
def run_free_workflow_node( def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
@ -689,8 +684,8 @@ class RagPipelineService:
WorkflowRun.app_id == pipeline.id, WorkflowRun.app_id == pipeline.id,
or_( or_(
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value, WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
) ),
) )
if args.get("last_id"): if args.get("last_id"):
@ -763,17 +758,16 @@ class RagPipelineService:
# Use the repository to get the node execution # Use the repository to get the node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository( repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine, session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
app_id=pipeline.id,
user=user,
triggered_from=None
) )
# Use the repository to get the node executions with ordering # Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc") order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, node_executions = repository.get_by_workflow_run(
workflow_run_id=run_id,
order_config=order_config, order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN) triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
# Convert domain models to database models # Convert domain models to database models
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]

View File

@ -279,7 +279,11 @@ class RagPipelineDslService:
if node.get("data", {}).get("type") == "knowledge_index": if node.get("data", {}).get("type") == "knowledge_index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure: if (
dataset
and pipeline.is_published
and dataset.chunk_structure != knowledge_configuration.chunk_structure
):
raise ValueError("Chunk structure is not compatible with the published pipeline") raise ValueError("Chunk structure is not compatible with the published pipeline")
else: else:
dataset = Dataset( dataset = Dataset(
@ -304,8 +308,7 @@ class RagPipelineDslService:
.filter( .filter(
DatasetCollectionBinding.provider_name DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider, == knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
== knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset", DatasetCollectionBinding.type == "dataset",
) )
.order_by(DatasetCollectionBinding.created_at) .order_by(DatasetCollectionBinding.created_at)
@ -323,12 +326,8 @@ class RagPipelineDslService:
db.session.commit() db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = ( dataset.embedding_model = knowledge_configuration.embedding_model
knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
)
dataset.embedding_model_provider = (
knowledge_configuration.embedding_model_provider
)
elif knowledge_configuration.indexing_technique == "economy": elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id dataset.pipeline_id = pipeline.id
@ -443,8 +442,7 @@ class RagPipelineDslService:
.filter( .filter(
DatasetCollectionBinding.provider_name DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider, == knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
== knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset", DatasetCollectionBinding.type == "dataset",
) )
.order_by(DatasetCollectionBinding.created_at) .order_by(DatasetCollectionBinding.created_at)
@ -462,12 +460,8 @@ class RagPipelineDslService:
db.session.commit() db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = ( dataset.embedding_model = knowledge_configuration.embedding_model
knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
)
dataset.embedding_model_provider = (
knowledge_configuration.embedding_model_provider
)
elif knowledge_configuration.indexing_technique == "economy": elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id dataset.pipeline_id = pipeline.id
@ -538,7 +532,6 @@ class RagPipelineDslService:
icon_type = "emoji" icon_type = "emoji"
icon = str(pipeline_data.get("icon", "")) icon = str(pipeline_data.get("icon", ""))
# Initialize pipeline based on mode # Initialize pipeline based on mode
workflow_data = data.get("workflow") workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict): if not workflow_data or not isinstance(workflow_data, dict):
@ -554,7 +547,6 @@ class RagPipelineDslService:
] ]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
graph = workflow_data.get("graph", {}) graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []): for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
@ -576,7 +568,6 @@ class RagPipelineDslService:
pipeline.description = pipeline_data.get("description", pipeline.description) pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.updated_by = account.id pipeline.updated_by = account.id
else: else:
if account.current_tenant_id is None: if account.current_tenant_id is None:
raise ValueError("Current tenant is not set") raise ValueError("Current tenant is not set")
@ -636,7 +627,6 @@ class RagPipelineDslService:
# commit db session changes # commit db session changes
db.session.commit() db.session.commit()
return pipeline return pipeline
@classmethod @classmethod
@ -874,7 +864,6 @@ class RagPipelineDslService:
except Exception: except Exception:
return None return None
@staticmethod @staticmethod
def create_rag_pipeline_dataset( def create_rag_pipeline_dataset(
tenant_id: str, tenant_id: str,
@ -886,9 +875,7 @@ class RagPipelineDslService:
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first() .first()
): ):
raise ValueError( raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
with Session(db.engine) as session: with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session) rag_pipeline_dsl_service = RagPipelineDslService(session)

View File

@ -15,9 +15,9 @@ class RagPipelineManageService:
datasources = manager.fetch_datasource_providers(tenant_id) datasources = manager.fetch_datasource_providers(tenant_id)
for datasource in datasources: for datasource in datasources:
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id, credentials = datasource_provider_service.get_datasource_credentials(
provider=datasource.provider, tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
plugin_id=datasource.plugin_id) )
if credentials: if credentials:
datasource.is_authorized = True datasource.is_authorized = True
return datasources return datasources