Support datacontenttype and lay groundwork for unknown payloads (#444)

* Support datacontenttype and lay groundwork for unknown payloads

* Update dotnet based on proto changes
This commit is contained in:
Jack Gerrits 2024-09-05 16:36:59 -04:00 committed by GitHub
parent f941fe15a6
commit 8504ea0bf2
20 changed files with 358 additions and 224 deletions

View File

@ -1,4 +1,5 @@
using Agents;
using Google.Protobuf;
using Greeter.AgentWorker;
using Microsoft.AI.Agents.Worker.Client;
using AgentId = Microsoft.AI.Agents.Worker.Client.AgentId;
@ -30,7 +31,11 @@ internal sealed class GreetingAgent(IAgentContext context, ILogger<GreetingAgent
protected override Task<RpcResponse> HandleRequest(RpcRequest request)
{
logger.LogInformation("[{Id}] Received request: '{Request}'.", AgentId, request);
return Task.FromResult(new RpcResponse() { Result = "Okay!" });
return Task.FromResult(new RpcResponse() { Payload = new Payload {
DataContentType = "text/plain",
Data = ByteString.CopyFromUtf8("Hello, agents!"),
DataType = "text"
}});
}
}

View File

@ -3,6 +3,8 @@ using System.Threading.Channels;
using Microsoft.Extensions.Logging;
using System.Text.Json;
using System.Diagnostics;
using System.Text;
using Google.Protobuf;
namespace Microsoft.AI.Agents.Worker.Client;
@ -80,12 +82,12 @@ public abstract class AgentBase
{
case Message.MessageOneofCase.Event:
{
var activity = ExtractActivity(msg.Event.DataType, msg.Event.Metadata);
var activity = ExtractActivity(msg.Event.Payload.DataType, msg.Event.Metadata);
await InvokeWithActivityAsync(
static ((AgentBase Agent, Event Item) state) => state.Agent.HandleEvent(state.Item),
(this, msg.Event),
activity,
msg.Event.DataType).ConfigureAwait(false);
msg.Event.Payload.DataType).ConfigureAwait(false);
}
break;
case Message.MessageOneofCase.Request:
@ -143,7 +145,12 @@ public abstract class AgentBase
Target = target,
RequestId = requestId,
Method = method,
Data = JsonSerializer.Serialize(parameters)
Payload = new Payload{
DataType = "application/json",
Data = ByteString.CopyFrom(JsonSerializer.Serialize(parameters), Encoding.UTF8),
DataContentType = "application/json"
}
};
var activity = s_source.StartActivity($"Call '{method}'", ActivityKind.Client, Activity.Current?.Context ?? default);
@ -175,8 +182,8 @@ public abstract class AgentBase
protected async ValueTask PublishEvent(Event item)
{
var activity = s_source.StartActivity($"PublishEvent '{item.DataType}'", ActivityKind.Client, Activity.Current?.Context ?? default);
activity?.SetTag("peer.service", $"{item.DataType}/{item.Namespace}");
var activity = s_source.StartActivity($"PublishEvent '{item.Payload.DataType}'", ActivityKind.Client, Activity.Current?.Context ?? default);
activity?.SetTag("peer.service", $"{item.Payload.DataType}/{item.TopicSource}");
var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
Context.DistributedContextPropagator.Inject(activity, item.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
@ -187,7 +194,7 @@ public abstract class AgentBase
},
(this, item, completion),
activity,
item.DataType).ConfigureAwait(false);
item.Payload.DataType).ConfigureAwait(false);
}
protected virtual Task<RpcResponse> HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });
@ -224,7 +231,7 @@ public abstract class AgentBase
activity.SetTag("exception.type", e.GetType().FullName);
activity.SetTag("exception.message", e.Message);
// Note that "exception.stacktrace" is the full exception detail, not just the StackTrace property.
// Note that "exception.stacktrace" is the full exception detail, not just the StackTrace property.
// See https://opentelemetry.io/docs/specs/semconv/attributes-registry/exception/
// and https://github.com/open-telemetry/opentelemetry-specification/pull/697#discussion_r453662519
activity.SetTag("exception.stacktrace", e.ToString());

View File

@ -2,14 +2,14 @@ using RpcAgentId = Agents.AgentId;
namespace Microsoft.AI.Agents.Worker.Client;
public sealed record class AgentId(string Name, string Namespace)
public sealed record class AgentId(string Type, string Key)
{
public static implicit operator RpcAgentId(AgentId agentId) => new()
{
Name = agentId.Name,
Namespace = agentId.Namespace
Type = agentId.Type,
Key = agentId.Key
};
public static implicit operator AgentId(RpcAgentId agentId) => new(agentId.Name, agentId.Namespace);
public override string ToString() => $"{Name}/{Namespace}";
public static implicit operator AgentId(RpcAgentId agentId) => new(agentId.Type, agentId.Key);
public override string ToString() => $"{Type}/{Key}";
}

View File

@ -143,7 +143,7 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable
var item = message.Event;
foreach (var (typeName, _) in _agentTypes)
{
var agent = GetOrActivateAgent(new AgentId(typeName, item.Namespace));
var agent = GetOrActivateAgent(new AgentId(typeName, item.TopicSource));
agent.ReceiveMessage(message);
}
@ -200,17 +200,17 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable
private AgentBase GetOrActivateAgent(AgentId agentId)
{
if (!_agents.TryGetValue((agentId.Name, agentId.Namespace), out var agent))
if (!_agents.TryGetValue((agentId.Type, agentId.Key), out var agent))
{
if (_agentTypes.TryGetValue(agentId.Name, out var agentType))
if (_agentTypes.TryGetValue(agentId.Type, out var agentType))
{
var context = new AgentContext(agentId, this, _serviceProvider.GetRequiredService<ILogger<AgentBase>>(), _distributedContextPropagator);
agent = (AgentBase)ActivatorUtilities.CreateInstance(_serviceProvider, agentType, context);
_agents.TryAdd((agentId.Name, agentId.Namespace), agent);
_agents.TryAdd((agentId.Type, agentId.Key), agent);
}
else
{
throw new InvalidOperationException($"Agent type '{agentId.Name}' is unknown.");
throw new InvalidOperationException($"Agent type '{agentId.Type}' is unknown.");
}
}

View File

@ -118,13 +118,13 @@ public sealed class AgentWorkerRegistryGrain : Grain, IAgentWorkerRegistryGrain
public ValueTask<(IWorkerGateway? Gateway, bool NewPlacment)> GetOrPlaceAgent(AgentId agentId)
{
bool isNewPlacement;
if (!_agentDirectory.TryGetValue((agentId.Name, agentId.Namespace), out var worker) || !_workerStates.ContainsKey(worker))
if (!_agentDirectory.TryGetValue((agentId.Type, agentId.Key), out var worker) || !_workerStates.ContainsKey(worker))
{
worker = GetCompatibleWorkerCore(agentId.Name);
worker = GetCompatibleWorkerCore(agentId.Type);
if (worker is not null)
{
// New activation.
_agentDirectory[(agentId.Name, agentId.Namespace)] = worker;
_agentDirectory[(agentId.Type, agentId.Key)] = worker;
isNewPlacement = true;
}
else

View File

@ -49,11 +49,11 @@ internal sealed class WorkerGateway : BackgroundService, IWorkerGateway
public async ValueTask<RpcResponse> InvokeRequest(RpcRequest request)
{
(string Type, string Key) agentId = (request.Target.Name, request.Target.Namespace);
(string Type, string Key) agentId = (request.Target.Type, request.Target.Key);
if (!_agentDirectory.TryGetValue(agentId, out var connection) || connection.Completion.IsCompleted)
{
// Activate the agent on a compatible worker process.
if (_supportedAgentTypes.TryGetValue(request.Target.Name, out var workers))
if (_supportedAgentTypes.TryGetValue(request.Target.Type, out var workers))
{
connection = workers[Random.Shared.Next(workers.Count)];
_agentDirectory[agentId] = connection;
@ -163,7 +163,7 @@ internal sealed class WorkerGateway : BackgroundService, IWorkerGateway
}
/*
if (string.Equals("runtime", request.Target.Name, StringComparison.Ordinal))
if (string.Equals("runtime", request.Target.Type, StringComparison.Ordinal))
{
if (string.Equals("subscribe", request.Method))
{
@ -184,7 +184,7 @@ internal sealed class WorkerGateway : BackgroundService, IWorkerGateway
var (gateway, isPlacement) = await _gatewayRegistry.GetOrPlaceAgent(request.Target);
if (gateway is null)
{
return new RpcResponse { Error = "Agent not found and no compatible gateways were found." };
return new RpcResponse { Error = "Agent not found and no compatible gateways were found." };
}
if (isPlacement)

View File

@ -1,6 +1,9 @@
using System.Text.Json;
using Event = Microsoft.AI.Agents.Abstractions.Event;
using RpcEvent = Agents.Event;
using Payload = Agents.Payload;
using Google.Protobuf;
using System.Text;
namespace Microsoft.AI.Agents.Worker;
@ -10,13 +13,19 @@ public static class RpcEventExtensions
{
var result = new RpcEvent
{
Namespace = input.Namespace,
DataType = input.Type,
TopicSource = input.Namespace,
// TODO: Is this the right way to handle topics?
TopicType = input.Subject
};
if (input.Data is not null)
{
result.Data = JsonSerializer.Serialize(input.Data);
result.Payload = new Payload
{
Data = ByteString.CopyFrom(JsonSerializer.Serialize(input.Data), Encoding.UTF8),
DataContentType = "application/json",
DataType = input.Type
};
}
return result;
@ -26,15 +35,20 @@ public static class RpcEventExtensions
{
var result = new Event
{
Type = input.DataType,
Subject = input.Namespace,
Namespace = input.Namespace,
Type = input.Payload.DataType,
Subject = input.TopicType,
Namespace = input.TopicSource,
Data = []
};
if (input.Data is not null)
if (input.Payload is not null)
{
result.Data = JsonSerializer.Deserialize<Dictionary<string, string>>(input.Data)!;
if (input.Payload.DataContentType != "application/json")
{
throw new InvalidOperationException("Only application/json content type is supported");
}
result.Data = JsonSerializer.Deserialize<Dictionary<string, string>>(input.Payload.Data.ToString(Encoding.UTF8))!;
}
return result;

View File

@ -3,8 +3,14 @@ syntax = "proto3";
package agents;
message AgentId {
string name = 1;
string namespace = 2;
string type = 1;
string key = 2;
}
message Payload {
string data_type = 1;
string data_content_type = 2;
bytes data = 3;
}
message RpcRequest {
@ -12,26 +18,22 @@ message RpcRequest {
AgentId source = 2;
AgentId target = 3;
string method = 4;
string data_type = 5;
string data = 6;
map<string, string> metadata = 7;
Payload payload = 5;
map<string, string> metadata = 6;
}
message RpcResponse {
string request_id = 1;
string result_type = 2;
string result = 3;
string error = 4;
map<string, string> metadata = 5;
Payload payload = 2;
string error = 3;
map<string, string> metadata = 4;
}
message Event {
string namespace = 1;
string topic_type = 2;
string topic_source = 3;
string data_type = 4;
string data = 5;
map<string, string> metadata = 6;
string topic_type = 1;
string topic_source = 2;
Payload payload = 3;
map<string, string> metadata = 4;
}
message RegisterAgentType {

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Any, NoReturn
from autogen_core.application import WorkerAgentRuntime
from autogen_core.base import MESSAGE_TYPE_REGISTRY, MessageContext
from autogen_core.base import MESSAGE_TYPE_REGISTRY, MessageContext, try_get_known_codecs_for_type
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
@ -67,11 +67,11 @@ class GreeterAgent(RoutedAgent):
async def main() -> None:
runtime = WorkerAgentRuntime()
MESSAGE_TYPE_REGISTRY.add_type(Greeting)
MESSAGE_TYPE_REGISTRY.add_type(AskToGreet)
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
MESSAGE_TYPE_REGISTRY.add_type(ReturnedGreeting)
MESSAGE_TYPE_REGISTRY.add_type(ReturnedFeedback)
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Greeting))
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(AskToGreet))
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Feedback))
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(ReturnedGreeting))
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(ReturnedFeedback))
await runtime.start(host_connection_string="localhost:50051")
await runtime.register("receiver", ReceiveAgent, lambda: [DefaultSubscription()])

View File

@ -4,7 +4,13 @@ from dataclasses import dataclass
from typing import Any, NoReturn
from autogen_core.application import WorkerAgentRuntime
from autogen_core.base import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext
from autogen_core.base import (
MESSAGE_TYPE_REGISTRY,
AgentId,
AgentInstantiationContext,
MessageContext,
try_get_known_codecs_for_type,
)
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
@ -55,9 +61,9 @@ class GreeterAgent(RoutedAgent):
async def main() -> None:
runtime = WorkerAgentRuntime()
MESSAGE_TYPE_REGISTRY.add_type(Greeting)
MESSAGE_TYPE_REGISTRY.add_type(AskToGreet)
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Greeting))
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(AskToGreet))
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Feedback))
await runtime.start(host_connection_string="localhost:50051")
await runtime.register("receiver", lambda: ReceiveAgent(), lambda: [DefaultSubscription()])

View File

@ -125,9 +125,9 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
# Deliver the message to a client given the target agent type.
async with self._agent_type_to_client_id_lock:
target_client_id = self._agent_type_to_client_id.get(request.target.name)
target_client_id = self._agent_type_to_client_id.get(request.target.type)
if target_client_id is None:
logger.error(f"Agent {request.target.name} not found, failed to deliver message.")
logger.error(f"Agent {request.target.type} not found, failed to deliver message.")
return
target_send_queue = self._send_queues.get(target_client_id)
if target_send_queue is None:

View File

@ -28,6 +28,8 @@ import grpc
from grpc.aio import StreamStreamCall
from typing_extensions import Self
from autogen_core.base import JSON_DATA_CONTENT_TYPE
from ..base import (
MESSAGE_TYPE_REGISTRY,
Agent,
@ -247,14 +249,19 @@ class WorkerAgentRuntime(AgentRuntime):
self._pending_requests[request_id_str] = future
sender = cast(AgentId, sender)
data_type = MESSAGE_TYPE_REGISTRY.type_name(message)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=data_type)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(
message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
runtime_message = agent_worker_pb2.Message(
request=agent_worker_pb2.RpcRequest(
request_id=request_id_str,
target=agent_worker_pb2.AgentId(name=recipient.type, namespace=recipient.key),
source=agent_worker_pb2.AgentId(name=sender.type, namespace=sender.key),
data_type=data_type,
data=serialized_message,
target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key),
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key),
payload=agent_worker_pb2.Payload(
data_type=data_type,
data=serialized_message,
data_content_type=JSON_DATA_CONTENT_TYPE,
),
)
)
# TODO: Find a way to handle timeouts/errors
@ -277,10 +284,18 @@ class WorkerAgentRuntime(AgentRuntime):
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(
message, type_name=message_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
runtime_message = agent_worker_pb2.Message(
event=agent_worker_pb2.Event(
topic_type=topic_id.type, topic_source=topic_id.source, data_type=message_type, data=serialized_message
topic_type=topic_id.type,
topic_source=topic_id.source,
payload=agent_worker_pb2.Payload(
data_type=message_type,
data=serialized_message,
data_content_type=JSON_DATA_CONTENT_TYPE,
),
)
)
task = asyncio.create_task(self._host_connection.send(runtime_message))
@ -305,13 +320,17 @@ class WorkerAgentRuntime(AgentRuntime):
async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
assert self._host_connection is not None
target = AgentId(request.target.name, request.target.namespace)
source = AgentId(request.source.name, request.source.namespace)
target = AgentId(request.target.type, request.target.key)
source = AgentId(request.source.type, request.source.key)
logging.info(f"Processing request from {source} to {target}")
# Deserialize the message.
message = MESSAGE_TYPE_REGISTRY.deserialize(request.data, type_name=request.data_type)
message = MESSAGE_TYPE_REGISTRY.deserialize(
request.payload.data,
type_name=request.payload.data_type,
data_content_type=request.payload.data_content_type,
)
# Get the target agent and prepare the message context.
target_agent = await self._get_agent(target)
@ -339,14 +358,19 @@ class WorkerAgentRuntime(AgentRuntime):
# Serialize the result.
result_type = MESSAGE_TYPE_REGISTRY.type_name(result)
serialized_result = MESSAGE_TYPE_REGISTRY.serialize(result, type_name=result_type)
serialized_result = MESSAGE_TYPE_REGISTRY.serialize(
result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
# Create the response message.
response_message = agent_worker_pb2.Message(
response=agent_worker_pb2.RpcResponse(
request_id=request.request_id,
result_type=result_type,
result=serialized_result,
payload=agent_worker_pb2.Payload(
data_type=result_type,
data=serialized_result,
data_content_type=JSON_DATA_CONTENT_TYPE,
),
)
)
@ -355,7 +379,11 @@ class WorkerAgentRuntime(AgentRuntime):
async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None:
# Deserialize the result.
result = MESSAGE_TYPE_REGISTRY.deserialize(response.result, type_name=response.result_type)
result = MESSAGE_TYPE_REGISTRY.deserialize(
response.payload.data,
type_name=response.payload.data_type,
data_content_type=response.payload.data_content_type,
)
# Get the future and set the result.
future = self._pending_requests.pop(response.request_id)
if len(response.error) > 0:
@ -364,7 +392,9 @@ class WorkerAgentRuntime(AgentRuntime):
future.set_result(result)
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.data_type)
message = MESSAGE_TYPE_REGISTRY.deserialize(
event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type
)
topic_id = TopicId(event.topic_type, event.topic_source)
# Get the recipients for the topic.
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)

View File

@ -24,7 +24,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xf8\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x11\n\tdata_type\x18\x05 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\t\x12\x32\n\x08metadata\x18\x07 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xbb\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x13\n\x0bresult_type\x18\x02 \x01(\t\x12\x0e\n\x06result\x18\x03 \x01(\t\x12\r\n\x05\x65rror\x18\x04 \x01(\t\x12\x33\n\x08metadata\x18\x05 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc5\x01\n\x05\x45vent\x12\x11\n\tnamespace\x18\x01 \x01(\t\x12\x12\n\ntopic_type\x18\x02 \x01(\t\x12\x14\n\x0ctopic_source\x18\x03 \x01(\t\x12\x11\n\tdata_type\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12-\n\x08metadata\x18\x06 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xf9\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb3\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12 \n\x07payload\x18\x03 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x04 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@ -38,29 +38,31 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_EVENT_METADATAENTRY']._loaded_options = None
_globals['_EVENT_METADATAENTRY']._serialized_options = b'8\001'
_globals['_AGENTID']._serialized_start=30
_globals['_AGENTID']._serialized_end=72
_globals['_RPCREQUEST']._serialized_start=75
_globals['_RPCREQUEST']._serialized_end=323
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=276
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=323
_globals['_RPCRESPONSE']._serialized_start=326
_globals['_RPCRESPONSE']._serialized_end=513
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=276
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=323
_globals['_EVENT']._serialized_start=516
_globals['_EVENT']._serialized_end=713
_globals['_EVENT_METADATAENTRY']._serialized_start=276
_globals['_EVENT_METADATAENTRY']._serialized_end=323
_globals['_REGISTERAGENTTYPE']._serialized_start=715
_globals['_REGISTERAGENTTYPE']._serialized_end=748
_globals['_TYPESUBSCRIPTION']._serialized_start=750
_globals['_TYPESUBSCRIPTION']._serialized_end=808
_globals['_SUBSCRIPTION']._serialized_start=810
_globals['_SUBSCRIPTION']._serialized_end=894
_globals['_ADDSUBSCRIPTION']._serialized_start=896
_globals['_ADDSUBSCRIPTION']._serialized_end=957
_globals['_MESSAGE']._serialized_start=960
_globals['_MESSAGE']._serialized_end=1200
_globals['_AGENTRPC']._serialized_start=1202
_globals['_AGENTRPC']._serialized_end=1265
_globals['_AGENTID']._serialized_end=66
_globals['_PAYLOAD']._serialized_start=68
_globals['_PAYLOAD']._serialized_end=137
_globals['_RPCREQUEST']._serialized_start=140
_globals['_RPCREQUEST']._serialized_end=389
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=342
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=389
_globals['_RPCRESPONSE']._serialized_start=392
_globals['_RPCRESPONSE']._serialized_end=576
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=342
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=389
_globals['_EVENT']._serialized_start=579
_globals['_EVENT']._serialized_end=758
_globals['_EVENT_METADATAENTRY']._serialized_start=342
_globals['_EVENT_METADATAENTRY']._serialized_end=389
_globals['_REGISTERAGENTTYPE']._serialized_start=760
_globals['_REGISTERAGENTTYPE']._serialized_end=793
_globals['_TYPESUBSCRIPTION']._serialized_start=795
_globals['_TYPESUBSCRIPTION']._serialized_end=853
_globals['_SUBSCRIPTION']._serialized_start=855
_globals['_SUBSCRIPTION']._serialized_end=939
_globals['_ADDSUBSCRIPTION']._serialized_start=941
_globals['_ADDSUBSCRIPTION']._serialized_end=1002
_globals['_MESSAGE']._serialized_start=1005
_globals['_MESSAGE']._serialized_end=1245
_globals['_AGENTRPC']._serialized_start=1247
_globals['_AGENTRPC']._serialized_end=1310
# @@protoc_insertion_point(module_scope)

View File

@ -16,20 +16,41 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
class AgentId(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NAME_FIELD_NUMBER: builtins.int
NAMESPACE_FIELD_NUMBER: builtins.int
name: builtins.str
namespace: builtins.str
TYPE_FIELD_NUMBER: builtins.int
KEY_FIELD_NUMBER: builtins.int
type: builtins.str
key: builtins.str
def __init__(
self,
*,
name: builtins.str = ...,
namespace: builtins.str = ...,
type: builtins.str = ...,
key: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["name", b"name", "namespace", b"namespace"]) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "type", b"type"]) -> None: ...
global___AgentId = AgentId
@typing.final
class Payload(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DATA_TYPE_FIELD_NUMBER: builtins.int
DATA_CONTENT_TYPE_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
data_type: builtins.str
data_content_type: builtins.str
data: builtins.bytes
def __init__(
self,
*,
data_type: builtins.str = ...,
data_content_type: builtins.str = ...,
data: builtins.bytes = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"]) -> None: ...
global___Payload = Payload
@typing.final
class RpcRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@ -54,18 +75,17 @@ class RpcRequest(google.protobuf.message.Message):
SOURCE_FIELD_NUMBER: builtins.int
TARGET_FIELD_NUMBER: builtins.int
METHOD_FIELD_NUMBER: builtins.int
DATA_TYPE_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
PAYLOAD_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
request_id: builtins.str
method: builtins.str
data_type: builtins.str
data: builtins.str
@property
def source(self) -> global___AgentId: ...
@property
def target(self) -> global___AgentId: ...
@property
def payload(self) -> global___Payload: ...
@property
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
def __init__(
self,
@ -74,12 +94,11 @@ class RpcRequest(google.protobuf.message.Message):
source: global___AgentId | None = ...,
target: global___AgentId | None = ...,
method: builtins.str = ...,
data_type: builtins.str = ...,
data: builtins.str = ...,
payload: global___Payload | None = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["source", b"source", "target", b"target"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "method", b"method", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
def HasField(self, field_name: typing.Literal["payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
global___RpcRequest = RpcRequest
@ -104,26 +123,25 @@ class RpcResponse(google.protobuf.message.Message):
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
REQUEST_ID_FIELD_NUMBER: builtins.int
RESULT_TYPE_FIELD_NUMBER: builtins.int
RESULT_FIELD_NUMBER: builtins.int
PAYLOAD_FIELD_NUMBER: builtins.int
ERROR_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
request_id: builtins.str
result_type: builtins.str
result: builtins.str
error: builtins.str
@property
def payload(self) -> global___Payload: ...
@property
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
def __init__(
self,
*,
request_id: builtins.str = ...,
result_type: builtins.str = ...,
result: builtins.str = ...,
payload: global___Payload | None = ...,
error: builtins.str = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "request_id", b"request_id", "result", b"result", "result_type", b"result_type"]) -> None: ...
def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"]) -> None: ...
global___RpcResponse = RpcResponse
@ -147,30 +165,26 @@ class Event(google.protobuf.message.Message):
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
NAMESPACE_FIELD_NUMBER: builtins.int
TOPIC_TYPE_FIELD_NUMBER: builtins.int
TOPIC_SOURCE_FIELD_NUMBER: builtins.int
DATA_TYPE_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
PAYLOAD_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
namespace: builtins.str
topic_type: builtins.str
topic_source: builtins.str
data_type: builtins.str
data: builtins.str
@property
def payload(self) -> global___Payload: ...
@property
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
def __init__(
self,
*,
namespace: builtins.str = ...,
topic_type: builtins.str = ...,
topic_source: builtins.str = ...,
data_type: builtins.str = ...,
data: builtins.str = ...,
payload: global___Payload | None = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "namespace", b"namespace", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["metadata", b"metadata", "payload", b"payload", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
global___Event = Event

View File

@ -14,7 +14,14 @@ from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
from ._message_context import MessageContext
from ._message_handler_context import MessageHandlerContext
from ._serialization import MESSAGE_TYPE_REGISTRY, Serialization, TypeDeserializer, TypeSerializer
from ._serialization import (
JSON_DATA_CONTENT_TYPE,
MESSAGE_TYPE_REGISTRY,
MessageCodec,
Serialization,
UnknownPayload,
try_get_known_codecs_for_type,
)
from ._subscription import Subscription
from ._subscription_context import SubscriptionInstantiationContext
from ._topic import TopicId
@ -30,8 +37,6 @@ __all__ = [
"AgentChildren",
"AgentInstantiationContext",
"MESSAGE_TYPE_REGISTRY",
"TypeSerializer",
"TypeDeserializer",
"TopicId",
"Subscription",
"MessageContext",
@ -39,4 +44,8 @@ __all__ = [
"AgentType",
"SubscriptionInstantiationContext",
"MessageHandlerContext",
"JSON_DATA_CONTENT_TYPE",
"MessageCodec",
"try_get_known_codecs_for_type",
"UnknownPayload",
]

View File

@ -1,4 +1,4 @@
from typing import Any, Mapping, Protocol, runtime_checkable
from typing import Any, List, Mapping, Protocol, runtime_checkable
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata

View File

@ -1,9 +1,23 @@
import json
from dataclasses import asdict
from typing import Any, ClassVar, Dict, Protocol, TypeVar, cast, runtime_checkable
from dataclasses import asdict, dataclass
from typing import Any, ClassVar, Dict, List, Protocol, TypeVar, cast, runtime_checkable
from pydantic import BaseModel
T = TypeVar("T")
class MessageCodec(Protocol[T]):
@property
def data_content_type(self) -> str: ...
@property
def type_name(self) -> str: ...
def deserialize(self, payload: bytes) -> T: ...
def serialize(self, message: T) -> bytes: ...
@runtime_checkable
class IsDataclass(Protocol):
@ -26,53 +40,62 @@ def has_nested_base_model(cls: type[IsDataclass]) -> bool:
return any(issubclass(f.type, BaseModel) for f in cls.__dataclass_fields__.values())
T = TypeVar("T", covariant=True)
class TypeDeserializer(Protocol[T]):
def deserialize(self, message: str) -> T: ...
U = TypeVar("U", contravariant=True)
class TypeSerializer(Protocol[U]):
def serialize(self, message: U) -> str: ...
DataclassT = TypeVar("DataclassT", bound=IsDataclass)
JSON_DATA_CONTENT_TYPE = "application/json"
class DataclassTypeDeserializer(TypeDeserializer[DataclassT]):
def __init__(self, cls: type[DataclassT]) -> None:
class DataclassJsonMessageCodec(MessageCodec[IsDataclass]):
def __init__(self, cls: type[IsDataclass]) -> None:
self.cls = cls
def deserialize(self, message: str) -> DataclassT:
return self.cls(**json.loads(message))
@property
def data_content_type(self) -> str:
return JSON_DATA_CONTENT_TYPE
@property
def type_name(self) -> str:
return _type_name(self.cls)
class DataclassTypeSerializer(TypeSerializer[IsDataclass]):
def serialize(self, message: IsDataclass) -> str:
def deserialize(self, payload: bytes) -> IsDataclass:
message_str = payload.decode("utf-8")
return self.cls(**json.loads(message_str))
def serialize(self, message: IsDataclass) -> bytes:
if has_nested_dataclass(type(message)) or has_nested_base_model(type(message)):
raise ValueError("Dataclass has nested dataclasses or base models, which are not supported")
return json.dumps(asdict(message))
return json.dumps(asdict(message)).encode("utf-8")
PydanticT = TypeVar("PydanticT", bound=BaseModel)
class PydanticTypeDeserializer(TypeDeserializer[PydanticT]):
class PydanticJsonMessageCodec(MessageCodec[PydanticT]):
def __init__(self, cls: type[PydanticT]) -> None:
self.cls = cls
def deserialize(self, message: str) -> PydanticT:
return self.cls.model_validate_json(message)
@property
def data_content_type(self) -> str:
return JSON_DATA_CONTENT_TYPE
@property
def type_name(self) -> str:
return _type_name(self.cls)
def deserialize(self, payload: bytes) -> PydanticT:
message_str = payload.decode("utf-8")
return self.cls.model_validate_json(message_str)
def serialize(self, message: PydanticT) -> bytes:
return message.model_dump_json().encode("utf-8")
class PydanticTypeSerializer(TypeSerializer[BaseModel]):
def serialize(self, message: BaseModel) -> str:
return message.model_dump_json()
@dataclass
class UnknownPayload:
type_name: str
data_content_type: str
payload: bytes
def _type_name(cls: type[Any] | Any) -> str:
@ -85,38 +108,49 @@ def _type_name(cls: type[Any] | Any) -> str:
V = TypeVar("V")
def try_get_known_codecs_for_type(cls: type[Any]) -> list[MessageCodec[Any]]:
# TODO: Support protobuf types
codecs: List[MessageCodec[Any]] = []
if issubclass(cls, BaseModel):
codecs.append(PydanticJsonMessageCodec(cls))
elif isinstance(cls, IsDataclass):
codecs.append(DataclassJsonMessageCodec(cls))
return codecs
class Serialization:
def __init__(self) -> None:
self._deserializers: Dict[str, TypeDeserializer[Any]] = {}
self._serializers: Dict[str, TypeSerializer[Any]] = {}
# type_name, data_content_type -> codec
self._codecs: dict[tuple[str, str], MessageCodec[Any]] = {}
def add_type(self, message_type: type[BaseModel] | type[IsDataclass]) -> None:
if issubclass(message_type, BaseModel):
self.add_type_custom(
_type_name(message_type), PydanticTypeDeserializer(message_type), PydanticTypeSerializer()
)
elif isinstance(message_type, IsDataclass):
self.add_type_custom(
_type_name(message_type), DataclassTypeDeserializer(message_type), DataclassTypeSerializer()
)
else:
raise ValueError(f"Unsupported type {message_type}")
def add_codec(self, codec: MessageCodec[Any] | List[MessageCodec[Any]]) -> None:
if isinstance(codec, list):
for c in codec:
self.add_codec(c)
return
def add_type_custom(self, type_name: str, deserializer: TypeDeserializer[V], serializer: TypeSerializer[V]) -> None:
self._deserializers[type_name] = deserializer
self._serializers[type_name] = serializer
self._codecs[(codec.type_name, codec.data_content_type)] = codec
def deserialize(self, message: str, *, type_name: str) -> Any:
return self._deserializers[type_name].deserialize(message)
def deserialize(self, payload: bytes, *, type_name: str, data_content_type: str) -> Any:
codec = self._codecs.get((type_name, data_content_type))
if codec is None:
return UnknownPayload(type_name, data_content_type, payload)
return codec.deserialize(payload)
def serialize(self, message: Any, *, type_name: str, data_content_type: str) -> bytes:
codec = self._codecs.get((type_name, data_content_type))
if codec is None:
raise ValueError(f"Unknown type {type_name} with content type {data_content_type}")
return codec.serialize(message)
def is_registered(self, type_name: str, data_content_type: str) -> bool:
return (type_name, data_content_type) in self._codecs
def type_name(self, message: Any) -> str:
return _type_name(message)
def serialize(self, message: Any, *, type_name: str) -> str:
return self._serializers[type_name].serialize(message)
def is_registered(self, type_name: str) -> bool:
return type_name in self._deserializers
MESSAGE_TYPE_REGISTRY = Serialization()

View File

@ -8,7 +8,7 @@ from ..base._agent_id import AgentId
from ..base._agent_instantiation import AgentInstantiationContext
from ..base._agent_metadata import AgentMetadata
from ..base._agent_runtime import AgentRuntime
from ..base._serialization import MESSAGE_TYPE_REGISTRY
from ..base._serialization import JSON_DATA_CONTENT_TYPE, MESSAGE_TYPE_REGISTRY, try_get_known_codecs_for_type
from ..base.exceptions import CantHandleException
from ._type_helpers import get_types
@ -59,9 +59,10 @@ class ClosureAgent(Agent):
self._id: AgentId = id
self._description = description
subscription_types = get_subscriptions_from_closure(closure)
# TODO fold this into runtime
for message_type in subscription_types:
if not MESSAGE_TYPE_REGISTRY.is_registered(MESSAGE_TYPE_REGISTRY.type_name(message_type)):
MESSAGE_TYPE_REGISTRY.add_type(message_type)
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(message_type))
self._subscriptions = [MESSAGE_TYPE_REGISTRY.type_name(message_type) for message_type in subscription_types]
self._closure = closure

View File

@ -17,6 +17,8 @@ from typing import (
runtime_checkable,
)
from autogen_core.base import try_get_known_codecs_for_type
from ..base import MESSAGE_TYPE_REGISTRY, BaseAgent, MessageContext
from ..base.exceptions import CantHandleException
from ._type_helpers import AnyType, get_types
@ -144,8 +146,8 @@ class RoutedAgent(BaseAgent):
self._handlers[target_type] = message_handler
for message_type in self._handlers.keys():
if not MESSAGE_TYPE_REGISTRY.is_registered(MESSAGE_TYPE_REGISTRY.type_name(message_type)):
MESSAGE_TYPE_REGISTRY.add_type(message_type)
for codec in try_get_known_codecs_for_type(message_type):
MESSAGE_TYPE_REGISTRY.add_codec(codec)
super().__init__(description)

View File

@ -1,11 +1,10 @@
#custom type
from pydantic import BaseModel
from dataclasses import dataclass
import pytest
from autogen_core.base import Serialization
from autogen_core.base import JSON_DATA_CONTENT_TYPE, MessageCodec, try_get_known_codecs_for_type
class PydanticMessage(BaseModel):
message: str
@ -30,77 +29,86 @@ class NestingPydanticDataclassMessage:
def test_pydantic() -> None:
serde = Serialization()
serde.add_type(PydanticMessage)
serde.add_codec(try_get_known_codecs_for_type(PydanticMessage))
message = PydanticMessage(message="hello")
name = serde.type_name(message)
json = serde.serialize(message, type_name=name)
json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert name == "PydanticMessage"
assert json == '{"message":"hello"}'
deserialized = serde.deserialize(json, type_name=name)
assert json == b'{"message":"hello"}'
deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert deserialized == message
def test_nested_pydantic() -> None:
serde = Serialization()
serde.add_type(NestingPydanticMessage)
serde.add_codec(try_get_known_codecs_for_type(NestingPydanticMessage))
message = NestingPydanticMessage(message="hello", nested=PydanticMessage(message="world"))
name = serde.type_name(message)
json = serde.serialize(message, type_name=name)
assert json == '{"message":"hello","nested":{"message":"world"}}'
deserialized = serde.deserialize(json, type_name=name)
json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert json == b'{"message":"hello","nested":{"message":"world"}}'
deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert deserialized == message
def test_dataclass() -> None:
serde = Serialization()
serde.add_type(DataclassMessage)
serde.add_codec(try_get_known_codecs_for_type(DataclassMessage))
message = DataclassMessage(message="hello")
name = serde.type_name(message)
json = serde.serialize(message, type_name=name)
assert json == '{"message": "hello"}'
deserialized = serde.deserialize(json, type_name=name)
json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert json == b'{"message": "hello"}'
deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert deserialized == message
def test_nesting_dataclass_dataclass() -> None:
serde = Serialization()
serde.add_type(NestingDataclassMessage)
serde.add_codec(try_get_known_codecs_for_type(NestingDataclassMessage))
message = NestingDataclassMessage(message="hello", nested=DataclassMessage(message="world"))
name = serde.type_name(message)
with pytest.raises(ValueError):
_json = serde.serialize(message, type_name=name)
_json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
def test_nesting_dataclass_pydantic() -> None:
serde = Serialization()
serde.add_type(NestingPydanticDataclassMessage)
serde.add_codec(try_get_known_codecs_for_type(NestingPydanticDataclassMessage))
message = NestingPydanticDataclassMessage(message="hello", nested=PydanticMessage(message="world"))
name = serde.type_name(message)
with pytest.raises(ValueError):
_json = serde.serialize(message, type_name=name)
_json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
def test_invalid_type() -> None:
serde = Serialization()
try:
serde.add_type(str) # type: ignore
serde.add_codec(try_get_known_codecs_for_type(str))
except ValueError as e:
assert str(e) == "Unsupported type <class 'str'>"
def test_custom_type() -> None:
serde = Serialization()
class CustomStringTypeDeserializer:
def deserialize(self, message: str) -> str:
class CustomStringTypeCodec(MessageCodec[str]):
@property
def data_content_type(self) -> str:
return "str"
@property
def type_name(self) -> str:
return "custom_str"
def deserialize(self, payload: bytes) -> str:
message = payload.decode("utf-8")
return message[1:-1]
class CustomStringTypeSerializer:
def serialize(self, message: str) -> str:
return f'"{message}"'
def serialize(self, message: str) -> bytes:
return f'"{message}"'.encode("utf-8")
serde.add_type_custom("custom_str", CustomStringTypeDeserializer(), CustomStringTypeSerializer())
serde.add_codec(CustomStringTypeCodec())
message = "hello"
json = serde.serialize(message, type_name="custom_str")
assert json == '"hello"'
deserialized = serde.deserialize(json, type_name="custom_str")
assert deserialized == message
json = serde.serialize(message, type_name="custom_str", data_content_type="str")
assert json == b'"hello"'
deserialized = serde.deserialize(json, type_name="custom_str", data_content_type="str")
assert deserialized == message