diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 5be87a3bb6..76139c4ebe 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,3 +1,4 @@ +import re from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity from models.workflow import Workflow @@ -30,10 +31,34 @@ class WorkflowVariablesConfigManager: """ variables = [] - user_input_form = workflow.rag_pipeline_user_input_form() - # filter variables by start_node_id - for variable in user_input_form: - if variable.get("belong_to_node_id") == start_node_id or variable.get("belong_to_node_id") == "shared": - variables.append(RagPipelineVariableEntity.model_validate(variable)) + # get second step node + rag_pipeline_variables = workflow.rag_pipeline_variables + if not rag_pipeline_variables: + return [] + variables_map = {item["variable"]: item for item in rag_pipeline_variables} + + # get datasource node data + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == start_node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if datasource_node_data: + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + + for key, value in datasource_parameters.items(): + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split(".")[-1] + variables_map.pop(last_part) + all_second_step_variables = list(variables_map.values()) + + for item in all_second_step_variables: + if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared": + variables.append(RagPipelineVariableEntity.model_validate(item)) return variables