This commit is contained in:
jyong 2025-06-25 17:32:26 +08:00
parent 540096a8d8
commit efccbe4039
11 changed files with 74 additions and 20 deletions

View File

@ -1,3 +1,4 @@
from ast import Str
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Any, Literal, Optional
@ -113,9 +114,9 @@ class VariableEntity(BaseModel):
hide: bool = False
max_length: Optional[int] = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
@field_validator("description", mode="before")
@classmethod
@ -127,6 +128,13 @@ class VariableEntity(BaseModel):
def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or []
class RagPipelineVariableEntity(VariableEntity):
"""
Rag Pipeline Variable Entity.
"""
tooltips: Optional[str] = None
placeholder: Optional[str] = None
belong_to_node_id: str
class ExternalDataVariableEntity(BaseModel):
"""

View File

@ -1,4 +1,6 @@
from core.app.app_config.entities import VariableEntity
from typing import Any
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from models.workflow import Workflow
@ -20,3 +22,19 @@ class WorkflowVariablesConfigManager:
variables.append(VariableEntity.model_validate(variable))
return variables
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]:
"""
Convert workflow start variables to variables
:param workflow: workflow instance
"""
variables = []
user_input_form = workflow.rag_pipeline_user_input_form()
# variables
for variable in user_input_form:
variables.append(RagPipelineVariableEntity.model_validate(variable))
return variables

View File

@ -1,6 +1,6 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
@ -13,7 +13,7 @@ class PipelineConfig(WorkflowUIBasedAppConfig):
"""
Pipeline Config Entity.
"""
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
pass
@ -25,7 +25,7 @@ class PipelineConfigManager(BaseAppConfigManager):
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow),
)
return pipeline_config

View File

@ -160,7 +160,7 @@ class PipelineGenerator(BaseAppGenerator):
document_id=document_id,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=pipeline_config.variables,
variables=pipeline_config.rag_pipeline_variables,
tenant_id=pipeline.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),

View File

@ -10,7 +10,7 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
from core.variables.variables import RAGPipelineVariable
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
@ -45,6 +45,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None:
"""
Run application
@ -107,15 +109,20 @@ class PipelineRunner(WorkflowBasedAppRunner):
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
}
rag_pipeline_variables = {}
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if (
rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id
(rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id or rag_pipeline_variable.belong_to_node_id == "shared")
and rag_pipeline_variable.variable in inputs
):
rag_pipeline_variables[rag_pipeline_variable.variable] = inputs[rag_pipeline_variable.variable]
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=inputs[rag_pipeline_variable.variable],
)
)
variable_pool = VariablePool(
system_variables=system_inputs,

View File

@ -117,3 +117,8 @@ class RAGPipelineVariable(BaseModel):
)
required: bool = Field(description="optional, default false", default=False)
options: list[str] | None = Field(default_factory=list)
class RAGPipelineVariableInput(BaseModel):
variable: RAGPipelineVariable
value: Any

View File

@ -9,7 +9,9 @@ from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.variables.variables import RAGPipelineVariableInput
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, \
SYSTEM_VARIABLE_NODE_ID, RAG_PIPELINE_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey
from factories import variable_factory
@ -44,9 +46,9 @@ class VariablePool(BaseModel):
description="Conversation variables.",
default_factory=list,
)
rag_pipeline_variables: Mapping[str, Any] = Field(
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
default_factory=dict,
default_factory=list,
)
def model_post_init(self, context: Any, /) -> None:
@ -59,8 +61,8 @@ class VariablePool(BaseModel):
for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
# Add rag pipeline variables to the variable pool
for var, value in self.rag_pipeline_variables.items():
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var), value)
for var in self.rag_pipeline_variables:
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, var.variable.belong_to_node_id, var.variable.variable), var.value)
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""

View File

@ -436,3 +436,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
inputs=parameters_for_log,
)
)
@classmethod
def version(cls) -> str:
return "1"

View File

@ -159,3 +159,7 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
return index_processor.format_preview(chunks)
@classmethod
def version(cls) -> str:
return "1"

View File

@ -322,6 +322,14 @@ class Workflow(Base):
return variables
def rag_pipeline_user_input_form(self) -> list:
# get user_input_form from start node
variables: list[Any] = self.rag_pipeline_variables
return variables
@property
def unique_hash(self) -> str:
"""

View File

@ -344,8 +344,7 @@ class DatasetService:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found")
# check if dataset name is exists
# check if dataset name is exists
if (
db.session.query(Dataset)
.filter(
@ -471,7 +470,7 @@ class DatasetService:
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
# update icon info
# update icon info
if data.get("icon_info"):
filtered_data["icon_info"] = data.get("icon_info")