Dotnet Grpc worker implementation (#5245)

Co-authored-by: Jacob Alber <jaalber@microsoft.com>
Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
This commit is contained in:
Jack Gerrits 2025-02-05 11:34:02 -05:00 committed by GitHub
parent 9030f75b4d
commit 08f9830bf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 1878 additions and 286 deletions

View File

@ -107,6 +107,38 @@ jobs:
- name: Unit Test V2
run: dotnet test --no-build -bl --configuration Release --filter "Category=UnitV2"
grpc-unit-tests:
name: Dotnet Grpc unit tests
needs: paths-filter
if: needs.paths-filter.outputs.hasChanges == 'true'
defaults:
run:
working-directory: dotnet
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
runs-on: ${{ matrix.os }}
timeout-minutes: 30
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Setup .NET 8.0
uses: actions/setup-dotnet@v4
with:
dotnet-version: '8.0.x'
- name: Install dev certs
run: dotnet --version && dotnet dev-certs https --trust
- name: Restore dependencies
run: |
# dotnet nuget add source --name dotnet-tool https://pkgs.dev.azure.com/dnceng/public/_packaging/dotnet-tools/nuget/v3/index.json --configfile NuGet.config
dotnet restore -bl
- name: Build
run: dotnet build --no-restore --configuration Release -bl /p:SignAssembly=true
- name: GRPC tests
run: dotnet test --no-build -bl --configuration Release --filter "Category=GRPC"
integration-test:
strategy:
fail-fast: true
@ -224,6 +256,8 @@ jobs:
with:
dotnet-version: '8.0.x'
global-json-file: dotnet/global.json
- name: Install dev certs
run: dotnet --version && dotnet dev-certs https --trust
- name: Restore dependencies
run: |
dotnet restore -bl

View File

@ -52,6 +52,7 @@ jobs:
run: |
echo "Build AutoGen"
dotnet build --no-restore --configuration Release -bl /p:SignAssembly=true
- run: sudo dotnet dev-certs https --trust --no-password
- name: Unit Test
run: dotnet test --no-build -bl --configuration Release
env:

View File

@ -118,6 +118,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Hello", "Hello", "{F42F9C8E
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Core.Grpc", "src\Microsoft.AutoGen\Core.Grpc\Microsoft.AutoGen.Core.Grpc.csproj", "{3D83C6DB-ACEA-48F3-959F-145CCD2EE135}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GettingStartedGrpc", "samples\GettingStartedGrpc\GettingStartedGrpc.csproj", "{C3740DF1-18B1-4607-81E4-302F0308C848}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Core.Grpc.Tests", "test\Microsoft.AutoGen.Core.Grpc.Tests\Microsoft.AutoGen.Core.Grpc.Tests.csproj", "{23A028D3-5EB1-4FA0-9CD1-A1340B830579}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -306,6 +310,14 @@ Global
{AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Debug|Any CPU.Build.0 = Debug|Any CPU
{AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Release|Any CPU.ActiveCfg = Release|Any CPU
{AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Release|Any CPU.Build.0 = Release|Any CPU
{C3740DF1-18B1-4607-81E4-302F0308C848}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{C3740DF1-18B1-4607-81E4-302F0308C848}.Debug|Any CPU.Build.0 = Debug|Any CPU
{C3740DF1-18B1-4607-81E4-302F0308C848}.Release|Any CPU.ActiveCfg = Release|Any CPU
{C3740DF1-18B1-4607-81E4-302F0308C848}.Release|Any CPU.Build.0 = Release|Any CPU
{23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Debug|Any CPU.Build.0 = Debug|Any CPU
{23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Release|Any CPU.ActiveCfg = Release|Any CPU
{23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@ -359,6 +371,8 @@ Global
{3D83C6DB-ACEA-48F3-959F-145CCD2EE135} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
{AAD593FE-A49B-425E-A9FE-A0022CD25E3D} = {F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1}
{F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6}
{C3740DF1-18B1-4607-81E4-302F0308C848} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6}
{23A028D3-5EB1-4FA0-9CD1-A1340B830579} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B}

View File

@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Checker.cs
using Microsoft.AutoGen.Contracts;
using Microsoft.AutoGen.Core;
using Microsoft.Extensions.Hosting;
using TerminationF = System.Func<int, bool>;
namespace GettingStartedGrpcSample;
[TypeSubscription("default")]
public class Checker(
AgentId id,
IAgentRuntime runtime,
IHostApplicationLifetime hostApplicationLifetime,
TerminationF runUntilFunc
) :
BaseAgent(id, runtime, "Modifier", null),
IHandle<Events.CountUpdate>
{
public async ValueTask HandleAsync(Events.CountUpdate item, MessageContext messageContext)
{
if (!runUntilFunc(item.NewCount))
{
Console.WriteLine($"\nChecker:\n{item.NewCount} passed the check, continue.");
await this.PublishMessageAsync(new Events.CountMessage { Content = item.NewCount }, new TopicId("default"));
}
else
{
Console.WriteLine($"\nChecker:\n{item.NewCount} failed the check, stopping.");
hostApplicationLifetime.StopApplication();
}
}
}

View File

@ -0,0 +1,26 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<RootNamespace>getting_started</RootNamespace>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Contracts\Microsoft.AutoGen.Contracts.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Core\Microsoft.AutoGen.Core.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Core.Grpc\Microsoft.AutoGen.Core.Grpc.csproj" />
</ItemGroup>
<ItemGroup>
<Protobuf Include="message.proto" GrpcServices="Client" Link="Protos\message.proto" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Grpc.Tools" PrivateAssets="All" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Modifier.cs
using Microsoft.AutoGen.Contracts;
using Microsoft.AutoGen.Core;
using ModifyF = System.Func<int, int>;
namespace GettingStartedGrpcSample;
[TypeSubscription("default")]
public class Modifier(
AgentId id,
IAgentRuntime runtime,
ModifyF modifyFunc
) :
BaseAgent(id, runtime, "Modifier", null),
IHandle<Events.CountMessage>
{
public async ValueTask HandleAsync(Events.CountMessage item, MessageContext messageContext)
{
int newValue = modifyFunc(item.Content);
Console.WriteLine($"\nModifier:\nModified {item.Content} to {newValue}");
var updateMessage = new Events.CountUpdate { NewCount = newValue };
await this.PublishMessageAsync(updateMessage, topic: new TopicId("default"));
}
}

View File

@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs
using GettingStartedGrpcSample;
using Microsoft.AutoGen.Contracts;
using Microsoft.AutoGen.Core;
using Microsoft.AutoGen.Core.Grpc;
using Microsoft.Extensions.DependencyInjection.Extensions;
using ModifyF = System.Func<int, int>;
using TerminationF = System.Func<int, bool>;
ModifyF modifyFunc = (int x) => x - 1;
TerminationF runUntilFunc = (int x) =>
{
return x <= 1;
};
AgentsAppBuilder appBuilder = new AgentsAppBuilder();
appBuilder.AddGrpcAgentWorker("http://localhost:50051");
appBuilder.Services.TryAddSingleton(modifyFunc);
appBuilder.Services.TryAddSingleton(runUntilFunc);
appBuilder.AddAgent<Checker>("Checker");
appBuilder.AddAgent<Modifier>("Modifier");
var app = await appBuilder.BuildAsync();
await app.StartAsync();
// Send the initial count to the agents app, running on the `local` runtime, and pass through the registered services via the application `builder`
await app.PublishMessageAsync(new GettingStartedGrpcSample.Events.CountMessage
{
Content = 10
}, new TopicId("default"));
// Run until application shutdown
await app.WaitForShutdownAsync();

View File

@ -0,0 +1,11 @@
syntax = "proto3";
option csharp_namespace = "GettingStartedGrpcSample.Events";
message CountMessage {
int32 content = 1;
}
message CountUpdate {
int32 new_count = 1;
}

View File

@ -0,0 +1,76 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentsAppBuilderExtensions.cs
using System.Diagnostics;
using Grpc.Core;
using Grpc.Net.Client.Configuration;
using Microsoft.AutoGen.Contracts;
using Microsoft.AutoGen.Protobuf;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Logging;
namespace Microsoft.AutoGen.Core.Grpc;
public static class AgentsAppBuilderExtensions
{
private const string _defaultAgentServiceAddress = "http://localhost:53071";
// TODO: How do we ensure AddGrpcAgentWorker and UseInProcessRuntime are mutually exclusive?
public static AgentsAppBuilder AddGrpcAgentWorker(this AgentsAppBuilder builder, string? agentServiceAddress = null)
{
builder.Services.AddGrpcClient<AgentRpc.AgentRpcClient>(options =>
{
options.Address = new Uri(agentServiceAddress ?? builder.Configuration.GetValue("AGENT_HOST", _defaultAgentServiceAddress));
options.ChannelOptionsActions.Add(channelOptions =>
{
var loggerFactory = new LoggerFactory();
if (Debugger.IsAttached)
{
channelOptions.HttpHandler = new SocketsHttpHandler
{
EnableMultipleHttp2Connections = false,
KeepAlivePingDelay = TimeSpan.FromSeconds(200),
KeepAlivePingTimeout = TimeSpan.FromSeconds(100),
KeepAlivePingPolicy = HttpKeepAlivePingPolicy.Always
};
}
else
{
channelOptions.HttpHandler = new SocketsHttpHandler
{
EnableMultipleHttp2Connections = true,
KeepAlivePingDelay = TimeSpan.FromSeconds(20),
KeepAlivePingTimeout = TimeSpan.FromSeconds(10),
KeepAlivePingPolicy = HttpKeepAlivePingPolicy.WithActiveRequests
};
}
var methodConfig = new MethodConfig
{
Names = { MethodName.Default },
RetryPolicy = new RetryPolicy
{
MaxAttempts = 5,
InitialBackoff = TimeSpan.FromSeconds(1),
MaxBackoff = TimeSpan.FromSeconds(5),
BackoffMultiplier = 1.5,
RetryableStatusCodes = { StatusCode.Unavailable }
}
};
channelOptions.ServiceConfig = new() { MethodConfigs = { methodConfig } };
channelOptions.ThrowOperationCanceledOnCancellation = true;
});
});
builder.Services.TryAddSingleton(DistributedContextPropagator.Current);
builder.Services.AddSingleton<IAgentRuntime, GrpcAgentRuntime>();
builder.Services.AddHostedService<GrpcAgentRuntime>(services =>
{
return (services.GetRequiredService<IAgentRuntime>() as GrpcAgentRuntime)!;
});
return builder;
}
}

View File

@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// CloudEventExtensions.cs
using Microsoft.AutoGen.Contracts;
namespace Microsoft.AutoGen.Core.Grpc;
internal static class CloudEventExtensions
{
// Convert an ISubscrptionDefinition to a Protobuf Subscription
internal static CloudEvent CreateCloudEvent(Google.Protobuf.WellKnownTypes.Any payload, TopicId topic, string dataType, AgentId? sender, string messageId)
{
var attributes = new Dictionary<string, CloudEvent.Types.CloudEventAttributeValue>
{
{
Constants.DATA_CONTENT_TYPE_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = Constants.DATA_CONTENT_TYPE_PROTOBUF_VALUE }
},
{
Constants.DATA_SCHEMA_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = dataType }
},
{
Constants.MESSAGE_KIND_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = Constants.MESSAGE_KIND_VALUE_PUBLISH }
}
};
if (sender != null)
{
var senderNonNull = (AgentId)sender;
attributes.Add(Constants.AGENT_SENDER_TYPE_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = senderNonNull.Type });
attributes.Add(Constants.AGENT_SENDER_KEY_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = senderNonNull.Key });
}
return new CloudEvent
{
ProtoData = payload,
Type = topic.Type,
Source = topic.Source,
Id = messageId,
Attributes = { attributes }
};
}
}

View File

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Constants.cs
namespace Microsoft.AutoGen.Core.Grpc;
public static class Constants
{
public const string DATA_CONTENT_TYPE_PROTOBUF_VALUE = "application/x-protobuf";
public const string DATA_CONTENT_TYPE_JSON_VALUE = "application/json";
public const string DATA_CONTENT_TYPE_TEXT_VALUE = "text/plain";
public const string DATA_CONTENT_TYPE_ATTR = "datacontenttype";
public const string DATA_SCHEMA_ATTR = "dataschema";
public const string AGENT_SENDER_TYPE_ATTR = "agagentsendertype";
public const string AGENT_SENDER_KEY_ATTR = "agagentsenderkey";
public const string MESSAGE_KIND_ATTR = "agmsgkind";
public const string MESSAGE_KIND_VALUE_PUBLISH = "publish";
public const string MESSAGE_KIND_VALUE_RPC_REQUEST = "rpc_request";
public const string MESSAGE_KIND_VALUE_RPC_RESPONSE = "rpc_response";
}

View File

@ -0,0 +1,430 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GrpcAgentRuntime.cs
using System.Collections.Concurrent;
using Grpc.Core;
using Microsoft.AutoGen.Contracts;
using Microsoft.AutoGen.Protobuf;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
namespace Microsoft.AutoGen.Core.Grpc;
internal sealed class AgentsContainer(IAgentRuntime hostingRuntime)
{
private readonly IAgentRuntime hostingRuntime = hostingRuntime;
private Dictionary<Contracts.AgentId, IHostableAgent> agentInstances = new();
public Dictionary<string, ISubscriptionDefinition> Subscriptions = new();
private Dictionary<AgentType, Func<Contracts.AgentId, IAgentRuntime, ValueTask<IHostableAgent>>> agentFactories = new();
public async ValueTask<IHostableAgent> EnsureAgentAsync(Contracts.AgentId agentId)
{
if (!this.agentInstances.TryGetValue(agentId, out IHostableAgent? agent))
{
if (!this.agentFactories.TryGetValue(agentId.Type, out Func<Contracts.AgentId, IAgentRuntime, ValueTask<IHostableAgent>>? factoryFunc))
{
throw new Exception($"Agent with name {agentId.Type} not found.");
}
agent = await factoryFunc(agentId, this.hostingRuntime);
this.agentInstances.Add(agentId, agent);
}
return this.agentInstances[agentId];
}
public async ValueTask<Contracts.AgentId> GetAgentAsync(Contracts.AgentId agentId, bool lazy = true)
{
if (!lazy)
{
await this.EnsureAgentAsync(agentId);
}
return agentId;
}
public AgentType RegisterAgentFactory(AgentType type, Func<Contracts.AgentId, IAgentRuntime, ValueTask<IHostableAgent>> factoryFunc)
{
if (this.agentFactories.ContainsKey(type))
{
throw new Exception($"Agent factory with type {type} already exists.");
}
this.agentFactories.Add(type, factoryFunc);
return type;
}
public void AddSubscription(ISubscriptionDefinition subscription)
{
if (this.Subscriptions.ContainsKey(subscription.Id))
{
throw new Exception($"Subscription with id {subscription.Id} already exists.");
}
this.Subscriptions.Add(subscription.Id, subscription);
}
public bool RemoveSubscriptionAsync(string subscriptionId)
{
if (!this.Subscriptions.ContainsKey(subscriptionId))
{
throw new Exception($"Subscription with id {subscriptionId} does not exist.");
}
return this.Subscriptions.Remove(subscriptionId);
}
public HashSet<AgentType> RegisteredAgentTypes => this.agentFactories.Keys.ToHashSet();
public IEnumerable<IHostableAgent> LiveAgents => this.agentInstances.Values;
}
public sealed class GrpcAgentRuntime : IHostedService, IAgentRuntime, IMessageSink<Message>, IDisposable
{
public GrpcAgentRuntime(AgentRpc.AgentRpcClient client,
IHostApplicationLifetime hostApplicationLifetime,
IServiceProvider serviceProvider,
ILogger<GrpcAgentRuntime> logger)
{
this._client = client;
this._logger = logger;
this._shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping);
this._messageRouter = new GrpcMessageRouter(client, this, _clientId, logger, this._shutdownCts.Token);
this._agentsContainer = new AgentsContainer(this);
this.ServiceProvider = serviceProvider;
}
// Request ID -> ResultSink<...>
private readonly ConcurrentDictionary<string, ResultSink<object?>> _pendingRequests = new();
private readonly AgentRpc.AgentRpcClient _client;
private readonly GrpcMessageRouter _messageRouter;
private readonly ILogger<GrpcAgentRuntime> _logger;
private readonly CancellationTokenSource _shutdownCts;
private readonly AgentsContainer _agentsContainer;
public IServiceProvider ServiceProvider { get; }
private Guid _clientId = Guid.NewGuid();
private CallOptions CallOptions
{
get
{
var metadata = new Metadata
{
{ "client-id", this._clientId.ToString() }
};
return new CallOptions(headers: metadata);
}
}
public IProtoSerializationRegistry SerializationRegistry { get; } = new ProtobufSerializationRegistry();
public void Dispose()
{
this._shutdownCts.Cancel();
this._messageRouter.Dispose();
}
private async ValueTask HandleRequest(RpcRequest request, CancellationToken cancellationToken = default)
{
if (request is null)
{
throw new InvalidOperationException("Request is null.");
}
if (request.Payload is null)
{
throw new InvalidOperationException("Payload is null.");
}
if (request.Target is null)
{
throw new InvalidOperationException("Target is null.");
}
var agentId = request.Target;
var agent = await this._agentsContainer.EnsureAgentAsync(agentId.FromProtobuf());
// Convert payload back to object
var payload = request.Payload;
var message = payload.ToObject(SerializationRegistry);
var messageContext = new MessageContext(request.RequestId, cancellationToken)
{
Sender = request.Source?.FromProtobuf() ?? null,
Topic = null,
IsRpc = true
};
var result = await agent.OnMessageAsync(message, messageContext);
if (result is not null)
{
var response = new RpcResponse
{
RequestId = request.RequestId,
Payload = result.ToPayload(SerializationRegistry)
};
var responseMessage = new Message
{
Response = response
};
await this._messageRouter.RouteMessageAsync(responseMessage, cancellationToken);
}
}
private async ValueTask HandleResponse(RpcResponse request, CancellationToken _ = default)
{
if (request is null)
{
throw new InvalidOperationException("Request is null.");
}
if (request.Payload is null)
{
throw new InvalidOperationException("Payload is null.");
}
if (request.RequestId is null)
{
throw new InvalidOperationException("RequestId is null.");
}
if (_pendingRequests.TryRemove(request.RequestId, out var resultSink))
{
var payload = request.Payload;
var message = payload.ToObject(SerializationRegistry);
resultSink.SetResult(message);
}
}
private async ValueTask HandlePublish(CloudEvent evt, CancellationToken cancellationToken = default)
{
if (evt is null)
{
throw new InvalidOperationException("CloudEvent is null.");
}
if (evt.ProtoData is null)
{
throw new InvalidOperationException("ProtoData is null.");
}
if (evt.Attributes is null)
{
throw new InvalidOperationException("Attributes is null.");
}
var topic = new TopicId(evt.Type, evt.Source);
Contracts.AgentId? sender = null;
if (evt.Attributes.TryGetValue(Constants.AGENT_SENDER_TYPE_ATTR, out var typeValue) && evt.Attributes.TryGetValue(Constants.AGENT_SENDER_KEY_ATTR, out var keyValue))
{
sender = new Contracts.AgentId
{
Type = typeValue.CeString,
Key = keyValue.CeString
};
}
var messageId = evt.Id;
var typeName = evt.Attributes[Constants.DATA_SCHEMA_ATTR].CeString;
var serializer = SerializationRegistry.GetSerializer(typeName) ?? throw new Exception();
var message = serializer.Deserialize(evt.ProtoData);
var messageContext = new MessageContext(messageId, cancellationToken)
{
Sender = sender,
Topic = topic,
IsRpc = false
};
// Iterate over subscriptions values to find receiving agents
foreach (var subscription in this._agentsContainer.Subscriptions.Values)
{
if (subscription.Matches(topic))
{
var recipient = subscription.MapToAgent(topic);
var agent = await this._agentsContainer.EnsureAgentAsync(recipient);
await agent.OnMessageAsync(message, messageContext);
}
}
}
public ValueTask StartAsync(CancellationToken cancellationToken)
{
return this._messageRouter.StartAsync(cancellationToken);
}
Task IHostedService.StartAsync(CancellationToken cancellationToken) => this._messageRouter.StartAsync(cancellationToken).AsTask();
public Task StopAsync(CancellationToken cancellationToken)
{
return this._messageRouter.StopAsync();
}
public async ValueTask<object?> SendMessageAsync(object message, Contracts.AgentId recepient, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default)
{
if (!SerializationRegistry.Exists(message.GetType()))
{
SerializationRegistry.RegisterSerializer(message.GetType());
}
var payload = message.ToPayload(SerializationRegistry);
var request = new RpcRequest
{
RequestId = Guid.NewGuid().ToString(),
Source = sender?.ToProtobuf() ?? null,
Target = recepient.ToProtobuf(),
Payload = payload,
};
Message msg = new()
{
Request = request
};
// Create a future that will be completed when the response is received
var resultSink = new ResultSink<object?>();
this._pendingRequests.TryAdd(request.RequestId, resultSink);
await this._messageRouter.RouteMessageAsync(msg, cancellationToken);
return await resultSink.Future;
}
public async ValueTask PublishMessageAsync(object message, TopicId topic, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default)
{
if (!SerializationRegistry.Exists(message.GetType()))
{
SerializationRegistry.RegisterSerializer(message.GetType());
}
var protoAny = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message);
var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message.GetType());
var cloudEvent = CloudEventExtensions.CreateCloudEvent(protoAny, topic, typeName, sender, messageId ?? Guid.NewGuid().ToString());
Message msg = new()
{
CloudEvent = cloudEvent
};
await this._messageRouter.RouteMessageAsync(msg, cancellationToken);
}
public ValueTask<Contracts.AgentId> GetAgentAsync(Contracts.AgentId agentId, bool lazy = true) => this._agentsContainer.GetAgentAsync(agentId, lazy);
public ValueTask<Contracts.AgentId> GetAgentAsync(AgentType agentType, string key = "default", bool lazy = true)
=> this.GetAgentAsync(new Contracts.AgentId(agentType, key), lazy);
public ValueTask<Contracts.AgentId> GetAgentAsync(string agent, string key = "default", bool lazy = true)
=> this.GetAgentAsync(new Contracts.AgentId(agent, key), lazy);
public async ValueTask<IDictionary<string, object>> SaveAgentStateAsync(Contracts.AgentId agentId)
{
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
return await agent.SaveStateAsync();
}
public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary<string, object> state)
{
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
await agent.LoadStateAsync(state);
}
public async ValueTask<AgentMetadata> GetAgentMetadataAsync(Contracts.AgentId agentId)
{
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
return agent.Metadata;
}
public async ValueTask AddSubscriptionAsync(ISubscriptionDefinition subscription)
{
this._agentsContainer.AddSubscription(subscription);
var _ = await this._client.AddSubscriptionAsync(new AddSubscriptionRequest
{
Subscription = subscription.ToProtobuf()
}, this.CallOptions);
}
public async ValueTask RemoveSubscriptionAsync(string subscriptionId)
{
this._agentsContainer.RemoveSubscriptionAsync(subscriptionId);
await this._client.RemoveSubscriptionAsync(new RemoveSubscriptionRequest
{
Id = subscriptionId
}, this.CallOptions);
}
public async ValueTask<AgentType> RegisterAgentFactoryAsync(AgentType type, Func<Contracts.AgentId, IAgentRuntime, ValueTask<IHostableAgent>> factoryFunc)
{
this._agentsContainer.RegisterAgentFactory(type, factoryFunc);
await this._client.RegisterAgentAsync(new RegisterAgentTypeRequest
{
Type = type,
}, this.CallOptions);
return type;
}
public ValueTask<AgentProxy> TryGetAgentProxyAsync(Contracts.AgentId agentId)
{
// TODO: Do we want to support getting remote agent proxies?
return ValueTask.FromResult(new AgentProxy(agentId, this));
}
public async ValueTask<IDictionary<string, object>> SaveStateAsync()
{
Dictionary<string, object> state = new();
foreach (var agent in this._agentsContainer.LiveAgents)
{
state[agent.Id.ToString()] = await agent.SaveStateAsync();
}
return state;
}
public async ValueTask LoadStateAsync(IDictionary<string, object> state)
{
HashSet<AgentType> registeredTypes = this._agentsContainer.RegisteredAgentTypes;
foreach (var agentIdStr in state.Keys)
{
Contracts.AgentId agentId = Contracts.AgentId.FromStr(agentIdStr);
if (state[agentIdStr] is not IDictionary<string, object> agentStateDict)
{
throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary<string, object>)}: {state[agentIdStr].GetType()}");
}
if (registeredTypes.Contains(agentId.Type))
{
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
await agent.LoadStateAsync(agentStateDict);
}
}
}
public async ValueTask OnMessageAsync(Message message, CancellationToken cancellation = default)
{
switch (message.MessageCase)
{
case Message.MessageOneofCase.Request:
var request = message.Request ?? throw new InvalidOperationException("Request is null.");
await HandleRequest(request);
break;
case Message.MessageOneofCase.Response:
var response = message.Response ?? throw new InvalidOperationException("Response is null.");
await HandleResponse(response);
break;
case Message.MessageOneofCase.CloudEvent:
var cloudEvent = message.CloudEvent ?? throw new InvalidOperationException("CloudEvent is null.");
await HandlePublish(cloudEvent);
break;
default:
throw new InvalidOperationException($"Unexpected message '{message}'.");
}
}
}

View File

@ -0,0 +1,296 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GrpcMessageRouter.cs
using System.Threading.Channels;
using Grpc.Core;
using Microsoft.AutoGen.Protobuf;
using Microsoft.Extensions.Logging;
namespace Microsoft.AutoGen.Core.Grpc;
// TODO: Consider whether we want to just reuse IHandle
internal interface IMessageSink<TMessage>
{
public ValueTask OnMessageAsync(TMessage message, CancellationToken cancellation = default);
}
internal sealed class AutoRestartChannel : IDisposable
{
private readonly object _channelLock = new();
private readonly AgentRpc.AgentRpcClient _client;
private readonly Guid _clientId;
private readonly ILogger<GrpcAgentRuntime> _logger;
private readonly CancellationTokenSource _shutdownCts;
private AsyncDuplexStreamingCall<Message, Message>? _channel;
public AutoRestartChannel(AgentRpc.AgentRpcClient client,
Guid clientId,
ILogger<GrpcAgentRuntime> logger,
CancellationToken shutdownCancellation = default)
{
_client = client;
_clientId = clientId;
_logger = logger;
_shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation);
}
public void EnsureConnected()
{
_logger.LogInformation("Connecting to gRPC endpoint " + Environment.GetEnvironmentVariable("AGENT_HOST"));
if (this.RecreateChannel(null) == null)
{
throw new Exception("Failed to connect to gRPC endpoint.");
};
}
public AsyncDuplexStreamingCall<Message, Message> StreamingCall
{
get
{
if (_channel is { } channel)
{
return channel;
}
lock (_channelLock)
{
if (_channel is not null)
{
return _channel;
}
return RecreateChannel(null);
}
}
}
public AsyncDuplexStreamingCall<Message, Message> RecreateChannel() => RecreateChannel(this._channel);
private AsyncDuplexStreamingCall<Message, Message> RecreateChannel(AsyncDuplexStreamingCall<Message, Message>? ownedChannel)
{
// Make sure we are only re-creating the channel if it does not exit or we are the owner.
if (_channel is null || _channel == ownedChannel)
{
lock (_channelLock)
{
if (_channel is null || _channel == ownedChannel)
{
var metadata = new Metadata
{
{ "client-id", _clientId.ToString() }
};
_channel?.Dispose();
_channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token, headers: metadata);
}
}
}
return _channel;
}
public void Dispose()
{
IDisposable? channelDisposable = Interlocked.Exchange(ref this._channel, null);
channelDisposable?.Dispose();
}
}
internal sealed class GrpcMessageRouter(AgentRpc.AgentRpcClient client,
IMessageSink<Message> incomingMessageSink,
Guid clientId,
ILogger<GrpcAgentRuntime> logger,
CancellationToken shutdownCancellation = default) : IDisposable
{
private static readonly BoundedChannelOptions DefaultChannelOptions = new BoundedChannelOptions(1024)
{
AllowSynchronousContinuations = true,
SingleReader = true,
SingleWriter = false,
FullMode = BoundedChannelFullMode.Wait
};
private readonly ILogger<GrpcAgentRuntime> _logger = logger;
private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation);
private readonly IMessageSink<Message> _incomingMessageSink = incomingMessageSink;
private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel
// TODO: Enable a way to configure the channel options
= Channel.CreateBounded<(Message, TaskCompletionSource)>(DefaultChannelOptions);
private readonly AutoRestartChannel _incomingMessageChannel = new AutoRestartChannel(client, clientId, logger, shutdownCancellation);
private Task? _readTask;
private Task? _writeTask;
private async Task RunReadPump()
{
var cachedChannel = _incomingMessageChannel.StreamingCall;
while (!_shutdownCts.Token.IsCancellationRequested)
{
try
{
await foreach (var message in cachedChannel.ResponseStream.ReadAllAsync(_shutdownCts.Token))
{
// next if message is null
if (message == null)
{
continue;
}
await _incomingMessageSink.OnMessageAsync(message, _shutdownCts.Token);
}
}
catch (OperationCanceledException)
{
// Time to shut down.
break;
}
catch (Exception ex) when (!_shutdownCts.IsCancellationRequested)
{
_logger.LogError(ex, "Error reading from channel.");
cachedChannel = this._incomingMessageChannel.RecreateChannel();
}
catch
{
// Shutdown requested.
break;
}
}
}
private async Task RunWritePump()
{
var cachedChannel = this._incomingMessageChannel.StreamingCall;
var outboundMessages = _outboundMessagesChannel.Reader;
while (!_shutdownCts.IsCancellationRequested)
{
(Message Message, TaskCompletionSource WriteCompletionSource) item = default;
try
{
await outboundMessages.WaitToReadAsync().ConfigureAwait(false);
// Read the next message if we don't already have an unsent message
// waiting to be sent.
if (!outboundMessages.TryRead(out item))
{
break;
}
while (!_shutdownCts.IsCancellationRequested)
{
await cachedChannel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false);
item.WriteCompletionSource.TrySetResult();
break;
}
}
catch (OperationCanceledException)
{
// Time to shut down.
item.WriteCompletionSource?.TrySetCanceled();
break;
}
catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable)
{
// we could not connect to the endpoint - most likely we have the wrong port or failed ssl
// we need to let the user know what port we tried to connect to and then do backoff and retry
_logger.LogError(ex, "Error connecting to GRPC endpoint {Endpoint}.", Environment.GetEnvironmentVariable("AGENT_HOST"));
break;
}
catch (RpcException ex) when (ex.StatusCode == StatusCode.OK)
{
_logger.LogError(ex, "Error writing to channel, continuing (Status OK). {ex}", cachedChannel.ToString());
break;
}
catch (Exception ex) when (!_shutdownCts.IsCancellationRequested)
{
item.WriteCompletionSource?.TrySetException(ex);
_logger.LogError(ex, $"Error writing to channel.{ex}");
cachedChannel = this._incomingMessageChannel.RecreateChannel();
continue;
}
catch
{
// Shutdown requested.
item.WriteCompletionSource?.TrySetCanceled();
break;
}
}
while (outboundMessages.TryRead(out var item))
{
item.WriteCompletionSource.TrySetCanceled();
}
}
public ValueTask RouteMessageAsync(Message message, CancellationToken cancellation = default)
{
var tcs = new TaskCompletionSource();
return _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellation);
}
public ValueTask StartAsync(CancellationToken cancellation)
{
// TODO: Should we error out on a noncancellable token?
this._incomingMessageChannel.EnsureConnected();
var didSuppress = false;
// Make sure we do not mistakenly flow the ExecutionContext into the background pumping tasks.
if (!ExecutionContext.IsFlowSuppressed())
{
didSuppress = true;
ExecutionContext.SuppressFlow();
}
try
{
_readTask = Task.Run(RunReadPump, cancellation);
_writeTask = Task.Run(RunWritePump, cancellation);
return ValueTask.CompletedTask;
}
catch (Exception ex)
{
return ValueTask.FromException(ex);
}
finally
{
if (didSuppress)
{
ExecutionContext.RestoreFlow();
}
}
}
// No point in returning a ValueTask here, since we are awaiting the two tasks
public async Task StopAsync()
{
_shutdownCts.Cancel();
_outboundMessagesChannel.Writer.TryComplete();
List<Task> pendingTasks = new();
if (_readTask is { } readTask)
{
pendingTasks.Add(readTask);
}
if (_writeTask is { } writeTask)
{
pendingTasks.Add(writeTask);
}
await Task.WhenAll(pendingTasks).ConfigureAwait(false);
this._incomingMessageChannel.Dispose();
}
public void Dispose()
{
_outboundMessagesChannel.Writer.TryComplete();
this._incomingMessageChannel.Dispose();
}
}

View File

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IAgentMessageSerializer.cs
namespace Microsoft.AutoGen.Core.Grpc;
/// <summary>
/// Interface for serializing and deserializing agent messages.
/// </summary>
public interface IAgentMessageSerializer
{
/// <summary>
/// Serialize an agent message.
/// </summary>
/// <param name="message">The message to serialize.</param>
/// <returns>The serialized message.</returns>
Google.Protobuf.WellKnownTypes.Any Serialize(object message);
/// <summary>
/// Deserialize an agent message.
/// </summary>
/// <param name="message">The message to deserialize.</param>
/// <returns>The deserialized message.</returns>
object Deserialize(Google.Protobuf.WellKnownTypes.Any message);
}

View File

@ -0,0 +1,102 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IAgentRuntimeExtensions.cs
using System.Diagnostics;
using Google.Protobuf.Collections;
using Microsoft.AutoGen.Contracts;
using Microsoft.AutoGen.Protobuf;
using Microsoft.Extensions.DependencyInjection;
using static Microsoft.AutoGen.Contracts.CloudEvent.Types;
namespace Microsoft.AutoGen.Core.Grpc;
public static class GrpcAgentRuntimeExtensions
{
public static (string?, string?) GetTraceIdAndState(GrpcAgentRuntime runtime, IDictionary<string, string> metadata)
{
var dcp = runtime.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
dcp.ExtractTraceIdAndState(metadata,
static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
{
var metadata = (IDictionary<string, string>)carrier!;
fieldValues = null;
metadata.TryGetValue(fieldName, out fieldValue);
},
out var traceParent,
out var traceState);
return (traceParent, traceState);
}
public static (string?, string?) GetTraceIdAndState(GrpcAgentRuntime worker, MapField<string, CloudEventAttributeValue> metadata)
{
var dcp = worker.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
dcp.ExtractTraceIdAndState(metadata,
static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
{
var metadata = (MapField<string, CloudEventAttributeValue>)carrier!;
fieldValues = null;
metadata.TryGetValue(fieldName, out var ceValue);
fieldValue = ceValue?.CeString;
},
out var traceParent,
out var traceState);
return (traceParent, traceState);
}
public static void Update(GrpcAgentRuntime worker, RpcRequest request, Activity? activity = null)
{
var dcp = worker.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
dcp.Inject(activity, request.Metadata, static (carrier, key, value) =>
{
var metadata = (IDictionary<string, string>)carrier!;
if (metadata.TryGetValue(key, out _))
{
metadata[key] = value;
}
else
{
metadata.Add(key, value);
}
});
}
public static void Update(GrpcAgentRuntime worker, CloudEvent cloudEvent, Activity? activity = null)
{
var dcp = worker.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
dcp.Inject(activity, cloudEvent.Attributes, static (carrier, key, value) =>
{
var mapField = (MapField<string, CloudEventAttributeValue>)carrier!;
if (mapField.TryGetValue(key, out var ceValue))
{
mapField[key] = new CloudEventAttributeValue { CeString = value };
}
else
{
mapField.Add(key, new CloudEventAttributeValue { CeString = value });
}
});
}
public static IDictionary<string, string> ExtractMetadata(GrpcAgentRuntime worker, IDictionary<string, string> metadata)
{
var dcp = worker.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
{
var metadata = (IDictionary<string, string>)carrier!;
fieldValues = null;
metadata.TryGetValue(fieldName, out fieldValue);
});
return baggage as IDictionary<string, string> ?? new Dictionary<string, string>();
}
public static IDictionary<string, string> ExtractMetadata(GrpcAgentRuntime worker, MapField<string, CloudEventAttributeValue> metadata)
{
var dcp = worker.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
{
var metadata = (MapField<string, CloudEventAttributeValue>)carrier!;
fieldValues = null;
metadata.TryGetValue(fieldName, out var ceValue);
fieldValue = ceValue?.CeString;
});
return baggage as IDictionary<string, string> ?? new Dictionary<string, string>();
}
}

View File

@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IProtobufMessageSerializer.cs
namespace Microsoft.AutoGen.Core.Grpc;
public interface IProtobufMessageSerializer
{
Google.Protobuf.WellKnownTypes.Any Serialize(object input);
object Deserialize(Google.Protobuf.WellKnownTypes.Any input);
}

View File

@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ISerializationRegistry.cs
namespace Microsoft.AutoGen.Core.Grpc;
public interface IProtoSerializationRegistry
{
/// <summary>
/// Registers a serializer for the specified type.
/// </summary>
/// <param name="type">The type to register.</param>
void RegisterSerializer(System.Type type) => RegisterSerializer(type, new ProtobufMessageSerializer(type));
void RegisterSerializer(System.Type type, IProtobufMessageSerializer serializer);
/// <summary>
/// Gets the serializer for the specified type.
/// </summary>
/// <param name="type">The type to get the serializer for.</param>
/// <returns>The serializer for the specified type.</returns>
IProtobufMessageSerializer? GetSerializer(System.Type type) => GetSerializer(TypeNameResolver.ResolveTypeName(type));
IProtobufMessageSerializer? GetSerializer(string typeName);
ITypeNameResolver TypeNameResolver { get; }
bool Exists(System.Type type);
}

View File

@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ITypeNameResolver.cs
namespace Microsoft.AutoGen.Core.Grpc;
public interface ITypeNameResolver
{
string ResolveTypeName(Type input);
}

View File

@ -14,7 +14,6 @@
<ItemGroup>
<Protobuf Include="..\..\..\..\protos\agent_worker.proto" GrpcServices="Client;Server" Link="Protos\agent_worker.proto" />
<Protobuf Include="..\..\..\..\protos\cloudevent.proto" GrpcServices="Client;Server" Link="Protos\cloudevent.proto" />
<Protobuf Include="..\..\..\..\protos\agent_events.proto" GrpcServices="Client;Server" Link="Protos\agent_events.proto" />
</ItemGroup>
<ItemGroup>

View File

@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ProtobufConversionExtensions.cs
using Microsoft.AutoGen.Contracts;
using Microsoft.AutoGen.Protobuf;
namespace Microsoft.AutoGen.Core.Grpc;
public static class ProtobufConversionExtensions
{
// Convert an ISubscrptionDefinition to a Protobuf Subscription
public static Subscription? ToProtobuf(this ISubscriptionDefinition subscriptionDefinition)
{
// Check if is a TypeSubscription
if (subscriptionDefinition is Contracts.TypeSubscription typeSubscription)
{
return new Subscription
{
Id = typeSubscription.Id,
TypeSubscription = new Protobuf.TypeSubscription
{
TopicType = typeSubscription.TopicType,
AgentType = typeSubscription.AgentType
}
};
}
// Check if is a TypePrefixSubscription
if (subscriptionDefinition is Contracts.TypePrefixSubscription typePrefixSubscription)
{
return new Subscription
{
Id = typePrefixSubscription.Id,
TypePrefixSubscription = new Protobuf.TypePrefixSubscription
{
TopicTypePrefix = typePrefixSubscription.TopicTypePrefix,
AgentType = typePrefixSubscription.AgentType
}
};
}
return null;
}
// Convert AgentId from Protobuf to AgentId
public static Contracts.AgentId FromProtobuf(this Protobuf.AgentId agentId)
{
return new Contracts.AgentId(agentId.Type, agentId.Key);
}
// Convert AgentId from AgentId to Protobuf
public static Protobuf.AgentId ToProtobuf(this Contracts.AgentId agentId)
{
return new Protobuf.AgentId
{
Type = agentId.Type,
Key = agentId.Key
};
}
}

View File

@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ProtobufMessageSerializer.cs
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
namespace Microsoft.AutoGen.Core.Grpc;
/// <summary>
/// Interface for serializing and deserializing agent messages.
/// </summary>
public class ProtobufMessageSerializer : IProtobufMessageSerializer
{
private System.Type _concreteType;
public ProtobufMessageSerializer(System.Type concreteType)
{
_concreteType = concreteType;
}
public object Deserialize(Any message)
{
// Check if the concrete type is a proto IMessage
if (typeof(IMessage).IsAssignableFrom(_concreteType))
{
var nameOfMethod = nameof(Any.Unpack);
var result = message.GetType().GetMethods().Where(m => m.Name == nameOfMethod && m.IsGenericMethod).First().MakeGenericMethod(_concreteType).Invoke(message, null);
return result as IMessage ?? throw new ArgumentException("Failed to deserialize", nameof(message));
}
// Raise an exception if the concrete type is not a proto IMessage
throw new ArgumentException("Concrete type must be a proto IMessage", nameof(_concreteType));
}
public Any Serialize(object message)
{
// Check if message is a proto IMessage
if (message is IMessage protoMessage)
{
return Any.Pack(protoMessage);
}
// Raise an exception if the message is not a proto IMessage
throw new ArgumentException("Message must be a proto IMessage", nameof(message));
}
}

View File

@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ProtobufSerializationRegistry.cs
namespace Microsoft.AutoGen.Core.Grpc;
public class ProtobufSerializationRegistry : IProtoSerializationRegistry
{
private readonly Dictionary<string, IProtobufMessageSerializer> _serializers
= new Dictionary<string, IProtobufMessageSerializer>();
public ITypeNameResolver TypeNameResolver => new ProtobufTypeNameResolver();
public bool Exists(Type type)
{
return _serializers.ContainsKey(TypeNameResolver.ResolveTypeName(type));
}
public IProtobufMessageSerializer? GetSerializer(Type type)
{
return GetSerializer(TypeNameResolver.ResolveTypeName(type));
}
public IProtobufMessageSerializer? GetSerializer(string typeName)
{
_serializers.TryGetValue(typeName, out var serializer);
return serializer;
}
public void RegisterSerializer(Type type, IProtobufMessageSerializer serializer)
{
if (_serializers.ContainsKey(TypeNameResolver.ResolveTypeName(type)))
{
throw new InvalidOperationException($"Serializer already registered for {type.FullName}");
}
_serializers[TypeNameResolver.ResolveTypeName(type)] = serializer;
}
}

View File

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ProtobufTypeNameResolver.cs
using Google.Protobuf;
namespace Microsoft.AutoGen.Core.Grpc;
public class ProtobufTypeNameResolver : ITypeNameResolver
{
public string ResolveTypeName(Type input)
{
if (typeof(IMessage).IsAssignableFrom(input))
{
// TODO: Consider changing this to avoid instantiation...
var protoMessage = (IMessage?)Activator.CreateInstance(input) ?? throw new InvalidOperationException($"Failed to create instance of {input.FullName}");
return protoMessage.Descriptor.FullName;
}
else
{
throw new ArgumentException("Input must be a protobuf message.");
}
}
}

View File

@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// RpcExtensions.cs
using Google.Protobuf;
using Microsoft.AutoGen.Protobuf;
namespace Microsoft.AutoGen.Core.Grpc;
internal static class RpcExtensions
{
public static Payload ToPayload(this object message, IProtoSerializationRegistry serializationRegistry)
{
if (!serializationRegistry.Exists(message.GetType()))
{
serializationRegistry.RegisterSerializer(message.GetType());
}
var rpcMessage = (serializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message);
var typeName = serializationRegistry.TypeNameResolver.ResolveTypeName(message.GetType());
const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf";
// Protobuf any to byte array
Payload payload = new()
{
DataType = typeName,
DataContentType = PAYLOAD_DATA_CONTENT_TYPE,
Data = rpcMessage.ToByteString()
};
return payload;
}
public static object ToObject(this Payload payload, IProtoSerializationRegistry serializationRegistry)
{
var typeName = payload.DataType;
var data = payload.Data;
var serializer = serializationRegistry.GetSerializer(typeName) ?? throw new Exception();
var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data);
return serializer.Deserialize(any);
}
}

View File

@ -4,6 +4,7 @@
using System.Diagnostics;
using System.Reflection;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
@ -21,6 +22,7 @@ public class AgentsAppBuilder
}
public IServiceCollection Services => this.builder.Services;
public IConfiguration Configuration => this.builder.Configuration;
public void AddAgentsFromAssemblies()
{

View File

@ -1,263 +1,152 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentGrpcTests.cs
using System.Collections.Concurrent;
using System.Text.Json;
using FluentAssertions;
using Google.Protobuf.Reflection;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
// using Microsoft.AutoGen.Core.Tests;
using Microsoft.AutoGen.Core.Grpc.Tests.Protobuf;
using Microsoft.Extensions.Logging;
using Xunit;
using static Microsoft.AutoGen.Core.Grpc.Tests.AgentGrpcTests;
namespace Microsoft.AutoGen.Core.Grpc.Tests;
[Trait("Category", "UnitV2")]
[Trait("Category", "GRPC")]
public class AgentGrpcTests
{
/// <summary>
/// Verify that if the agent is not initialized via AgentWorker, it should throw the correct exception.
/// </summary>
/// <returns>void</returns>
[Fact]
public async Task Agent_ShouldThrowException_WhenNotInitialized()
public async Task AgentShouldNotReceiveMessagesWhenNotSubscribedTest()
{
using var runtime = new GrpcRuntime();
var (_, agent) = runtime.Start(false); // Do not initialize
var fixture = new GrpcAgentRuntimeFixture();
var runtime = (GrpcAgentRuntime)await fixture.Start();
// Expect an exception when calling AddSubscriptionAsync because the agent is uninitialized
await Assert.ThrowsAsync<UninitializedAgentWorker.AgentInitalizedIncorrectlyException>(
async () => await agent.AddSubscriptionAsync("TestEvent")
);
}
Logger<BaseAgent> logger = new(new LoggerFactory());
TestProtobufAgent agent = null!;
/// <summary>
/// validate that the agent is initialized correctly with implicit subs
/// </summary>
/// <returns>void</returns>
[Fact]
public async Task Agent_ShouldInitializeCorrectly()
{
using var runtime = new GrpcRuntime();
var (worker, agent) = runtime.Start();
Assert.Equal(nameof(GrpcAgentRuntime), worker.GetType().Name);
await Task.Delay(5000);
var subscriptions = await agent.GetSubscriptionsAsync();
Assert.Equal(2, subscriptions.Count);
}
/// <summary>
/// Test AddSubscriptionAsync method
/// </summary>
/// <returns>void</returns>
[Fact]
public async Task SubscribeAsync_UnsubscribeAsync_and_GetSubscriptionsTest()
{
using var runtime = new GrpcRuntime();
var (_, agent) = runtime.Start();
await agent.AddSubscriptionAsync("TestEvent");
await Task.Delay(100);
var subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true);
var found = false;
foreach (var subscription in subscriptions)
await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) =>
{
if (subscription.TypeSubscription.TopicType == "TestEvent")
{
found = true;
}
}
Assert.True(found);
await agent.RemoveSubscriptionAsync("TestEvent").ConfigureAwait(true);
await Task.Delay(1000);
subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true);
found = false;
foreach (var subscription in subscriptions)
{
if (subscription.TypeSubscription.TopicType == "TestEvent")
{
found = true;
}
}
Assert.False(found);
}
agent = new TestProtobufAgent(id, runtime, logger);
return await ValueTask.FromResult(agent);
});
/// <summary>
/// Test StoreAsync and ReadAsync methods
/// </summary>
/// <returns>void</returns>
[Fact]
public async Task StoreAsync_and_ReadAsyncTest()
{
using var runtime = new GrpcRuntime();
var (_, agent) = runtime.Start();
Dictionary<string, string> state = new()
{
{ "testdata", "Active" }
};
await agent.StoreAsync(new AgentState
{
AgentId = agent.AgentId,
TextData = JsonSerializer.Serialize(state)
}).ConfigureAwait(true);
var readState = await agent.ReadAsync<AgentState>(agent.AgentId).ConfigureAwait(true);
var read = JsonSerializer.Deserialize<Dictionary<string, string>>(readState.TextData) ?? new Dictionary<string, string> { { "data", "No state data found" } };
read.TryGetValue("testdata", out var value);
Assert.Equal("Active", value);
}
// Ensure the agent is actually created
AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false);
// Validate agent ID
agentId.Should().Be(agent.Id, "Agent ID should match the registered agent");
/// <summary>
/// Test PublishMessageAsync method and ReceiveMessage method
/// </summary>
/// <returns>void</returns>
[Fact]
public async Task PublishMessageAsync_and_ReceiveMessageTest()
{
using var runtime = new GrpcRuntime();
var (_, agent) = runtime.Start();
var topicType = "TestTopic";
await agent.AddSubscriptionAsync(topicType).ConfigureAwait(true);
var subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true);
var found = false;
foreach (var subscription in subscriptions)
{
if (subscription.TypeSubscription.TopicType == topicType)
{
found = true;
}
}
Assert.True(found);
await agent.PublishMessageAsync(new TextMessage()
{
Source = topicType,
TextMessage_ = "buffer"
}, topicType).ConfigureAwait(true);
await Task.Delay(100);
Assert.True(TestAgent.ReceivedMessages.ContainsKey(topicType));
runtime.Stop();
await runtime.PublishMessageAsync(new Protobuf.TextMessage { Source = topicType, Content = "test" }, new TopicId(topicType)).ConfigureAwait(true);
agent.ReceivedMessages.Any().Should().BeFalse("Agent should not receive messages when not subscribed.");
fixture.Dispose();
}
[Fact]
public async Task InvokeCorrectHandler()
public async Task AgentShouldReceiveMessagesWhenSubscribedTest()
{
var agent = new TestAgent(new AgentsMetadata(TypeRegistry.Empty, new Dictionary<string, Type>(), new Dictionary<Type, HashSet<string>>(), new Dictionary<Type, HashSet<string>>()), new Logger<Agent>(new LoggerFactory()));
var fixture = new GrpcAgentRuntimeFixture();
var runtime = (GrpcAgentRuntime)await fixture.Start();
await agent.HandleObjectAsync("hello world");
await agent.HandleObjectAsync(42);
Logger<BaseAgent> logger = new(new LoggerFactory());
SubscribedProtobufAgent agent = null!;
agent.ReceivedItems.Should().HaveCount(2);
agent.ReceivedItems[0].Should().Be("hello world");
agent.ReceivedItems[1].Should().Be(42);
await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) =>
{
agent = new SubscribedProtobufAgent(id, runtime, logger);
return await ValueTask.FromResult(agent);
});
// Ensure the agent is actually created
AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false);
// Validate agent ID
agentId.Should().Be(agent.Id, "Agent ID should match the registered agent");
await runtime.RegisterImplicitAgentSubscriptionsAsync<SubscribedProtobufAgent>("MyAgent");
var topicType = "TestTopic";
await runtime.PublishMessageAsync(new TextMessage { Source = topicType, Content = "test" }, new TopicId(topicType)).ConfigureAwait(true);
// Wait for the message to be processed
await Task.Delay(100);
agent.ReceivedMessages.Any().Should().BeTrue("Agent should receive messages when subscribed.");
fixture.Dispose();
}
/// <summary>
/// The test agent is a simple agent that is used for testing purposes.
/// </summary>
public class TestAgent(
[FromKeyedServices("AgentsMetadata")] AgentsMetadata eventTypes,
Logger<Agent>? logger = null) : Agent(eventTypes, logger), IHandle<TextMessage>
[Fact]
public async Task SendMessageAsyncShouldReturnResponseTest()
{
public Task Handle(TextMessage item, CancellationToken cancellationToken = default)
{
ReceivedMessages[item.Source] = item.TextMessage_;
return Task.CompletedTask;
}
public Task Handle(string item)
{
ReceivedItems.Add(item);
return Task.CompletedTask;
}
public Task Handle(int item)
{
ReceivedItems.Add(item);
return Task.CompletedTask;
}
public List<object> ReceivedItems { get; private set; } = [];
// Arrange
var fixture = new GrpcAgentRuntimeFixture();
var runtime = (GrpcAgentRuntime)await fixture.Start();
/// <summary>
/// Key: source
/// Value: message
/// </summary>
public static ConcurrentDictionary<string, object> ReceivedMessages { get; private set; } = new();
}
}
/// <summary>
/// GrpcRuntimeFixture - provides a fixture for the agent runtime.
/// </summary>
/// <remarks>
/// This fixture is used to provide a runtime for the agent tests.
/// However, it is shared between tests. So operations from one test can affect another.
/// </remarks>
public sealed class GrpcRuntime : IDisposable
{
public IHost Client { get; private set; }
public IHost? AppHost { get; private set; }
public GrpcRuntime()
{
Environment.SetEnvironmentVariable("ASPNETCORE_ENVIRONMENT", "Development");
AppHost = Host.CreateDefaultBuilder().Build();
Client = Host.CreateDefaultBuilder().Build();
}
private static int GetAvailablePort()
{
using var listener = new System.Net.Sockets.TcpListener(System.Net.IPAddress.Loopback, 0);
listener.Start();
int port = ((System.Net.IPEndPoint)listener.LocalEndpoint).Port;
listener.Stop();
return port;
}
private static async Task<IHost> StartClientAsync()
{
return await AgentsApp.StartAsync().ConfigureAwait(false);
}
private static async Task<IHost> StartAppHostAsync()
{
return await Microsoft.AutoGen.Runtime.Grpc.Host.StartAsync(local: false, useGrpc: true).ConfigureAwait(false);
}
/// <summary>
/// Start - gets a new port and starts fresh instances
/// </summary>
public (IAgentRuntime, TestAgent) Start(bool initialize = true)
{
int port = GetAvailablePort(); // Get a new port per test run
// Update environment variables so each test runs independently
Environment.SetEnvironmentVariable("ASPNETCORE_HTTPS_PORTS", port.ToString());
Environment.SetEnvironmentVariable("AGENT_HOST", $"https://localhost:{port}");
AppHost = StartAppHostAsync().GetAwaiter().GetResult();
Client = StartClientAsync().GetAwaiter().GetResult();
var agent = ActivatorUtilities.CreateInstance<TestAgent>(Client.Services);
var worker = Client.Services.GetRequiredService<IAgentRuntime>();
if (initialize)
{
Agent.Initialize(worker, agent);
}
return (worker, agent);
}
/// <summary>
/// Stop - stops the agent and ensures cleanup
/// </summary>
public void Stop()
{
Client?.StopAsync().GetAwaiter().GetResult();
AppHost?.StopAsync().GetAwaiter().GetResult();
}
/// <summary>
/// Dispose - Ensures cleanup after each test
/// </summary>
public void Dispose()
{
Stop();
Logger<BaseAgent> logger = new(new LoggerFactory());
await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) => await ValueTask.FromResult(new TestProtobufAgent(id, runtime, logger)));
var agentId = new AgentId("MyAgent", "default");
var response = await runtime.SendMessageAsync(new RpcTextMessage { Source = "TestTopic", Content = "Request" }, agentId);
// Assert
Assert.NotNull(response);
Assert.IsType<RpcTextMessage>(response);
if (response is RpcTextMessage responseString)
{
Assert.Equal("Request", responseString.Content);
}
fixture.Dispose();
}
public class ReceiverAgent(AgentId id,
IAgentRuntime runtime) : BaseAgent(id, runtime, "Receiver Agent", null),
IHandle<TextMessage>
{
public ValueTask HandleAsync(TextMessage item, MessageContext messageContext)
{
ReceivedItems.Add(item.Content);
return ValueTask.CompletedTask;
}
public List<string> ReceivedItems { get; private set; } = [];
}
[Fact]
public async Task SubscribeAsyncRemoveSubscriptionAsyncAndGetSubscriptionsTest()
{
var fixture = new GrpcAgentRuntimeFixture();
var runtime = (GrpcAgentRuntime)await fixture.Start();
ReceiverAgent? agent = null;
await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) =>
{
agent = new ReceiverAgent(id, runtime);
return await ValueTask.FromResult(agent);
});
Assert.Null(agent);
await runtime.GetAgentAsync("MyAgent", lazy: false);
Assert.NotNull(agent);
Assert.True(agent.ReceivedItems.Count == 0);
var topicTypeName = "TestTopic";
await runtime.PublishMessageAsync(new TextMessage { Source = "topic", Content = "test" }, new TopicId(topicTypeName));
await Task.Delay(100);
Assert.True(agent.ReceivedItems.Count == 0);
var subscription = new TypeSubscription(topicTypeName, "MyAgent");
await runtime.AddSubscriptionAsync(subscription);
await runtime.PublishMessageAsync(new TextMessage { Source = "topic", Content = "test" }, new TopicId(topicTypeName));
await Task.Delay(100);
Assert.True(agent.ReceivedItems.Count == 1);
Assert.Equal("test", agent.ReceivedItems[0]);
await runtime.RemoveSubscriptionAsync(subscription.Id);
await runtime.PublishMessageAsync(new TextMessage { Source = "topic", Content = "test" }, new TopicId(topicTypeName));
await Task.Delay(100);
Assert.True(agent.ReceivedItems.Count == 1);
fixture.Dispose();
}
}

View File

@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GrpcAgentRuntimeFixture.cs
using Microsoft.AspNetCore.Builder;
using Microsoft.AutoGen.Contracts;
// using Microsoft.AutoGen.Core.Tests;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
namespace Microsoft.AutoGen.Core.Grpc.Tests;
/// <summary>
/// Fixture for setting up the gRPC agent runtime for testing.
/// </summary>
public sealed class GrpcAgentRuntimeFixture : IDisposable
{
/// the gRPC agent runtime.
public AgentsApp? Client { get; private set; }
/// mock server for testing.
public WebApplication? Server { get; private set; }
public GrpcAgentRuntimeFixture()
{
}
/// <summary>
/// Start - gets a new port and starts fresh instances
/// </summary>
public async Task<IAgentRuntime> Start(bool initialize = true)
{
int port = GetAvailablePort(); // Get a new port per test run
// Update environment variables so each test runs independently
Environment.SetEnvironmentVariable("ASPNETCORE_HTTPS_PORTS", port.ToString());
Environment.SetEnvironmentVariable("AGENT_HOST", $"https://localhost:{port}");
Environment.SetEnvironmentVariable("ASPNETCORE_ENVIRONMENT", "Development");
Server = ServerBuilder().Result;
await Server.StartAsync().ConfigureAwait(true);
Client = ClientBuilder().Result;
await Client.StartAsync().ConfigureAwait(true);
var worker = Client.Services.GetRequiredService<IAgentRuntime>();
return (worker);
}
private static async Task<AgentsApp> ClientBuilder()
{
var appBuilder = new AgentsAppBuilder();
appBuilder.AddGrpcAgentWorker();
appBuilder.AddAgent<TestProtobufAgent>("TestAgent");
return await appBuilder.BuildAsync();
}
private static async Task<WebApplication> ServerBuilder()
{
var builder = WebApplication.CreateBuilder();
builder.Services.AddGrpc();
var app = builder.Build();
app.MapGrpcService<GrpcAgentServiceFixture>();
return app;
}
private static int GetAvailablePort()
{
using var listener = new System.Net.Sockets.TcpListener(System.Net.IPAddress.Loopback, 0);
listener.Start();
int port = ((System.Net.IPEndPoint)listener.LocalEndpoint).Port;
listener.Stop();
return port;
}
/// <summary>
/// Stop - stops the agent and ensures cleanup
/// </summary>
public void Stop()
{
(Client as IHost)?.StopAsync(TimeSpan.FromSeconds(30)).GetAwaiter().GetResult();
Server?.StopAsync().GetAwaiter().GetResult();
}
/// <summary>
/// Dispose - Ensures cleanup after each test
/// </summary>
public void Dispose()
{
Stop();
}
}

View File

@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GrpcAgentServiceFixture.cs
using Grpc.Core;
using Microsoft.AutoGen.Protobuf;
namespace Microsoft.AutoGen.Core.Grpc.Tests;
/// <summary>
/// This fixture is largely just a loopback as we are testing the client side logic of the GrpcAgentRuntime in isolation from the rest of the system.
/// </summary>
public sealed class GrpcAgentServiceFixture() : AgentRpc.AgentRpcBase
{
public override async Task OpenChannel(IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
{
try
{
var workerProcess = new TestGrpcWorkerConnection(requestStream, responseStream, context);
await workerProcess.Connect().ConfigureAwait(true);
}
catch
{
if (context.CancellationToken.IsCancellationRequested)
{
return;
}
throw;
}
}
public override async Task<GetStateResponse> GetState(AgentId request, ServerCallContext context) => new GetStateResponse { AgentState = new AgentState { AgentId = request } };
public override async Task<SaveStateResponse> SaveState(AgentState request, ServerCallContext context) => new SaveStateResponse { };
public override async Task<AddSubscriptionResponse> AddSubscription(AddSubscriptionRequest request, ServerCallContext context) => new AddSubscriptionResponse { };
public override async Task<RemoveSubscriptionResponse> RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) => new RemoveSubscriptionResponse { };
public override async Task<GetSubscriptionsResponse> GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) => new GetSubscriptionsResponse { };
public override async Task<RegisterAgentTypeResponse> RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context) => new RegisterAgentTypeResponse { };
}

View File

@ -10,8 +10,17 @@
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Core\Microsoft.AutoGen.Core.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Core.Grpc\Microsoft.AutoGen.Core.Grpc.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AutoGen\AgentHost\Microsoft.AutoGen.AgentHost.csproj" />
<PackageReference Include="Microsoft.Extensions.Hosting" />
</ItemGroup>
<ItemGroup>
<Protobuf Include="./messages.proto" GrpcServices="Client;Server" Link="Protos\messages.proto" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Grpc.AspNetCore" />
<PackageReference Include="Grpc.Net.ClientFactory" />
<PackageReference Include="Grpc.Tools" PrivateAssets="All" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,134 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// TestGrpcWorkerConnection.cs
using System.Threading.Channels;
using Grpc.Core;
using Microsoft.AutoGen.Protobuf;
namespace Microsoft.AutoGen.Core.Grpc.Tests;
internal sealed class TestGrpcWorkerConnection : IAsyncDisposable
{
private static long s_nextConnectionId;
private Task _readTask = Task.CompletedTask;
private Task _writeTask = Task.CompletedTask;
private readonly string _connectionId = Interlocked.Increment(ref s_nextConnectionId).ToString();
private readonly object _lock = new();
private readonly HashSet<string> _supportedTypes = [];
private readonly CancellationTokenSource _shutdownCancellationToken = new();
public Task Completion { get; private set; } = Task.CompletedTask;
public IAsyncStreamReader<Message> RequestStream { get; }
public IServerStreamWriter<Message> ResponseStream { get; }
public ServerCallContext ServerCallContext { get; }
private readonly Channel<Message> _outboundMessages;
public TestGrpcWorkerConnection(IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
{
RequestStream = requestStream;
ResponseStream = responseStream;
ServerCallContext = context;
_outboundMessages = Channel.CreateUnbounded<Message>(new UnboundedChannelOptions { AllowSynchronousContinuations = true, SingleReader = true, SingleWriter = false });
}
public Task Connect()
{
var didSuppress = false;
if (!ExecutionContext.IsFlowSuppressed())
{
didSuppress = true;
ExecutionContext.SuppressFlow();
}
try
{
_readTask = Task.Run(RunReadPump);
_writeTask = Task.Run(RunWritePump);
}
finally
{
if (didSuppress)
{
ExecutionContext.RestoreFlow();
}
}
return Completion = Task.WhenAll(_readTask, _writeTask);
}
public void AddSupportedType(string type)
{
lock (_lock)
{
_supportedTypes.Add(type);
}
}
public HashSet<string> GetSupportedTypes()
{
lock (_lock)
{
return new HashSet<string>(_supportedTypes);
}
}
public async Task SendMessage(Message message)
{
await _outboundMessages.Writer.WriteAsync(message).ConfigureAwait(false);
}
public async Task RunReadPump()
{
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
try
{
await foreach (var message in RequestStream.ReadAllAsync(_shutdownCancellationToken.Token))
{
//_gateway.OnReceivedMessageAsync(this, message, _shutdownCancellationToken.Token).Ignore();
switch (message.MessageCase)
{
case Message.MessageOneofCase.Request:
await SendMessage(new Message { Request = message.Request }).ConfigureAwait(false);
break;
case Message.MessageOneofCase.Response:
await SendMessage(new Message { Response = message.Response }).ConfigureAwait(false);
break;
case Message.MessageOneofCase.CloudEvent:
await SendMessage(new Message { CloudEvent = message.CloudEvent }).ConfigureAwait(false);
break;
default:
// if it wasn't recognized return bad request
throw new RpcException(new Status(StatusCode.InvalidArgument, $"Unknown message type for message '{message}'"));
};
}
}
catch (OperationCanceledException)
{
}
finally
{
_shutdownCancellationToken.Cancel();
//_gateway.OnRemoveWorkerProcess(this);
}
}
public async Task RunWritePump()
{
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
try
{
await foreach (var message in _outboundMessages.Reader.ReadAllAsync(_shutdownCancellationToken.Token))
{
await ResponseStream.WriteAsync(message);
}
}
catch (OperationCanceledException)
{
}
finally
{
_shutdownCancellationToken.Cancel();
}
}
public async ValueTask DisposeAsync()
{
_shutdownCancellationToken.Cancel();
await Completion.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
}
public override string ToString() => $"Connection-{_connectionId}";
}

View File

@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// TestProtobufAgent.cs
using Microsoft.AutoGen.Contracts;
using Microsoft.AutoGen.Core.Grpc.Tests.Protobuf;
using Microsoft.Extensions.Logging;
namespace Microsoft.AutoGen.Core.Grpc.Tests;
/// <summary>
/// The test agent is a simple agent that is used for testing purposes.
/// </summary>
public class TestProtobufAgent(AgentId id,
IAgentRuntime runtime,
Logger<BaseAgent>? logger = null) : BaseAgent(id, runtime, "Test Agent", logger),
IHandle<TextMessage>,
IHandle<RpcTextMessage, RpcTextMessage>
{
public ValueTask HandleAsync(TextMessage item, MessageContext messageContext)
{
ReceivedMessages[item.Source] = item.Content;
return ValueTask.CompletedTask;
}
public ValueTask<RpcTextMessage> HandleAsync(RpcTextMessage item, MessageContext messageContext)
{
ReceivedMessages[item.Source] = item.Content;
return ValueTask.FromResult(new RpcTextMessage { Source = item.Source, Content = item.Content });
}
public List<object> ReceivedItems { get; private set; } = [];
/// <summary>
/// Key: source
/// Value: message
/// </summary>
private readonly Dictionary<string, object> _receivedMessages = new();
public Dictionary<string, object> ReceivedMessages => _receivedMessages;
}
[TypeSubscription("TestTopic")]
public class SubscribedProtobufAgent : TestProtobufAgent
{
public SubscribedProtobufAgent(AgentId id,
IAgentRuntime runtime,
Logger<BaseAgent>? logger = null) : base(id, runtime, logger)
{
}
}

View File

@ -0,0 +1,13 @@
syntax = "proto3";
option csharp_namespace = "Microsoft.AutoGen.Core.Grpc.Tests.Protobuf";
message TextMessage {
string content = 1;
string source = 2;
}
message RpcTextMessage {
string content = 1;
string source = 2;
}

View File

@ -109,7 +109,7 @@ public class AgentTests()
}
[Fact]
public async Task SubscribeAsyncRemoveSubscriptionAsyncAndGetSubscriptionsTest()
public async Task SubscribeAsyncRemoveSubscriptionAsyncTest()
{
var runtime = new InProcessRuntime();
await runtime.StartAsync();

View File

@ -1,43 +0,0 @@
syntax = "proto3";
package agents;
option csharp_namespace = "Microsoft.AutoGen.Contracts";
message TextMessage {
string textMessage = 1;
string source = 2;
}
message Input {
string message = 1;
}
message InputProcessed {
string route = 1;
}
message Output {
string message = 1;
}
message OutputWritten {
string route = 1;
}
message IOError {
string message = 1;
}
message NewMessageReceived {
string message = 1;
}
message ResponseGenerated {
string response = 1;
}
message GoodBye {
string message = 1;
}
message MessageStored {
string message = 1;
}
message ConversationClosed {
string user_id = 1;
string user_message = 2;
}
message Shutdown {
string message = 1;
}

View File

@ -1,8 +0,0 @@
syntax = "proto3";
package agents;
option csharp_namespace = "Microsoft.AutoGen.Contracts";
message AgentState {
string message = 1;
}