From bd77ccbd7b96f1a7f63a41f44e5c6ea0c99d75bb Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Wed, 27 Nov 2024 10:32:01 -0500 Subject: [PATCH] Serialize to Proto.Any for python serializer (#4404) --- .../src/autogen_core/base/_serialization.py | 25 +++++++++++------- .../autogen-core/tests/test_serialization.py | 26 +++++++++---------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/base/_serialization.py b/python/packages/autogen-core/src/autogen_core/base/_serialization.py index 51fd531fe..74e028641 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_serialization.py +++ b/python/packages/autogen-core/src/autogen_core/base/_serialization.py @@ -2,6 +2,7 @@ import json from dataclasses import asdict, dataclass, fields from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable +from google.protobuf import any_pb2 from google.protobuf.message import Message from pydantic import BaseModel @@ -149,29 +150,35 @@ class PydanticJsonMessageSerializer(MessageSerializer[PydanticT]): ProtobufT = TypeVar("ProtobufT", bound=Message) +# This class serializes to and from a google.protobuf.Any message that has been serialized to a string class ProtobufMessageSerializer(MessageSerializer[ProtobufT]): def __init__(self, cls: type[ProtobufT]) -> None: self.cls = cls @property def data_content_type(self) -> str: - # TODO: This should be PROTOBUF_DATA_CONTENT_TYPE. There are currently - # a couple of hard coded places where the system assumes the - # content is JSON_DATA_CONTENT_TYPE which will need to be fixed - # first. - return JSON_DATA_CONTENT_TYPE + return PROTOBUF_DATA_CONTENT_TYPE @property def type_name(self) -> str: return _type_name(self.cls) def deserialize(self, payload: bytes) -> ProtobufT: - ret = self.cls() - ret.ParseFromString(payload) - return ret + # Parse payload into a proto any + any_proto = any_pb2.Any() + any_proto.ParseFromString(payload) + + destination_message = self.cls() + + if not any_proto.Unpack(destination_message): # type: ignore + raise ValueError(f"Failed to unpack payload into {self.cls}") + + return destination_message def serialize(self, message: ProtobufT) -> bytes: - return message.SerializeToString() + any_proto = any_pb2.Any() + any_proto.Pack(message) # type: ignore + return any_proto.SerializeToString() @dataclass diff --git a/python/packages/autogen-core/tests/test_serialization.py b/python/packages/autogen-core/tests/test_serialization.py index 6b5568411..f6ab2067c 100644 --- a/python/packages/autogen-core/tests/test_serialization.py +++ b/python/packages/autogen-core/tests/test_serialization.py @@ -8,7 +8,11 @@ from autogen_core.base import ( SerializationRegistry, try_get_known_serializers_for_type, ) -from autogen_core.base._serialization import DataclassJsonMessageSerializer, PydanticJsonMessageSerializer +from autogen_core.base._serialization import ( + PROTOBUF_DATA_CONTENT_TYPE, + DataclassJsonMessageSerializer, + PydanticJsonMessageSerializer, +) from autogen_core.components import Image from PIL import Image as PILImage from protos.serialization_test_pb2 import NestingProtoMessage, ProtoMessage @@ -90,12 +94,10 @@ def test_proto() -> None: message = ProtoMessage(message="hello") name = serde.type_name(message) - # TODO: should be PROTO_DATA_CONTENT_TYPE - data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) + data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE) assert name == "ProtoMessage" - # TODO: assert data == stuff - deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) - assert deserialized == message + deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE) + assert deserialized.message == message.message def test_nested_proto() -> None: @@ -104,14 +106,10 @@ def test_nested_proto() -> None: message = NestingProtoMessage(message="hello", nested=ProtoMessage(message="world")) name = serde.type_name(message) - - # TODO: should be PROTO_DATA_CONTENT_TYPE - data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) - - # TODO: assert data == stuff - - deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) - assert deserialized == message + data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE) + deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE) + assert deserialized.message == message.message + assert deserialized.nested.message == message.nested.message @dataclass