from collections.abc import Mapping, Sequence from typing import Any, cast from uuid import uuid4 from configs import dify_config from core.file import File from core.variables.exc import VariableError from core.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArraySegment, ArrayStringSegment, FileSegment, FloatSegment, IntegerSegment, NoneSegment, ObjectSegment, Segment, StringSegment, ) from core.variables.types import SegmentType from core.variables.variables import ( ArrayAnyVariable, ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, FileVariable, FloatVariable, IntegerVariable, NoneVariable, ObjectVariable, SecretVariable, StringVariable, Variable, ) from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) class UnsupportedSegmentTypeError(Exception): pass class TypeMismatchError(Exception): pass # Define the constant SEGMENT_TO_VARIABLE_MAP = { StringSegment: StringVariable, IntegerSegment: IntegerVariable, FloatSegment: FloatVariable, ObjectSegment: ObjectVariable, FileSegment: FileVariable, ArrayStringSegment: ArrayStringVariable, ArrayNumberSegment: ArrayNumberVariable, ArrayObjectSegment: ArrayObjectVariable, ArrayFileSegment: ArrayFileVariable, ArrayAnySegment: ArrayAnyVariable, NoneSegment: NoneVariable, } def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: if not mapping.get("name"): raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]]) def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: if not mapping.get("name"): raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: if not mapping.get("variable"): raise VariableError("missing variable") return mapping["variable"] def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: """ This factory function is used to create the environment variable or the conversation variable, not support the File type. """ if (value_type := mapping.get("value_type")) is None: raise VariableError("missing value type") if (value := mapping.get("value")) is None: raise VariableError("missing value") result: Variable match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) case SegmentType.SECRET: result = SecretVariable.model_validate(mapping) case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int): mapping = dict(mapping) mapping["value_type"] = SegmentType.INTEGER result = IntegerVariable.model_validate(mapping) case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float): mapping = dict(mapping) mapping["value_type"] = SegmentType.FLOAT result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): raise VariableError(f"invalid number value {value}") case SegmentType.OBJECT if isinstance(value, dict): result = ObjectVariable.model_validate(mapping) case SegmentType.ARRAY_STRING if isinstance(value, list): result = ArrayStringVariable.model_validate(mapping) case SegmentType.ARRAY_NUMBER if isinstance(value, list): result = ArrayNumberVariable.model_validate(mapping) case SegmentType.ARRAY_OBJECT if isinstance(value, list): result = ArrayObjectVariable.model_validate(mapping) case _: raise VariableError(f"not supported value type {value_type}") if result.size > dify_config.MAX_VARIABLE_SIZE: raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") if not result.selector: result = result.model_copy(update={"selector": selector}) return cast(Variable, result) def infer_segment_type_from_value(value: Any, /) -> SegmentType: return build_segment(value).value_type def build_segment(value: Any, /) -> Segment: # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` # below if value is None: return NoneSegment() if isinstance(value, str): return StringSegment(value=value) if isinstance(value, int): return IntegerSegment(value=value) if isinstance(value, float): return FloatSegment(value=value) if isinstance(value, dict): return ObjectSegment(value=value) if isinstance(value, File): return FileSegment(value=value) if isinstance(value, list): items = [build_segment(item) for item in value] types = {item.value_type for item in items} if all(isinstance(item, ArraySegment) for item in items): return ArrayAnySegment(value=value) elif len(types) != 1: if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): return ArrayNumberSegment(value=value) return ArrayAnySegment(value=value) match types.pop(): case SegmentType.STRING: return ArrayStringSegment(value=value) case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return ArrayNumberSegment(value=value) case SegmentType.OBJECT: return ArrayObjectSegment(value=value) case SegmentType.FILE: return ArrayFileSegment(value=value) case SegmentType.NONE: return ArrayAnySegment(value=value) case _: # This should be unreachable. raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}") _segment_factory: Mapping[SegmentType, type[Segment]] = { SegmentType.NONE: NoneSegment, SegmentType.STRING: StringSegment, SegmentType.INTEGER: IntegerSegment, SegmentType.FLOAT: FloatSegment, SegmentType.FILE: FileSegment, SegmentType.OBJECT: ObjectSegment, # Array types SegmentType.ARRAY_ANY: ArrayAnySegment, SegmentType.ARRAY_STRING: ArrayStringSegment, SegmentType.ARRAY_NUMBER: ArrayNumberSegment, SegmentType.ARRAY_OBJECT: ArrayObjectSegment, SegmentType.ARRAY_FILE: ArrayFileSegment, } def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: """ Build a segment with explicit type checking. This function creates a segment from a value while enforcing type compatibility with the specified segment_type. It provides stricter type validation compared to the standard build_segment function. Args: segment_type: The expected SegmentType for the resulting segment value: The value to be converted into a segment Returns: Segment: A segment instance of the appropriate type Raises: TypeMismatchError: If the value type doesn't match the expected segment_type Special Cases: - For empty list [] values, if segment_type is array[*], returns the corresponding array type - Type validation is performed before segment creation Examples: >>> build_segment_with_type(SegmentType.STRING, "hello") StringSegment(value="hello") >>> build_segment_with_type(SegmentType.ARRAY_STRING, []) ArrayStringSegment(value=[]) >>> build_segment_with_type(SegmentType.STRING, 123) # Raises TypeMismatchError """ # Handle None values if value is None: if segment_type == SegmentType.NONE: return NoneSegment() else: raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") # Handle empty list special case for array types if isinstance(value, list) and len(value) == 0: if segment_type == SegmentType.ARRAY_ANY: return ArrayAnySegment(value=value) elif segment_type == SegmentType.ARRAY_STRING: return ArrayStringSegment(value=value) elif segment_type == SegmentType.ARRAY_NUMBER: return ArrayNumberSegment(value=value) elif segment_type == SegmentType.ARRAY_OBJECT: return ArrayObjectSegment(value=value) elif segment_type == SegmentType.ARRAY_FILE: return ArrayFileSegment(value=value) else: raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") inferred_type = SegmentType.infer_segment_type(value) # Type compatibility checking if inferred_type is None: raise TypeMismatchError( f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" ) if inferred_type == segment_type: segment_class = _segment_factory[segment_type] return segment_class(value_type=segment_type, value=value) elif segment_type == SegmentType.NUMBER and inferred_type in ( SegmentType.INTEGER, SegmentType.FLOAT, ): segment_class = _segment_factory[inferred_type] return segment_class(value_type=inferred_type, value=value) else: raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") def segment_to_variable( *, segment: Segment, selector: Sequence[str], id: str | None = None, name: str | None = None, description: str = "", ) -> Variable: if isinstance(segment, Variable): return segment name = name or selector[-1] id = id or str(uuid4()) segment_type = type(segment) if segment_type not in SEGMENT_TO_VARIABLE_MAP: raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] return cast( Variable, variable_class( id=id, name=name, description=description, value=segment.value, selector=list(selector), ), )