Serialize to Proto.Any for python serializer (#4404)

This commit is contained in:
Jack Gerrits 2024-11-27 10:32:01 -05:00 committed by GitHub
parent a4e6d0d977
commit bd77ccbd7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 23 deletions

View File

@ -2,6 +2,7 @@ import json
from dataclasses import asdict, dataclass, fields from dataclasses import asdict, dataclass, fields
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable
from google.protobuf import any_pb2
from google.protobuf.message import Message from google.protobuf.message import Message
from pydantic import BaseModel from pydantic import BaseModel
@ -149,29 +150,35 @@ class PydanticJsonMessageSerializer(MessageSerializer[PydanticT]):
ProtobufT = TypeVar("ProtobufT", bound=Message) ProtobufT = TypeVar("ProtobufT", bound=Message)
# This class serializes to and from a google.protobuf.Any message that has been serialized to a string
class ProtobufMessageSerializer(MessageSerializer[ProtobufT]): class ProtobufMessageSerializer(MessageSerializer[ProtobufT]):
def __init__(self, cls: type[ProtobufT]) -> None: def __init__(self, cls: type[ProtobufT]) -> None:
self.cls = cls self.cls = cls
@property @property
def data_content_type(self) -> str: def data_content_type(self) -> str:
# TODO: This should be PROTOBUF_DATA_CONTENT_TYPE. There are currently return PROTOBUF_DATA_CONTENT_TYPE
# a couple of hard coded places where the system assumes the
# content is JSON_DATA_CONTENT_TYPE which will need to be fixed
# first.
return JSON_DATA_CONTENT_TYPE
@property @property
def type_name(self) -> str: def type_name(self) -> str:
return _type_name(self.cls) return _type_name(self.cls)
def deserialize(self, payload: bytes) -> ProtobufT: def deserialize(self, payload: bytes) -> ProtobufT:
ret = self.cls() # Parse payload into a proto any
ret.ParseFromString(payload) any_proto = any_pb2.Any()
return ret any_proto.ParseFromString(payload)
destination_message = self.cls()
if not any_proto.Unpack(destination_message): # type: ignore
raise ValueError(f"Failed to unpack payload into {self.cls}")
return destination_message
def serialize(self, message: ProtobufT) -> bytes: def serialize(self, message: ProtobufT) -> bytes:
return message.SerializeToString() any_proto = any_pb2.Any()
any_proto.Pack(message) # type: ignore
return any_proto.SerializeToString()
@dataclass @dataclass

View File

@ -8,7 +8,11 @@ from autogen_core.base import (
SerializationRegistry, SerializationRegistry,
try_get_known_serializers_for_type, try_get_known_serializers_for_type,
) )
from autogen_core.base._serialization import DataclassJsonMessageSerializer, PydanticJsonMessageSerializer from autogen_core.base._serialization import (
PROTOBUF_DATA_CONTENT_TYPE,
DataclassJsonMessageSerializer,
PydanticJsonMessageSerializer,
)
from autogen_core.components import Image from autogen_core.components import Image
from PIL import Image as PILImage from PIL import Image as PILImage
from protos.serialization_test_pb2 import NestingProtoMessage, ProtoMessage from protos.serialization_test_pb2 import NestingProtoMessage, ProtoMessage
@ -90,12 +94,10 @@ def test_proto() -> None:
message = ProtoMessage(message="hello") message = ProtoMessage(message="hello")
name = serde.type_name(message) name = serde.type_name(message)
# TODO: should be PROTO_DATA_CONTENT_TYPE data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert name == "ProtoMessage" assert name == "ProtoMessage"
# TODO: assert data == stuff deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) assert deserialized.message == message.message
assert deserialized == message
def test_nested_proto() -> None: def test_nested_proto() -> None:
@ -104,14 +106,10 @@ def test_nested_proto() -> None:
message = NestingProtoMessage(message="hello", nested=ProtoMessage(message="world")) message = NestingProtoMessage(message="hello", nested=ProtoMessage(message="world"))
name = serde.type_name(message) name = serde.type_name(message)
data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
# TODO: should be PROTO_DATA_CONTENT_TYPE deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) assert deserialized.message == message.message
assert deserialized.nested.message == message.nested.message
# TODO: assert data == stuff
deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert deserialized == message
@dataclass @dataclass