mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 06:59:03 +00:00
Support datacontenttype and lay groundwork for unknown payloads (#444)
* Support datacontenttype and lay groundwork for unknown payloads * Update dotnet based on proto changes
This commit is contained in:
parent
f941fe15a6
commit
8504ea0bf2
@ -1,4 +1,5 @@
|
||||
using Agents;
|
||||
using Google.Protobuf;
|
||||
using Greeter.AgentWorker;
|
||||
using Microsoft.AI.Agents.Worker.Client;
|
||||
using AgentId = Microsoft.AI.Agents.Worker.Client.AgentId;
|
||||
@ -30,7 +31,11 @@ internal sealed class GreetingAgent(IAgentContext context, ILogger<GreetingAgent
|
||||
protected override Task<RpcResponse> HandleRequest(RpcRequest request)
|
||||
{
|
||||
logger.LogInformation("[{Id}] Received request: '{Request}'.", AgentId, request);
|
||||
return Task.FromResult(new RpcResponse() { Result = "Okay!" });
|
||||
return Task.FromResult(new RpcResponse() { Payload = new Payload {
|
||||
DataContentType = "text/plain",
|
||||
Data = ByteString.CopyFromUtf8("Hello, agents!"),
|
||||
DataType = "text"
|
||||
}});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@ using System.Threading.Channels;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using System.Text.Json;
|
||||
using System.Diagnostics;
|
||||
using System.Text;
|
||||
using Google.Protobuf;
|
||||
|
||||
namespace Microsoft.AI.Agents.Worker.Client;
|
||||
|
||||
@ -80,12 +82,12 @@ public abstract class AgentBase
|
||||
{
|
||||
case Message.MessageOneofCase.Event:
|
||||
{
|
||||
var activity = ExtractActivity(msg.Event.DataType, msg.Event.Metadata);
|
||||
var activity = ExtractActivity(msg.Event.Payload.DataType, msg.Event.Metadata);
|
||||
await InvokeWithActivityAsync(
|
||||
static ((AgentBase Agent, Event Item) state) => state.Agent.HandleEvent(state.Item),
|
||||
(this, msg.Event),
|
||||
activity,
|
||||
msg.Event.DataType).ConfigureAwait(false);
|
||||
msg.Event.Payload.DataType).ConfigureAwait(false);
|
||||
}
|
||||
break;
|
||||
case Message.MessageOneofCase.Request:
|
||||
@ -143,7 +145,12 @@ public abstract class AgentBase
|
||||
Target = target,
|
||||
RequestId = requestId,
|
||||
Method = method,
|
||||
Data = JsonSerializer.Serialize(parameters)
|
||||
Payload = new Payload{
|
||||
DataType = "application/json",
|
||||
Data = ByteString.CopyFrom(JsonSerializer.Serialize(parameters), Encoding.UTF8),
|
||||
DataContentType = "application/json"
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
var activity = s_source.StartActivity($"Call '{method}'", ActivityKind.Client, Activity.Current?.Context ?? default);
|
||||
@ -175,8 +182,8 @@ public abstract class AgentBase
|
||||
|
||||
protected async ValueTask PublishEvent(Event item)
|
||||
{
|
||||
var activity = s_source.StartActivity($"PublishEvent '{item.DataType}'", ActivityKind.Client, Activity.Current?.Context ?? default);
|
||||
activity?.SetTag("peer.service", $"{item.DataType}/{item.Namespace}");
|
||||
var activity = s_source.StartActivity($"PublishEvent '{item.Payload.DataType}'", ActivityKind.Client, Activity.Current?.Context ?? default);
|
||||
activity?.SetTag("peer.service", $"{item.Payload.DataType}/{item.TopicSource}");
|
||||
|
||||
var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
Context.DistributedContextPropagator.Inject(activity, item.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
|
||||
@ -187,7 +194,7 @@ public abstract class AgentBase
|
||||
},
|
||||
(this, item, completion),
|
||||
activity,
|
||||
item.DataType).ConfigureAwait(false);
|
||||
item.Payload.DataType).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
protected virtual Task<RpcResponse> HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });
|
||||
@ -224,7 +231,7 @@ public abstract class AgentBase
|
||||
activity.SetTag("exception.type", e.GetType().FullName);
|
||||
activity.SetTag("exception.message", e.Message);
|
||||
|
||||
// Note that "exception.stacktrace" is the full exception detail, not just the StackTrace property.
|
||||
// Note that "exception.stacktrace" is the full exception detail, not just the StackTrace property.
|
||||
// See https://opentelemetry.io/docs/specs/semconv/attributes-registry/exception/
|
||||
// and https://github.com/open-telemetry/opentelemetry-specification/pull/697#discussion_r453662519
|
||||
activity.SetTag("exception.stacktrace", e.ToString());
|
||||
|
||||
@ -2,14 +2,14 @@ using RpcAgentId = Agents.AgentId;
|
||||
|
||||
namespace Microsoft.AI.Agents.Worker.Client;
|
||||
|
||||
public sealed record class AgentId(string Name, string Namespace)
|
||||
public sealed record class AgentId(string Type, string Key)
|
||||
{
|
||||
public static implicit operator RpcAgentId(AgentId agentId) => new()
|
||||
{
|
||||
Name = agentId.Name,
|
||||
Namespace = agentId.Namespace
|
||||
Type = agentId.Type,
|
||||
Key = agentId.Key
|
||||
};
|
||||
|
||||
public static implicit operator AgentId(RpcAgentId agentId) => new(agentId.Name, agentId.Namespace);
|
||||
public override string ToString() => $"{Name}/{Namespace}";
|
||||
public static implicit operator AgentId(RpcAgentId agentId) => new(agentId.Type, agentId.Key);
|
||||
public override string ToString() => $"{Type}/{Key}";
|
||||
}
|
||||
|
||||
@ -143,7 +143,7 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable
|
||||
var item = message.Event;
|
||||
foreach (var (typeName, _) in _agentTypes)
|
||||
{
|
||||
var agent = GetOrActivateAgent(new AgentId(typeName, item.Namespace));
|
||||
var agent = GetOrActivateAgent(new AgentId(typeName, item.TopicSource));
|
||||
agent.ReceiveMessage(message);
|
||||
}
|
||||
|
||||
@ -200,17 +200,17 @@ public sealed class AgentWorkerRuntime : IHostedService, IDisposable
|
||||
|
||||
private AgentBase GetOrActivateAgent(AgentId agentId)
|
||||
{
|
||||
if (!_agents.TryGetValue((agentId.Name, agentId.Namespace), out var agent))
|
||||
if (!_agents.TryGetValue((agentId.Type, agentId.Key), out var agent))
|
||||
{
|
||||
if (_agentTypes.TryGetValue(agentId.Name, out var agentType))
|
||||
if (_agentTypes.TryGetValue(agentId.Type, out var agentType))
|
||||
{
|
||||
var context = new AgentContext(agentId, this, _serviceProvider.GetRequiredService<ILogger<AgentBase>>(), _distributedContextPropagator);
|
||||
agent = (AgentBase)ActivatorUtilities.CreateInstance(_serviceProvider, agentType, context);
|
||||
_agents.TryAdd((agentId.Name, agentId.Namespace), agent);
|
||||
_agents.TryAdd((agentId.Type, agentId.Key), agent);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException($"Agent type '{agentId.Name}' is unknown.");
|
||||
throw new InvalidOperationException($"Agent type '{agentId.Type}' is unknown.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -118,13 +118,13 @@ public sealed class AgentWorkerRegistryGrain : Grain, IAgentWorkerRegistryGrain
|
||||
public ValueTask<(IWorkerGateway? Gateway, bool NewPlacment)> GetOrPlaceAgent(AgentId agentId)
|
||||
{
|
||||
bool isNewPlacement;
|
||||
if (!_agentDirectory.TryGetValue((agentId.Name, agentId.Namespace), out var worker) || !_workerStates.ContainsKey(worker))
|
||||
if (!_agentDirectory.TryGetValue((agentId.Type, agentId.Key), out var worker) || !_workerStates.ContainsKey(worker))
|
||||
{
|
||||
worker = GetCompatibleWorkerCore(agentId.Name);
|
||||
worker = GetCompatibleWorkerCore(agentId.Type);
|
||||
if (worker is not null)
|
||||
{
|
||||
// New activation.
|
||||
_agentDirectory[(agentId.Name, agentId.Namespace)] = worker;
|
||||
_agentDirectory[(agentId.Type, agentId.Key)] = worker;
|
||||
isNewPlacement = true;
|
||||
}
|
||||
else
|
||||
|
||||
@ -49,11 +49,11 @@ internal sealed class WorkerGateway : BackgroundService, IWorkerGateway
|
||||
|
||||
public async ValueTask<RpcResponse> InvokeRequest(RpcRequest request)
|
||||
{
|
||||
(string Type, string Key) agentId = (request.Target.Name, request.Target.Namespace);
|
||||
(string Type, string Key) agentId = (request.Target.Type, request.Target.Key);
|
||||
if (!_agentDirectory.TryGetValue(agentId, out var connection) || connection.Completion.IsCompleted)
|
||||
{
|
||||
// Activate the agent on a compatible worker process.
|
||||
if (_supportedAgentTypes.TryGetValue(request.Target.Name, out var workers))
|
||||
if (_supportedAgentTypes.TryGetValue(request.Target.Type, out var workers))
|
||||
{
|
||||
connection = workers[Random.Shared.Next(workers.Count)];
|
||||
_agentDirectory[agentId] = connection;
|
||||
@ -163,7 +163,7 @@ internal sealed class WorkerGateway : BackgroundService, IWorkerGateway
|
||||
}
|
||||
|
||||
/*
|
||||
if (string.Equals("runtime", request.Target.Name, StringComparison.Ordinal))
|
||||
if (string.Equals("runtime", request.Target.Type, StringComparison.Ordinal))
|
||||
{
|
||||
if (string.Equals("subscribe", request.Method))
|
||||
{
|
||||
@ -184,7 +184,7 @@ internal sealed class WorkerGateway : BackgroundService, IWorkerGateway
|
||||
var (gateway, isPlacement) = await _gatewayRegistry.GetOrPlaceAgent(request.Target);
|
||||
if (gateway is null)
|
||||
{
|
||||
return new RpcResponse { Error = "Agent not found and no compatible gateways were found." };
|
||||
return new RpcResponse { Error = "Agent not found and no compatible gateways were found." };
|
||||
}
|
||||
|
||||
if (isPlacement)
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
using System.Text.Json;
|
||||
using Event = Microsoft.AI.Agents.Abstractions.Event;
|
||||
using RpcEvent = Agents.Event;
|
||||
using Payload = Agents.Payload;
|
||||
using Google.Protobuf;
|
||||
using System.Text;
|
||||
|
||||
namespace Microsoft.AI.Agents.Worker;
|
||||
|
||||
@ -10,13 +13,19 @@ public static class RpcEventExtensions
|
||||
{
|
||||
var result = new RpcEvent
|
||||
{
|
||||
Namespace = input.Namespace,
|
||||
DataType = input.Type,
|
||||
TopicSource = input.Namespace,
|
||||
// TODO: Is this the right way to handle topics?
|
||||
TopicType = input.Subject
|
||||
};
|
||||
|
||||
if (input.Data is not null)
|
||||
{
|
||||
result.Data = JsonSerializer.Serialize(input.Data);
|
||||
result.Payload = new Payload
|
||||
{
|
||||
Data = ByteString.CopyFrom(JsonSerializer.Serialize(input.Data), Encoding.UTF8),
|
||||
DataContentType = "application/json",
|
||||
DataType = input.Type
|
||||
};
|
||||
}
|
||||
|
||||
return result;
|
||||
@ -26,15 +35,20 @@ public static class RpcEventExtensions
|
||||
{
|
||||
var result = new Event
|
||||
{
|
||||
Type = input.DataType,
|
||||
Subject = input.Namespace,
|
||||
Namespace = input.Namespace,
|
||||
Type = input.Payload.DataType,
|
||||
Subject = input.TopicType,
|
||||
Namespace = input.TopicSource,
|
||||
Data = []
|
||||
};
|
||||
|
||||
if (input.Data is not null)
|
||||
if (input.Payload is not null)
|
||||
{
|
||||
result.Data = JsonSerializer.Deserialize<Dictionary<string, string>>(input.Data)!;
|
||||
if (input.Payload.DataContentType != "application/json")
|
||||
{
|
||||
throw new InvalidOperationException("Only application/json content type is supported");
|
||||
}
|
||||
|
||||
result.Data = JsonSerializer.Deserialize<Dictionary<string, string>>(input.Payload.Data.ToString(Encoding.UTF8))!;
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
@ -3,8 +3,14 @@ syntax = "proto3";
|
||||
package agents;
|
||||
|
||||
message AgentId {
|
||||
string name = 1;
|
||||
string namespace = 2;
|
||||
string type = 1;
|
||||
string key = 2;
|
||||
}
|
||||
|
||||
message Payload {
|
||||
string data_type = 1;
|
||||
string data_content_type = 2;
|
||||
bytes data = 3;
|
||||
}
|
||||
|
||||
message RpcRequest {
|
||||
@ -12,26 +18,22 @@ message RpcRequest {
|
||||
AgentId source = 2;
|
||||
AgentId target = 3;
|
||||
string method = 4;
|
||||
string data_type = 5;
|
||||
string data = 6;
|
||||
map<string, string> metadata = 7;
|
||||
Payload payload = 5;
|
||||
map<string, string> metadata = 6;
|
||||
}
|
||||
|
||||
message RpcResponse {
|
||||
string request_id = 1;
|
||||
string result_type = 2;
|
||||
string result = 3;
|
||||
string error = 4;
|
||||
map<string, string> metadata = 5;
|
||||
Payload payload = 2;
|
||||
string error = 3;
|
||||
map<string, string> metadata = 4;
|
||||
}
|
||||
|
||||
message Event {
|
||||
string namespace = 1;
|
||||
string topic_type = 2;
|
||||
string topic_source = 3;
|
||||
string data_type = 4;
|
||||
string data = 5;
|
||||
map<string, string> metadata = 6;
|
||||
string topic_type = 1;
|
||||
string topic_source = 2;
|
||||
Payload payload = 3;
|
||||
map<string, string> metadata = 4;
|
||||
}
|
||||
|
||||
message RegisterAgentType {
|
||||
|
||||
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_core.base import MESSAGE_TYPE_REGISTRY, MessageContext
|
||||
from autogen_core.base import MESSAGE_TYPE_REGISTRY, MessageContext, try_get_known_codecs_for_type
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
|
||||
|
||||
|
||||
@ -67,11 +67,11 @@ class GreeterAgent(RoutedAgent):
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime()
|
||||
MESSAGE_TYPE_REGISTRY.add_type(Greeting)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(AskToGreet)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(ReturnedGreeting)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(ReturnedFeedback)
|
||||
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Greeting))
|
||||
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(AskToGreet))
|
||||
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Feedback))
|
||||
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(ReturnedGreeting))
|
||||
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(ReturnedFeedback))
|
||||
await runtime.start(host_connection_string="localhost:50051")
|
||||
|
||||
await runtime.register("receiver", ReceiveAgent, lambda: [DefaultSubscription()])
|
||||
|
||||
@ -4,7 +4,13 @@ from dataclasses import dataclass
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_core.base import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext
|
||||
from autogen_core.base import (
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
MessageContext,
|
||||
try_get_known_codecs_for_type,
|
||||
)
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
|
||||
|
||||
|
||||
@ -55,9 +61,9 @@ class GreeterAgent(RoutedAgent):
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime()
|
||||
MESSAGE_TYPE_REGISTRY.add_type(Greeting)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(AskToGreet)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(Feedback)
|
||||
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Greeting))
|
||||
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(AskToGreet))
|
||||
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(Feedback))
|
||||
await runtime.start(host_connection_string="localhost:50051")
|
||||
|
||||
await runtime.register("receiver", lambda: ReceiveAgent(), lambda: [DefaultSubscription()])
|
||||
|
||||
@ -125,9 +125,9 @@ class HostRuntimeServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
|
||||
# Deliver the message to a client given the target agent type.
|
||||
async with self._agent_type_to_client_id_lock:
|
||||
target_client_id = self._agent_type_to_client_id.get(request.target.name)
|
||||
target_client_id = self._agent_type_to_client_id.get(request.target.type)
|
||||
if target_client_id is None:
|
||||
logger.error(f"Agent {request.target.name} not found, failed to deliver message.")
|
||||
logger.error(f"Agent {request.target.type} not found, failed to deliver message.")
|
||||
return
|
||||
target_send_queue = self._send_queues.get(target_client_id)
|
||||
if target_send_queue is None:
|
||||
|
||||
@ -28,6 +28,8 @@ import grpc
|
||||
from grpc.aio import StreamStreamCall
|
||||
from typing_extensions import Self
|
||||
|
||||
from autogen_core.base import JSON_DATA_CONTENT_TYPE
|
||||
|
||||
from ..base import (
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
Agent,
|
||||
@ -247,14 +249,19 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
self._pending_requests[request_id_str] = future
|
||||
sender = cast(AgentId, sender)
|
||||
data_type = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=data_type)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(
|
||||
message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
|
||||
)
|
||||
runtime_message = agent_worker_pb2.Message(
|
||||
request=agent_worker_pb2.RpcRequest(
|
||||
request_id=request_id_str,
|
||||
target=agent_worker_pb2.AgentId(name=recipient.type, namespace=recipient.key),
|
||||
source=agent_worker_pb2.AgentId(name=sender.type, namespace=sender.key),
|
||||
data_type=data_type,
|
||||
data=serialized_message,
|
||||
target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key),
|
||||
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key),
|
||||
payload=agent_worker_pb2.Payload(
|
||||
data_type=data_type,
|
||||
data=serialized_message,
|
||||
data_content_type=JSON_DATA_CONTENT_TYPE,
|
||||
),
|
||||
)
|
||||
)
|
||||
# TODO: Find a way to handle timeouts/errors
|
||||
@ -277,10 +284,18 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type)
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(
|
||||
message, type_name=message_type, data_content_type=JSON_DATA_CONTENT_TYPE
|
||||
)
|
||||
runtime_message = agent_worker_pb2.Message(
|
||||
event=agent_worker_pb2.Event(
|
||||
topic_type=topic_id.type, topic_source=topic_id.source, data_type=message_type, data=serialized_message
|
||||
topic_type=topic_id.type,
|
||||
topic_source=topic_id.source,
|
||||
payload=agent_worker_pb2.Payload(
|
||||
data_type=message_type,
|
||||
data=serialized_message,
|
||||
data_content_type=JSON_DATA_CONTENT_TYPE,
|
||||
),
|
||||
)
|
||||
)
|
||||
task = asyncio.create_task(self._host_connection.send(runtime_message))
|
||||
@ -305,13 +320,17 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
|
||||
assert self._host_connection is not None
|
||||
target = AgentId(request.target.name, request.target.namespace)
|
||||
source = AgentId(request.source.name, request.source.namespace)
|
||||
target = AgentId(request.target.type, request.target.key)
|
||||
source = AgentId(request.source.type, request.source.key)
|
||||
|
||||
logging.info(f"Processing request from {source} to {target}")
|
||||
|
||||
# Deserialize the message.
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(request.data, type_name=request.data_type)
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(
|
||||
request.payload.data,
|
||||
type_name=request.payload.data_type,
|
||||
data_content_type=request.payload.data_content_type,
|
||||
)
|
||||
|
||||
# Get the target agent and prepare the message context.
|
||||
target_agent = await self._get_agent(target)
|
||||
@ -339,14 +358,19 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
|
||||
# Serialize the result.
|
||||
result_type = MESSAGE_TYPE_REGISTRY.type_name(result)
|
||||
serialized_result = MESSAGE_TYPE_REGISTRY.serialize(result, type_name=result_type)
|
||||
serialized_result = MESSAGE_TYPE_REGISTRY.serialize(
|
||||
result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE
|
||||
)
|
||||
|
||||
# Create the response message.
|
||||
response_message = agent_worker_pb2.Message(
|
||||
response=agent_worker_pb2.RpcResponse(
|
||||
request_id=request.request_id,
|
||||
result_type=result_type,
|
||||
result=serialized_result,
|
||||
payload=agent_worker_pb2.Payload(
|
||||
data_type=result_type,
|
||||
data=serialized_result,
|
||||
data_content_type=JSON_DATA_CONTENT_TYPE,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@ -355,7 +379,11 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
|
||||
async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None:
|
||||
# Deserialize the result.
|
||||
result = MESSAGE_TYPE_REGISTRY.deserialize(response.result, type_name=response.result_type)
|
||||
result = MESSAGE_TYPE_REGISTRY.deserialize(
|
||||
response.payload.data,
|
||||
type_name=response.payload.data_type,
|
||||
data_content_type=response.payload.data_content_type,
|
||||
)
|
||||
# Get the future and set the result.
|
||||
future = self._pending_requests.pop(response.request_id)
|
||||
if len(response.error) > 0:
|
||||
@ -364,7 +392,9 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
future.set_result(result)
|
||||
|
||||
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.data_type)
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(
|
||||
event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type
|
||||
)
|
||||
topic_id = TopicId(event.topic_type, event.topic_source)
|
||||
# Get the recipients for the topic.
|
||||
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
|
||||
|
||||
@ -24,7 +24,7 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xf8\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x11\n\tdata_type\x18\x05 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\t\x12\x32\n\x08metadata\x18\x07 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xbb\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x13\n\x0bresult_type\x18\x02 \x01(\t\x12\x0e\n\x06result\x18\x03 \x01(\t\x12\r\n\x05\x65rror\x18\x04 \x01(\t\x12\x33\n\x08metadata\x18\x05 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc5\x01\n\x05\x45vent\x12\x11\n\tnamespace\x18\x01 \x01(\t\x12\x12\n\ntopic_type\x18\x02 \x01(\t\x12\x14\n\x0ctopic_source\x18\x03 \x01(\t\x12\x11\n\tdata_type\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12-\n\x08metadata\x18\x06 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\xf9\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb3\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12 \n\x07payload\x18\x03 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x04 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"T\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x42\x0e\n\x0csubscription\"=\n\x0f\x41\x64\x64Subscription\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\xf0\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x12\x32\n\x0f\x61\x64\x64Subscription\x18\x05 \x01(\x0b\x32\x17.agents.AddSubscriptionH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
@ -38,29 +38,31 @@ if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['_EVENT_METADATAENTRY']._loaded_options = None
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_options = b'8\001'
|
||||
_globals['_AGENTID']._serialized_start=30
|
||||
_globals['_AGENTID']._serialized_end=72
|
||||
_globals['_RPCREQUEST']._serialized_start=75
|
||||
_globals['_RPCREQUEST']._serialized_end=323
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=276
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=323
|
||||
_globals['_RPCRESPONSE']._serialized_start=326
|
||||
_globals['_RPCRESPONSE']._serialized_end=513
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=276
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=323
|
||||
_globals['_EVENT']._serialized_start=516
|
||||
_globals['_EVENT']._serialized_end=713
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_start=276
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_end=323
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_start=715
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_end=748
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_start=750
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_end=808
|
||||
_globals['_SUBSCRIPTION']._serialized_start=810
|
||||
_globals['_SUBSCRIPTION']._serialized_end=894
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_start=896
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_end=957
|
||||
_globals['_MESSAGE']._serialized_start=960
|
||||
_globals['_MESSAGE']._serialized_end=1200
|
||||
_globals['_AGENTRPC']._serialized_start=1202
|
||||
_globals['_AGENTRPC']._serialized_end=1265
|
||||
_globals['_AGENTID']._serialized_end=66
|
||||
_globals['_PAYLOAD']._serialized_start=68
|
||||
_globals['_PAYLOAD']._serialized_end=137
|
||||
_globals['_RPCREQUEST']._serialized_start=140
|
||||
_globals['_RPCREQUEST']._serialized_end=389
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=342
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=389
|
||||
_globals['_RPCRESPONSE']._serialized_start=392
|
||||
_globals['_RPCRESPONSE']._serialized_end=576
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=342
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=389
|
||||
_globals['_EVENT']._serialized_start=579
|
||||
_globals['_EVENT']._serialized_end=758
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_start=342
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_end=389
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_start=760
|
||||
_globals['_REGISTERAGENTTYPE']._serialized_end=793
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_start=795
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_end=853
|
||||
_globals['_SUBSCRIPTION']._serialized_start=855
|
||||
_globals['_SUBSCRIPTION']._serialized_end=939
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_start=941
|
||||
_globals['_ADDSUBSCRIPTION']._serialized_end=1002
|
||||
_globals['_MESSAGE']._serialized_start=1005
|
||||
_globals['_MESSAGE']._serialized_end=1245
|
||||
_globals['_AGENTRPC']._serialized_start=1247
|
||||
_globals['_AGENTRPC']._serialized_end=1310
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -16,20 +16,41 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
class AgentId(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
NAME_FIELD_NUMBER: builtins.int
|
||||
NAMESPACE_FIELD_NUMBER: builtins.int
|
||||
name: builtins.str
|
||||
namespace: builtins.str
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
type: builtins.str
|
||||
key: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: builtins.str = ...,
|
||||
namespace: builtins.str = ...,
|
||||
type: builtins.str = ...,
|
||||
key: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["name", b"name", "namespace", b"namespace"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "type", b"type"]) -> None: ...
|
||||
|
||||
global___AgentId = AgentId
|
||||
|
||||
@typing.final
|
||||
class Payload(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
DATA_TYPE_FIELD_NUMBER: builtins.int
|
||||
DATA_CONTENT_TYPE_FIELD_NUMBER: builtins.int
|
||||
DATA_FIELD_NUMBER: builtins.int
|
||||
data_type: builtins.str
|
||||
data_content_type: builtins.str
|
||||
data: builtins.bytes
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
data_type: builtins.str = ...,
|
||||
data_content_type: builtins.str = ...,
|
||||
data: builtins.bytes = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"]) -> None: ...
|
||||
|
||||
global___Payload = Payload
|
||||
|
||||
@typing.final
|
||||
class RpcRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@ -54,18 +75,17 @@ class RpcRequest(google.protobuf.message.Message):
|
||||
SOURCE_FIELD_NUMBER: builtins.int
|
||||
TARGET_FIELD_NUMBER: builtins.int
|
||||
METHOD_FIELD_NUMBER: builtins.int
|
||||
DATA_TYPE_FIELD_NUMBER: builtins.int
|
||||
DATA_FIELD_NUMBER: builtins.int
|
||||
PAYLOAD_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
request_id: builtins.str
|
||||
method: builtins.str
|
||||
data_type: builtins.str
|
||||
data: builtins.str
|
||||
@property
|
||||
def source(self) -> global___AgentId: ...
|
||||
@property
|
||||
def target(self) -> global___AgentId: ...
|
||||
@property
|
||||
def payload(self) -> global___Payload: ...
|
||||
@property
|
||||
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
@ -74,12 +94,11 @@ class RpcRequest(google.protobuf.message.Message):
|
||||
source: global___AgentId | None = ...,
|
||||
target: global___AgentId | None = ...,
|
||||
method: builtins.str = ...,
|
||||
data_type: builtins.str = ...,
|
||||
data: builtins.str = ...,
|
||||
payload: global___Payload | None = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["source", b"source", "target", b"target"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "method", b"method", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
|
||||
|
||||
global___RpcRequest = RpcRequest
|
||||
|
||||
@ -104,26 +123,25 @@ class RpcResponse(google.protobuf.message.Message):
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
REQUEST_ID_FIELD_NUMBER: builtins.int
|
||||
RESULT_TYPE_FIELD_NUMBER: builtins.int
|
||||
RESULT_FIELD_NUMBER: builtins.int
|
||||
PAYLOAD_FIELD_NUMBER: builtins.int
|
||||
ERROR_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
request_id: builtins.str
|
||||
result_type: builtins.str
|
||||
result: builtins.str
|
||||
error: builtins.str
|
||||
@property
|
||||
def payload(self) -> global___Payload: ...
|
||||
@property
|
||||
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_id: builtins.str = ...,
|
||||
result_type: builtins.str = ...,
|
||||
result: builtins.str = ...,
|
||||
payload: global___Payload | None = ...,
|
||||
error: builtins.str = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "request_id", b"request_id", "result", b"result", "result_type", b"result_type"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"]) -> None: ...
|
||||
|
||||
global___RpcResponse = RpcResponse
|
||||
|
||||
@ -147,30 +165,26 @@ class Event(google.protobuf.message.Message):
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
NAMESPACE_FIELD_NUMBER: builtins.int
|
||||
TOPIC_TYPE_FIELD_NUMBER: builtins.int
|
||||
TOPIC_SOURCE_FIELD_NUMBER: builtins.int
|
||||
DATA_TYPE_FIELD_NUMBER: builtins.int
|
||||
DATA_FIELD_NUMBER: builtins.int
|
||||
PAYLOAD_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
namespace: builtins.str
|
||||
topic_type: builtins.str
|
||||
topic_source: builtins.str
|
||||
data_type: builtins.str
|
||||
data: builtins.str
|
||||
@property
|
||||
def payload(self) -> global___Payload: ...
|
||||
@property
|
||||
def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
namespace: builtins.str = ...,
|
||||
topic_type: builtins.str = ...,
|
||||
topic_source: builtins.str = ...,
|
||||
data_type: builtins.str = ...,
|
||||
data: builtins.str = ...,
|
||||
payload: global___Payload | None = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "namespace", b"namespace", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["metadata", b"metadata", "payload", b"payload", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
|
||||
|
||||
global___Event = Event
|
||||
|
||||
|
||||
@ -14,7 +14,14 @@ from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._message_handler_context import MessageHandlerContext
|
||||
from ._serialization import MESSAGE_TYPE_REGISTRY, Serialization, TypeDeserializer, TypeSerializer
|
||||
from ._serialization import (
|
||||
JSON_DATA_CONTENT_TYPE,
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
MessageCodec,
|
||||
Serialization,
|
||||
UnknownPayload,
|
||||
try_get_known_codecs_for_type,
|
||||
)
|
||||
from ._subscription import Subscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._topic import TopicId
|
||||
@ -30,8 +37,6 @@ __all__ = [
|
||||
"AgentChildren",
|
||||
"AgentInstantiationContext",
|
||||
"MESSAGE_TYPE_REGISTRY",
|
||||
"TypeSerializer",
|
||||
"TypeDeserializer",
|
||||
"TopicId",
|
||||
"Subscription",
|
||||
"MessageContext",
|
||||
@ -39,4 +44,8 @@ __all__ = [
|
||||
"AgentType",
|
||||
"SubscriptionInstantiationContext",
|
||||
"MessageHandlerContext",
|
||||
"JSON_DATA_CONTENT_TYPE",
|
||||
"MessageCodec",
|
||||
"try_get_known_codecs_for_type",
|
||||
"UnknownPayload",
|
||||
]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Mapping, Protocol, runtime_checkable
|
||||
from typing import Any, List, Mapping, Protocol, runtime_checkable
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
|
||||
@ -1,9 +1,23 @@
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
from typing import Any, ClassVar, Dict, Protocol, TypeVar, cast, runtime_checkable
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, ClassVar, Dict, List, Protocol, TypeVar, cast, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MessageCodec(Protocol[T]):
|
||||
@property
|
||||
def data_content_type(self) -> str: ...
|
||||
|
||||
@property
|
||||
def type_name(self) -> str: ...
|
||||
|
||||
def deserialize(self, payload: bytes) -> T: ...
|
||||
|
||||
def serialize(self, message: T) -> bytes: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IsDataclass(Protocol):
|
||||
@ -26,53 +40,62 @@ def has_nested_base_model(cls: type[IsDataclass]) -> bool:
|
||||
return any(issubclass(f.type, BaseModel) for f in cls.__dataclass_fields__.values())
|
||||
|
||||
|
||||
T = TypeVar("T", covariant=True)
|
||||
|
||||
|
||||
class TypeDeserializer(Protocol[T]):
|
||||
def deserialize(self, message: str) -> T: ...
|
||||
|
||||
|
||||
U = TypeVar("U", contravariant=True)
|
||||
|
||||
|
||||
class TypeSerializer(Protocol[U]):
|
||||
def serialize(self, message: U) -> str: ...
|
||||
|
||||
|
||||
DataclassT = TypeVar("DataclassT", bound=IsDataclass)
|
||||
|
||||
JSON_DATA_CONTENT_TYPE = "application/json"
|
||||
|
||||
class DataclassTypeDeserializer(TypeDeserializer[DataclassT]):
|
||||
def __init__(self, cls: type[DataclassT]) -> None:
|
||||
|
||||
class DataclassJsonMessageCodec(MessageCodec[IsDataclass]):
|
||||
def __init__(self, cls: type[IsDataclass]) -> None:
|
||||
self.cls = cls
|
||||
|
||||
def deserialize(self, message: str) -> DataclassT:
|
||||
return self.cls(**json.loads(message))
|
||||
@property
|
||||
def data_content_type(self) -> str:
|
||||
return JSON_DATA_CONTENT_TYPE
|
||||
|
||||
@property
|
||||
def type_name(self) -> str:
|
||||
return _type_name(self.cls)
|
||||
|
||||
class DataclassTypeSerializer(TypeSerializer[IsDataclass]):
|
||||
def serialize(self, message: IsDataclass) -> str:
|
||||
def deserialize(self, payload: bytes) -> IsDataclass:
|
||||
message_str = payload.decode("utf-8")
|
||||
return self.cls(**json.loads(message_str))
|
||||
|
||||
def serialize(self, message: IsDataclass) -> bytes:
|
||||
if has_nested_dataclass(type(message)) or has_nested_base_model(type(message)):
|
||||
raise ValueError("Dataclass has nested dataclasses or base models, which are not supported")
|
||||
|
||||
return json.dumps(asdict(message))
|
||||
return json.dumps(asdict(message)).encode("utf-8")
|
||||
|
||||
|
||||
PydanticT = TypeVar("PydanticT", bound=BaseModel)
|
||||
|
||||
|
||||
class PydanticTypeDeserializer(TypeDeserializer[PydanticT]):
|
||||
class PydanticJsonMessageCodec(MessageCodec[PydanticT]):
|
||||
def __init__(self, cls: type[PydanticT]) -> None:
|
||||
self.cls = cls
|
||||
|
||||
def deserialize(self, message: str) -> PydanticT:
|
||||
return self.cls.model_validate_json(message)
|
||||
@property
|
||||
def data_content_type(self) -> str:
|
||||
return JSON_DATA_CONTENT_TYPE
|
||||
|
||||
@property
|
||||
def type_name(self) -> str:
|
||||
return _type_name(self.cls)
|
||||
|
||||
def deserialize(self, payload: bytes) -> PydanticT:
|
||||
message_str = payload.decode("utf-8")
|
||||
return self.cls.model_validate_json(message_str)
|
||||
|
||||
def serialize(self, message: PydanticT) -> bytes:
|
||||
return message.model_dump_json().encode("utf-8")
|
||||
|
||||
|
||||
class PydanticTypeSerializer(TypeSerializer[BaseModel]):
|
||||
def serialize(self, message: BaseModel) -> str:
|
||||
return message.model_dump_json()
|
||||
@dataclass
|
||||
class UnknownPayload:
|
||||
type_name: str
|
||||
data_content_type: str
|
||||
payload: bytes
|
||||
|
||||
|
||||
def _type_name(cls: type[Any] | Any) -> str:
|
||||
@ -85,38 +108,49 @@ def _type_name(cls: type[Any] | Any) -> str:
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
def try_get_known_codecs_for_type(cls: type[Any]) -> list[MessageCodec[Any]]:
|
||||
# TODO: Support protobuf types
|
||||
codecs: List[MessageCodec[Any]] = []
|
||||
if issubclass(cls, BaseModel):
|
||||
codecs.append(PydanticJsonMessageCodec(cls))
|
||||
elif isinstance(cls, IsDataclass):
|
||||
codecs.append(DataclassJsonMessageCodec(cls))
|
||||
|
||||
return codecs
|
||||
|
||||
|
||||
class Serialization:
|
||||
def __init__(self) -> None:
|
||||
self._deserializers: Dict[str, TypeDeserializer[Any]] = {}
|
||||
self._serializers: Dict[str, TypeSerializer[Any]] = {}
|
||||
# type_name, data_content_type -> codec
|
||||
self._codecs: dict[tuple[str, str], MessageCodec[Any]] = {}
|
||||
|
||||
def add_type(self, message_type: type[BaseModel] | type[IsDataclass]) -> None:
|
||||
if issubclass(message_type, BaseModel):
|
||||
self.add_type_custom(
|
||||
_type_name(message_type), PydanticTypeDeserializer(message_type), PydanticTypeSerializer()
|
||||
)
|
||||
elif isinstance(message_type, IsDataclass):
|
||||
self.add_type_custom(
|
||||
_type_name(message_type), DataclassTypeDeserializer(message_type), DataclassTypeSerializer()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {message_type}")
|
||||
def add_codec(self, codec: MessageCodec[Any] | List[MessageCodec[Any]]) -> None:
|
||||
if isinstance(codec, list):
|
||||
for c in codec:
|
||||
self.add_codec(c)
|
||||
return
|
||||
|
||||
def add_type_custom(self, type_name: str, deserializer: TypeDeserializer[V], serializer: TypeSerializer[V]) -> None:
|
||||
self._deserializers[type_name] = deserializer
|
||||
self._serializers[type_name] = serializer
|
||||
self._codecs[(codec.type_name, codec.data_content_type)] = codec
|
||||
|
||||
def deserialize(self, message: str, *, type_name: str) -> Any:
|
||||
return self._deserializers[type_name].deserialize(message)
|
||||
def deserialize(self, payload: bytes, *, type_name: str, data_content_type: str) -> Any:
|
||||
codec = self._codecs.get((type_name, data_content_type))
|
||||
if codec is None:
|
||||
return UnknownPayload(type_name, data_content_type, payload)
|
||||
|
||||
return codec.deserialize(payload)
|
||||
|
||||
def serialize(self, message: Any, *, type_name: str, data_content_type: str) -> bytes:
|
||||
codec = self._codecs.get((type_name, data_content_type))
|
||||
if codec is None:
|
||||
raise ValueError(f"Unknown type {type_name} with content type {data_content_type}")
|
||||
|
||||
return codec.serialize(message)
|
||||
|
||||
def is_registered(self, type_name: str, data_content_type: str) -> bool:
|
||||
return (type_name, data_content_type) in self._codecs
|
||||
|
||||
def type_name(self, message: Any) -> str:
|
||||
return _type_name(message)
|
||||
|
||||
def serialize(self, message: Any, *, type_name: str) -> str:
|
||||
return self._serializers[type_name].serialize(message)
|
||||
|
||||
def is_registered(self, type_name: str) -> bool:
|
||||
return type_name in self._deserializers
|
||||
|
||||
|
||||
MESSAGE_TYPE_REGISTRY = Serialization()
|
||||
|
||||
@ -8,7 +8,7 @@ from ..base._agent_id import AgentId
|
||||
from ..base._agent_instantiation import AgentInstantiationContext
|
||||
from ..base._agent_metadata import AgentMetadata
|
||||
from ..base._agent_runtime import AgentRuntime
|
||||
from ..base._serialization import MESSAGE_TYPE_REGISTRY
|
||||
from ..base._serialization import JSON_DATA_CONTENT_TYPE, MESSAGE_TYPE_REGISTRY, try_get_known_codecs_for_type
|
||||
from ..base.exceptions import CantHandleException
|
||||
from ._type_helpers import get_types
|
||||
|
||||
@ -59,9 +59,10 @@ class ClosureAgent(Agent):
|
||||
self._id: AgentId = id
|
||||
self._description = description
|
||||
subscription_types = get_subscriptions_from_closure(closure)
|
||||
# TODO fold this into runtime
|
||||
for message_type in subscription_types:
|
||||
if not MESSAGE_TYPE_REGISTRY.is_registered(MESSAGE_TYPE_REGISTRY.type_name(message_type)):
|
||||
MESSAGE_TYPE_REGISTRY.add_type(message_type)
|
||||
MESSAGE_TYPE_REGISTRY.add_codec(try_get_known_codecs_for_type(message_type))
|
||||
|
||||
self._subscriptions = [MESSAGE_TYPE_REGISTRY.type_name(message_type) for message_type in subscription_types]
|
||||
self._closure = closure
|
||||
|
||||
|
||||
@ -17,6 +17,8 @@ from typing import (
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from autogen_core.base import try_get_known_codecs_for_type
|
||||
|
||||
from ..base import MESSAGE_TYPE_REGISTRY, BaseAgent, MessageContext
|
||||
from ..base.exceptions import CantHandleException
|
||||
from ._type_helpers import AnyType, get_types
|
||||
@ -144,8 +146,8 @@ class RoutedAgent(BaseAgent):
|
||||
self._handlers[target_type] = message_handler
|
||||
|
||||
for message_type in self._handlers.keys():
|
||||
if not MESSAGE_TYPE_REGISTRY.is_registered(MESSAGE_TYPE_REGISTRY.type_name(message_type)):
|
||||
MESSAGE_TYPE_REGISTRY.add_type(message_type)
|
||||
for codec in try_get_known_codecs_for_type(message_type):
|
||||
MESSAGE_TYPE_REGISTRY.add_codec(codec)
|
||||
|
||||
super().__init__(description)
|
||||
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
#custom type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from autogen_core.base import Serialization
|
||||
from autogen_core.base import JSON_DATA_CONTENT_TYPE, MessageCodec, try_get_known_codecs_for_type
|
||||
|
||||
class PydanticMessage(BaseModel):
|
||||
message: str
|
||||
@ -30,77 +29,86 @@ class NestingPydanticDataclassMessage:
|
||||
|
||||
def test_pydantic() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_type(PydanticMessage)
|
||||
serde.add_codec(try_get_known_codecs_for_type(PydanticMessage))
|
||||
|
||||
message = PydanticMessage(message="hello")
|
||||
name = serde.type_name(message)
|
||||
json = serde.serialize(message, type_name=name)
|
||||
json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
||||
assert name == "PydanticMessage"
|
||||
assert json == '{"message":"hello"}'
|
||||
deserialized = serde.deserialize(json, type_name=name)
|
||||
assert json == b'{"message":"hello"}'
|
||||
deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
||||
assert deserialized == message
|
||||
|
||||
def test_nested_pydantic() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_type(NestingPydanticMessage)
|
||||
serde.add_codec(try_get_known_codecs_for_type(NestingPydanticMessage))
|
||||
|
||||
message = NestingPydanticMessage(message="hello", nested=PydanticMessage(message="world"))
|
||||
name = serde.type_name(message)
|
||||
json = serde.serialize(message, type_name=name)
|
||||
assert json == '{"message":"hello","nested":{"message":"world"}}'
|
||||
deserialized = serde.deserialize(json, type_name=name)
|
||||
json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
||||
assert json == b'{"message":"hello","nested":{"message":"world"}}'
|
||||
deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
||||
assert deserialized == message
|
||||
|
||||
def test_dataclass() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_type(DataclassMessage)
|
||||
serde.add_codec(try_get_known_codecs_for_type(DataclassMessage))
|
||||
|
||||
message = DataclassMessage(message="hello")
|
||||
name = serde.type_name(message)
|
||||
json = serde.serialize(message, type_name=name)
|
||||
assert json == '{"message": "hello"}'
|
||||
deserialized = serde.deserialize(json, type_name=name)
|
||||
json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
||||
assert json == b'{"message": "hello"}'
|
||||
deserialized = serde.deserialize(json, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
||||
assert deserialized == message
|
||||
|
||||
def test_nesting_dataclass_dataclass() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_type(NestingDataclassMessage)
|
||||
serde.add_codec(try_get_known_codecs_for_type(NestingDataclassMessage))
|
||||
|
||||
message = NestingDataclassMessage(message="hello", nested=DataclassMessage(message="world"))
|
||||
name = serde.type_name(message)
|
||||
with pytest.raises(ValueError):
|
||||
_json = serde.serialize(message, type_name=name)
|
||||
_json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
||||
|
||||
def test_nesting_dataclass_pydantic() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_type(NestingPydanticDataclassMessage)
|
||||
serde.add_codec(try_get_known_codecs_for_type(NestingPydanticDataclassMessage))
|
||||
|
||||
message = NestingPydanticDataclassMessage(message="hello", nested=PydanticMessage(message="world"))
|
||||
name = serde.type_name(message)
|
||||
with pytest.raises(ValueError):
|
||||
_json = serde.serialize(message, type_name=name)
|
||||
_json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
||||
|
||||
def test_invalid_type() -> None:
|
||||
serde = Serialization()
|
||||
try:
|
||||
serde.add_type(str) # type: ignore
|
||||
serde.add_codec(try_get_known_codecs_for_type(str))
|
||||
except ValueError as e:
|
||||
assert str(e) == "Unsupported type <class 'str'>"
|
||||
|
||||
def test_custom_type() -> None:
|
||||
serde = Serialization()
|
||||
|
||||
class CustomStringTypeDeserializer:
|
||||
def deserialize(self, message: str) -> str:
|
||||
class CustomStringTypeCodec(MessageCodec[str]):
|
||||
@property
|
||||
def data_content_type(self) -> str:
|
||||
return "str"
|
||||
|
||||
@property
|
||||
def type_name(self) -> str:
|
||||
return "custom_str"
|
||||
|
||||
def deserialize(self, payload: bytes) -> str:
|
||||
message = payload.decode("utf-8")
|
||||
return message[1:-1]
|
||||
|
||||
class CustomStringTypeSerializer:
|
||||
def serialize(self, message: str) -> str:
|
||||
return f'"{message}"'
|
||||
def serialize(self, message: str) -> bytes:
|
||||
return f'"{message}"'.encode("utf-8")
|
||||
|
||||
serde.add_type_custom("custom_str", CustomStringTypeDeserializer(), CustomStringTypeSerializer())
|
||||
|
||||
serde.add_codec(CustomStringTypeCodec())
|
||||
message = "hello"
|
||||
json = serde.serialize(message, type_name="custom_str")
|
||||
assert json == '"hello"'
|
||||
deserialized = serde.deserialize(json, type_name="custom_str")
|
||||
assert deserialized == message
|
||||
json = serde.serialize(message, type_name="custom_str", data_content_type="str")
|
||||
assert json == b'"hello"'
|
||||
deserialized = serde.deserialize(json, type_name="custom_str", data_content_type="str")
|
||||
assert deserialized == message
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user