diff --git a/api/app.py b/api/app.py index 14807f5031..e0a903b10d 100644 --- a/api/app.py +++ b/api/app.py @@ -1,4 +1,3 @@ -import os import sys @@ -18,19 +17,19 @@ else: # so we need to disable gevent in debug mode. # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. # if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: - #from gevent import monkey - # - # # gevent - # monkey.patch_all() - # - # from grpc.experimental import gevent as grpc_gevent # type: ignore - # - # # grpc gevent - # grpc_gevent.init_gevent() + # from gevent import monkey + # + # # gevent + # monkey.patch_all() + # + # from grpc.experimental import gevent as grpc_gevent # type: ignore + # + # # grpc gevent + # grpc_gevent.init_gevent() - # import psycogreen.gevent # type: ignore - # - # psycogreen.gevent.patch_psycopg() + # import psycogreen.gevent # type: ignore + # + # psycogreen.gevent.patch_psycopg() from app_factory import create_app diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 5299064e17..1049f864c3 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -109,8 +109,6 @@ class OAuthDataSourceSync(Resource): return {"result": "success"}, 200 - - api.add_resource(OAuthDataSource, "/oauth/data-source/") api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index ed595f5d3d..395367c9e2 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,24 +4,19 @@ from typing import Optional import requests from flask import current_app, redirect, request -from flask_login import current_user from flask_restful import Resource from sqlalchemy import select from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden, NotFound, Unauthorized +from werkzeug.exceptions import Unauthorized from configs import dify_config from constants.languages import languages -from controllers.console.wraps import account_initialization_required, setup_required -from core.plugin.impl.oauth import OAuthHandler from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import extract_remote_ip -from libs.login import login_required from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models import Account from models.account import AccountStatus -from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError @@ -186,6 +181,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): return account - api.add_resource(OAuthLogin, "/oauth/login/") api.add_resource(OAuthCallback, "/oauth/authorize/") diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index ceb7a277e4..96cb3f5602 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -1,12 +1,9 @@ - from flask import redirect, request from flask_login import current_user # type: ignore from flask_restful import ( # type: ignore Resource, # type: ignore - marshal_with, reqparse, ) -from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config @@ -16,7 +13,6 @@ from controllers.console.wraps import ( setup_required, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.oauth import OAuthHandler from extensions.ext_database import db from libs.login import login_required @@ -33,10 +29,9 @@ class DatasourcePluginOauthApi(Resource): if not current_user.is_editor: raise Forbidden() # get all plugin oauth configs - plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( - provider=provider, - plugin_id=plugin_id - ).first() + plugin_oauth_config = ( + db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + ) if not plugin_oauth_config: raise NotFound() oauth_handler = OAuthHandler() @@ -45,24 +40,20 @@ class DatasourcePluginOauthApi(Resource): if system_credentials: system_credentials["redirect_url"] = redirect_url response = oauth_handler.get_authorization_url( - current_user.current_tenant.id, - current_user.id, - plugin_id, - provider, - system_credentials=system_credentials + current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials ) return response.model_dump() + class DatasourceOauthCallback(Resource): @setup_required @login_required @account_initialization_required def get(self, provider, plugin_id): oauth_handler = OAuthHandler() - plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by( - provider=provider, - plugin_id=plugin_id - ).first() + plugin_oauth_config = ( + db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + ) if not plugin_oauth_config: raise NotFound() credentials = oauth_handler.get_credentials( @@ -71,18 +62,16 @@ class DatasourceOauthCallback(Resource): plugin_id, provider, system_credentials=plugin_oauth_config.system_credentials, - request=request + request=request, ) datasource_provider = DatasourceProvider( - plugin_id=plugin_id, - provider=provider, - auth_type="oauth", - encrypted_credentials=credentials + plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials ) db.session.add(datasource_provider) db.session.commit() return redirect(f"{dify_config.CONSOLE_WEB_URL}") + class DatasourceAuth(Resource): @setup_required @login_required @@ -99,28 +88,27 @@ class DatasourceAuth(Resource): try: datasource_provider_service.datasource_provider_credentials_validate( - tenant_id=current_user.current_tenant_id, - provider=provider, - plugin_id=plugin_id, - credentials=args["credentials"] + tenant_id=current_user.current_tenant_id, + provider=provider, + plugin_id=plugin_id, + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) return {"result": "success"}, 201 - + @setup_required @login_required @account_initialization_required def get(self, provider, plugin_id): datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider, - plugin_id=plugin_id + tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id ) return {"result": datasources}, 200 - + + class DatasourceAuthDeleteApi(Resource): @setup_required @login_required @@ -130,12 +118,11 @@ class DatasourceAuthDeleteApi(Resource): raise Forbidden() datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider, - plugin_id=plugin_id + tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id ) return {"result": "success"}, 200 + # Import Rag Pipeline api.add_resource( DatasourcePluginOauthApi, @@ -149,4 +136,3 @@ api.add_resource( DatasourceAuth, "/auth/datasource/provider//plugin/", ) - diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 471ecbf070..1b869d9847 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -110,6 +110,7 @@ class CustomizedPipelineTemplateApi(Resource): dsl = yaml.safe_load(template.yaml_content) return {"data": dsl}, 200 + class CustomizedPipelineTemplateApi(Resource): @setup_required @login_required @@ -142,6 +143,7 @@ class CustomizedPipelineTemplateApi(Resource): RagPipelineService.publish_customized_pipeline_template(pipeline_id, args) return 200 + api.add_resource( PipelineTemplateListApi, "/rag/pipeline/templates", diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index fe91f01af6..d7ed5d475d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -540,7 +540,6 @@ class RagPipelineConfigApi(Resource): @login_required @account_initialization_required def get(self, pipeline_id): - return { "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, } diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 61d4b723e1..5fb5bff2a9 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -32,7 +32,6 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db -from fields.document_fields import dataset_and_document_fields from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, Pipeline from models.enums import WorkflowRunTriggeredFrom @@ -55,8 +54,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[True], call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: - ... + ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... @overload def generate( @@ -70,8 +68,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: Literal[False], call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Mapping[str, Any]: - ... + ) -> Mapping[str, Any]: ... @overload def generate( @@ -85,8 +82,7 @@ class PipelineGenerator(BaseAppGenerator): streaming: bool, call_depth: int, workflow_thread_pool_id: Optional[str], - ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: - ... + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... def generate( self, @@ -233,17 +229,19 @@ class PipelineGenerator(BaseAppGenerator): description=dataset.description, chunk_structure=dataset.chunk_structure, ).model_dump(), - "documents": [PipelineDocument( - id=document.id, - position=document.position, - data_source_type=document.data_source_type, - data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, - name=document.name, - indexing_status=document.indexing_status, - error=document.error, - enabled=document.enabled, - ).model_dump() for document in documents - ] + "documents": [ + PipelineDocument( + id=document.id, + position=document.position, + data_source_type=document.data_source_type, + data_source_info=json.loads(document.data_source_info) if document.data_source_info else None, + name=document.name, + indexing_status=document.indexing_status, + error=document.error, + enabled=document.enabled, + ).model_dump() + for document in documents + ], } def _generate( @@ -316,9 +314,7 @@ class PipelineGenerator(BaseAppGenerator): ) # new thread - worker_thread = threading.Thread( - target=worker_with_context - ) + worker_thread = threading.Thread(target=worker_with_context) worker_thread.start() diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 8d90e7ee3e..4582dcbb0d 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -111,7 +111,10 @@ class PipelineRunner(WorkflowBasedAppRunner): if workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables: rag_pipeline_variable = RAGPipelineVariable(**v) - if rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id and rag_pipeline_variable.variable in inputs: + if ( + rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id + and rag_pipeline_variable.variable in inputs + ): rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable] variable_pool = VariablePool( @@ -195,7 +198,7 @@ class PipelineRunner(WorkflowBasedAppRunner): continue real_run_nodes.append(node) for edge in edges: - if edge.get("source") in exclude_node_ids : + if edge.get("source") in exclude_node_ids: continue real_edges.append(edge) graph_config = dict(graph_config) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index f346994b30..75693be5ea 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -232,6 +232,7 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): """ RAG Pipeline Application Generate Entity. """ + # pipeline config pipeline_config: WorkflowUIBasedAppConfig datasource_type: str diff --git a/api/core/datasource/__base/datasource_runtime.py b/api/core/datasource/__base/datasource_runtime.py index 9ddc25a637..264145d261 100644 --- a/api/core/datasource/__base/datasource_runtime.py +++ b/api/core/datasource/__base/datasource_runtime.py @@ -5,7 +5,6 @@ from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import DatasourceInvokeFrom -from core.tools.entities.tool_entities import ToolInvokeFrom class DatasourceRuntime(BaseModel): diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 8c74aeb320..46b36d8349 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -46,7 +46,7 @@ class DatasourceManager: if not provider_entity: raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found") - match (datasource_type): + match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: controller = OnlineDocumentDatasourcePluginProviderController( entity=provider_entity.declaration, @@ -98,5 +98,3 @@ class DatasourceManager: tenant_id, datasource_type, ).get_datasource(datasource_name) - - diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 51d5489c4c..ea357d85b2 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -215,7 +215,6 @@ class PluginDatasourceManager(BasePluginClient): "X-Plugin-ID": datasource_provider_id.plugin_id, "Content-Type": "application/json", }, - ) for resp in response: @@ -233,41 +232,23 @@ class PluginDatasourceManager(BasePluginClient): "identity": { "author": "langgenius", "name": "langgenius/file/file", - "label": { - "zh_Hans": "File", - "en_US": "File", - "pt_BR": "File", - "ja_JP": "File" - }, + "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", - "description": { - "zh_Hans": "File", - "en_US": "File", - "pt_BR": "File", - "ja_JP": "File" - } + "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, }, "credentials_schema": [], "provider_type": "local_file", - "datasources": [{ - "identity": { - "author": "langgenius", - "name": "upload-file", - "provider": "langgenius", - "label": { - "zh_Hans": "File", - "en_US": "File", - "pt_BR": "File", - "ja_JP": "File" - } - }, - "parameters": [], - "description": { - "zh_Hans": "File", - "en_US": "File", - "pt_BR": "File", - "ja_JP": "File" + "datasources": [ + { + "identity": { + "author": "langgenius", + "name": "upload-file", + "provider": "langgenius", + "label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, + }, + "parameters": [], + "description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, } - }] - } + ], + }, } diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index ca54290796..be1765feee 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -28,12 +28,12 @@ class Jieba(BaseKeyword): with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + keyword_number = ( + self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + ) for text in texts: - keywords = keyword_table_handler.extract_keywords( - text.page_content, keyword_number - ) + keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) if text.metadata is not None: self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) keyword_table = self._add_text_to_keyword_table( @@ -51,19 +51,17 @@ class Jieba(BaseKeyword): keyword_table = self._get_dataset_keyword_table() keywords_list = kwargs.get("keywords_list") - keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + keyword_number = ( + self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + ) for i in range(len(texts)): text = texts[i] if keywords_list: keywords = keywords_list[i] if not keywords: - keywords = keyword_table_handler.extract_keywords( - text.page_content, keyword_number - ) + keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) else: - keywords = keyword_table_handler.extract_keywords( - text.page_content, keyword_number - ) + keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number) if text.metadata is not None: self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) keyword_table = self._add_text_to_keyword_table( @@ -242,7 +240,9 @@ class Jieba(BaseKeyword): keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] ) else: - keyword_number = self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + keyword_number = ( + self.dataset.keyword_number if self.dataset.keyword_number else self._config.max_keywords_per_chunk + ) keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number) segment.keywords = list(keywords) diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 72e4923b58..ff6f843a28 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -38,7 +38,7 @@ class BaseIndexProcessor(ABC): @abstractmethod def index(self, dataset: Dataset, document: DatasetDocument, chunks: Mapping[str, Any]): raise NotImplementedError - + @abstractmethod def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: raise NotImplementedError diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 559bc5d59b..eee8353214 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -15,7 +15,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document, GeneralStructureChunk from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset, Document as DatasetDocument, DatasetProcessRule +from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -152,13 +153,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword = Keyword(dataset) keyword.add_texts(documents) - def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: paragraph = GeneralStructureChunk(**chunks) preview = [] for content in paragraph.general_chunks: preview.append({"content": content}) - return { - "preview": preview, - "total_segments": len(paragraph.general_chunks) - } \ No newline at end of file + return {"preview": preview, "total_segments": len(paragraph.general_chunks)} diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 7a3f8f1c63..158fc819ee 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -16,7 +16,8 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk from extensions.ext_database import db from libs import helper -from models.dataset import ChildChunk, Dataset, Document as DatasetDocument, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -239,14 +240,5 @@ class ParentChildIndexProcessor(BaseIndexProcessor): parent_childs = ParentChildStructureChunk(**chunks) preview = [] for parent_child in parent_childs.parent_child_chunks: - preview.append( - { - "content": parent_child.parent_content, - "child_chunks": parent_child.child_contents - - } - ) - return { - "preview": preview, - "total_segments": len(parent_childs.parent_child_chunks) - } \ No newline at end of file + preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) + return {"preview": preview, "total_segments": len(parent_childs.parent_child_chunks)} diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index b415596254..407f1b6f6d 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,7 +4,8 @@ import logging import re import threading import uuid -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Optional import pandas as pd from flask import Flask, current_app @@ -20,7 +21,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import Document from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset, Document as DatasetDocument +from models.dataset import Dataset from services.entities.knowledge_entities.knowledge_entities import Rule @@ -160,10 +161,10 @@ class QAIndexProcessor(BaseIndexProcessor): doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) return docs - + def index(self, dataset: Dataset, document: Document, chunks: Mapping[str, Any]): pass - + def format_preview(self, chunks: Mapping[str, Any]) -> Mapping[str, Any]: return {"preview": chunks} diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index c0952383a9..1fe0e36a47 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -94,19 +94,26 @@ class FileVariable(FileSegment, Variable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass + class RAGPipelineVariable(BaseModel): belong_to_node_id: str = Field(description="belong to which node id, shared means public") type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") label: str = Field(description="label") description: str | None = Field(description="description", default="") variable: str = Field(description="variable key", default="") - max_length: int | None = Field(description="max length, applicable to text-input, paragraph, and file-list", default=0) + max_length: int | None = Field( + description="max length, applicable to text-input, paragraph, and file-list", default=0 + ) default_value: str | None = Field(description="default value", default="") placeholder: str | None = Field(description="placeholder", default="") unit: str | None = Field(description="unit, applicable to Number", default="") tooltips: str | None = Field(description="helpful text", default="") - allowed_file_types: list[str] | None = Field(description="image, document, audio, video, custom.", default_factory=list) + allowed_file_types: list[str] | None = Field( + description="image, document, audio, video, custom.", default_factory=list + ) allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) - allowed_file_upload_methods: list[str] | None = Field(description="remote_url, local_file, tool_file.", default_factory=list) + allowed_file_upload_methods: list[str] | None = Field( + description="remote_url, local_file, tool_file.", default_factory=list + ) required: bool = Field(description="optional, default false", default=False) options: list[str] | None = Field(default_factory=list) diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 319833145e..21ea26862a 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -49,7 +49,7 @@ class VariablePool(BaseModel): ) rag_pipeline_variables: Mapping[str, Any] = Field( description="RAG pipeline variables.", - default_factory=dict, + default_factory=dict, ) def __init__( diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py index 773f5b777b..10271f6062 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/core/workflow/entities/workflow_node_execution.py @@ -28,6 +28,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): AGENT_LOG = "agent_log" ITERATION_ID = "iteration_id" ITERATION_INDEX = "iteration_index" + DATASOURCE_INFO = "datasource_info" LOOP_ID = "loop_id" LOOP_INDEX = "loop_index" PARALLEL_ID = "parallel_id" diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 7062fc4565..16bf847189 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -122,7 +122,6 @@ class Graph(BaseModel): root_node_configs = [] all_node_id_config_mapping: dict[str, dict] = {} - for node_config in node_configs: node_id = node_config.get("id") if not node_id: @@ -142,7 +141,7 @@ class Graph(BaseModel): ( node_config.get("id") for node_config in root_node_configs - if node_config.get("data", {}).get("type", "") == NodeType.START.value + if node_config.get("data", {}).get("type", "") == NodeType.START.value or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value ), None, diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ced1acfdd2..86654e6fac 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -317,10 +317,10 @@ class GraphEngine: raise e # It may not be necessary, but it is necessary. :) - if ( - self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() - in [NodeType.END.value, NodeType.KNOWLEDGE_INDEX.value] - ): + if self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() in [ + NodeType.END.value, + NodeType.KNOWLEDGE_INDEX.value, + ]: break previous_route_node_state = route_node_state diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index b44039298c..92b2daea54 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -11,18 +11,19 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen from core.file import File from core.file.enums import FileTransferMethod, FileType from core.plugin.impl.exc import PluginDaemonClientSideError -from core.variables.segments import ArrayAnySegment, FileSegment +from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models.model import UploadFile -from models.workflow import WorkflowNodeExecutionStatus +from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from .entities import DatasourceNodeData from .exc import DatasourceNodeError, DatasourceParameterError @@ -54,7 +55,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): try: from core.datasource.datasource_manager import DatasourceManager - if datasource_type is None: raise DatasourceNodeError("Datasource type is not set") @@ -66,13 +66,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) except DatasourceNodeError as e: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to get datasource runtime: {str(e)}", - error_type=type(e).__name__, - ) - + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to get datasource runtime: {str(e)}", + error_type=type(e).__name__, + ) # get parameters datasource_parameters = datasource_runtime.entity.parameters @@ -102,7 +101,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ "online_document": online_document_result.result.model_dump(), "datasource_type": datasource_type, @@ -112,18 +111,16 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "website": datasource_info, - "datasource_type": datasource_type, + "website": datasource_info, + "datasource_type": datasource_type, }, ) case DatasourceProviderType.LOCAL_FILE: related_id = datasource_info.get("related_id") if not related_id: - raise DatasourceNodeError( - "File is not exist" - ) + raise DatasourceNodeError("File is not exist") upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() if not upload_file: raise ValueError("Invalid upload file Info") @@ -146,26 +143,27 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): # construct new key list new_key_list = ["file", key] self._append_variables_recursively( - variable_pool=variable_pool, node_id=self.node_id, variable_key_list=new_key_list, variable_value=value + variable_pool=variable_pool, + node_id=self.node_id, + variable_key_list=new_key_list, + variable_value=value, ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={ - "file_info": datasource_info, - "datasource_type": datasource_type, - }, - ) - case _: - raise DatasourceNodeError( - f"Unsupported datasource provider: {datasource_type}" + "file_info": datasource_info, + "datasource_type": datasource_type, + }, ) + case _: + raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") except PluginDaemonClientSideError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to transform datasource message: {str(e)}", error_type=type(e).__name__, ) @@ -173,7 +171,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to invoke datasource: {str(e)}", error_type=type(e).__name__, ) @@ -227,8 +225,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - - def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue): + def _append_variables_recursively( + self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue + ): """ Append variables recursively :param node_id: node id diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 41a6c6141e..a1ee3aa823 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -6,7 +6,6 @@ from typing import Any, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.variables.segments import ObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey @@ -72,8 +71,9 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): process_data=None, outputs=outputs, ) - results = self._invoke_knowledge_index(dataset=dataset, node_data=node_data, chunks=chunks, - variable_pool=variable_pool) + results = self._invoke_knowledge_index( + dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results ) @@ -96,8 +96,11 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): ) def _invoke_knowledge_index( - self, dataset: Dataset, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], - variable_pool: VariablePool + self, + dataset: Dataset, + node_data: KnowledgeIndexNodeData, + chunks: Mapping[str, Any], + variable_pool: VariablePool, ) -> Any: document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if not document_id: @@ -116,7 +119,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(document) - #update document segment status + # update document segment status db.session.query(DocumentSegment).filter( DocumentSegment.document_id == document.id, DocumentSegment.dataset_id == dataset.id, diff --git a/api/models/dataset.py b/api/models/dataset.py index 86216ffe98..d2fdd5e900 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -208,6 +208,7 @@ class Dataset(Base): "external_knowledge_api_name": external_knowledge_api.name, "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), } + @property def is_published(self): if self.pipeline_id: @@ -1177,7 +1178,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_customized_templates" __table_args__ = ( diff --git a/api/models/oauth.py b/api/models/oauth.py index 9a070c2fbe..2fb34f0ac9 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -1,4 +1,3 @@ - from datetime import datetime from sqlalchemy.dialects.postgresql import JSONB @@ -21,6 +20,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined] provider: Mapped[str] = db.Column(db.String(255), nullable=False) system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False) + class DatasourceProvider(Base): __tablename__ = "datasource_providers" __table_args__ = ( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 6d0f8ec6a9..133e3765f7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1,4 +1,3 @@ -from calendar import day_abbr import copy import datetime import json @@ -7,7 +6,7 @@ import random import time import uuid from collections import Counter -from typing import Any, Optional, cast +from typing import Any, Optional from flask_login import current_user from sqlalchemy import func, select @@ -282,7 +281,6 @@ class DatasetService: db.session.commit() return dataset - @staticmethod def get_dataset(dataset_id) -> Optional[Dataset]: dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() @@ -494,10 +492,9 @@ class DatasetService: return dataset @staticmethod - def update_rag_pipeline_dataset_settings(session: Session, - dataset: Dataset, - knowledge_configuration: KnowledgeConfiguration, - has_published: bool = False): + def update_rag_pipeline_dataset_settings( + session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False + ): dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure @@ -616,7 +613,6 @@ class DatasetService: if action: deal_dataset_index_update_task.delay(dataset.id, action) - @staticmethod def delete_dataset(dataset_id, user): dataset = DatasetService.get_dataset(dataset_id) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 09c4cca706..ccafc5555c 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from flask_login import current_user @@ -22,11 +21,9 @@ class DatasourceProviderService: def __init__(self) -> None: self.provider_manager = PluginDatasourceManager() - def datasource_provider_credentials_validate(self, - tenant_id: str, - provider: str, - plugin_id: str, - credentials: dict) -> None: + def datasource_provider_credentials_validate( + self, tenant_id: str, provider: str, plugin_id: str, credentials: dict + ) -> None: """ validate datasource provider credentials. @@ -34,29 +31,30 @@ class DatasourceProviderService: :param provider: :param credentials: """ - credential_valid = self.provider_manager.validate_provider_credentials(tenant_id=tenant_id, - user_id=current_user.id, - provider=provider, - credentials=credentials) + credential_valid = self.provider_manager.validate_provider_credentials( + tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials + ) if credential_valid: # Get all provider configurations of the current workspace - datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id).first() + datasource_provider = ( + db.session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + .first() + ) - provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, - provider=provider - ) + provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider) if not datasource_provider: for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value credentials[key] = encrypter.encrypt_token(tenant_id, value) - datasource_provider = DatasourceProvider(tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, - auth_type="api_key", - encrypted_credentials=credentials) + datasource_provider = DatasourceProvider( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + auth_type="api_key", + encrypted_credentials=credentials, + ) db.session.add(datasource_provider) db.session.commit() else: @@ -101,11 +99,15 @@ class DatasourceProviderService: :return: """ # Get all provider configurations of the current workspace - datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter( - DatasourceProvider.tenant_id == tenant_id, - DatasourceProvider.provider == provider, - DatasourceProvider.plugin_id == plugin_id - ).all() + datasource_providers: list[DatasourceProvider] = ( + db.session.query(DatasourceProvider) + .filter( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .all() + ) if not datasource_providers: return [] copy_credentials_list = [] @@ -128,10 +130,7 @@ class DatasourceProviderService: return copy_credentials_list - def remove_datasource_credentials(self, - tenant_id: str, - provider: str, - plugin_id: str) -> None: + def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None: """ remove datasource credentials. @@ -140,9 +139,11 @@ class DatasourceProviderService: :param plugin_id: plugin id :return: """ - datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id).first() + datasource_provider = ( + db.session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + .first() + ) if datasource_provider: db.session.delete(datasource_provider) db.session.commit() diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index 8da2e4aade..620fb2426a 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -107,6 +107,7 @@ class KnowledgeConfiguration(BaseModel): """ Knowledge Base Configuration. """ + chunk_structure: str indexing_technique: Literal["high_quality", "economy"] embedding_model_provider: Optional[str] = "" diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 911086066a..da67801877 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -3,7 +3,6 @@ from typing import Any, Union from configs import dify_config from core.app.apps.pipeline.pipeline_generator import PipelineGenerator -from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from models.dataset import Pipeline from models.model import Account, App, EndUser diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index b6670b70cd..3ede75309d 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,13 +1,12 @@ from typing import Optional -from flask_login import current_user import yaml +from flask_login import current_user from extensions.ext_database import db from models.dataset import PipelineCustomizedTemplate from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType -from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): @@ -43,7 +42,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): ) recommended_pipelines_results = [] for pipeline_customized_template in pipeline_customized_templates: - recommended_pipeline_result = { "id": pipeline_customized_template.id, "name": pipeline_customized_template.name, @@ -56,7 +54,6 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return {"pipeline_templates": recommended_pipelines_results} - @classmethod def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: """ diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 8019dac0a8..741384afc2 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -38,7 +38,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): recommended_pipelines_results = [] for pipeline_built_in_template in pipeline_built_in_templates: - recommended_pipeline_result = { "id": pipeline_built_in_template.id, "name": pipeline_built_in_template.name, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 43451528db..07697c9851 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -35,7 +35,7 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore +from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( @@ -57,9 +57,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi class RagPipelineService: @classmethod - def get_pipeline_templates( - cls, type: str = "built-in", language: str = "en-US" - ) -> dict: + def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() @@ -308,7 +306,7 @@ class RagPipelineService: session=session, dataset=dataset, knowledge_configuration=knowledge_configuration, - has_published=pipeline.is_published + has_published=pipeline.is_published, ) # return new workflow return workflow @@ -444,12 +442,10 @@ class RagPipelineService: ) if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPagesResponse = ( - datasource_runtime._get_online_document_pages( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) + online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), ) return { "result": [page.model_dump() for page in online_document_result.result], @@ -470,7 +466,6 @@ class RagPipelineService: else: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") - def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: @@ -689,8 +684,8 @@ class RagPipelineService: WorkflowRun.app_id == pipeline.id, or_( WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value - ) + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value, + ), ) if args.get("last_id"): @@ -763,18 +758,17 @@ class RagPipelineService: # Use the repository to get the node execution repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, - app_id=pipeline.id, - user=user, - triggered_from=None + session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None ) # Use the repository to get the node executions with ordering order_config = OrderConfig(order_by=["index"], order_direction="desc") - node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, - order_config=order_config, - triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN) - # Convert domain models to database models + node_executions = repository.get_by_workflow_run( + workflow_run_id=run_id, + order_config=order_config, + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, + ) + # Convert domain models to database models workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] return workflow_node_executions diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 189ba0973f..2e1ed57908 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -279,7 +279,11 @@ class RagPipelineDslService: if node.get("data", {}).get("type") == "knowledge_index": knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) - if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure: + if ( + dataset + and pipeline.is_published + and dataset.chunk_structure != knowledge_configuration.chunk_structure + ): raise ValueError("Chunk structure is not compatible with the published pipeline") else: dataset = Dataset( @@ -304,8 +308,7 @@ class RagPipelineDslService: .filter( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, - DatasetCollectionBinding.model_name - == knowledge_configuration.embedding_model, + DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, DatasetCollectionBinding.type == "dataset", ) .order_by(DatasetCollectionBinding.created_at) @@ -323,12 +326,8 @@ class RagPipelineDslService: db.session.commit() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id - dataset.embedding_model = ( - knowledge_configuration.embedding_model - ) - dataset.embedding_model_provider = ( - knowledge_configuration.embedding_model_provider - ) + dataset.embedding_model = knowledge_configuration.embedding_model + dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number dataset.pipeline_id = pipeline.id @@ -443,8 +442,7 @@ class RagPipelineDslService: .filter( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, - DatasetCollectionBinding.model_name - == knowledge_configuration.embedding_model, + DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, DatasetCollectionBinding.type == "dataset", ) .order_by(DatasetCollectionBinding.created_at) @@ -462,12 +460,8 @@ class RagPipelineDslService: db.session.commit() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id - dataset.embedding_model = ( - knowledge_configuration.embedding_model - ) - dataset.embedding_model_provider = ( - knowledge_configuration.embedding_model_provider - ) + dataset.embedding_model = knowledge_configuration.embedding_model + dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number dataset.pipeline_id = pipeline.id @@ -538,7 +532,6 @@ class RagPipelineDslService: icon_type = "emoji" icon = str(pipeline_data.get("icon", "")) - # Initialize pipeline based on mode workflow_data = data.get("workflow") if not workflow_data or not isinstance(workflow_data, dict): @@ -554,7 +547,6 @@ class RagPipelineDslService: ] rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) - graph = workflow_data.get("graph", {}) for node in graph.get("nodes", []): if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: @@ -576,7 +568,6 @@ class RagPipelineDslService: pipeline.description = pipeline_data.get("description", pipeline.description) pipeline.updated_by = account.id - else: if account.current_tenant_id is None: raise ValueError("Current tenant is not set") @@ -636,7 +627,6 @@ class RagPipelineDslService: # commit db session changes db.session.commit() - return pipeline @classmethod @@ -874,7 +864,6 @@ class RagPipelineDslService: except Exception: return None - @staticmethod def create_rag_pipeline_dataset( tenant_id: str, @@ -886,9 +875,7 @@ class RagPipelineDslService: .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) .first() ): - raise ValueError( - f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." - ) + raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.") with Session(db.engine) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) diff --git a/api/services/rag_pipeline/rag_pipeline_manage_service.py b/api/services/rag_pipeline/rag_pipeline_manage_service.py index df6085fafa..0908d30c12 100644 --- a/api/services/rag_pipeline/rag_pipeline_manage_service.py +++ b/api/services/rag_pipeline/rag_pipeline_manage_service.py @@ -12,12 +12,12 @@ class RagPipelineManageService: # get all builtin providers manager = PluginDatasourceManager() - datasources = manager.fetch_datasource_providers(tenant_id) + datasources = manager.fetch_datasource_providers(tenant_id) for datasource in datasources: datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id, - provider=datasource.provider, - plugin_id=datasource.plugin_id) + credentials = datasource_provider_service.get_datasource_credentials( + tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id + ) if credentials: datasource.is_authorized = True return datasources