Add infra for message serialization (#122)

This commit is contained in:
Jack Gerrits 2024-06-25 12:08:38 -04:00 committed by GitHub
parent 9afd1fbe93
commit f5f4c39238
2 changed files with 222 additions and 0 deletions

View File

@ -0,0 +1,116 @@
import json
from dataclasses import asdict
from typing import Any, ClassVar, Dict, Protocol, TypeVar, cast, runtime_checkable
from pydantic import BaseModel
@runtime_checkable
class IsDataclass(Protocol):
# as already noted in comments, checking for this attribute is currently
# the most reliable way to ascertain that something is a dataclass
__dataclass_fields__: ClassVar[Dict[str, Any]]
def is_dataclass(cls: type[Any]) -> bool:
return isinstance(cls, IsDataclass)
def has_nested_dataclass(cls: type[IsDataclass]) -> bool:
# iterate fields and check if any of them are dataclasses
return any(is_dataclass(f.type) for f in cls.__dataclass_fields__.values())
def has_nested_base_model(cls: type[IsDataclass]) -> bool:
# iterate fields and check if any of them are basebodels
return any(issubclass(f.type, BaseModel) for f in cls.__dataclass_fields__.values())
T = TypeVar("T", covariant=True)
class TypeDeserializer(Protocol[T]):
def deserialize(self, message: str) -> T: ...
U = TypeVar("U", contravariant=True)
class TypeSerializer(Protocol[U]):
def serialize(self, message: U) -> str: ...
DataclassT = TypeVar("DataclassT", bound=IsDataclass)
class DataclassTypeDeserializer(TypeDeserializer[DataclassT]):
def __init__(self, cls: type[DataclassT]) -> None:
self.cls = cls
def deserialize(self, message: str) -> DataclassT:
return self.cls(**json.loads(message))
class DataclassTypeSerializer(TypeSerializer[IsDataclass]):
def serialize(self, message: IsDataclass) -> str:
if has_nested_dataclass(type(message)) or has_nested_base_model(type(message)):
raise ValueError("Dataclass has nested dataclasses or base models, which are not supported")
return json.dumps(asdict(message))
PydanticT = TypeVar("PydanticT", bound=BaseModel)
class PydanticTypeDeserializer(TypeDeserializer[PydanticT]):
def __init__(self, cls: type[PydanticT]) -> None:
self.cls = cls
def deserialize(self, message: str) -> PydanticT:
return self.cls.model_validate_json(message)
class PydanticTypeSerializer(TypeSerializer[BaseModel]):
def serialize(self, message: BaseModel) -> str:
return message.model_dump_json()
def _type_name(cls: type[Any] | Any) -> str:
if isinstance(cls, type):
return cls.__name__
else:
return cast(str, cls.__class__.__name__)
V = TypeVar("V")
class Serialization:
def __init__(self) -> None:
self._deserializers: Dict[str, TypeDeserializer[Any]] = {}
self._serializers: Dict[str, TypeSerializer[Any]] = {}
def add_type(self, message_type: type[BaseModel] | type[IsDataclass]) -> None:
if issubclass(message_type, BaseModel):
self.add_type_custom(
_type_name(message_type), PydanticTypeDeserializer(message_type), PydanticTypeSerializer()
)
elif isinstance(message_type, IsDataclass):
self.add_type_custom(
_type_name(message_type), DataclassTypeDeserializer(message_type), DataclassTypeSerializer()
)
else:
raise ValueError(f"Unsupported type {message_type}")
def add_type_custom(self, type_name: str, deserializer: TypeDeserializer[V], serializer: TypeSerializer[V]) -> None:
self._deserializers[type_name] = deserializer
self._serializers[type_name] = serializer
def deserialize(self, message: str, *, type_name: str) -> Any:
return self._deserializers[type_name].deserialize(message)
def type_name(self, message: Any) -> str:
return _type_name(message)
def serialize(self, message: Any, *, type_name: str) -> str:
return self._serializers[type_name].serialize(message)

View File

@ -0,0 +1,106 @@
#custom type
from pydantic import BaseModel
from dataclasses import dataclass
import pytest
from agnext.application.message_serialization import Serialization
class PydanticMessage(BaseModel):
message: str
class NestingPydanticMessage(BaseModel):
message: str
nested: PydanticMessage
@dataclass
class DataclassMessage:
message: str
@dataclass
class NestingDataclassMessage:
message: str
nested: DataclassMessage
@dataclass
class NestingPydanticDataclassMessage:
message: str
nested: PydanticMessage
def test_pydantic() -> None:
serde = Serialization()
serde.add_type(PydanticMessage)
message = PydanticMessage(message="hello")
name = serde.type_name(message)
json = serde.serialize(message, type_name=name)
assert name == "PydanticMessage"
assert json == '{"message":"hello"}'
deserialized = serde.deserialize(json, type_name=name)
assert deserialized == message
def test_nested_pydantic() -> None:
serde = Serialization()
serde.add_type(NestingPydanticMessage)
message = NestingPydanticMessage(message="hello", nested=PydanticMessage(message="world"))
name = serde.type_name(message)
json = serde.serialize(message, type_name=name)
assert json == '{"message":"hello","nested":{"message":"world"}}'
deserialized = serde.deserialize(json, type_name=name)
assert deserialized == message
def test_dataclass() -> None:
serde = Serialization()
serde.add_type(DataclassMessage)
message = DataclassMessage(message="hello")
name = serde.type_name(message)
json = serde.serialize(message, type_name=name)
assert json == '{"message": "hello"}'
deserialized = serde.deserialize(json, type_name=name)
assert deserialized == message
def test_nesting_dataclass_dataclass() -> None:
serde = Serialization()
serde.add_type(NestingDataclassMessage)
message = NestingDataclassMessage(message="hello", nested=DataclassMessage(message="world"))
name = serde.type_name(message)
with pytest.raises(ValueError):
_json = serde.serialize(message, type_name=name)
def test_nesting_dataclass_pydantic() -> None:
serde = Serialization()
serde.add_type(NestingPydanticDataclassMessage)
message = NestingPydanticDataclassMessage(message="hello", nested=PydanticMessage(message="world"))
name = serde.type_name(message)
with pytest.raises(ValueError):
_json = serde.serialize(message, type_name=name)
def test_invalid_type() -> None:
serde = Serialization()
try:
serde.add_type(str) # type: ignore
except ValueError as e:
assert str(e) == "Unsupported type <class 'str'>"
def test_custom_type() -> None:
serde = Serialization()
class CustomStringTypeDeserializer:
def deserialize(self, message: str) -> str:
return message[1:-1]
class CustomStringTypeSerializer:
def serialize(self, message: str) -> str:
return f'"{message}"'
serde.add_type_custom("custom_str", CustomStringTypeDeserializer(), CustomStringTypeSerializer())
message = "hello"
json = serde.serialize(message, type_name="custom_str")
assert json == '"hello"'
deserialized = serde.deserialize(json, type_name="custom_str")
assert deserialized == message