dify/api/services/rag_pipeline/rag_pipeline.py

995 lines
39 KiB
Python
Raw Normal View History

2025-04-14 18:17:17 +08:00
import json
2025-06-04 16:23:12 +08:00
import re
2025-05-15 15:14:52 +08:00
import threading
2025-04-14 18:17:17 +08:00
import time
from collections.abc import Callable, Generator, Sequence
from datetime import UTC, datetime
2025-05-23 15:55:41 +08:00
from typing import Any, Optional, cast
2025-05-15 15:14:52 +08:00
from uuid import uuid4
2025-04-10 18:00:22 +08:00
from flask_login import current_user
2025-06-10 17:11:49 +08:00
from sqlalchemy import func, or_, select
2025-04-14 18:17:17 +08:00
from sqlalchemy.orm import Session
2025-04-10 18:00:22 +08:00
2025-05-15 15:14:52 +08:00
import contexts
2025-04-14 11:10:44 +08:00
from configs import dify_config
2025-06-10 19:22:08 +08:00
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
2025-06-13 18:22:15 +08:00
OnlineDocumentPagesMessage,
2025-06-13 17:36:24 +08:00
WebsiteCrawlMessage,
)
2025-05-23 15:55:41 +08:00
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.rag.entities.event import BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceProcessingEvent
2025-05-15 17:19:14 +08:00
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
2025-04-14 18:17:17 +08:00
from core.variables.variables import Variable
2025-05-15 15:14:52 +08:00
from core.workflow.entities.node_entities import NodeRunResult
2025-06-03 19:10:40 +08:00
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
)
2025-06-10 19:22:08 +08:00
from core.workflow.enums import SystemVariableKey
2025-05-15 15:14:52 +08:00
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
2025-04-14 18:17:17 +08:00
from core.workflow.nodes.base.node import BaseNode
2025-05-15 15:14:52 +08:00
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
2025-04-14 18:17:17 +08:00
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
2025-06-03 19:10:40 +08:00
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
2025-04-14 18:17:17 +08:00
from core.workflow.workflow_entry import WorkflowEntry
2025-04-17 15:07:23 +08:00
from extensions.ext_database import db
2025-05-15 15:14:52 +08:00
from libs.infinite_scroll_pagination import InfiniteScrollPagination
2025-04-14 18:17:17 +08:00
from models.account import Account
2025-06-10 19:22:08 +08:00
from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore
2025-06-13 17:36:24 +08:00
from models.enums import WorkflowRunTriggeredFrom
2025-05-23 15:55:41 +08:00
from models.model import EndUser
2025-05-15 15:14:52 +08:00
from models.workflow import (
Workflow,
2025-06-13 17:36:24 +08:00
WorkflowNodeExecutionModel,
2025-05-15 15:14:52 +08:00
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
2025-06-13 17:36:24 +08:00
WorkflowType,
2025-05-15 15:14:52 +08:00
)
2025-05-28 17:56:04 +08:00
from services.dataset_service import DatasetService
2025-06-06 14:22:00 +08:00
from services.datasource_provider_service import DatasourceProviderService
2025-05-29 23:04:04 +08:00
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
PipelineTemplateInfoEntity,
)
2025-04-14 18:17:17 +08:00
from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
2025-04-10 18:00:22 +08:00
2025-05-16 17:22:17 +08:00
2025-04-10 18:00:22 +08:00
class RagPipelineService:
2025-05-29 23:04:04 +08:00
@classmethod
2025-06-03 19:02:57 +08:00
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
2025-04-10 18:00:22 +08:00
if type == "built-in":
2025-04-14 18:17:17 +08:00
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
2025-05-16 17:22:17 +08:00
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
2025-04-14 18:17:17 +08:00
result = retrieval_instance.get_pipeline_templates(language)
if not result.get("pipeline_templates") and language != "en-US":
2025-04-27 14:31:19 +08:00
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
2025-05-29 23:04:04 +08:00
return result
2025-04-10 18:00:22 +08:00
else:
2025-04-14 18:17:17 +08:00
mode = "customized"
2025-05-16 17:22:17 +08:00
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
2025-04-14 18:17:17 +08:00
result = retrieval_instance.get_pipeline_templates(language)
2025-05-29 23:04:04 +08:00
return result
2025-04-14 11:10:44 +08:00
2025-06-03 13:30:51 +08:00
@classmethod
2025-06-10 17:11:49 +08:00
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
2025-04-14 11:10:44 +08:00
"""
2025-04-14 18:17:17 +08:00
Get pipeline template detail.
2025-05-23 17:11:56 +08:00
:param template_id: template id
2025-04-14 11:10:44 +08:00
:return:
"""
2025-06-10 17:11:49 +08:00
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
2025-04-14 18:17:17 +08:00
return result
@classmethod
def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
"""
Update pipeline template.
:param template_id: template id
:param template_info: template info
"""
customized_template: PipelineCustomizedTemplate | None = (
2025-06-11 11:21:17 +08:00
db.session.query(PipelineCustomizedTemplate)
2025-04-14 18:17:17 +08:00
.filter(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
customized_template.name = template_info.name
customized_template.description = template_info.description
customized_template.icon = template_info.icon_info.model_dump()
2025-06-11 13:12:18 +08:00
customized_template.updated_by = current_user.id
2025-06-11 11:21:17 +08:00
db.session.commit()
2025-04-14 18:17:17 +08:00
return customized_template
@classmethod
def delete_customized_pipeline_template(cls, template_id: str):
"""
Delete customized pipeline template.
"""
customized_template: PipelineCustomizedTemplate | None = (
2025-06-11 17:10:20 +08:00
db.session.query(PipelineCustomizedTemplate)
2025-04-14 18:17:17 +08:00
.filter(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
2025-06-11 17:10:20 +08:00
db.session.delete(customized_template)
db.session.commit()
2025-04-14 18:17:17 +08:00
def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
2025-04-27 14:31:19 +08:00
"""
Get draft workflow
"""
# fetch draft workflow by rag pipeline
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
2025-04-14 11:10:44 +08:00
)
2025-04-27 14:31:19 +08:00
.first()
)
2025-04-14 11:10:44 +08:00
2025-04-27 14:31:19 +08:00
# return draft workflow
return workflow
2025-04-14 11:10:44 +08:00
2025-04-14 18:17:17 +08:00
def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
2025-04-14 11:10:44 +08:00
"""
2025-04-14 18:17:17 +08:00
Get published workflow
"""
if not pipeline.workflow_id:
return None
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == pipeline.workflow_id,
)
.first()
)
return workflow
def get_all_published_workflow(
self,
*,
session: Session,
pipeline: Pipeline,
page: int,
limit: int,
user_id: str | None,
named_only: bool = False,
) -> tuple[Sequence[Workflow], bool]:
"""
Get published workflow with pagination
"""
if not pipeline.workflow_id:
return [], False
stmt = (
select(Workflow)
.where(Workflow.app_id == pipeline.id)
.order_by(Workflow.version.desc())
.limit(limit + 1)
.offset((page - 1) * limit)
)
if user_id:
stmt = stmt.where(Workflow.created_by == user_id)
if named_only:
stmt = stmt.where(Workflow.marked_name != "")
workflows = session.scalars(stmt).all()
has_more = len(workflows) > limit
if has_more:
workflows = workflows[:-1]
return workflows, has_more
def sync_draft_workflow(
self,
*,
pipeline: Pipeline,
graph: dict,
unique_hash: Optional[str],
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
2025-05-20 15:18:33 +08:00
rag_pipeline_variables: list,
2025-04-14 18:17:17 +08:00
) -> Workflow:
"""
Sync draft workflow
:raises WorkflowHashNotEqualError
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(pipeline=pipeline)
if workflow and workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# create draft workflow if not found
if not workflow:
workflow = Workflow(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
2025-05-16 17:22:17 +08:00
features="{}",
2025-04-14 18:17:17 +08:00
type=WorkflowType.RAG_PIPELINE.value,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
2025-05-16 17:22:17 +08:00
rag_pipeline_variables=rag_pipeline_variables,
2025-04-14 18:17:17 +08:00
)
db.session.add(workflow)
2025-05-16 17:22:17 +08:00
db.session.flush()
pipeline.workflow_id = workflow.id
2025-04-14 18:17:17 +08:00
# update draft workflow if found
else:
workflow.graph = json.dumps(graph)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
2025-05-16 17:22:17 +08:00
workflow.rag_pipeline_variables = rag_pipeline_variables
2025-04-14 18:17:17 +08:00
# commit db session changes
db.session.commit()
2025-05-15 15:14:52 +08:00
# trigger workflow events TODO
# app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
2025-04-14 18:17:17 +08:00
# return draft workflow
return workflow
def publish_workflow(
self,
*,
session: Session,
pipeline: Pipeline,
account: Account,
) -> Workflow:
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
draft_workflow = session.scalar(draft_workflow_stmt)
if not draft_workflow:
raise ValueError("No valid workflow found.")
# create new workflow
workflow = Workflow.new(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
type=draft_workflow.type,
version=str(datetime.now(UTC).replace(tzinfo=None)),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
2025-06-03 13:30:51 +08:00
environment_variables=draft_workflow.environment_variables,
2025-04-14 18:17:17 +08:00
conversation_variables=draft_workflow.conversation_variables,
2025-05-28 17:56:04 +08:00
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
marked_name="",
marked_comment="",
2025-04-14 18:17:17 +08:00
)
# commit db session changes
session.add(workflow)
2025-05-29 09:53:42 +08:00
graph = workflow.graph_dict
nodes = graph.get("nodes", [])
for node in nodes:
2025-06-03 13:30:51 +08:00
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = node.get("data", {})
2025-05-29 09:53:42 +08:00
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
# update dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
DatasetService.update_rag_pipeline_dataset_settings(
session=session,
2025-06-03 13:30:51 +08:00
dataset=dataset,
knowledge_configuration=knowledge_configuration,
2025-06-03 19:02:57 +08:00
has_published=pipeline.is_published,
2025-05-29 09:53:42 +08:00
)
2025-04-14 18:17:17 +08:00
# return new workflow
return workflow
def get_default_block_configs(self) -> list[dict]:
"""
Get default block configs
"""
# return default block config
default_block_configs = []
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
node_class = node_class_mapping[LATEST_VERSION]
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(default_config)
return default_block_configs
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
"""
Get default config of node.
:param node_type: node type
:param filters: filter by node config parameters.
2025-04-14 11:10:44 +08:00
:return:
"""
2025-04-14 18:17:17 +08:00
node_type_enum = NodeType(node_type)
# return default block config
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
return None
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return default_config
def run_draft_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
) -> WorkflowNodeExecution:
"""
Run draft workflow node
"""
# fetch draft workflow by app_model
draft_workflow = self.get_draft_workflow(pipeline=pipeline)
if not draft_workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
),
start_at=start_at,
tenant_id=pipeline.tenant_id,
node_id=node_id,
)
workflow_node_execution.workflow_id = draft_workflow.id
db.session.add(workflow_node_execution)
db.session.commit()
return workflow_node_execution
2025-04-27 14:31:19 +08:00
2025-05-23 00:05:57 +08:00
def run_published_workflow_node(
2025-04-17 15:07:23 +08:00
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
) -> WorkflowNodeExecution:
"""
2025-05-23 00:05:57 +08:00
Run published workflow node
2025-04-17 15:07:23 +08:00
"""
# fetch published workflow by app_model
published_workflow = self.get_published_workflow(pipeline=pipeline)
if not published_workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run(
workflow=published_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
),
start_at=start_at,
tenant_id=pipeline.tenant_id,
node_id=node_id,
)
workflow_node_execution.workflow_id = published_workflow.id
db.session.add(workflow_node_execution)
db.session.commit()
return workflow_node_execution
2025-04-14 18:17:17 +08:00
# def run_datasource_workflow_node_status(
# self, pipeline: Pipeline, node_id: str, job_id: str, account: Account,
# datasource_type: str, is_published: bool
# ) -> dict:
# """
# Run published workflow datasource
# """
# if is_published:
# # fetch published workflow by app_model
# workflow = self.get_published_workflow(pipeline=pipeline)
# else:
# workflow = self.get_draft_workflow(pipeline=pipeline)
# if not workflow:
# raise ValueError("Workflow not initialized")
#
# # run draft workflow node
# datasource_node_data = None
# start_at = time.perf_counter()
# datasource_nodes = workflow.graph_dict.get("nodes", [])
# for datasource_node in datasource_nodes:
# if datasource_node.get("id") == node_id:
# datasource_node_data = datasource_node.get("data", {})
# break
# if not datasource_node_data:
# raise ValueError("Datasource node data not found")
#
# from core.datasource.datasource_manager import DatasourceManager
#
# datasource_runtime = DatasourceManager.get_datasource_runtime(
# provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
# datasource_name=datasource_node_data.get("datasource_name"),
# tenant_id=pipeline.tenant_id,
# datasource_type=DatasourceProviderType(datasource_type),
# )
# datasource_provider_service = DatasourceProviderService()
# credentials = datasource_provider_service.get_real_datasource_credentials(
# tenant_id=pipeline.tenant_id,
# provider=datasource_node_data.get('provider_name'),
# plugin_id=datasource_node_data.get('plugin_id'),
# )
# if credentials:
# datasource_runtime.runtime.credentials = credentials[0].get("credentials")
# match datasource_type:
#
# case DatasourceProviderType.WEBSITE_CRAWL:
# datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
# website_crawl_results: list[WebsiteCrawlMessage] = []
# for website_message in datasource_runtime.get_website_crawl(
# user_id=account.id,
# datasource_parameters={"job_id": job_id},
# provider_type=datasource_runtime.datasource_provider_type(),
# ):
# website_crawl_results.append(website_message)
# return {
# "result": [result for result in website_crawl_results.result],
# "status": website_crawl_results.result.status,
# "provider_type": datasource_node_data.get("provider_type"),
# }
# case _:
# raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
2025-06-06 17:14:43 +08:00
2025-05-23 00:05:57 +08:00
def run_datasource_workflow_node(
2025-06-06 17:47:06 +08:00
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str,
is_published: bool
) -> Generator[BaseDatasourceEvent, None, None]:
2025-05-23 00:05:57 +08:00
"""
Run published workflow datasource
"""
2025-06-06 10:40:06 +08:00
if is_published:
2025-06-06 17:47:06 +08:00
# fetch published workflow by app_model
2025-06-06 10:40:06 +08:00
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
2025-05-23 00:05:57 +08:00
raise ValueError("Workflow not initialized")
# run draft workflow node
2025-06-05 11:45:53 +08:00
datasource_node_data = None
2025-05-23 00:05:57 +08:00
start_at = time.perf_counter()
2025-06-06 10:40:06 +08:00
datasource_nodes = workflow.graph_dict.get("nodes", [])
2025-06-05 11:45:53 +08:00
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
2025-05-23 00:05:57 +08:00
if not datasource_node_data:
raise ValueError("Datasource node data not found")
2025-06-06 10:40:06 +08:00
2025-06-04 16:23:12 +08:00
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if not user_inputs.get(key):
user_inputs[key] = value["value"]
2025-06-06 14:22:00 +08:00
2025-05-23 00:05:57 +08:00
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
2025-06-04 16:23:12 +08:00
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
2025-05-23 00:05:57 +08:00
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
2025-05-23 15:55:41 +08:00
datasource_type=DatasourceProviderType(datasource_type),
2025-05-23 00:05:57 +08:00
)
2025-06-06 14:22:00 +08:00
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
2025-06-17 19:06:17 +08:00
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
2025-06-06 14:22:00 +08:00
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
2025-06-17 19:06:17 +08:00
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
2025-06-17 19:06:17 +08:00
)
start_time = time.time()
2025-06-16 13:48:43 +08:00
for message in online_document_result:
end_time = time.time()
online_document_event = DatasourceCompletedEvent(
data=message.result,
time_consuming=round(end_time - start_time, 2)
2025-06-16 13:48:43 +08:00
)
yield online_document_event.model_dump()
2025-06-06 14:22:00 +08:00
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl(
2025-06-06 14:22:00 +08:00
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
start_time = time.time()
for message in website_crawl_result:
end_time = time.time()
if message.result.status == "completed":
crawl_event = DatasourceCompletedEvent(
data=message.result.web_info_list,
total=message.result.total,
completed=message.result.completed,
time_consuming=round(end_time - start_time, 2)
)
else:
crawl_event = DatasourceProcessingEvent(
total=message.result.total,
completed=message.result.completed,
)
yield crawl_event.model_dump()
2025-06-06 14:22:00 +08:00
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
2025-05-23 00:05:57 +08:00
2025-04-14 18:17:17 +08:00
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
"""
Run draft workflow node
"""
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.run_free_node(
node_id=node_id,
node_data=node_data,
tenant_id=tenant_id,
user_id=user_id,
user_inputs=user_inputs,
),
start_at=start_at,
tenant_id=tenant_id,
node_id=node_id,
)
return workflow_node_execution
def _handle_node_run_result(
self,
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
start_at: float,
tenant_id: str,
node_id: str,
) -> WorkflowNodeExecution:
"""
Handle node run result
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
:param start_at: float
:param tenant_id: str
:param node_id: str
"""
try:
node_instance, generator = getter()
node_run_result: NodeRunResult | None = None
for event in generator:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break
if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
else:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node_instance = e.node_instance
run_succeeded = False
node_run_result = None
error = e.error
2025-06-05 15:28:44 +08:00
workflow_node_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=node_instance.workflow_id,
index=1,
node_id=node_id,
node_type=node_instance.node_type,
title=node_instance.node_data.title,
elapsed_time=time.perf_counter() - start_at,
finished_at=datetime.now(UTC).replace(tzinfo=None),
created_at=datetime.now(UTC).replace(tzinfo=None),
)
2025-04-14 18:17:17 +08:00
if run_succeeded and node_run_result:
# create workflow node execution
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
process_data = (
WorkflowEntry.handle_special_values(node_run_result.process_data)
if node_run_result.process_data
else None
)
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
2025-06-05 15:28:44 +08:00
workflow_node_execution.inputs = inputs
workflow_node_execution.process_data = process_data
workflow_node_execution.outputs = outputs
workflow_node_execution.metadata = node_run_result.metadata
2025-04-14 18:17:17 +08:00
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
2025-06-05 15:28:44 +08:00
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
2025-04-14 18:17:17 +08:00
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
2025-06-05 15:28:44 +08:00
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
2025-04-14 18:17:17 +08:00
workflow_node_execution.error = node_run_result.error
else:
# create workflow node execution
2025-06-05 15:28:44 +08:00
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED
2025-04-14 18:17:17 +08:00
workflow_node_execution.error = error
2025-06-10 19:22:08 +08:00
# update document status
variable_pool = node_instance.graph_runtime_state.variable_pool
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
if invoke_from.value == InvokeFrom.PUBLISHED.value:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter(Document.id == document_id.value).first()
if document:
document.indexing_status = "error"
document.error = error
db.session.add(document)
db.session.commit()
2025-04-14 18:17:17 +08:00
return workflow_node_execution
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
"""
Update workflow attributes
:param session: SQLAlchemy database session
:param workflow_id: Workflow ID
:param tenant_id: Tenant ID
:param account_id: Account ID (for permission check)
:param data: Dictionary containing fields to update
:return: Updated workflow or None if not found
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
workflow = session.scalar(stmt)
if not workflow:
return None
allowed_fields = ["marked_name", "marked_comment"]
for field, value in data.items():
if field in allowed_fields:
setattr(workflow, field, value)
workflow.updated_by = account_id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
return workflow
2025-05-23 15:55:41 +08:00
def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
2025-04-17 15:07:23 +08:00
"""
Get second step parameters of rag pipeline
"""
2025-04-27 14:31:19 +08:00
2025-04-17 15:07:23 +08:00
workflow = self.get_published_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
2025-04-27 14:31:19 +08:00
2025-04-17 15:07:23 +08:00
# get second step node
2025-05-20 14:57:26 +08:00
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
2025-05-23 15:55:41 +08:00
return []
2025-05-20 14:57:26 +08:00
2025-04-17 15:07:23 +08:00
# get datasource provider
2025-05-23 00:05:57 +08:00
datasource_provider_variables = [
item
for item in rag_pipeline_variables
if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
]
return datasource_provider_variables
2025-06-06 10:40:06 +08:00
2025-06-04 16:23:12 +08:00
def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
"""
Get first step parameters of rag pipeline
"""
published_workflow = self.get_published_workflow(pipeline=pipeline)
if not published_workflow:
raise ValueError("Workflow not initialized")
# get second step node
2025-06-05 11:45:53 +08:00
datasource_node_data = None
datasource_nodes = published_workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
2025-06-04 16:23:12 +08:00
if not datasource_node_data:
raise ValueError("Datasource node data not found")
2025-06-05 15:28:44 +08:00
variables = datasource_node_data.get("variables", {})
if variables:
2025-06-17 19:06:17 +08:00
variables_map = {item["variable"]: item for item in variables}
2025-06-04 16:23:12 +08:00
else:
2025-06-05 16:43:47 +08:00
return []
2025-06-05 15:28:44 +08:00
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
2025-06-04 16:23:12 +08:00
user_input_variables = []
2025-06-05 15:28:44 +08:00
for key, value in datasource_parameters.items():
2025-06-04 16:23:12 +08:00
if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
2025-06-05 15:28:44 +08:00
user_input_variables.append(variables_map.get(key, {}))
2025-06-04 16:23:12 +08:00
return user_input_variables
2025-06-06 10:40:06 +08:00
2025-06-04 16:23:12 +08:00
def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
"""
Get first step parameters of rag pipeline
"""
draft_workflow = self.get_draft_workflow(pipeline=pipeline)
if not draft_workflow:
raise ValueError("Workflow not initialized")
# get second step node
2025-06-05 11:45:53 +08:00
datasource_node_data = None
datasource_nodes = draft_workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
2025-06-04 16:23:12 +08:00
if not datasource_node_data:
raise ValueError("Datasource node data not found")
2025-06-05 15:28:44 +08:00
variables = datasource_node_data.get("variables", {})
if variables:
2025-06-17 19:06:17 +08:00
variables_map = {item["variable"]: item for item in variables}
2025-06-04 16:23:12 +08:00
else:
2025-06-05 16:43:47 +08:00
return []
2025-06-05 15:28:44 +08:00
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
2025-06-04 16:23:12 +08:00
user_input_variables = []
2025-06-05 15:28:44 +08:00
for key, value in datasource_parameters.items():
2025-06-04 16:23:12 +08:00
if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
2025-06-05 15:28:44 +08:00
user_input_variables.append(variables_map.get(key, {}))
2025-06-04 16:23:12 +08:00
return user_input_variables
2025-05-23 00:05:57 +08:00
2025-05-23 15:55:41 +08:00
def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]:
2025-05-23 00:05:57 +08:00
"""
Get second step parameters of rag pipeline
"""
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# get second step node
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
2025-05-23 15:55:41 +08:00
return []
2025-05-23 00:05:57 +08:00
# get datasource provider
datasource_provider_variables = [
item
for item in rag_pipeline_variables
if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
]
2025-05-20 14:57:26 +08:00
return datasource_provider_variables
2025-05-15 15:14:52 +08:00
def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
"""
Get debug workflow run list
Only return triggered_from == debugging
:param app_model: app model
:param args: request args
"""
limit = int(args.get("limit", 20))
base_query = db.session.query(WorkflowRun).filter(
WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id,
2025-05-29 23:04:04 +08:00
or_(
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
2025-06-03 19:02:57 +08:00
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
),
2025-05-15 15:14:52 +08:00
)
if args.get("last_id"):
last_workflow_run = base_query.filter(
WorkflowRun.id == args.get("last_id"),
).first()
if not last_workflow_run:
raise ValueError("Last workflow run not exists")
workflow_runs = (
base_query.filter(
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
)
.order_by(WorkflowRun.created_at.desc())
.limit(limit)
.all()
)
else:
workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
has_more = False
if len(workflow_runs) == limit:
current_page_first_workflow_run = workflow_runs[-1]
rest_count = base_query.filter(
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
WorkflowRun.id != current_page_first_workflow_run.id,
).count()
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
"""
Get workflow run detail
:param app_model: app model
:param run_id: workflow run id
"""
workflow_run = (
db.session.query(WorkflowRun)
.filter(
WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id,
WorkflowRun.id == run_id,
)
.first()
)
return workflow_run
def get_rag_pipeline_workflow_run_node_executions(
self,
pipeline: Pipeline,
run_id: str,
2025-05-23 15:55:41 +08:00
user: Account | EndUser,
2025-06-11 18:03:21 +08:00
) -> list[WorkflowNodeExecutionModel]:
2025-05-15 15:14:52 +08:00
"""
Get workflow run node execution list
"""
workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if not workflow_run:
return []
2025-05-15 17:19:14 +08:00
# Use the repository to get the node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository(
2025-06-03 19:02:57 +08:00
session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
2025-05-15 15:14:52 +08:00
)
# Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc")
2025-06-11 18:03:21 +08:00
node_executions = repository.get_db_models_by_workflow_run(
2025-06-03 19:02:57 +08:00
workflow_run_id=run_id,
order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
2025-05-15 15:14:52 +08:00
2025-06-05 15:28:44 +08:00
return list(node_executions)
2025-06-03 13:30:51 +08:00
2025-05-29 23:04:04 +08:00
@classmethod
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
"""
Publish customized pipeline template
"""
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError("Workflow not found")
2025-06-10 17:11:49 +08:00
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
2025-06-17 19:06:17 +08:00
max_position = (
db.session.query(func.max(PipelineCustomizedTemplate.position))
.filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.scalar()
)
2025-06-10 17:11:49 +08:00
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
2025-06-17 19:06:17 +08:00
2025-06-10 17:11:49 +08:00
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
2025-06-03 13:30:51 +08:00
2025-06-10 17:11:49 +08:00
pipeline_customized_template = PipelineCustomizedTemplate(
name=args.get("name"),
description=args.get("description"),
icon=args.get("icon_info"),
tenant_id=pipeline.tenant_id,
yaml_content=dsl,
position=max_position + 1 if max_position else 1,
chunk_structure=dataset.chunk_structure,
language="en-US",
2025-06-11 13:12:18 +08:00
created_by=current_user.id,
2025-06-10 17:11:49 +08:00
)
db.session.add(pipeline_customized_template)
2025-06-03 13:30:51 +08:00
db.session.commit()