diff --git a/dotnet/samples/Greeter/Greeter.AgentWorker/Program.cs b/dotnet/samples/Greeter/Greeter.AgentWorker/Program.cs index ec2cfe01d..ddbc40d9d 100644 --- a/dotnet/samples/Greeter/Greeter.AgentWorker/Program.cs +++ b/dotnet/samples/Greeter/Greeter.AgentWorker/Program.cs @@ -1,4 +1,5 @@ using Agents; +using Google.Protobuf; using Greeter.AgentWorker; using Microsoft.AI.Agents.Worker.Client; using AgentId = Microsoft.AI.Agents.Worker.Client.AgentId; @@ -30,7 +31,11 @@ internal sealed class GreetingAgent(IAgentContext context, ILogger HandleRequest(RpcRequest request) { logger.LogInformation("[{Id}] Received request: '{Request}'.", AgentId, request); - return Task.FromResult(new RpcResponse() { Result = "Okay!" }); + return Task.FromResult(new RpcResponse() { Payload = new Payload { + DataContentType = "text/plain", + Data = ByteString.CopyFromUtf8("Hello, agents!"), + DataType = "text" + }}); } } diff --git a/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentBase.cs b/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentBase.cs index 1884c5284..d9df66e3d 100644 --- a/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentBase.cs +++ b/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentBase.cs @@ -3,6 +3,8 @@ using System.Threading.Channels; using Microsoft.Extensions.Logging; using System.Text.Json; using System.Diagnostics; +using System.Text; +using Google.Protobuf; namespace Microsoft.AI.Agents.Worker.Client; @@ -80,12 +82,12 @@ public abstract class AgentBase { case Message.MessageOneofCase.Event: { - var activity = ExtractActivity(msg.Event.DataType, msg.Event.Metadata); + var activity = ExtractActivity(msg.Event.Payload.DataType, msg.Event.Metadata); await InvokeWithActivityAsync( static ((AgentBase Agent, Event Item) state) => state.Agent.HandleEvent(state.Item), (this, msg.Event), activity, - msg.Event.DataType).ConfigureAwait(false); + msg.Event.Payload.DataType).ConfigureAwait(false); } break; case Message.MessageOneofCase.Request: @@ -143,7 +145,12 @@ public abstract class AgentBase Target = target, RequestId = requestId, Method = method, - Data = JsonSerializer.Serialize(parameters) + Payload = new Payload{ + DataType = "application/json", + Data = ByteString.CopyFrom(JsonSerializer.Serialize(parameters), Encoding.UTF8), + DataContentType = "application/json" + + } }; var activity = s_source.StartActivity($"Call '{method}'", ActivityKind.Client, Activity.Current?.Context ?? default); @@ -175,8 +182,8 @@ public abstract class AgentBase protected async ValueTask PublishEvent(Event item) { - var activity = s_source.StartActivity($"PublishEvent '{item.DataType}'", ActivityKind.Client, Activity.Current?.Context ?? default); - activity?.SetTag("peer.service", $"{item.DataType}/{item.Namespace}"); + var activity = s_source.StartActivity($"PublishEvent '{item.Payload.DataType}'", ActivityKind.Client, Activity.Current?.Context ?? default); + activity?.SetTag("peer.service", $"{item.Payload.DataType}/{item.TopicSource}"); var completion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); Context.DistributedContextPropagator.Inject(activity, item.Metadata, static (carrier, key, value) => ((IDictionary)carrier!)[key] = value); @@ -187,7 +194,7 @@ public abstract class AgentBase }, (this, item, completion), activity, - item.DataType).ConfigureAwait(false); + item.Payload.DataType).ConfigureAwait(false); } protected virtual Task HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" }); @@ -224,7 +231,7 @@ public abstract class AgentBase activity.SetTag("exception.type", e.GetType().FullName); activity.SetTag("exception.message", e.Message); - // Note that "exception.stacktrace" is the full exception detail, not just the StackTrace property. + // Note that "exception.stacktrace" is the full exception detail, not just the StackTrace property. // See https://opentelemetry.io/docs/specs/semconv/attributes-registry/exception/ // and https://github.com/open-telemetry/opentelemetry-specification/pull/697#discussion_r453662519 activity.SetTag("exception.stacktrace", e.ToString()); diff --git a/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentId.cs b/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentId.cs index 7fd59fde3..cf52213d1 100644 --- a/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentId.cs +++ b/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentId.cs @@ -2,14 +2,14 @@ using RpcAgentId = Agents.AgentId; namespace Microsoft.AI.Agents.Worker.Client; -public sealed record class AgentId(string Name, string Namespace) +public sealed record class AgentId(string Type, string Key) { public static implicit operator RpcAgentId(AgentId agentId) => new() { - Name = agentId.Name, - Namespace = agentId.Namespace + Type = agentId.Type, + Key = agentId.Key }; - public static implicit operator AgentId(RpcAgentId agentId) => new(agentId.Name, agentId.Namespace); - public override string ToString() => $"{Name}/{Namespace}"; + public static implicit operator AgentId(RpcAgentId agentId) => new(agentId.Type, agentId.Key); + public override string ToString() => $"{Type}/{Key}"; } diff --git a/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentWorkerRuntime.cs b/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentWorkerRuntime.cs index 39e05e00b..00b90a6e2 100644 --- a/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentWorkerRuntime.cs +++ b/dotnet/src/Microsoft.AI.Agents.Worker.Client/AgentWorkerRuntime.cs @@ -143,7 +143,7 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable var item = message.Event; foreach (var (typeName, _) in _agentTypes) { - var agent = GetOrActivateAgent(new AgentId(typeName, item.Namespace)); + var agent = GetOrActivateAgent(new AgentId(typeName, item.TopicSource)); agent.ReceiveMessage(message); } @@ -200,17 +200,17 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable private AgentBase GetOrActivateAgent(AgentId agentId) { - if (!_agents.TryGetValue((agentId.Name, agentId.Namespace), out var agent)) + if (!_agents.TryGetValue((agentId.Type, agentId.Key), out var agent)) { - if (_agentTypes.TryGetValue(agentId.Name, out var agentType)) + if (_agentTypes.TryGetValue(agentId.Type, out var agentType)) { var context = new AgentContext(agentId, this, _serviceProvider.GetRequiredService>(), _distributedContextPropagator); agent = (AgentBase)ActivatorUtilities.CreateInstance(_serviceProvider, agentType, context); - _agents.TryAdd((agentId.Name, agentId.Namespace), agent); + _agents.TryAdd((agentId.Type, agentId.Key), agent); } else { - throw new InvalidOperationException($"Agent type '{agentId.Name}' is unknown."); + throw new InvalidOperationException($"Agent type '{agentId.Type}' is unknown."); } } diff --git a/dotnet/src/Microsoft.AI.Agents.Worker.Server/AgentWorkerRegistryGrain.cs b/dotnet/src/Microsoft.AI.Agents.Worker.Server/AgentWorkerRegistryGrain.cs index 026b2080c..b153d8cf9 100644 --- a/dotnet/src/Microsoft.AI.Agents.Worker.Server/AgentWorkerRegistryGrain.cs +++ b/dotnet/src/Microsoft.AI.Agents.Worker.Server/AgentWorkerRegistryGrain.cs @@ -118,13 +118,13 @@ public sealed class AgentWorkerRegistryGrain : Grain, IAgentWorkerRegistryGrain public ValueTask<(IWorkerGateway? Gateway, bool NewPlacment)> GetOrPlaceAgent(AgentId agentId) { bool isNewPlacement; - if (!_agentDirectory.TryGetValue((agentId.Name, agentId.Namespace), out var worker) || !_workerStates.ContainsKey(worker)) + if (!_agentDirectory.TryGetValue((agentId.Type, agentId.Key), out var worker) || !_workerStates.ContainsKey(worker)) { - worker = GetCompatibleWorkerCore(agentId.Name); + worker = GetCompatibleWorkerCore(agentId.Type); if (worker is not null) { // New activation. - _agentDirectory[(agentId.Name, agentId.Namespace)] = worker; + _agentDirectory[(agentId.Type, agentId.Key)] = worker; isNewPlacement = true; } else diff --git a/dotnet/src/Microsoft.AI.Agents.Worker.Server/WorkerGateway.cs b/dotnet/src/Microsoft.AI.Agents.Worker.Server/WorkerGateway.cs index f9bdea289..429e38d82 100644 --- a/dotnet/src/Microsoft.AI.Agents.Worker.Server/WorkerGateway.cs +++ b/dotnet/src/Microsoft.AI.Agents.Worker.Server/WorkerGateway.cs @@ -49,11 +49,11 @@ internal sealed class WorkerGateway : BackgroundService, IWorkerGateway public async ValueTask InvokeRequest(RpcRequest request) { - (string Type, string Key) agentId = (request.Target.Name, request.Target.Namespace); + (string Type, string Key) agentId = (request.Target.Type, request.Target.Key); if (!_agentDirectory.TryGetValue(agentId, out var connection) || connection.Completion.IsCompleted) { // Activate the agent on a compatible worker process. - if (_supportedAgentTypes.TryGetValue(request.Target.Name, out var workers)) + if (_supportedAgentTypes.TryGetValue(request.Target.Type, out var workers)) { connection = workers[Random.Shared.Next(workers.Count)]; _agentDirectory[agentId] = connection; @@ -163,7 +163,7 @@ internal sealed class WorkerGateway : BackgroundService, IWorkerGateway } /* - if (string.Equals("runtime", request.Target.Name, StringComparison.Ordinal)) + if (string.Equals("runtime", request.Target.Type, StringComparison.Ordinal)) { if (string.Equals("subscribe", request.Method)) { @@ -184,7 +184,7 @@ internal sealed class WorkerGateway : BackgroundService, IWorkerGateway var (gateway, isPlacement) = await _gatewayRegistry.GetOrPlaceAgent(request.Target); if (gateway is null) { - return new RpcResponse { Error = "Agent not found and no compatible gateways were found." }; + return new RpcResponse { Error = "Agent not found and no compatible gateways were found." }; } if (isPlacement) diff --git a/dotnet/src/Shared/RpcEventExtensions.cs b/dotnet/src/Shared/RpcEventExtensions.cs index 1c539db49..f36702bfb 100644 --- a/dotnet/src/Shared/RpcEventExtensions.cs +++ b/dotnet/src/Shared/RpcEventExtensions.cs @@ -1,6 +1,9 @@ using System.Text.Json; using Event = Microsoft.AI.Agents.Abstractions.Event; using RpcEvent = Agents.Event; +using Payload = Agents.Payload; +using Google.Protobuf; +using System.Text; namespace Microsoft.AI.Agents.Worker; @@ -10,13 +13,19 @@ public static class RpcEventExtensions { var result = new RpcEvent { - Namespace = input.Namespace, - DataType = input.Type, + TopicSource = input.Namespace, + // TODO: Is this the right way to handle topics? + TopicType = input.Subject }; if (input.Data is not null) { - result.Data = JsonSerializer.Serialize(input.Data); + result.Payload = new Payload + { + Data = ByteString.CopyFrom(JsonSerializer.Serialize(input.Data), Encoding.UTF8), + DataContentType = "application/json", + DataType = input.Type + }; } return result; @@ -26,15 +35,20 @@ public static class RpcEventExtensions { var result = new Event { - Type = input.DataType, - Subject = input.Namespace, - Namespace = input.Namespace, + Type = input.Payload.DataType, + Subject = input.TopicType, + Namespace = input.TopicSource, Data = [] }; - if (input.Data is not null) + if (input.Payload is not null) { - result.Data = JsonSerializer.Deserialize>(input.Data)!; + if (input.Payload.DataContentType != "application/json") + { + throw new InvalidOperationException("Only application/json content type is supported"); + } + + result.Data = JsonSerializer.Deserialize>(input.Payload.Data.ToString(Encoding.UTF8))!; } return result; diff --git a/protos/agent_worker.proto b/protos/agent_worker.proto index 38dbc117e..fc204feb8 100644 --- a/protos/agent_worker.proto +++ b/protos/agent_worker.proto @@ -3,8 +3,14 @@ syntax = "proto3"; package agents; message AgentId { - string name = 1; - string namespace = 2; + string type = 1; + string key = 2; +} + +message Payload { + string data_type = 1; + string data_content_type = 2; + bytes data = 3; } message RpcRequest { @@ -12,26 +18,22 @@ message RpcRequest { AgentId source = 2; AgentId target = 3; string method = 4; - string data_type = 5; - string data = 6; - map metadata = 7; + Payload payload = 5; + map metadata = 6; } message RpcResponse { string request_id = 1; - string result_type = 2; - string result = 3; - string error = 4; - map metadata = 5; + Payload payload = 2; + string error = 3; + map metadata = 4; } message Event { - string namespace = 1; - string topic_type = 2; - string topic_source = 3; - string data_type = 4; - string data = 5; - map metadata = 6; + string topic_type = 1; + string topic_source = 2; + Payload payload = 3; + map metadata = 4; } message RegisterAgentType { diff --git a/python/packages/autogen-core/samples/worker/run_worker_pub_sub.py b/python/packages/autogen-core/samples/worker/run_worker_pub_sub.py index 94dbb128c..5119c49ec 100644 --- a/python/packages/autogen-core/samples/worker/run_worker_pub_sub.py +++ b/python/packages/autogen-core/samples/worker/run_worker_pub_sub.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Any, NoReturn from autogen_core.application import WorkerAgentRuntime -from autogen_core.base import MESSAGE_TYPE_REGISTRY, MessageContext +from autogen_core.base import MESSAGE_TYPE_REGISTRY, MessageContext, try_get_known_codecs_for_type from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler @@ -67,11 +67,11 @@ class GreeterAgent(RoutedAgent): async def main() -> None: runtime = WorkerAgentRuntime() - MESSAGE_TYPE_REGISTRY.add_type(Greeting) - MESSAGE_TYPE_REGISTRY.add_type(AskToGreet) - MESSAGE_TYPE_REGISTRY.add_type(Feedback) - MESSAGE_TYPE_REGISTRY.add_type(ReturnedGreeting) - MESSAGE_TYPE_REGISTRY.add_type(ReturnedFeedback) + MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Greeting)) + MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(AskToGreet)) + MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Feedback)) + MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(ReturnedGreeting)) + MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(ReturnedFeedback)) await runtime.start(host_connection_string="localhost:50051") await runtime.register("receiver", ReceiveAgent, lambda: [DefaultSubscription()]) diff --git a/python/packages/autogen-core/samples/worker/run_worker_rpc.py b/python/packages/autogen-core/samples/worker/run_worker_rpc.py index b00f0ef64..bbc917f01 100644 --- a/python/packages/autogen-core/samples/worker/run_worker_rpc.py +++ b/python/packages/autogen-core/samples/worker/run_worker_rpc.py @@ -4,7 +4,13 @@ from dataclasses import dataclass from typing import Any, NoReturn from autogen_core.application import WorkerAgentRuntime -from autogen_core.base import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext +from autogen_core.base import ( + MESSAGE_TYPE_REGISTRY, + AgentId, + AgentInstantiationContext, + MessageContext, + try_get_known_codecs_for_type, +) from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler @@ -55,9 +61,9 @@ class GreeterAgent(RoutedAgent): async def main() -> None: runtime = WorkerAgentRuntime() - MESSAGE_TYPE_REGISTRY.add_type(Greeting) - MESSAGE_TYPE_REGISTRY.add_type(AskToGreet) - MESSAGE_TYPE_REGISTRY.add_type(Feedback) + MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Greeting)) + MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(AskToGreet)) + MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Feedback)) await runtime.start(host_connection_string="localhost:50051") await runtime.register("receiver", lambda: ReceiveAgent(), lambda: [DefaultSubscription()]) diff --git a/python/packages/autogen-core/src/autogen_core/application/_host_runtime_servicer.py b/python/packages/autogen-core/src/autogen_core/application/_host_runtime_servicer.py index 09b33ea56..be86a4fd2 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_host_runtime_servicer.py +++ b/python/packages/autogen-core/src/autogen_core/application/_host_runtime_servicer.py @@ -125,9 +125,9 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer): async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None: # Deliver the message to a client given the target agent type. async with self._agent_type_to_client_id_lock: - target_client_id = self._agent_type_to_client_id.get(request.target.name) + target_client_id = self._agent_type_to_client_id.get(request.target.type) if target_client_id is None: - logger.error(f"Agent {request.target.name} not found, failed to deliver message.") + logger.error(f"Agent {request.target.type} not found, failed to deliver message.") return target_send_queue = self._send_queues.get(target_client_id) if target_send_queue is None: diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py index a3a112d6b..d26cc4d99 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py @@ -28,6 +28,8 @@ import grpc from grpc.aio import StreamStreamCall from typing_extensions import Self +from autogen_core.base import JSON_DATA_CONTENT_TYPE + from ..base import ( MESSAGE_TYPE_REGISTRY, Agent, @@ -247,14 +249,19 @@ class WorkerAgentRuntime(AgentRuntime): self._pending_requests[request_id_str] = future sender = cast(AgentId, sender) data_type = MESSAGE_TYPE_REGISTRY.type_name(message) - serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=data_type) + serialized_message = MESSAGE_TYPE_REGISTRY.serialize( + message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE + ) runtime_message = agent_worker_pb2.Message( request=agent_worker_pb2.RpcRequest( request_id=request_id_str, - target=agent_worker_pb2.AgentId(name=recipient.type, namespace=recipient.key), - source=agent_worker_pb2.AgentId(name=sender.type, namespace=sender.key), - data_type=data_type, - data=serialized_message, + target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key), + source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key), + payload=agent_worker_pb2.Payload( + data_type=data_type, + data=serialized_message, + data_content_type=JSON_DATA_CONTENT_TYPE, + ), ) ) # TODO: Find a way to handle timeouts/errors @@ -277,10 +284,18 @@ class WorkerAgentRuntime(AgentRuntime): if self._host_connection is None: raise RuntimeError("Host connection is not set.") message_type = MESSAGE_TYPE_REGISTRY.type_name(message) - serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type) + serialized_message = MESSAGE_TYPE_REGISTRY.serialize( + message, type_name=message_type, data_content_type=JSON_DATA_CONTENT_TYPE + ) runtime_message = agent_worker_pb2.Message( event=agent_worker_pb2.Event( - topic_type=topic_id.type, topic_source=topic_id.source, data_type=message_type, data=serialized_message + topic_type=topic_id.type, + topic_source=topic_id.source, + payload=agent_worker_pb2.Payload( + data_type=message_type, + data=serialized_message, + data_content_type=JSON_DATA_CONTENT_TYPE, + ), ) ) task = asyncio.create_task(self._host_connection.send(runtime_message)) @@ -305,13 +320,17 @@ class WorkerAgentRuntime(AgentRuntime): async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None: assert self._host_connection is not None - target = AgentId(request.target.name, request.target.namespace) - source = AgentId(request.source.name, request.source.namespace) + target = AgentId(request.target.type, request.target.key) + source = AgentId(request.source.type, request.source.key) logging.info(f"Processing request from {source} to {target}") # Deserialize the message. - message = MESSAGE_TYPE_REGISTRY.deserialize(request.data, type_name=request.data_type) + message = MESSAGE_TYPE_REGISTRY.deserialize( + request.payload.data, + type_name=request.payload.data_type, + data_content_type=request.payload.data_content_type, + ) # Get the target agent and prepare the message context. target_agent = await self._get_agent(target) @@ -339,14 +358,19 @@ class WorkerAgentRuntime(AgentRuntime): # Serialize the result. result_type = MESSAGE_TYPE_REGISTRY.type_name(result) - serialized_result = MESSAGE_TYPE_REGISTRY.serialize(result, type_name=result_type) + serialized_result = MESSAGE_TYPE_REGISTRY.serialize( + result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE + ) # Create the response message. response_message = agent_worker_pb2.Message( response=agent_worker_pb2.RpcResponse( request_id=request.request_id, - result_type=result_type, - result=serialized_result, + payload=agent_worker_pb2.Payload( + data_type=result_type, + data=serialized_result, + data_content_type=JSON_DATA_CONTENT_TYPE, + ), ) ) @@ -355,7 +379,11 @@ class WorkerAgentRuntime(AgentRuntime): async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None: # Deserialize the result. - result = MESSAGE_TYPE_REGISTRY.deserialize(response.result, type_name=response.result_type) + result = MESSAGE_TYPE_REGISTRY.deserialize( + response.payload.data, + type_name=response.payload.data_type, + data_content_type=response.payload.data_content_type, + ) # Get the future and set the result. future = self._pending_requests.pop(response.request_id) if len(response.error) > 0: @@ -364,7 +392,9 @@ class WorkerAgentRuntime(AgentRuntime): future.set_result(result) async def _process_event(self, event: agent_worker_pb2.Event) -> None: - message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.data_type) + message = MESSAGE_TYPE_REGISTRY.deserialize( + event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type + ) topic_id = TopicId(event.topic_type, event.topic_source) # Get the recipients for the topic. recipients = await self._subscription_manager.get_subscribed_recipients(topic_id) diff --git a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py index 7d42949ea..e241d018e 100644 --- a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py +++ b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py @@ -24,7 +24,7 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xf8\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x11\n\tdata_type\x18\x05 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\t\x12\x32\n\x08metadata\x18\x07 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xbb\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x13\n\x0bresult_type\x18\x02 \x01(\t\x12\x0e\n\x06result\x18\x03 \x01(\t\x12\r\n\x05\x65rror\x18\x04 \x01(\t\x12\x33\n\x08metadata\x18\x05 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc5\x01\n\x05\x45vent\x12\x11\n\tnamespace\x18\x01 \x01(\t\x12\x12\n\ntopic_type\x18\x02 \x01(\t\x12\x14\n\x0ctopic_source\x18\x03 \x01(\t\x12\x11\n\tdata_type\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12-\n\x08metadata\x18\x06 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xf9\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb3\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12 \n\x07payload\x18\x03 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x04 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -38,29 +38,31 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['_EVENT_METADATAENTRY']._loaded_options = None _globals['_EVENT_METADATAENTRY']._serialized_options = b'8\001' _globals['_AGENTID']._serialized_start=30 - _globals['_AGENTID']._serialized_end=72 - _globals['_RPCREQUEST']._serialized_start=75 - _globals['_RPCREQUEST']._serialized_end=323 - _globals['_RPCREQUEST_METADATAENTRY']._serialized_start=276 - _globals['_RPCREQUEST_METADATAENTRY']._serialized_end=323 - _globals['_RPCRESPONSE']._serialized_start=326 - _globals['_RPCRESPONSE']._serialized_end=513 - _globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=276 - _globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=323 - _globals['_EVENT']._serialized_start=516 - _globals['_EVENT']._serialized_end=713 - _globals['_EVENT_METADATAENTRY']._serialized_start=276 - _globals['_EVENT_METADATAENTRY']._serialized_end=323 - _globals['_REGISTERAGENTTYPE']._serialized_start=715 - _globals['_REGISTERAGENTTYPE']._serialized_end=748 - _globals['_TYPESUBSCRIPTION']._serialized_start=750 - _globals['_TYPESUBSCRIPTION']._serialized_end=808 - _globals['_SUBSCRIPTION']._serialized_start=810 - _globals['_SUBSCRIPTION']._serialized_end=894 - _globals['_ADDSUBSCRIPTION']._serialized_start=896 - _globals['_ADDSUBSCRIPTION']._serialized_end=957 - _globals['_MESSAGE']._serialized_start=960 - _globals['_MESSAGE']._serialized_end=1200 - _globals['_AGENTRPC']._serialized_start=1202 - _globals['_AGENTRPC']._serialized_end=1265 + _globals['_AGENTID']._serialized_end=66 + _globals['_PAYLOAD']._serialized_start=68 + _globals['_PAYLOAD']._serialized_end=137 + _globals['_RPCREQUEST']._serialized_start=140 + _globals['_RPCREQUEST']._serialized_end=389 + _globals['_RPCREQUEST_METADATAENTRY']._serialized_start=342 + _globals['_RPCREQUEST_METADATAENTRY']._serialized_end=389 + _globals['_RPCRESPONSE']._serialized_start=392 + _globals['_RPCRESPONSE']._serialized_end=576 + _globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=342 + _globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=389 + _globals['_EVENT']._serialized_start=579 + _globals['_EVENT']._serialized_end=758 + _globals['_EVENT_METADATAENTRY']._serialized_start=342 + _globals['_EVENT_METADATAENTRY']._serialized_end=389 + _globals['_REGISTERAGENTTYPE']._serialized_start=760 + _globals['_REGISTERAGENTTYPE']._serialized_end=793 + _globals['_TYPESUBSCRIPTION']._serialized_start=795 + _globals['_TYPESUBSCRIPTION']._serialized_end=853 + _globals['_SUBSCRIPTION']._serialized_start=855 + _globals['_SUBSCRIPTION']._serialized_end=939 + _globals['_ADDSUBSCRIPTION']._serialized_start=941 + _globals['_ADDSUBSCRIPTION']._serialized_end=1002 + _globals['_MESSAGE']._serialized_start=1005 + _globals['_MESSAGE']._serialized_end=1245 + _globals['_AGENTRPC']._serialized_start=1247 + _globals['_AGENTRPC']._serialized_end=1310 # @@protoc_insertion_point(module_scope) diff --git a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi index b3f4ff5b5..3a7564dfc 100644 --- a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi +++ b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi @@ -16,20 +16,41 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor class AgentId(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - NAME_FIELD_NUMBER: builtins.int - NAMESPACE_FIELD_NUMBER: builtins.int - name: builtins.str - namespace: builtins.str + TYPE_FIELD_NUMBER: builtins.int + KEY_FIELD_NUMBER: builtins.int + type: builtins.str + key: builtins.str def __init__( self, *, - name: builtins.str = ..., - namespace: builtins.str = ..., + type: builtins.str = ..., + key: builtins.str = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["name", b"name", "namespace", b"namespace"]) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "type", b"type"]) -> None: ... global___AgentId = AgentId +@typing.final +class Payload(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATA_TYPE_FIELD_NUMBER: builtins.int + DATA_CONTENT_TYPE_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + data_type: builtins.str + data_content_type: builtins.str + data: builtins.bytes + def __init__( + self, + *, + data_type: builtins.str = ..., + data_content_type: builtins.str = ..., + data: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"]) -> None: ... + +global___Payload = Payload + @typing.final class RpcRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -54,18 +75,17 @@ class RpcRequest(google.protobuf.message.Message): SOURCE_FIELD_NUMBER: builtins.int TARGET_FIELD_NUMBER: builtins.int METHOD_FIELD_NUMBER: builtins.int - DATA_TYPE_FIELD_NUMBER: builtins.int - DATA_FIELD_NUMBER: builtins.int + PAYLOAD_FIELD_NUMBER: builtins.int METADATA_FIELD_NUMBER: builtins.int request_id: builtins.str method: builtins.str - data_type: builtins.str - data: builtins.str @property def source(self) -> global___AgentId: ... @property def target(self) -> global___AgentId: ... @property + def payload(self) -> global___Payload: ... + @property def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... def __init__( self, @@ -74,12 +94,11 @@ class RpcRequest(google.protobuf.message.Message): source: global___AgentId | None = ..., target: global___AgentId | None = ..., method: builtins.str = ..., - data_type: builtins.str = ..., - data: builtins.str = ..., + payload: global___Payload | None = ..., metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["source", b"source", "target", b"target"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "method", b"method", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ... + def HasField(self, field_name: typing.Literal["payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ... global___RpcRequest = RpcRequest @@ -104,26 +123,25 @@ class RpcResponse(google.protobuf.message.Message): def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... REQUEST_ID_FIELD_NUMBER: builtins.int - RESULT_TYPE_FIELD_NUMBER: builtins.int - RESULT_FIELD_NUMBER: builtins.int + PAYLOAD_FIELD_NUMBER: builtins.int ERROR_FIELD_NUMBER: builtins.int METADATA_FIELD_NUMBER: builtins.int request_id: builtins.str - result_type: builtins.str - result: builtins.str error: builtins.str @property + def payload(self) -> global___Payload: ... + @property def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... def __init__( self, *, request_id: builtins.str = ..., - result_type: builtins.str = ..., - result: builtins.str = ..., + payload: global___Payload | None = ..., error: builtins.str = ..., metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "request_id", b"request_id", "result", b"result", "result_type", b"result_type"]) -> None: ... + def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"]) -> None: ... global___RpcResponse = RpcResponse @@ -147,30 +165,26 @@ class Event(google.protobuf.message.Message): ) -> None: ... def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - NAMESPACE_FIELD_NUMBER: builtins.int TOPIC_TYPE_FIELD_NUMBER: builtins.int TOPIC_SOURCE_FIELD_NUMBER: builtins.int - DATA_TYPE_FIELD_NUMBER: builtins.int - DATA_FIELD_NUMBER: builtins.int + PAYLOAD_FIELD_NUMBER: builtins.int METADATA_FIELD_NUMBER: builtins.int - namespace: builtins.str topic_type: builtins.str topic_source: builtins.str - data_type: builtins.str - data: builtins.str + @property + def payload(self) -> global___Payload: ... @property def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... def __init__( self, *, - namespace: builtins.str = ..., topic_type: builtins.str = ..., topic_source: builtins.str = ..., - data_type: builtins.str = ..., - data: builtins.str = ..., + payload: global___Payload | None = ..., metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "namespace", b"namespace", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ... + def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["metadata", b"metadata", "payload", b"payload", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ... global___Event = Event diff --git a/python/packages/autogen-core/src/autogen_core/base/__init__.py b/python/packages/autogen-core/src/autogen_core/base/__init__.py index a75807d61..1b2047401 100644 --- a/python/packages/autogen-core/src/autogen_core/base/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/base/__init__.py @@ -14,7 +14,14 @@ from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken from ._message_context import MessageContext from ._message_handler_context import MessageHandlerContext -from ._serialization import MESSAGE_TYPE_REGISTRY, Serialization, TypeDeserializer, TypeSerializer +from ._serialization import ( + JSON_DATA_CONTENT_TYPE, + MESSAGE_TYPE_REGISTRY, + MessageCodec, + Serialization, + UnknownPayload, + try_get_known_codecs_for_type, +) from ._subscription import Subscription from ._subscription_context import SubscriptionInstantiationContext from ._topic import TopicId @@ -30,8 +37,6 @@ __all__ = [ "AgentChildren", "AgentInstantiationContext", "MESSAGE_TYPE_REGISTRY", - "TypeSerializer", - "TypeDeserializer", "TopicId", "Subscription", "MessageContext", @@ -39,4 +44,8 @@ __all__ = [ "AgentType", "SubscriptionInstantiationContext", "MessageHandlerContext", + "JSON_DATA_CONTENT_TYPE", + "MessageCodec", + "try_get_known_codecs_for_type", + "UnknownPayload", ] diff --git a/python/packages/autogen-core/src/autogen_core/base/_agent.py b/python/packages/autogen-core/src/autogen_core/base/_agent.py index 376efa254..192f2d1c0 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_agent.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Protocol, runtime_checkable +from typing import Any, List, Mapping, Protocol, runtime_checkable from ._agent_id import AgentId from ._agent_metadata import AgentMetadata diff --git a/python/packages/autogen-core/src/autogen_core/base/_serialization.py b/python/packages/autogen-core/src/autogen_core/base/_serialization.py index abf538594..cde31ccb0 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_serialization.py +++ b/python/packages/autogen-core/src/autogen_core/base/_serialization.py @@ -1,9 +1,23 @@ import json -from dataclasses import asdict -from typing import Any, ClassVar, Dict, Protocol, TypeVar, cast, runtime_checkable +from dataclasses import asdict, dataclass +from typing import Any, ClassVar, Dict, List, Protocol, TypeVar, cast, runtime_checkable from pydantic import BaseModel +T = TypeVar("T") + + +class MessageCodec(Protocol[T]): + @property + def data_content_type(self) -> str: ... + + @property + def type_name(self) -> str: ... + + def deserialize(self, payload: bytes) -> T: ... + + def serialize(self, message: T) -> bytes: ... + @runtime_checkable class IsDataclass(Protocol): @@ -26,53 +40,62 @@ def has_nested_base_model(cls: type[IsDataclass]) -> bool: return any(issubclass(f.type, BaseModel) for f in cls.__dataclass_fields__.values()) -T = TypeVar("T", covariant=True) - - -class TypeDeserializer(Protocol[T]): - def deserialize(self, message: str) -> T: ... - - -U = TypeVar("U", contravariant=True) - - -class TypeSerializer(Protocol[U]): - def serialize(self, message: U) -> str: ... - - DataclassT = TypeVar("DataclassT", bound=IsDataclass) +JSON_DATA_CONTENT_TYPE = "application/json" -class DataclassTypeDeserializer(TypeDeserializer[DataclassT]): - def __init__(self, cls: type[DataclassT]) -> None: + +class DataclassJsonMessageCodec(MessageCodec[IsDataclass]): + def __init__(self, cls: type[IsDataclass]) -> None: self.cls = cls - def deserialize(self, message: str) -> DataclassT: - return self.cls(**json.loads(message)) + @property + def data_content_type(self) -> str: + return JSON_DATA_CONTENT_TYPE + @property + def type_name(self) -> str: + return _type_name(self.cls) -class DataclassTypeSerializer(TypeSerializer[IsDataclass]): - def serialize(self, message: IsDataclass) -> str: + def deserialize(self, payload: bytes) -> IsDataclass: + message_str = payload.decode("utf-8") + return self.cls(**json.loads(message_str)) + + def serialize(self, message: IsDataclass) -> bytes: if has_nested_dataclass(type(message)) or has_nested_base_model(type(message)): raise ValueError("Dataclass has nested dataclasses or base models, which are not supported") - return json.dumps(asdict(message)) + return json.dumps(asdict(message)).encode("utf-8") PydanticT = TypeVar("PydanticT", bound=BaseModel) -class PydanticTypeDeserializer(TypeDeserializer[PydanticT]): +class PydanticJsonMessageCodec(MessageCodec[PydanticT]): def __init__(self, cls: type[PydanticT]) -> None: self.cls = cls - def deserialize(self, message: str) -> PydanticT: - return self.cls.model_validate_json(message) + @property + def data_content_type(self) -> str: + return JSON_DATA_CONTENT_TYPE + + @property + def type_name(self) -> str: + return _type_name(self.cls) + + def deserialize(self, payload: bytes) -> PydanticT: + message_str = payload.decode("utf-8") + return self.cls.model_validate_json(message_str) + + def serialize(self, message: PydanticT) -> bytes: + return message.model_dump_json().encode("utf-8") -class PydanticTypeSerializer(TypeSerializer[BaseModel]): - def serialize(self, message: BaseModel) -> str: - return message.model_dump_json() +@dataclass +class UnknownPayload: + type_name: str + data_content_type: str + payload: bytes def _type_name(cls: type[Any] | Any) -> str: @@ -85,38 +108,49 @@ def _type_name(cls: type[Any] | Any) -> str: V = TypeVar("V") +def try_get_known_codecs_for_type(cls: type[Any]) -> list[MessageCodec[Any]]: + # TODO: Support protobuf types + codecs: List[MessageCodec[Any]] = [] + if issubclass(cls, BaseModel): + codecs.append(PydanticJsonMessageCodec(cls)) + elif isinstance(cls, IsDataclass): + codecs.append(DataclassJsonMessageCodec(cls)) + + return codecs + + class Serialization: def __init__(self) -> None: - self._deserializers: Dict[str, TypeDeserializer[Any]] = {} - self._serializers: Dict[str, TypeSerializer[Any]] = {} + # type_name, data_content_type -> codec + self._codecs: dict[tuple[str, str], MessageCodec[Any]] = {} - def add_type(self, message_type: type[BaseModel] | type[IsDataclass]) -> None: - if issubclass(message_type, BaseModel): - self.add_type_custom( - _type_name(message_type), PydanticTypeDeserializer(message_type), PydanticTypeSerializer() - ) - elif isinstance(message_type, IsDataclass): - self.add_type_custom( - _type_name(message_type), DataclassTypeDeserializer(message_type), DataclassTypeSerializer() - ) - else: - raise ValueError(f"Unsupported type {message_type}") + def add_codec(self, codec: MessageCodec[Any] | List[MessageCodec[Any]]) -> None: + if isinstance(codec, list): + for c in codec: + self.add_codec(c) + return - def add_type_custom(self, type_name: str, deserializer: TypeDeserializer[V], serializer: TypeSerializer[V]) -> None: - self._deserializers[type_name] = deserializer - self._serializers[type_name] = serializer + self._codecs[(codec.type_name, codec.data_content_type)] = codec - def deserialize(self, message: str, *, type_name: str) -> Any: - return self._deserializers[type_name].deserialize(message) + def deserialize(self, payload: bytes, *, type_name: str, data_content_type: str) -> Any: + codec = self._codecs.get((type_name, data_content_type)) + if codec is None: + return UnknownPayload(type_name, data_content_type, payload) + + return codec.deserialize(payload) + + def serialize(self, message: Any, *, type_name: str, data_content_type: str) -> bytes: + codec = self._codecs.get((type_name, data_content_type)) + if codec is None: + raise ValueError(f"Unknown type {type_name} with content type {data_content_type}") + + return codec.serialize(message) + + def is_registered(self, type_name: str, data_content_type: str) -> bool: + return (type_name, data_content_type) in self._codecs def type_name(self, message: Any) -> str: return _type_name(message) - def serialize(self, message: Any, *, type_name: str) -> str: - return self._serializers[type_name].serialize(message) - - def is_registered(self, type_name: str) -> bool: - return type_name in self._deserializers - MESSAGE_TYPE_REGISTRY = Serialization() diff --git a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py index ab8622fcd..401c00176 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py @@ -8,7 +8,7 @@ from ..base._agent_id import AgentId from ..base._agent_instantiation import AgentInstantiationContext from ..base._agent_metadata import AgentMetadata from ..base._agent_runtime import AgentRuntime -from ..base._serialization import MESSAGE_TYPE_REGISTRY +from ..base._serialization import JSON_DATA_CONTENT_TYPE, MESSAGE_TYPE_REGISTRY, try_get_known_codecs_for_type from ..base.exceptions import CantHandleException from ._type_helpers import get_types @@ -59,9 +59,10 @@ class ClosureAgent(Agent): self._id: AgentId = id self._description = description subscription_types = get_subscriptions_from_closure(closure) + # TODO fold this into runtime for message_type in subscription_types: - if not MESSAGE_TYPE_REGISTRY.is_registered(MESSAGE_TYPE_REGISTRY.type_name(message_type)): - MESSAGE_TYPE_REGISTRY.add_type(message_type) + MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(message_type)) + self._subscriptions = [MESSAGE_TYPE_REGISTRY.type_name(message_type) for message_type in subscription_types] self._closure = closure diff --git a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py index fb16b11c1..ceb058fe6 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py @@ -17,6 +17,8 @@ from typing import ( runtime_checkable, ) +from autogen_core.base import try_get_known_codecs_for_type + from ..base import MESSAGE_TYPE_REGISTRY, BaseAgent, MessageContext from ..base.exceptions import CantHandleException from ._type_helpers import AnyType, get_types @@ -144,8 +146,8 @@ class RoutedAgent(BaseAgent): self._handlers[target_type] = message_handler for message_type in self._handlers.keys(): - if not MESSAGE_TYPE_REGISTRY.is_registered(MESSAGE_TYPE_REGISTRY.type_name(message_type)): - MESSAGE_TYPE_REGISTRY.add_type(message_type) + for codec in try_get_known_codecs_for_type(message_type): + MESSAGE_TYPE_REGISTRY.add_codec(codec) super().__init__(description) diff --git a/python/packages/autogen-core/tests/test_serialization.py b/python/packages/autogen-core/tests/test_serialization.py index e63905b68..bcfca802c 100644 --- a/python/packages/autogen-core/tests/test_serialization.py +++ b/python/packages/autogen-core/tests/test_serialization.py @@ -1,11 +1,10 @@ -#custom type - from pydantic import BaseModel from dataclasses import dataclass import pytest from autogen_core.base import Serialization +from autogen_core.base import JSON_DATA_CONTENT_TYPE, MessageCodec, try_get_known_codecs_for_type class PydanticMessage(BaseModel): message: str @@ -30,77 +29,86 @@ class NestingPydanticDataclassMessage: def test_pydantic() -> None: serde = Serialization() - serde.add_type(PydanticMessage) + serde.add_codec(try_get_known_codecs_for_type(PydanticMessage)) message = PydanticMessage(message="hello") name = serde.type_name(message) - json = serde.serialize(message, type_name=name) + json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) assert name == "PydanticMessage" - assert json == '{"message":"hello"}' - deserialized = serde.deserialize(json, type_name=name) + assert json == b'{"message":"hello"}' + deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) assert deserialized == message def test_nested_pydantic() -> None: serde = Serialization() - serde.add_type(NestingPydanticMessage) + serde.add_codec(try_get_known_codecs_for_type(NestingPydanticMessage)) message = NestingPydanticMessage(message="hello", nested=PydanticMessage(message="world")) name = serde.type_name(message) - json = serde.serialize(message, type_name=name) - assert json == '{"message":"hello","nested":{"message":"world"}}' - deserialized = serde.deserialize(json, type_name=name) + json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) + assert json == b'{"message":"hello","nested":{"message":"world"}}' + deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) assert deserialized == message def test_dataclass() -> None: serde = Serialization() - serde.add_type(DataclassMessage) + serde.add_codec(try_get_known_codecs_for_type(DataclassMessage)) message = DataclassMessage(message="hello") name = serde.type_name(message) - json = serde.serialize(message, type_name=name) - assert json == '{"message": "hello"}' - deserialized = serde.deserialize(json, type_name=name) + json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) + assert json == b'{"message": "hello"}' + deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) assert deserialized == message def test_nesting_dataclass_dataclass() -> None: serde = Serialization() - serde.add_type(NestingDataclassMessage) + serde.add_codec(try_get_known_codecs_for_type(NestingDataclassMessage)) message = NestingDataclassMessage(message="hello", nested=DataclassMessage(message="world")) name = serde.type_name(message) with pytest.raises(ValueError): - _json = serde.serialize(message, type_name=name) + _json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) def test_nesting_dataclass_pydantic() -> None: serde = Serialization() - serde.add_type(NestingPydanticDataclassMessage) + serde.add_codec(try_get_known_codecs_for_type(NestingPydanticDataclassMessage)) message = NestingPydanticDataclassMessage(message="hello", nested=PydanticMessage(message="world")) name = serde.type_name(message) with pytest.raises(ValueError): - _json = serde.serialize(message, type_name=name) + _json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) def test_invalid_type() -> None: serde = Serialization() try: - serde.add_type(str) # type: ignore + serde.add_codec(try_get_known_codecs_for_type(str)) except ValueError as e: assert str(e) == "Unsupported type " def test_custom_type() -> None: serde = Serialization() - class CustomStringTypeDeserializer: - def deserialize(self, message: str) -> str: + class CustomStringTypeCodec(MessageCodec[str]): + @property + def data_content_type(self) -> str: + return "str" + + @property + def type_name(self) -> str: + return "custom_str" + + def deserialize(self, payload: bytes) -> str: + message = payload.decode("utf-8") return message[1:-1] - class CustomStringTypeSerializer: - def serialize(self, message: str) -> str: - return f'"{message}"' + def serialize(self, message: str) -> bytes: + return f'"{message}"'.encode("utf-8") - serde.add_type_custom("custom_str", CustomStringTypeDeserializer(), CustomStringTypeSerializer()) + + serde.add_codec(CustomStringTypeCodec()) message = "hello" - json = serde.serialize(message, type_name="custom_str") - assert json == '"hello"' - deserialized = serde.deserialize(json, type_name="custom_str") - assert deserialized == message \ No newline at end of file + json = serde.serialize(message, type_name="custom_str", data_content_type="str") + assert json == b'"hello"' + deserialized = serde.deserialize(json, type_name="custom_str", data_content_type="str") + assert deserialized == message