dify/api/services/rag_pipeline/rag_pipeline.py
2025-07-03 14:03:06 +08:00

1092 lines
44 KiB
Python

import json
import logging
import re
import threading
import time
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, cast
from uuid import uuid4
from flask_login import current_user
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session
import contexts
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDocumentPagesMessage,
OnlineDriveBrowseFilesRequest,
OnlineDriveBrowseFilesResponse,
WebsiteCrawlMessage,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.rag.entities.event import (
BaseDatasourceEvent,
DatasourceCompletedEvent,
DatasourceErrorEvent,
DatasourceProcessingEvent,
)
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
)
from core.workflow.enums import SystemVariableKey
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore
from models.enums import WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
Workflow,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowType,
)
from services.dataset_service import DatasetService
from services.datasource_provider_service import DatasourceProviderService
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
PipelineTemplateInfoEntity,
)
from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
logger = logging.getLogger(__name__)
class RagPipelineService:
@classmethod
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language)
if not result.get("pipeline_templates") and language != "en-US":
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
return result
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language)
return result
@classmethod
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
"""
Get pipeline template detail.
:param template_id: template id
:return:
"""
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
return result
@classmethod
def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
"""
Update pipeline template.
:param template_id: template id
:param template_info: template info
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
.filter(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
# check template name is exist
template_name = template_info.name
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
.filter(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.id != template_id,
)
.first()
)
if template:
raise ValueError("Template name is already exists")
customized_template.name = template_info.name
customized_template.description = template_info.description
customized_template.icon = template_info.icon_info.model_dump()
customized_template.updated_by = current_user.id
db.session.commit()
return customized_template
@classmethod
def delete_customized_pipeline_template(cls, template_id: str):
"""
Delete customized pipeline template.
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
.filter(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
db.session.delete(customized_template)
db.session.commit()
def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
"""
Get draft workflow
"""
# fetch draft workflow by rag pipeline
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
)
# return draft workflow
return workflow
def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
"""
Get published workflow
"""
if not pipeline.workflow_id:
return None
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == pipeline.workflow_id,
)
.first()
)
return workflow
def get_all_published_workflow(
self,
*,
session: Session,
pipeline: Pipeline,
page: int,
limit: int,
user_id: str | None,
named_only: bool = False,
) -> tuple[Sequence[Workflow], bool]:
"""
Get published workflow with pagination
"""
if not pipeline.workflow_id:
return [], False
stmt = (
select(Workflow)
.where(Workflow.app_id == pipeline.id)
.order_by(Workflow.version.desc())
.limit(limit + 1)
.offset((page - 1) * limit)
)
if user_id:
stmt = stmt.where(Workflow.created_by == user_id)
if named_only:
stmt = stmt.where(Workflow.marked_name != "")
workflows = session.scalars(stmt).all()
has_more = len(workflows) > limit
if has_more:
workflows = workflows[:-1]
return workflows, has_more
def sync_draft_workflow(
self,
*,
pipeline: Pipeline,
graph: dict,
unique_hash: Optional[str],
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
rag_pipeline_variables: list,
) -> Workflow:
"""
Sync draft workflow
:raises WorkflowHashNotEqualError
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(pipeline=pipeline)
if workflow and workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# create draft workflow if not found
if not workflow:
workflow = Workflow(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE.value,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables,
)
db.session.add(workflow)
db.session.flush()
pipeline.workflow_id = workflow.id
# update draft workflow if found
else:
workflow.graph = json.dumps(graph)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
workflow.rag_pipeline_variables = rag_pipeline_variables
# commit db session changes
db.session.commit()
# trigger workflow events TODO
# app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
# return draft workflow
return workflow
def publish_workflow(
self,
*,
session: Session,
pipeline: Pipeline,
account: Account,
) -> Workflow:
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
draft_workflow = session.scalar(draft_workflow_stmt)
if not draft_workflow:
raise ValueError("No valid workflow found.")
# create new workflow
workflow = Workflow.new(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
type=draft_workflow.type,
version=str(datetime.now(UTC).replace(tzinfo=None)),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
marked_name="",
marked_comment="",
)
# commit db session changes
session.add(workflow)
graph = workflow.graph_dict
nodes = graph.get("nodes", [])
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
# update dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
DatasetService.update_rag_pipeline_dataset_settings(
session=session,
dataset=dataset,
knowledge_configuration=knowledge_configuration,
has_published=pipeline.is_published,
)
# return new workflow
return workflow
def get_default_block_configs(self) -> list[dict]:
"""
Get default block configs
"""
# return default block config
default_block_configs = []
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
node_class = node_class_mapping[LATEST_VERSION]
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(default_config)
return default_block_configs
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
"""
Get default config of node.
:param node_type: node type
:param filters: filter by node config parameters.
:return:
"""
node_type_enum = NodeType(node_type)
# return default block config
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
return None
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return default_config
def run_draft_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
) -> WorkflowNodeExecution:
"""
Run draft workflow node
"""
# fetch draft workflow by app_model
draft_workflow = self.get_draft_workflow(pipeline=pipeline)
if not draft_workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
),
start_at=start_at,
tenant_id=pipeline.tenant_id,
node_id=node_id,
)
workflow_node_execution.workflow_id = draft_workflow.id
db.session.add(workflow_node_execution)
db.session.commit()
return workflow_node_execution
def run_published_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
) -> WorkflowNodeExecution:
"""
Run published workflow node
"""
# fetch published workflow by app_model
published_workflow = self.get_published_workflow(pipeline=pipeline)
if not published_workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run(
workflow=published_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
),
start_at=start_at,
tenant_id=pipeline.tenant_id,
node_id=node_id,
)
workflow_node_execution.workflow_id = published_workflow.id
db.session.add(workflow_node_execution)
db.session.commit()
return workflow_node_execution
def run_datasource_workflow_node(
self,
pipeline: Pipeline,
node_id: str,
user_inputs: dict,
account: Account,
datasource_type: str,
is_published: bool,
) -> Generator[BaseDatasourceEvent, None, None]:
"""
Run published workflow datasource
"""
try:
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if not user_inputs.get(key):
user_inputs[key] = value["value"]
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
)
start_time = time.time()
start_event = DatasourceProcessingEvent(
total=0,
completed=0,
)
yield start_event.model_dump()
try:
for message in online_document_result:
end_time = time.time()
online_document_event = DatasourceCompletedEvent(
data=message.result, time_consuming=round(end_time - start_time, 2)
)
yield online_document_event.model_dump()
except Exception as e:
logger.exception("Error during online document.")
yield DatasourceErrorEvent(error=str(e)).model_dump()
case DatasourceProviderType.ONLINE_DRIVE:
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = (
datasource_runtime.online_drive_browse_files(
user_id=account.id,
request=OnlineDriveBrowseFilesRequest(
bucket=user_inputs.get("bucket"),
prefix=user_inputs.get("prefix"),
max_keys=user_inputs.get("max_keys", 20),
start_after=user_inputs.get("start_after"),
),
provider_type=datasource_runtime.datasource_provider_type(),
)
)
start_time = time.time()
start_event = DatasourceProcessingEvent(
total=0,
completed=0,
)
yield start_event.model_dump()
for message in online_drive_result:
end_time = time.time()
online_drive_event = DatasourceCompletedEvent(
data=message.result,
time_consuming=round(end_time - start_time, 2),
total=None,
completed=None,
)
yield online_drive_event.model_dump()
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = (
datasource_runtime.get_website_crawl(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
)
start_time = time.time()
try:
for message in website_crawl_result:
end_time = time.time()
if message.result.status == "completed":
crawl_event = DatasourceCompletedEvent(
data=message.result.web_info_list,
total=message.result.total,
completed=message.result.completed,
time_consuming=round(end_time - start_time, 2),
)
else:
crawl_event = DatasourceProcessingEvent(
total=message.result.total,
completed=message.result.completed,
)
yield crawl_event.model_dump()
except Exception as e:
logger.exception("Error during website crawl.")
yield DatasourceErrorEvent(error=str(e)).model_dump()
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
except Exception as e:
logger.exception("Error in run_datasource_workflow_node.")
yield DatasourceErrorEvent(error=str(e)).model_dump()
def run_datasource_node_preview(
self,
pipeline: Pipeline,
node_id: str,
user_inputs: dict,
account: Account,
datasource_type: str,
is_published: bool,
) -> Mapping[str, Any]:
"""
Run published workflow datasource
"""
try:
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if not user_inputs.get(key):
user_inputs[key] = value["value"]
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.get_online_document_page_content(
user_id=account.id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=user_inputs.get("workspace_id"),
page_id=user_inputs.get("page_id"),
type=user_inputs.get("type"),
),
provider_type=datasource_type,
)
)
try:
variables: dict[str, Any] = {}
for message in online_document_result:
if message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
else:
variables[variable_name] = variable_value
return variables
except Exception as e:
logger.exception("Error during get online document content.")
raise RuntimeError(str(e))
# TODO Online Drive
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
except Exception as e:
logger.exception("Error in run_datasource_node_preview.")
raise RuntimeError(str(e))
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
"""
Run draft workflow node
"""
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.run_free_node(
node_id=node_id,
node_data=node_data,
tenant_id=tenant_id,
user_id=user_id,
user_inputs=user_inputs,
),
start_at=start_at,
tenant_id=tenant_id,
node_id=node_id,
)
return workflow_node_execution
def _handle_node_run_result(
self,
getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
start_at: float,
tenant_id: str,
node_id: str,
) -> WorkflowNodeExecution:
"""
Handle node run result
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
:param start_at: float
:param tenant_id: str
:param node_id: str
"""
try:
node_instance, generator = getter()
node_run_result: NodeRunResult | None = None
for event in generator:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break
if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
else:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node_instance = e.node_instance
run_succeeded = False
node_run_result = None
error = e.error
workflow_node_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=node_instance.workflow_id,
index=1,
node_id=node_id,
node_type=node_instance.node_type,
title=node_instance.node_data.title,
elapsed_time=time.perf_counter() - start_at,
finished_at=datetime.now(UTC).replace(tzinfo=None),
created_at=datetime.now(UTC).replace(tzinfo=None),
)
if run_succeeded and node_run_result:
# create workflow node execution
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
process_data = (
WorkflowEntry.handle_special_values(node_run_result.process_data)
if node_run_result.process_data
else None
)
outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
workflow_node_execution.inputs = inputs
workflow_node_execution.process_data = process_data
workflow_node_execution.outputs = outputs
workflow_node_execution.metadata = node_run_result.metadata
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
workflow_node_execution.error = node_run_result.error
else:
# create workflow node execution
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED
workflow_node_execution.error = error
# update document status
variable_pool = node_instance.graph_runtime_state.variable_pool
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
if invoke_from.value == InvokeFrom.PUBLISHED.value:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter(Document.id == document_id.value).first()
if document:
document.indexing_status = "error"
document.error = error
db.session.add(document)
db.session.commit()
return workflow_node_execution
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
"""
Update workflow attributes
:param session: SQLAlchemy database session
:param workflow_id: Workflow ID
:param tenant_id: Tenant ID
:param account_id: Account ID (for permission check)
:param data: Dictionary containing fields to update
:return: Updated workflow or None if not found
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
workflow = session.scalar(stmt)
if not workflow:
return None
allowed_fields = ["marked_name", "marked_comment"]
for field, value in data.items():
if field in allowed_fields:
setattr(workflow, field, value)
workflow.updated_by = account_id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
return workflow
def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
"""
Get first step parameters of rag pipeline
"""
workflow = (
self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
)
if not workflow:
raise ValueError("Workflow not initialized")
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
variables = workflow.rag_pipeline_variables
if variables:
variables_map = {item["variable"]: item for item in variables}
else:
return []
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
user_input_variables = []
for key, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
user_input_variables.append(variables_map.get(last_part, {}))
return user_input_variables
def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
"""
Get second step parameters of rag pipeline
"""
workflow = (
self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
)
if not workflow:
raise ValueError("Workflow not initialized")
# get second step node
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
return []
variables_map = {item["variable"]: item for item in rag_pipeline_variables}
# get datasource node data
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if datasource_node_data:
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
variables_map.pop(last_part)
all_second_step_variables = list(variables_map.values())
datasource_provider_variables = [
item
for item in all_second_step_variables
if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
]
return datasource_provider_variables
def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
"""
Get debug workflow run list
Only return triggered_from == debugging
:param app_model: app model
:param args: request args
"""
limit = int(args.get("limit", 20))
base_query = db.session.query(WorkflowRun).filter(
WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id,
or_(
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
),
)
if args.get("last_id"):
last_workflow_run = base_query.filter(
WorkflowRun.id == args.get("last_id"),
).first()
if not last_workflow_run:
raise ValueError("Last workflow run not exists")
workflow_runs = (
base_query.filter(
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
)
.order_by(WorkflowRun.created_at.desc())
.limit(limit)
.all()
)
else:
workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
has_more = False
if len(workflow_runs) == limit:
current_page_first_workflow_run = workflow_runs[-1]
rest_count = base_query.filter(
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
WorkflowRun.id != current_page_first_workflow_run.id,
).count()
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
"""
Get workflow run detail
:param app_model: app model
:param run_id: workflow run id
"""
workflow_run = (
db.session.query(WorkflowRun)
.filter(
WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id,
WorkflowRun.id == run_id,
)
.first()
)
return workflow_run
def get_rag_pipeline_workflow_run_node_executions(
self,
pipeline: Pipeline,
run_id: str,
user: Account | EndUser,
) -> list[WorkflowNodeExecutionModel]:
"""
Get workflow run node execution list
"""
workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if not workflow_run:
return []
# Use the repository to get the node execution
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine, app_id=pipeline.id, user=user, triggered_from=None
)
# Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_db_models_by_workflow_run(
workflow_run_id=run_id,
order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
return list(node_executions)
@classmethod
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
"""
Publish customized pipeline template
"""
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError("Workflow not found")
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
# check template name is exist
template_name = args.get("name")
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
.filter(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
)
.first()
)
if template:
raise ValueError("Template name is already exists")
max_position = (
db.session.query(func.max(PipelineCustomizedTemplate.position))
.filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.scalar()
)
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
pipeline_customized_template = PipelineCustomizedTemplate(
name=args.get("name"),
description=args.get("description"),
icon=args.get("icon_info"),
tenant_id=pipeline.tenant_id,
yaml_content=dsl,
position=max_position + 1 if max_position else 1,
chunk_structure=dataset.chunk_structure,
language="en-US",
created_by=current_user.id,
)
db.session.add(pipeline_customized_template)
db.session.commit()