mirror of
https://github.com/langgenius/dify.git
synced 2025-07-26 02:50:34 +00:00

refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025) This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization. Key changes: - Introduce distinct `value_type` tags for Integer and Float segments/variables - Add `VariableUnion` and `SegmentUnion` types for proper type discrimination - Leverage Pydantic's discriminated union feature for seamless serialization/deserialization - Enable accurate serialization of data structures containing these types Closes #22024.
232 lines
6.1 KiB
Python
232 lines
6.1 KiB
Python
import json
|
|
import sys
|
|
from collections.abc import Mapping, Sequence
|
|
from typing import Annotated, Any, TypeAlias
|
|
|
|
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
|
|
|
from core.file import File
|
|
|
|
from .types import SegmentType
|
|
|
|
|
|
class Segment(BaseModel):
|
|
"""Segment is runtime type used during the execution of workflow.
|
|
|
|
Note: this class is abstract, you should use subclasses of this class instead.
|
|
"""
|
|
|
|
model_config = ConfigDict(frozen=True)
|
|
|
|
value_type: SegmentType
|
|
value: Any
|
|
|
|
@field_validator("value_type")
|
|
@classmethod
|
|
def validate_value_type(cls, value):
|
|
"""
|
|
This validator checks if the provided value is equal to the default value of the 'value_type' field.
|
|
If the value is different, a ValueError is raised.
|
|
"""
|
|
if value != cls.model_fields["value_type"].default:
|
|
raise ValueError("Cannot modify 'value_type'")
|
|
return value
|
|
|
|
@property
|
|
def text(self) -> str:
|
|
return str(self.value)
|
|
|
|
@property
|
|
def log(self) -> str:
|
|
return str(self.value)
|
|
|
|
@property
|
|
def markdown(self) -> str:
|
|
return str(self.value)
|
|
|
|
@property
|
|
def size(self) -> int:
|
|
"""
|
|
Return the size of the value in bytes.
|
|
"""
|
|
return sys.getsizeof(self.value)
|
|
|
|
def to_object(self) -> Any:
|
|
return self.value
|
|
|
|
|
|
class NoneSegment(Segment):
|
|
value_type: SegmentType = SegmentType.NONE
|
|
value: None = None
|
|
|
|
@property
|
|
def text(self) -> str:
|
|
return ""
|
|
|
|
@property
|
|
def log(self) -> str:
|
|
return ""
|
|
|
|
@property
|
|
def markdown(self) -> str:
|
|
return ""
|
|
|
|
|
|
class StringSegment(Segment):
|
|
value_type: SegmentType = SegmentType.STRING
|
|
value: str
|
|
|
|
|
|
class FloatSegment(Segment):
|
|
value_type: SegmentType = SegmentType.FLOAT
|
|
value: float
|
|
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
|
# The following tests cannot pass.
|
|
#
|
|
# def test_float_segment_and_nan():
|
|
# nan = float("nan")
|
|
# assert nan != nan
|
|
#
|
|
# f1 = FloatSegment(value=float("nan"))
|
|
# f2 = FloatSegment(value=float("nan"))
|
|
# assert f1 != f2
|
|
#
|
|
# f3 = FloatSegment(value=nan)
|
|
# f4 = FloatSegment(value=nan)
|
|
# assert f3 != f4
|
|
|
|
|
|
class IntegerSegment(Segment):
|
|
value_type: SegmentType = SegmentType.INTEGER
|
|
value: int
|
|
|
|
|
|
class ObjectSegment(Segment):
|
|
value_type: SegmentType = SegmentType.OBJECT
|
|
value: Mapping[str, Any]
|
|
|
|
@property
|
|
def text(self) -> str:
|
|
return json.dumps(self.model_dump()["value"], ensure_ascii=False)
|
|
|
|
@property
|
|
def log(self) -> str:
|
|
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
|
|
|
|
@property
|
|
def markdown(self) -> str:
|
|
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
|
|
|
|
|
|
class ArraySegment(Segment):
|
|
@property
|
|
def markdown(self) -> str:
|
|
items = []
|
|
for item in self.value:
|
|
items.append(str(item))
|
|
return "\n".join(items)
|
|
|
|
|
|
class FileSegment(Segment):
|
|
value_type: SegmentType = SegmentType.FILE
|
|
value: File
|
|
|
|
@property
|
|
def markdown(self) -> str:
|
|
return self.value.markdown
|
|
|
|
@property
|
|
def log(self) -> str:
|
|
return ""
|
|
|
|
@property
|
|
def text(self) -> str:
|
|
return ""
|
|
|
|
|
|
class ArrayAnySegment(ArraySegment):
|
|
value_type: SegmentType = SegmentType.ARRAY_ANY
|
|
value: Sequence[Any]
|
|
|
|
|
|
class ArrayStringSegment(ArraySegment):
|
|
value_type: SegmentType = SegmentType.ARRAY_STRING
|
|
value: Sequence[str]
|
|
|
|
@property
|
|
def text(self) -> str:
|
|
return json.dumps(self.value, ensure_ascii=False)
|
|
|
|
|
|
class ArrayNumberSegment(ArraySegment):
|
|
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
|
value: Sequence[float | int]
|
|
|
|
|
|
class ArrayObjectSegment(ArraySegment):
|
|
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
|
value: Sequence[Mapping[str, Any]]
|
|
|
|
|
|
class ArrayFileSegment(ArraySegment):
|
|
value_type: SegmentType = SegmentType.ARRAY_FILE
|
|
value: Sequence[File]
|
|
|
|
@property
|
|
def markdown(self) -> str:
|
|
items = []
|
|
for item in self.value:
|
|
items.append(item.markdown)
|
|
return "\n".join(items)
|
|
|
|
@property
|
|
def log(self) -> str:
|
|
return ""
|
|
|
|
@property
|
|
def text(self) -> str:
|
|
return ""
|
|
|
|
|
|
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
|
if isinstance(v, Segment):
|
|
return v.value_type
|
|
elif isinstance(v, dict):
|
|
value_type = v.get("value_type")
|
|
if value_type is None:
|
|
return None
|
|
try:
|
|
seg_type = SegmentType(value_type)
|
|
except ValueError:
|
|
return None
|
|
return seg_type
|
|
else:
|
|
# return None if the discriminator value isn't found
|
|
return None
|
|
|
|
|
|
# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic.
|
|
# Use `Segment` for type hinting when serialization is not required.
|
|
#
|
|
# Note:
|
|
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
|
|
# - The union must include all non-abstract subclasses of `Segment`, except:
|
|
# - `SegmentGroup`, which is not added to the variable pool.
|
|
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
|
|
SegmentUnion: TypeAlias = Annotated[
|
|
(
|
|
Annotated[NoneSegment, Tag(SegmentType.NONE)]
|
|
| Annotated[StringSegment, Tag(SegmentType.STRING)]
|
|
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
|
|
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
|
|
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
|
|
| Annotated[FileSegment, Tag(SegmentType.FILE)]
|
|
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
|
|
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
|
|
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
|
|
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
|
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
|
),
|
|
Discriminator(get_segment_discriminator),
|
|
]
|