This commit is contained in:
jyong 2025-06-25 17:32:26 +08:00
parent 540096a8d8
commit efccbe4039
11 changed files with 74 additions and 20 deletions

View File

@ -1,3 +1,4 @@
from ast import Str
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
@ -113,9 +114,9 @@ class VariableEntity(BaseModel):
hide: bool = False hide: bool = False
max_length: Optional[int] = None max_length: Optional[int] = None
options: Sequence[str] = Field(default_factory=list) options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list) allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list) allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
@field_validator("description", mode="before") @field_validator("description", mode="before")
@classmethod @classmethod
@ -127,6 +128,13 @@ class VariableEntity(BaseModel):
def convert_none_options(cls, v: Any) -> Sequence[str]: def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or [] return v or []
class RagPipelineVariableEntity(VariableEntity):
"""
Rag Pipeline Variable Entity.
"""
tooltips: Optional[str] = None
placeholder: Optional[str] = None
belong_to_node_id: str
class ExternalDataVariableEntity(BaseModel): class ExternalDataVariableEntity(BaseModel):
""" """

View File

@ -1,4 +1,6 @@
from core.app.app_config.entities import VariableEntity from typing import Any
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from models.workflow import Workflow from models.workflow import Workflow
@ -20,3 +22,19 @@ class WorkflowVariablesConfigManager:
variables.append(VariableEntity.model_validate(variable)) variables.append(VariableEntity.model_validate(variable))
return variables return variables
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]:
"""
Convert workflow start variables to variables
:param workflow: workflow instance
"""
variables = []
user_input_form = workflow.rag_pipeline_user_input_form()
# variables
for variable in user_input_form:
variables.append(RagPipelineVariableEntity.model_validate(variable))
return variables

View File

@ -1,6 +1,6 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
@ -13,7 +13,7 @@ class PipelineConfig(WorkflowUIBasedAppConfig):
""" """
Pipeline Config Entity. Pipeline Config Entity.
""" """
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
pass pass
@ -25,7 +25,7 @@ class PipelineConfigManager(BaseAppConfigManager):
app_id=pipeline.id, app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE, app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id, workflow_id=workflow.id,
variables=WorkflowVariablesConfigManager.convert(workflow=workflow), rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow),
) )
return pipeline_config return pipeline_config

View File

@ -160,7 +160,7 @@ class PipelineGenerator(BaseAppGenerator):
document_id=document_id, document_id=document_id,
inputs=self._prepare_user_inputs( inputs=self._prepare_user_inputs(
user_inputs=inputs, user_inputs=inputs,
variables=pipeline_config.variables, variables=pipeline_config.rag_pipeline_variables,
tenant_id=pipeline.tenant_id, tenant_id=pipeline.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
), ),

View File

@ -10,7 +10,7 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom, InvokeFrom,
RagPipelineGenerateEntity, RagPipelineGenerateEntity,
) )
from core.variables.variables import RAGPipelineVariable from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
@ -45,6 +45,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
self.queue_manager = queue_manager self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id self.workflow_thread_pool_id = workflow_thread_pool_id
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None: def run(self) -> None:
""" """
Run application Run application
@ -107,15 +109,20 @@ class PipelineRunner(WorkflowBasedAppRunner):
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info, SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value, SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
} }
rag_pipeline_variables = {} rag_pipeline_variables = []
if workflow.rag_pipeline_variables: if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v) rag_pipeline_variable = RAGPipelineVariable(**v)
if ( if (
rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id (rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id or rag_pipeline_variable.belong_to_node_id == "shared")
and rag_pipeline_variable.variable in inputs and rag_pipeline_variable.variable in inputs
): ):
rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable] rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=inputs[rag_pipeline_variable.variable],
)
)
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables=system_inputs, system_variables=system_inputs,

View File

@ -117,3 +117,8 @@ class RAGPipelineVariable(BaseModel):
) )
required: bool = Field(description="optional, default false", default=False) required: bool = Field(description="optional, default false", default=False)
options: list[str] | None = Field(default_factory=list) options: list[str] | None = Field(default_factory=list)
class RAGPipelineVariableInput(BaseModel):
variable: RAGPipelineVariable
value: Any

View File

@ -9,7 +9,9 @@ from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment from core.variables.segments import FileSegment, NoneSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.variables.variables import RAGPipelineVariableInput
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, \
SYSTEM_VARIABLE_NODE_ID, RAG_PIPELINE_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from factories import variable_factory from factories import variable_factory
@ -44,9 +46,9 @@ class VariablePool(BaseModel):
description="Conversation variables.", description="Conversation variables.",
default_factory=list, default_factory=list,
) )
rag_pipeline_variables: Mapping[str, Any] = Field( rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.", description="RAG pipeline variables.",
default_factory=dict, default_factory=list,
) )
def model_post_init(self, context: Any, /) -> None: def model_post_init(self, context: Any, /) -> None:
@ -59,8 +61,8 @@ class VariablePool(BaseModel):
for var in self.conversation_variables: for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
# Add rag pipeline variables to the variable pool # Add rag pipeline variables to the variable pool
for var, value in self.rag_pipeline_variables.items(): for var in self.rag_pipeline_variables:
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var), value) self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var.variable.belong_to_node_id, var.variable.variable), var.value)
def add(self, selector: Sequence[str], value: Any, /) -> None: def add(self, selector: Sequence[str], value: Any, /) -> None:
""" """

View File

@ -436,3 +436,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
inputs=parameters_for_log, inputs=parameters_for_log,
) )
) )
@classmethod
def version(cls) -> str:
return "1"

View File

@ -159,3 +159,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]: def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
return index_processor.format_preview(chunks) return index_processor.format_preview(chunks)
@classmethod
def version(cls) -> str:
return "1"

View File

@ -322,6 +322,14 @@ class Workflow(Base):
return variables return variables
def rag_pipeline_user_input_form(self) -> list:
# get user_input_form from start node
variables: list[Any] = self.rag_pipeline_variables
return variables
@property @property
def unique_hash(self) -> str: def unique_hash(self) -> str:
""" """

View File

@ -344,8 +344,7 @@ class DatasetService:
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise ValueError("Dataset not found") raise ValueError("Dataset not found")
# check if dataset name is exists
# check if dataset name is exists
if ( if (
db.session.query(Dataset) db.session.query(Dataset)
.filter( .filter(
@ -471,7 +470,7 @@ class DatasetService:
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update Retrieval model # update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"] filtered_data["retrieval_model"] = data["retrieval_model"]
# update icon info # update icon info
if data.get("icon_info"): if data.get("icon_info"):
filtered_data["icon_info"] = data.get("icon_info") filtered_data["icon_info"] = data.get("icon_info")