mirror of
https://github.com/langgenius/dify.git
synced 2025-12-03 22:36:52 +00:00
403 lines
15 KiB
Python
403 lines
15 KiB
Python
import dataclasses
|
|
from collections.abc import Mapping
|
|
from typing import Any, Generic, TypeAlias, TypeVar, overload
|
|
|
|
from configs import dify_config
|
|
from core.file.models import File
|
|
from core.variables.segments import (
|
|
ArrayFileSegment,
|
|
ArraySegment,
|
|
BooleanSegment,
|
|
FileSegment,
|
|
FloatSegment,
|
|
IntegerSegment,
|
|
NoneSegment,
|
|
ObjectSegment,
|
|
Segment,
|
|
StringSegment,
|
|
)
|
|
from core.variables.utils import dumps_with_segments
|
|
|
|
_MAX_DEPTH = 100
|
|
|
|
|
|
class _QAKeys:
|
|
"""dict keys for _QAStructure"""
|
|
|
|
QA_CHUNKS = "qa_chunks"
|
|
QUESTION = "question"
|
|
ANSWER = "answer"
|
|
|
|
|
|
class _PCKeys:
|
|
"""dict keys for _ParentChildStructure"""
|
|
|
|
PARENT_MODE = "parent_mode"
|
|
PARENT_CHILD_CHUNKS = "parent_child_chunks"
|
|
PARENT_CONTENT = "parent_content"
|
|
CHILD_CONTENTS = "child_contents"
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class _PartResult(Generic[_T]):
|
|
value: _T
|
|
value_size: int
|
|
truncated: bool
|
|
|
|
|
|
class MaxDepthExceededError(Exception):
|
|
pass
|
|
|
|
|
|
class UnknownTypeError(Exception):
|
|
pass
|
|
|
|
|
|
JSONTypes: TypeAlias = int | float | str | list | dict | None | bool
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TruncationResult:
|
|
result: Segment
|
|
truncated: bool
|
|
|
|
|
|
class VariableTruncator:
|
|
"""
|
|
Handles variable truncation with structure-preserving strategies.
|
|
|
|
This class implements intelligent truncation that prioritizes maintaining data structure
|
|
integrity while ensuring the final size doesn't exceed specified limits.
|
|
|
|
Uses recursive size calculation to avoid repeated JSON serialization.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
string_length_limit=5000,
|
|
array_element_limit: int = 20,
|
|
max_size_bytes: int = 1024_000, # 100KB
|
|
):
|
|
if string_length_limit <= 3:
|
|
raise ValueError("string_length_limit should be greater than 3.")
|
|
self._string_length_limit = string_length_limit
|
|
|
|
if array_element_limit <= 0:
|
|
raise ValueError("array_element_limit should be greater than 0.")
|
|
self._array_element_limit = array_element_limit
|
|
|
|
if max_size_bytes <= 0:
|
|
raise ValueError("max_size_bytes should be greater than 0.")
|
|
self._max_size_bytes = max_size_bytes
|
|
|
|
@classmethod
|
|
def default(cls) -> "VariableTruncator":
|
|
return VariableTruncator(
|
|
max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
|
|
array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
|
|
string_length_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH,
|
|
)
|
|
|
|
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
|
|
"""
|
|
`truncate_variable_mapping` is responsible for truncating variable mappings
|
|
generated during workflow execution, such as `inputs`, `process_data`, or `outputs`
|
|
of a WorkflowNodeExecution record. This ensures the mappings remain within the
|
|
specified size limits while preserving their structure.
|
|
"""
|
|
budget = self._max_size_bytes
|
|
is_truncated = False
|
|
truncated_mapping: dict[str, Any] = {}
|
|
length = len(v.items())
|
|
used_size = 0
|
|
for key, value in v.items():
|
|
used_size += self.calculate_json_size(key)
|
|
if used_size > budget:
|
|
truncated_mapping[key] = "..."
|
|
continue
|
|
value_budget = (budget - used_size) // (length - len(truncated_mapping))
|
|
if isinstance(value, Segment):
|
|
part_result = self._truncate_segment(value, value_budget)
|
|
else:
|
|
part_result = self._truncate_json_primitives(value, value_budget)
|
|
is_truncated = is_truncated or part_result.truncated
|
|
truncated_mapping[key] = part_result.value
|
|
used_size += part_result.value_size
|
|
return truncated_mapping, is_truncated
|
|
|
|
@staticmethod
|
|
def _segment_need_truncation(segment: Segment) -> bool:
|
|
if isinstance(
|
|
segment,
|
|
(NoneSegment, FloatSegment, IntegerSegment, FileSegment, BooleanSegment, ArrayFileSegment),
|
|
):
|
|
return False
|
|
return True
|
|
|
|
@staticmethod
|
|
def _json_value_needs_truncation(value: Any) -> bool:
|
|
if value is None:
|
|
return False
|
|
if isinstance(value, (bool, int, float)):
|
|
return False
|
|
return True
|
|
|
|
def truncate(self, segment: Segment) -> TruncationResult:
|
|
if isinstance(segment, StringSegment):
|
|
result = self._truncate_segment(segment, self._string_length_limit)
|
|
else:
|
|
result = self._truncate_segment(segment, self._max_size_bytes)
|
|
|
|
if result.value_size > self._max_size_bytes:
|
|
if isinstance(result.value, str):
|
|
result = self._truncate_string(result.value, self._max_size_bytes)
|
|
return TruncationResult(StringSegment(value=result.value), True)
|
|
|
|
# Apply final fallback - convert to JSON string and truncate
|
|
json_str = dumps_with_segments(result.value, ensure_ascii=False)
|
|
if len(json_str) > self._max_size_bytes:
|
|
json_str = json_str[: self._max_size_bytes] + "..."
|
|
return TruncationResult(result=StringSegment(value=json_str), truncated=True)
|
|
|
|
return TruncationResult(
|
|
result=segment.model_copy(update={"value": result.value.value}), truncated=result.truncated
|
|
)
|
|
|
|
def _truncate_segment(self, segment: Segment, target_size: int) -> _PartResult[Segment]:
|
|
"""
|
|
Apply smart truncation to a variable value.
|
|
|
|
Args:
|
|
value: The value to truncate (can be Segment or raw value)
|
|
|
|
Returns:
|
|
TruncationResult with truncated data and truncation status
|
|
"""
|
|
|
|
if not VariableTruncator._segment_need_truncation(segment):
|
|
return _PartResult(segment, self.calculate_json_size(segment.value), False)
|
|
|
|
result: _PartResult[Any]
|
|
# Apply type-specific truncation with target size
|
|
if isinstance(segment, ArraySegment):
|
|
result = self._truncate_array(segment.value, target_size)
|
|
elif isinstance(segment, StringSegment):
|
|
result = self._truncate_string(segment.value, target_size)
|
|
elif isinstance(segment, ObjectSegment):
|
|
result = self._truncate_object(segment.value, target_size)
|
|
else:
|
|
raise AssertionError("this should be unreachable.")
|
|
|
|
return _PartResult(
|
|
value=segment.model_copy(update={"value": result.value}),
|
|
value_size=result.value_size,
|
|
truncated=result.truncated,
|
|
)
|
|
|
|
@staticmethod
|
|
def calculate_json_size(value: Any, depth=0) -> int:
|
|
"""Recursively calculate JSON size without serialization."""
|
|
if isinstance(value, Segment):
|
|
return VariableTruncator.calculate_json_size(value.value)
|
|
if depth > _MAX_DEPTH:
|
|
raise MaxDepthExceededError()
|
|
if isinstance(value, str):
|
|
# Ideally, the size of strings should be calculated based on their utf-8 encoded length.
|
|
# However, this adds complexity as we would need to compute encoded sizes consistently
|
|
# throughout the code. Therefore, we approximate the size using the string's length.
|
|
# Rough estimate: number of characters, plus 2 for quotes
|
|
return len(value) + 2
|
|
elif isinstance(value, (int, float)):
|
|
return len(str(value))
|
|
elif isinstance(value, bool):
|
|
return 4 if value else 5 # "true" or "false"
|
|
elif value is None:
|
|
return 4 # "null"
|
|
elif isinstance(value, list):
|
|
# Size = sum of elements + separators + brackets
|
|
total = 2 # "[]"
|
|
for i, item in enumerate(value):
|
|
if i > 0:
|
|
total += 1 # ","
|
|
total += VariableTruncator.calculate_json_size(item, depth=depth + 1)
|
|
return total
|
|
elif isinstance(value, dict):
|
|
# Size = sum of keys + values + separators + brackets
|
|
total = 2 # "{}"
|
|
for index, key in enumerate(value.keys()):
|
|
if index > 0:
|
|
total += 1 # ","
|
|
total += VariableTruncator.calculate_json_size(str(key), depth=depth + 1) # Key as string
|
|
total += 1 # ":"
|
|
total += VariableTruncator.calculate_json_size(value[key], depth=depth + 1)
|
|
return total
|
|
elif isinstance(value, File):
|
|
return VariableTruncator.calculate_json_size(value.model_dump(), depth=depth + 1)
|
|
else:
|
|
raise UnknownTypeError(f"got unknown type {type(value)}")
|
|
|
|
def _truncate_string(self, value: str, target_size: int) -> _PartResult[str]:
|
|
if (size := self.calculate_json_size(value)) < target_size:
|
|
return _PartResult(value, size, False)
|
|
if target_size < 5:
|
|
return _PartResult("...", 5, True)
|
|
truncated_size = min(self._string_length_limit, target_size - 5)
|
|
truncated_value = value[:truncated_size] + "..."
|
|
return _PartResult(truncated_value, self.calculate_json_size(truncated_value), True)
|
|
|
|
def _truncate_array(self, value: list, target_size: int) -> _PartResult[list]:
|
|
"""
|
|
Truncate array with correct strategy:
|
|
1. First limit to 20 items
|
|
2. If still too large, truncate individual items
|
|
"""
|
|
|
|
truncated_value: list[Any] = []
|
|
truncated = False
|
|
used_size = self.calculate_json_size([])
|
|
|
|
target_length = self._array_element_limit
|
|
|
|
for i, item in enumerate(value):
|
|
# Dirty fix:
|
|
# The output of `Start` node may contain list of `File` elements,
|
|
# causing `AssertionError` while invoking `_truncate_json_primitives`.
|
|
#
|
|
# This check ensures that `list[File]` are handled separately
|
|
if isinstance(item, File):
|
|
truncated_value.append(item)
|
|
continue
|
|
if i >= target_length:
|
|
return _PartResult(truncated_value, used_size, True)
|
|
if i > 0:
|
|
used_size += 1 # Account for comma
|
|
|
|
if used_size > target_size:
|
|
break
|
|
|
|
part_result = self._truncate_json_primitives(item, target_size - used_size)
|
|
truncated_value.append(part_result.value)
|
|
used_size += part_result.value_size
|
|
truncated = part_result.truncated
|
|
return _PartResult(truncated_value, used_size, truncated)
|
|
|
|
@classmethod
|
|
def _maybe_qa_structure(cls, m: Mapping[str, Any]) -> bool:
|
|
qa_chunks = m.get(_QAKeys.QA_CHUNKS)
|
|
if qa_chunks is None:
|
|
return False
|
|
if not isinstance(qa_chunks, list):
|
|
return False
|
|
return True
|
|
|
|
@classmethod
|
|
def _maybe_parent_child_structure(cls, m: Mapping[str, Any]) -> bool:
|
|
parent_mode = m.get(_PCKeys.PARENT_MODE)
|
|
if parent_mode is None:
|
|
return False
|
|
if not isinstance(parent_mode, str):
|
|
return False
|
|
parent_child_chunks = m.get(_PCKeys.PARENT_CHILD_CHUNKS)
|
|
if parent_child_chunks is None:
|
|
return False
|
|
if not isinstance(parent_child_chunks, list):
|
|
return False
|
|
|
|
return True
|
|
|
|
def _truncate_object(self, mapping: Mapping[str, Any], target_size: int) -> _PartResult[Mapping[str, Any]]:
|
|
"""
|
|
Truncate object with key preservation priority.
|
|
|
|
Strategy:
|
|
1. Keep all keys, truncate values to fit within budget
|
|
2. If still too large, drop keys starting from the end
|
|
"""
|
|
if not mapping:
|
|
return _PartResult(mapping, self.calculate_json_size(mapping), False)
|
|
|
|
truncated_obj = {}
|
|
truncated = False
|
|
used_size = self.calculate_json_size({})
|
|
|
|
# Sort keys to ensure deterministic behavior
|
|
sorted_keys = sorted(mapping.keys())
|
|
|
|
for i, key in enumerate(sorted_keys):
|
|
if used_size > target_size:
|
|
# No more room for additional key-value pairs
|
|
truncated = True
|
|
break
|
|
|
|
pair_size = 0
|
|
|
|
if i > 0:
|
|
pair_size += 1 # Account for comma
|
|
|
|
# Calculate budget for this key-value pair
|
|
# do not try to truncate keys, as we want to keep the structure of
|
|
# object.
|
|
key_size = self.calculate_json_size(key) + 1 # +1 for ":"
|
|
pair_size += key_size
|
|
remaining_pairs = len(sorted_keys) - i
|
|
value_budget = max(0, (target_size - pair_size - used_size) // remaining_pairs)
|
|
|
|
if value_budget <= 0:
|
|
truncated = True
|
|
break
|
|
|
|
# Truncate the value to fit within budget
|
|
value = mapping[key]
|
|
if isinstance(value, Segment):
|
|
value_result = self._truncate_segment(value, value_budget)
|
|
else:
|
|
value_result = self._truncate_json_primitives(mapping[key], value_budget)
|
|
|
|
truncated_obj[key] = value_result.value
|
|
pair_size += value_result.value_size
|
|
used_size += pair_size
|
|
|
|
if value_result.truncated:
|
|
truncated = True
|
|
|
|
return _PartResult(truncated_obj, used_size, truncated)
|
|
|
|
@overload
|
|
def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ...
|
|
|
|
@overload
|
|
def _truncate_json_primitives(self, val: list, target_size: int) -> _PartResult[list]: ...
|
|
|
|
@overload
|
|
def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ...
|
|
|
|
@overload
|
|
def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ... # type: ignore
|
|
|
|
@overload
|
|
def _truncate_json_primitives(self, val: int, target_size: int) -> _PartResult[int]: ...
|
|
|
|
@overload
|
|
def _truncate_json_primitives(self, val: float, target_size: int) -> _PartResult[float]: ...
|
|
|
|
@overload
|
|
def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ...
|
|
|
|
def _truncate_json_primitives(
|
|
self, val: str | list | dict | bool | int | float | None, target_size: int
|
|
) -> _PartResult[Any]:
|
|
"""Truncate a value within an object to fit within budget."""
|
|
if isinstance(val, str):
|
|
return self._truncate_string(val, target_size)
|
|
elif isinstance(val, list):
|
|
return self._truncate_array(val, target_size)
|
|
elif isinstance(val, dict):
|
|
return self._truncate_object(val, target_size)
|
|
elif val is None or isinstance(val, (bool, int, float)):
|
|
return _PartResult(val, self.calculate_json_size(val), False)
|
|
else:
|
|
raise AssertionError("this statement should be unreachable.")
|