mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-16 20:51:38 +00:00
Serialize to Proto.Any for python serializer (#4404)
This commit is contained in:
parent
a4e6d0d977
commit
bd77ccbd7b
@ -2,6 +2,7 @@ import json
|
|||||||
from dataclasses import asdict, dataclass, fields
|
from dataclasses import asdict, dataclass, fields
|
||||||
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable
|
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable
|
||||||
|
|
||||||
|
from google.protobuf import any_pb2
|
||||||
from google.protobuf.message import Message
|
from google.protobuf.message import Message
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -149,29 +150,35 @@ class PydanticJsonMessageSerializer(MessageSerializer[PydanticT]):
|
|||||||
ProtobufT = TypeVar("ProtobufT", bound=Message)
|
ProtobufT = TypeVar("ProtobufT", bound=Message)
|
||||||
|
|
||||||
|
|
||||||
|
# This class serializes to and from a google.protobuf.Any message that has been serialized to a string
|
||||||
class ProtobufMessageSerializer(MessageSerializer[ProtobufT]):
|
class ProtobufMessageSerializer(MessageSerializer[ProtobufT]):
|
||||||
def __init__(self, cls: type[ProtobufT]) -> None:
|
def __init__(self, cls: type[ProtobufT]) -> None:
|
||||||
self.cls = cls
|
self.cls = cls
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_content_type(self) -> str:
|
def data_content_type(self) -> str:
|
||||||
# TODO: This should be PROTOBUF_DATA_CONTENT_TYPE. There are currently
|
return PROTOBUF_DATA_CONTENT_TYPE
|
||||||
# a couple of hard coded places where the system assumes the
|
|
||||||
# content is JSON_DATA_CONTENT_TYPE which will need to be fixed
|
|
||||||
# first.
|
|
||||||
return JSON_DATA_CONTENT_TYPE
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type_name(self) -> str:
|
def type_name(self) -> str:
|
||||||
return _type_name(self.cls)
|
return _type_name(self.cls)
|
||||||
|
|
||||||
def deserialize(self, payload: bytes) -> ProtobufT:
|
def deserialize(self, payload: bytes) -> ProtobufT:
|
||||||
ret = self.cls()
|
# Parse payload into a proto any
|
||||||
ret.ParseFromString(payload)
|
any_proto = any_pb2.Any()
|
||||||
return ret
|
any_proto.ParseFromString(payload)
|
||||||
|
|
||||||
|
destination_message = self.cls()
|
||||||
|
|
||||||
|
if not any_proto.Unpack(destination_message): # type: ignore
|
||||||
|
raise ValueError(f"Failed to unpack payload into {self.cls}")
|
||||||
|
|
||||||
|
return destination_message
|
||||||
|
|
||||||
def serialize(self, message: ProtobufT) -> bytes:
|
def serialize(self, message: ProtobufT) -> bytes:
|
||||||
return message.SerializeToString()
|
any_proto = any_pb2.Any()
|
||||||
|
any_proto.Pack(message) # type: ignore
|
||||||
|
return any_proto.SerializeToString()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -8,7 +8,11 @@ from autogen_core.base import (
|
|||||||
SerializationRegistry,
|
SerializationRegistry,
|
||||||
try_get_known_serializers_for_type,
|
try_get_known_serializers_for_type,
|
||||||
)
|
)
|
||||||
from autogen_core.base._serialization import DataclassJsonMessageSerializer, PydanticJsonMessageSerializer
|
from autogen_core.base._serialization import (
|
||||||
|
PROTOBUF_DATA_CONTENT_TYPE,
|
||||||
|
DataclassJsonMessageSerializer,
|
||||||
|
PydanticJsonMessageSerializer,
|
||||||
|
)
|
||||||
from autogen_core.components import Image
|
from autogen_core.components import Image
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from protos.serialization_test_pb2 import NestingProtoMessage, ProtoMessage
|
from protos.serialization_test_pb2 import NestingProtoMessage, ProtoMessage
|
||||||
@ -90,12 +94,10 @@ def test_proto() -> None:
|
|||||||
|
|
||||||
message = ProtoMessage(message="hello")
|
message = ProtoMessage(message="hello")
|
||||||
name = serde.type_name(message)
|
name = serde.type_name(message)
|
||||||
# TODO: should be PROTO_DATA_CONTENT_TYPE
|
data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
|
||||||
data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
|
||||||
assert name == "ProtoMessage"
|
assert name == "ProtoMessage"
|
||||||
# TODO: assert data == stuff
|
deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
|
||||||
deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
assert deserialized.message == message.message
|
||||||
assert deserialized == message
|
|
||||||
|
|
||||||
|
|
||||||
def test_nested_proto() -> None:
|
def test_nested_proto() -> None:
|
||||||
@ -104,14 +106,10 @@ def test_nested_proto() -> None:
|
|||||||
|
|
||||||
message = NestingProtoMessage(message="hello", nested=ProtoMessage(message="world"))
|
message = NestingProtoMessage(message="hello", nested=ProtoMessage(message="world"))
|
||||||
name = serde.type_name(message)
|
name = serde.type_name(message)
|
||||||
|
data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
|
||||||
# TODO: should be PROTO_DATA_CONTENT_TYPE
|
deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
|
||||||
data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
assert deserialized.message == message.message
|
||||||
|
assert deserialized.nested.message == message.nested.message
|
||||||
# TODO: assert data == stuff
|
|
||||||
|
|
||||||
deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
|
||||||
assert deserialized == message
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user