Serialize to Proto.Any for python serializer (#4404)

This commit is contained in:
Jack Gerrits 2024-11-27 10:32:01 -05:00 committed by GitHub
parent a4e6d0d977
commit bd77ccbd7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 23 deletions

View File

@ -2,6 +2,7 @@ import json
from dataclasses import asdict, dataclass, fields
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable
from google.protobuf import any_pb2
from google.protobuf.message import Message
from pydantic import BaseModel
@ -149,29 +150,35 @@ class PydanticJsonMessageSerializer(MessageSerializer[PydanticT]):
ProtobufT = TypeVar("ProtobufT", bound=Message)
# This class serializes to and from a google.protobuf.Any message that has been serialized to a string
class ProtobufMessageSerializer(MessageSerializer[ProtobufT]):
def __init__(self, cls: type[ProtobufT]) -> None:
self.cls = cls
@property
def data_content_type(self) -> str:
# TODO: This should be PROTOBUF_DATA_CONTENT_TYPE. There are currently
# a couple of hard coded places where the system assumes the
# content is JSON_DATA_CONTENT_TYPE which will need to be fixed
# first.
return JSON_DATA_CONTENT_TYPE
return PROTOBUF_DATA_CONTENT_TYPE
@property
def type_name(self) -> str:
return _type_name(self.cls)
def deserialize(self, payload: bytes) -> ProtobufT:
ret = self.cls()
ret.ParseFromString(payload)
return ret
# Parse payload into a proto any
any_proto = any_pb2.Any()
any_proto.ParseFromString(payload)
destination_message = self.cls()
if not any_proto.Unpack(destination_message): # type: ignore
raise ValueError(f"Failed to unpack payload into {self.cls}")
return destination_message
def serialize(self, message: ProtobufT) -> bytes:
return message.SerializeToString()
any_proto = any_pb2.Any()
any_proto.Pack(message) # type: ignore
return any_proto.SerializeToString()
@dataclass

View File

@ -8,7 +8,11 @@ from autogen_core.base import (
SerializationRegistry,
try_get_known_serializers_for_type,
)
from autogen_core.base._serialization import DataclassJsonMessageSerializer, PydanticJsonMessageSerializer
from autogen_core.base._serialization import (
PROTOBUF_DATA_CONTENT_TYPE,
DataclassJsonMessageSerializer,
PydanticJsonMessageSerializer,
)
from autogen_core.components import Image
from PIL import Image as PILImage
from protos.serialization_test_pb2 import NestingProtoMessage, ProtoMessage
@ -90,12 +94,10 @@ def test_proto() -> None:
message = ProtoMessage(message="hello")
name = serde.type_name(message)
# TODO: should be PROTO_DATA_CONTENT_TYPE
data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
assert name == "ProtoMessage"
# TODO: assert data == stuff
deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert deserialized == message
deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
assert deserialized.message == message.message
def test_nested_proto() -> None:
@ -104,14 +106,10 @@ def test_nested_proto() -> None:
message = NestingProtoMessage(message="hello", nested=ProtoMessage(message="world"))
name = serde.type_name(message)
# TODO: should be PROTO_DATA_CONTENT_TYPE
data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
# TODO: assert data == stuff
deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert deserialized == message
data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
assert deserialized.message == message.message
assert deserialized.nested.message == message.nested.message
@dataclass