mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 23:18:59 +00:00
Add infra for message serialization (#122)
This commit is contained in:
parent
9afd1fbe93
commit
f5f4c39238
116
python/src/agnext/application/message_serialization.py
Normal file
116
python/src/agnext/application/message_serialization.py
Normal file
@ -0,0 +1,116 @@
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
from typing import Any, ClassVar, Dict, Protocol, TypeVar, cast, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IsDataclass(Protocol):
|
||||
# as already noted in comments, checking for this attribute is currently
|
||||
# the most reliable way to ascertain that something is a dataclass
|
||||
__dataclass_fields__: ClassVar[Dict[str, Any]]
|
||||
|
||||
|
||||
def is_dataclass(cls: type[Any]) -> bool:
|
||||
return isinstance(cls, IsDataclass)
|
||||
|
||||
|
||||
def has_nested_dataclass(cls: type[IsDataclass]) -> bool:
|
||||
# iterate fields and check if any of them are dataclasses
|
||||
return any(is_dataclass(f.type) for f in cls.__dataclass_fields__.values())
|
||||
|
||||
|
||||
def has_nested_base_model(cls: type[IsDataclass]) -> bool:
|
||||
# iterate fields and check if any of them are basebodels
|
||||
return any(issubclass(f.type, BaseModel) for f in cls.__dataclass_fields__.values())
|
||||
|
||||
|
||||
T = TypeVar("T", covariant=True)
|
||||
|
||||
|
||||
class TypeDeserializer(Protocol[T]):
|
||||
def deserialize(self, message: str) -> T: ...
|
||||
|
||||
|
||||
U = TypeVar("U", contravariant=True)
|
||||
|
||||
|
||||
class TypeSerializer(Protocol[U]):
|
||||
def serialize(self, message: U) -> str: ...
|
||||
|
||||
|
||||
DataclassT = TypeVar("DataclassT", bound=IsDataclass)
|
||||
|
||||
|
||||
class DataclassTypeDeserializer(TypeDeserializer[DataclassT]):
|
||||
def __init__(self, cls: type[DataclassT]) -> None:
|
||||
self.cls = cls
|
||||
|
||||
def deserialize(self, message: str) -> DataclassT:
|
||||
return self.cls(**json.loads(message))
|
||||
|
||||
|
||||
class DataclassTypeSerializer(TypeSerializer[IsDataclass]):
|
||||
def serialize(self, message: IsDataclass) -> str:
|
||||
if has_nested_dataclass(type(message)) or has_nested_base_model(type(message)):
|
||||
raise ValueError("Dataclass has nested dataclasses or base models, which are not supported")
|
||||
|
||||
return json.dumps(asdict(message))
|
||||
|
||||
|
||||
PydanticT = TypeVar("PydanticT", bound=BaseModel)
|
||||
|
||||
|
||||
class PydanticTypeDeserializer(TypeDeserializer[PydanticT]):
|
||||
def __init__(self, cls: type[PydanticT]) -> None:
|
||||
self.cls = cls
|
||||
|
||||
def deserialize(self, message: str) -> PydanticT:
|
||||
return self.cls.model_validate_json(message)
|
||||
|
||||
|
||||
class PydanticTypeSerializer(TypeSerializer[BaseModel]):
|
||||
def serialize(self, message: BaseModel) -> str:
|
||||
return message.model_dump_json()
|
||||
|
||||
|
||||
def _type_name(cls: type[Any] | Any) -> str:
|
||||
if isinstance(cls, type):
|
||||
return cls.__name__
|
||||
else:
|
||||
return cast(str, cls.__class__.__name__)
|
||||
|
||||
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class Serialization:
|
||||
def __init__(self) -> None:
|
||||
self._deserializers: Dict[str, TypeDeserializer[Any]] = {}
|
||||
self._serializers: Dict[str, TypeSerializer[Any]] = {}
|
||||
|
||||
def add_type(self, message_type: type[BaseModel] | type[IsDataclass]) -> None:
|
||||
if issubclass(message_type, BaseModel):
|
||||
self.add_type_custom(
|
||||
_type_name(message_type), PydanticTypeDeserializer(message_type), PydanticTypeSerializer()
|
||||
)
|
||||
elif isinstance(message_type, IsDataclass):
|
||||
self.add_type_custom(
|
||||
_type_name(message_type), DataclassTypeDeserializer(message_type), DataclassTypeSerializer()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {message_type}")
|
||||
|
||||
def add_type_custom(self, type_name: str, deserializer: TypeDeserializer[V], serializer: TypeSerializer[V]) -> None:
|
||||
self._deserializers[type_name] = deserializer
|
||||
self._serializers[type_name] = serializer
|
||||
|
||||
def deserialize(self, message: str, *, type_name: str) -> Any:
|
||||
return self._deserializers[type_name].deserialize(message)
|
||||
|
||||
def type_name(self, message: Any) -> str:
|
||||
return _type_name(message)
|
||||
|
||||
def serialize(self, message: Any, *, type_name: str) -> str:
|
||||
return self._serializers[type_name].serialize(message)
|
||||
106
python/tests/test_serialization.py
Normal file
106
python/tests/test_serialization.py
Normal file
@ -0,0 +1,106 @@
|
||||
#custom type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from agnext.application.message_serialization import Serialization
|
||||
|
||||
class PydanticMessage(BaseModel):
|
||||
message: str
|
||||
|
||||
class NestingPydanticMessage(BaseModel):
|
||||
message: str
|
||||
nested: PydanticMessage
|
||||
|
||||
@dataclass
|
||||
class DataclassMessage:
|
||||
message: str
|
||||
|
||||
@dataclass
|
||||
class NestingDataclassMessage:
|
||||
message: str
|
||||
nested: DataclassMessage
|
||||
|
||||
@dataclass
|
||||
class NestingPydanticDataclassMessage:
|
||||
message: str
|
||||
nested: PydanticMessage
|
||||
|
||||
def test_pydantic() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_type(PydanticMessage)
|
||||
|
||||
message = PydanticMessage(message="hello")
|
||||
name = serde.type_name(message)
|
||||
json = serde.serialize(message, type_name=name)
|
||||
assert name == "PydanticMessage"
|
||||
assert json == '{"message":"hello"}'
|
||||
deserialized = serde.deserialize(json, type_name=name)
|
||||
assert deserialized == message
|
||||
|
||||
def test_nested_pydantic() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_type(NestingPydanticMessage)
|
||||
|
||||
message = NestingPydanticMessage(message="hello", nested=PydanticMessage(message="world"))
|
||||
name = serde.type_name(message)
|
||||
json = serde.serialize(message, type_name=name)
|
||||
assert json == '{"message":"hello","nested":{"message":"world"}}'
|
||||
deserialized = serde.deserialize(json, type_name=name)
|
||||
assert deserialized == message
|
||||
|
||||
def test_dataclass() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_type(DataclassMessage)
|
||||
|
||||
message = DataclassMessage(message="hello")
|
||||
name = serde.type_name(message)
|
||||
json = serde.serialize(message, type_name=name)
|
||||
assert json == '{"message": "hello"}'
|
||||
deserialized = serde.deserialize(json, type_name=name)
|
||||
assert deserialized == message
|
||||
|
||||
def test_nesting_dataclass_dataclass() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_type(NestingDataclassMessage)
|
||||
|
||||
message = NestingDataclassMessage(message="hello", nested=DataclassMessage(message="world"))
|
||||
name = serde.type_name(message)
|
||||
with pytest.raises(ValueError):
|
||||
_json = serde.serialize(message, type_name=name)
|
||||
|
||||
def test_nesting_dataclass_pydantic() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_type(NestingPydanticDataclassMessage)
|
||||
|
||||
message = NestingPydanticDataclassMessage(message="hello", nested=PydanticMessage(message="world"))
|
||||
name = serde.type_name(message)
|
||||
with pytest.raises(ValueError):
|
||||
_json = serde.serialize(message, type_name=name)
|
||||
|
||||
def test_invalid_type() -> None:
|
||||
serde = Serialization()
|
||||
try:
|
||||
serde.add_type(str) # type: ignore
|
||||
except ValueError as e:
|
||||
assert str(e) == "Unsupported type <class 'str'>"
|
||||
|
||||
def test_custom_type() -> None:
|
||||
serde = Serialization()
|
||||
|
||||
class CustomStringTypeDeserializer:
|
||||
def deserialize(self, message: str) -> str:
|
||||
return message[1:-1]
|
||||
|
||||
class CustomStringTypeSerializer:
|
||||
def serialize(self, message: str) -> str:
|
||||
return f'"{message}"'
|
||||
|
||||
serde.add_type_custom("custom_str", CustomStringTypeDeserializer(), CustomStringTypeSerializer())
|
||||
message = "hello"
|
||||
json = serde.serialize(message, type_name="custom_str")
|
||||
assert json == '"hello"'
|
||||
deserialized = serde.deserialize(json, type_name="custom_str")
|
||||
assert deserialized == message
|
||||
Loading…
x
Reference in New Issue
Block a user