dify/api/core/variables/segments.py

232 lines
6.1 KiB
Python
Raw Normal View History

import json
import sys
from collections.abc import Mapping, Sequence
from typing import Annotated, Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.file import File
from .types import SegmentType
class Segment(BaseModel):
"""Segment is runtime type used during the execution of workflow.
Note: this class is abstract, you should use subclasses of this class instead.
"""
model_config = ConfigDict(frozen=True)
value_type: SegmentType
value: Any
@field_validator("value_type")
@classmethod
def validate_value_type(cls, value):
"""
This validator checks if the provided value is equal to the default value of the 'value_type' field.
If the value is different, a ValueError is raised.
"""
if value != cls.model_fields["value_type"].default:
raise ValueError("Cannot modify 'value_type'")
return value
@property
def text(self) -> str:
return str(self.value)
@property
def log(self) -> str:
return str(self.value)
@property
def markdown(self) -> str:
return str(self.value)
@property
def size(self) -> int:
"""
Return the size of the value in bytes.
"""
return sys.getsizeof(self.value)
def to_object(self) -> Any:
return self.value
class NoneSegment(Segment):
value_type: SegmentType = SegmentType.NONE
value: None = None
@property
def text(self) -> str:
return ""
@property
def log(self) -> str:
return ""
@property
def markdown(self) -> str:
return ""
class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING
value: str
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.FLOAT
value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
#
# def test_float_segment_and_nan():
# nan = float("nan")
# assert nan != nan
#
# f1 = FloatSegment(value=float("nan"))
# f2 = FloatSegment(value=float("nan"))
# assert f1 != f2
#
# f3 = FloatSegment(value=nan)
# f4 = FloatSegment(value=nan)
# assert f3 != f4
class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.INTEGER
value: int
class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any]
@property
def text(self) -> str:
return json.dumps(self.model_dump()["value"], ensure_ascii=False)
@property
def log(self) -> str:
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
@property
def markdown(self) -> str:
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
class ArraySegment(Segment):
@property
def markdown(self) -> str:
items = []
for item in self.value:
items.append(str(item))
return "\n".join(items)
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
value: File
@property
def markdown(self) -> str:
return self.value.markdown
@property
def log(self) -> str:
return ""
@property
def text(self) -> str:
return ""
class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any]
class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str]
@property
def text(self) -> str:
2025-03-10 09:22:41 +08:00
return json.dumps(self.value, ensure_ascii=False)
class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[float | int]
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[File]
@property
def markdown(self) -> str:
items = []
for item in self.value:
items.append(item.markdown)
return "\n".join(items)
@property
def log(self) -> str:
return ""
@property
def text(self) -> str:
return ""
def get_segment_discriminator(v: Any) -> SegmentType | None:
if isinstance(v, Segment):
return v.value_type
elif isinstance(v, dict):
value_type = v.get("value_type")
if value_type is None:
return None
try:
seg_type = SegmentType(value_type)
except ValueError:
return None
return seg_type
else:
# return None if the discriminator value isn't found
return None
# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Segment` for type hinting when serialization is not required.
#
# Note:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
| Annotated[StringSegment, Tag(SegmentType.STRING)]
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
| Annotated[FileSegment, Tag(SegmentType.FILE)]
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
),
Discriminator(get_segment_discriminator),
]