diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ceaa9ec4fa..644bcbddb1 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -283,7 +283,7 @@ class DatasetApi(Resource): location="json", help="Invalid external knowledge api id.", ) - + parser.add_argument( "icon_info", type=dict, diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 60fa1731ca..e5fde58d04 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -52,6 +52,7 @@ from fields.document_fields import ( ) from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile +from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from tasks.add_document_to_index_task import add_document_to_index_task @@ -1092,6 +1093,35 @@ class WebsiteDocumentSyncApi(DocumentResource): return {"result": "success"}, 200 +class DocumentPipelineExecutionLogApi(DocumentResource): + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + dataset_id = str(dataset_id) + document_id = str(document_id) + + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + document = DocumentService.get_document(dataset.id, document_id) + if not document: + raise NotFound("Document not found.") + log = ( + db.session.query(DocumentPipelineExecutionLog) + .filter_by(document_id=document_id) + .order_by(DocumentPipelineExecutionLog.created_at.desc()) + .first() + ) + if not log: + return {"datasource_info": None, "datasource_type": None, "input_data": None}, 200 + return { + "datasource_info": log.datasource_info, + "datasource_type": log.datasource_type, + "input_data": log.input_data, + }, 200 + + api.add_resource(GetProcessRuleApi, "/datasets/process-rule") api.add_resource(DatasetDocumentListApi, "/datasets//documents") api.add_resource(DatasetInitApi, "/datasets/init") diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index d2136f771b..21a7b998f0 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -41,8 +41,9 @@ class DatasourcePluginOauthApi(Resource): if not plugin_oauth_config: raise NotFound() oauth_handler = OAuthHandler() - redirect_url = (f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?" - f"provider={provider}&plugin_id={plugin_id}") + redirect_url = ( + f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}" + ) system_credentials = plugin_oauth_config.system_credentials if system_credentials: system_credentials["redirect_url"] = redirect_url @@ -123,9 +124,7 @@ class DatasourceAuth(Resource): args = parser.parse_args() datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, - provider=args["provider"], - plugin_id=args["plugin_id"] + tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"] ) return {"result": datasources}, 200 @@ -146,7 +145,7 @@ class DatasourceAuthUpdateDeleteApi(Resource): tenant_id=current_user.current_tenant_id, auth_id=auth_id, provider=args["provider"], - plugin_id=args["plugin_id"] + plugin_id=args["plugin_id"], ) return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 616803247c..00cd36b649 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -384,6 +384,7 @@ class PublishedRagPipelineRunApi(Resource): # return result # + class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @login_required @@ -419,7 +420,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): user_inputs=inputs, account=current_user, datasource_type=datasource_type, - is_published=True + is_published=True, ) return result @@ -458,12 +459,12 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): return helper.compact_generate_response( PipelineGenerator.convert_to_event_stream( rag_pipeline_service.run_datasource_workflow_node( - pipeline=pipeline, - node_id=node_id, - user_inputs=inputs, - account=current_user, - datasource_type=datasource_type, - is_published=False + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False, ) ) ) diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index f170d0ee3f..aa74f8c318 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -188,7 +188,7 @@ class WorkflowResponseConverter: manager = PluginDatasourceManager() provider_entity = manager.fetch_datasource_provider( self._application_generate_entity.app_config.tenant_id, - f"{node_data.plugin_id}/{node_data.provider_name}" + f"{node_data.plugin_id}/{node_data.provider_name}", ) response.data.extras["icon"] = provider_entity.declaration.identity.icon diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index a2123fdc49..ec565fe2e5 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -33,7 +33,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom -from models.dataset import Document, Pipeline +from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from services.dataset_service import DocumentService @@ -136,6 +136,16 @@ class PipelineGenerator(BaseAppGenerator): document_id = None if invoke_from == InvokeFrom.PUBLISHED: document_id = documents[i].id + document_pipeline_execution_log = DocumentPipelineExecutionLog( + document_id=document_id, + datasource_type=datasource_type, + datasource_info=datasource_info, + input_data=inputs, + pipeline_id=pipeline.id, + created_by=user.id, + ) + db.session.add(document_pipeline_execution_log) + db.session.commit() application_generate_entity = RagPipelineGenerateEntity( task_id=str(uuid.uuid4()), app_config=pipeline_config, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 9b72966b50..d072b8541b 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -284,17 +284,20 @@ class WebSiteInfo(BaseModel): """ Website info """ + status: Optional[str] = Field(..., description="crawl job status") web_info_list: Optional[list[WebSiteInfoDetail]] = [] total: Optional[int] = Field(default=0, description="The total number of websites") completed: Optional[int] = Field(default=0, description="The number of completed websites") + class WebsiteCrawlMessage(BaseModel): """ Get website crawl response """ + result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0) + class DatasourceMessage(ToolInvokeMessage): pass - diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 0567f1a480..8c0f20ce2d 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -43,7 +43,6 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") - return WebsiteCrawlDatasourcePlugin( entity=datasource_entity, runtime=DatasourceRuntime(tenant_id=self.tenant_id), diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index fbf591eb8f..061a69e009 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -277,8 +277,7 @@ InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | Bas class DatasourceRunEvent(BaseModel): status: str = Field(..., description="status") - data: Mapping[str,Any] | list = Field(..., description="result") + data: Mapping[str, Any] | list = Field(..., description="result") total: Optional[int] = Field(..., description="total") completed: Optional[int] = Field(..., description="completed") time_consuming: Optional[float] = Field(..., description="time consuming") - diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 0e3decc7b4..ab4477f538 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -74,12 +74,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): except DatasourceNodeError as e: yield RunCompletedEvent( run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to get datasource runtime: {str(e)}", - error_type=type(e).__name__, - ) + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to get datasource runtime: {str(e)}", + error_type=type(e).__name__, + ) ) # get parameters @@ -114,16 +114,17 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) case DatasourceProviderType.WEBSITE_CRAWL: - - yield RunCompletedEvent(run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - **datasource_info, - "datasource_type": datasource_type, - }, - )) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + **datasource_info, + "datasource_type": datasource_type, + }, + ) + ) case DatasourceProviderType.LOCAL_FILE: related_id = datasource_info.get("related_id") if not related_id: @@ -155,33 +156,39 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): variable_key_list=new_key_list, variable_value=value, ) - yield RunCompletedEvent(run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "file_info": datasource_info, - "datasource_type": datasource_type, - }, - )) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "file_info": datasource_info, + "datasource_type": datasource_type, + }, + ) + ) case _: raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") except PluginDaemonClientSideError as e: - yield RunCompletedEvent(run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to transform datasource message: {str(e)}", - error_type=type(e).__name__, - )) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to transform datasource message: {str(e)}", + error_type=type(e).__name__, + ) + ) except DatasourceNodeError as e: - yield RunCompletedEvent(run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to invoke datasource: {str(e)}", - error_type=type(e).__name__, - )) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to invoke datasource: {str(e)}", + error_type=type(e).__name__, + ) + ) def _generate_parameters( self, @@ -286,8 +293,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return result - - def _transform_message( self, messages: Generator[DatasourceMessage, None, None], diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 49c8ec1e69..2c45bf4073 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -123,10 +123,14 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]): # update document status document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - document.word_count = db.session.query(func.sum(DocumentSegment.word_count)).filter( - DocumentSegment.document_id == document.id, - DocumentSegment.dataset_id == dataset.id, - ).scalar() + document.word_count = ( + db.session.query(func.sum(DocumentSegment.word_count)) + .filter( + DocumentSegment.document_id == document.id, + DocumentSegment.dataset_id == dataset.id, + ) + .scalar() + ) db.session.add(document) # update document segment status db.session.query(DocumentSegment).filter( diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 128041a27d..81606594e0 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -349,6 +349,7 @@ def _build_from_datasource_file( storage_key=datasource_file.key, ) + def _is_file_valid_with_config( *, input_file_type: str, diff --git a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py index 4d726cecb1..503842b797 100644 --- a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py +++ b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py @@ -12,7 +12,7 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'b35c3db83d09' -down_revision = '2adcbe1f5dfb' +down_revision = '4474872b0ee6' branch_labels = None depends_on = None diff --git a/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py b/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py new file mode 100644 index 0000000000..a695adc74a --- /dev/null +++ b/api/migrations/versions/2025_06_17_1905-70a0fc0c013f_add_pipeline_info_7.py @@ -0,0 +1,45 @@ +"""add_pipeline_info_7 + +Revision ID: 70a0fc0c013f +Revises: 224fba149d48 +Create Date: 2025-06-17 19:05:39.920953 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '70a0fc0c013f' +down_revision = '224fba149d48' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('document_pipeline_execution_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('pipeline_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('datasource_type', sa.String(length=255), nullable=False), + sa.Column('datasource_info', sa.Text(), nullable=False), + sa.Column('input_data', sa.JSON(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey') + ) + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: + batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: + batch_op.drop_index('document_pipeline_execution_logs_document_id_idx') + + op.drop_table('document_pipeline_execution_logs') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 5d18eaff49..16d1865a83 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -75,12 +75,16 @@ class Dataset(Base): @property def total_available_documents(self): - return db.session.query(func.count(Document.id)).filter( - Document.dataset_id == self.id, - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, - ).scalar() + return ( + db.session.query(func.count(Document.id)) + .filter( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) + .scalar() + ) @property def dataset_keyword_table(self): @@ -325,6 +329,7 @@ class DatasetProcessRule(Base): except JSONDecodeError: return None + class Document(Base): __tablename__ = "documents" __table_args__ = ( @@ -1248,3 +1253,20 @@ class Pipeline(Base): # type: ignore[name-defined] @property def dataset(self): return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() + + +class DocumentPipelineExecutionLog(Base): + __tablename__ = "document_pipeline_execution_logs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"), + db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + pipeline_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + datasource_type = db.Column(db.String(255), nullable=False) + datasource_info = db.Column(db.Text, nullable=False) + input_data = db.Column(db.JSON, nullable=False) + created_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8719eb3be4..8c88a51ed7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -334,11 +334,15 @@ class DatasetService: dataset.retrieval_model = external_retrieval_model dataset.name = data.get("name", dataset.name) # check if dataset name is exists - if db.session.query(Dataset).filter( - Dataset.id != dataset_id, - Dataset.name == dataset.name, - Dataset.tenant_id == dataset.tenant_id, - ).first(): + if ( + db.session.query(Dataset) + .filter( + Dataset.id != dataset_id, + Dataset.name == dataset.name, + Dataset.tenant_id == dataset.tenant_id, + ) + .first() + ): raise ValueError("Dataset name already exists") dataset.description = data.get("description", "") permission = data.get("permission") diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 80e903bd46..fa01fe0afe 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -36,7 +36,7 @@ class DatasourceProviderService: user_id=current_user.id, provider=provider, plugin_id=plugin_id, - credentials=credentials + credentials=credentials, ) if credential_valid: # Get all provider configurations of the current workspace @@ -47,9 +47,8 @@ class DatasourceProviderService: ) provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}" - ) + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value @@ -73,9 +72,9 @@ class DatasourceProviderService: :param credential_form_schemas: :return: """ - datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id, - provider_id=provider_id - ) + datasource_provider = self.provider_manager.fetch_datasource_provider( + tenant_id=tenant_id, provider_id=provider_id + ) credential_form_schemas = datasource_provider.declaration.credentials_schema secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: @@ -108,8 +107,9 @@ class DatasourceProviderService: for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}") + credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() @@ -149,8 +149,9 @@ class DatasourceProviderService: for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}") + credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() @@ -166,18 +167,18 @@ class DatasourceProviderService: return copy_credentials_list - - def update_datasource_credentials(self, - tenant_id: str, - auth_id: str, - provider: str, - plugin_id: str, - credentials: dict) -> None: + def update_datasource_credentials( + self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict + ) -> None: """ update datasource credentials. """ credential_valid = self.provider_manager.validate_provider_credentials( - tenant_id=tenant_id, user_id=current_user.id, provider=provider,plugin_id=plugin_id, credentials=credentials + tenant_id=tenant_id, + user_id=current_user.id, + provider=provider, + plugin_id=plugin_id, + credentials=credentials, ) if credential_valid: # Get all provider configurations of the current workspace @@ -188,9 +189,8 @@ class DatasourceProviderService: ) provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, - provider_id=f"{plugin_id}/{provider}" - ) + tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" + ) if not datasource_provider: raise ValueError("Datasource provider not found") else: diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index ca94f7f47a..7280408889 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -66,7 +66,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): ) if not pipeline_template: return None - + dsl_data = yaml.safe_load(pipeline_template.yaml_content) graph_data = dsl_data.get("workflow", {}).get("graph", {}) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a5f2135100..87b13ba98d 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -484,8 +484,13 @@ class RagPipelineService: # raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") def run_datasource_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, - is_published: bool + self, + pipeline: Pipeline, + node_id: str, + user_inputs: dict, + account: Account, + datasource_type: str, + is_published: bool, ) -> Generator[str, None, None]: """ Run published workflow datasource @@ -525,27 +530,26 @@ class RagPipelineService: datasource_provider_service = DatasourceProviderService() credentials = datasource_provider_service.get_real_datasource_credentials( tenant_id=pipeline.tenant_id, - provider=datasource_node_data.get('provider_name'), - plugin_id=datasource_node_data.get('plugin_id'), + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), ) if credentials: datasource_runtime.runtime.credentials = credentials[0].get("credentials") match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: Generator[OnlineDocumentPagesMessage, None, None] =\ + online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( datasource_runtime.get_online_document_pages( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) + ) start_time = time.time() for message in online_document_result: end_time = time.time() online_document_event = DatasourceRunEvent( - status="completed", - data=message.result, - time_consuming=round(end_time - start_time, 2) + status="completed", data=message.result, time_consuming=round(end_time - start_time, 2) ) yield json.dumps(online_document_event.model_dump()) @@ -564,7 +568,7 @@ class RagPipelineService: data=message.result.web_info_list, total=message.result.total, completed=message.result.completed, - time_consuming = round(end_time - start_time, 2) + time_consuming=round(end_time - start_time, 2), ) yield json.dumps(crawl_event.model_dump()) case _: @@ -781,9 +785,7 @@ class RagPipelineService: raise ValueError("Datasource node data not found") variables = datasource_node_data.get("variables", {}) if variables: - variables_map = { - item["variable"]: item for item in variables - } + variables_map = {item["variable"]: item for item in variables} else: return [] datasource_parameters = datasource_node_data.get("datasource_parameters", {}) @@ -813,9 +815,7 @@ class RagPipelineService: raise ValueError("Datasource node data not found") variables = datasource_node_data.get("variables", {}) if variables: - variables_map = { - item["variable"]: item for item in variables - } + variables_map = {item["variable"]: item for item in variables} else: return [] datasource_parameters = datasource_node_data.get("datasource_parameters", {}) @@ -967,11 +967,14 @@ class RagPipelineService: if not dataset: raise ValueError("Dataset not found") - max_position = db.session.query( - func.max(PipelineCustomizedTemplate.position)).filter( - PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar() + max_position = ( + db.session.query(func.max(PipelineCustomizedTemplate.position)) + .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id) + .scalar() + ) from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) pipeline_customized_template = PipelineCustomizedTemplate(