mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-30 00:30:23 +00:00
Dotnet Grpc worker implementation (#5245)
Co-authored-by: Jacob Alber <jaalber@microsoft.com> Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
This commit is contained in:
parent
9030f75b4d
commit
08f9830bf7
34
.github/workflows/dotnet-build.yml
vendored
34
.github/workflows/dotnet-build.yml
vendored
@ -107,6 +107,38 @@ jobs:
|
||||
- name: Unit Test V2
|
||||
run: dotnet test --no-build -bl --configuration Release --filter "Category=UnitV2"
|
||||
|
||||
grpc-unit-tests:
|
||||
name: Dotnet Grpc unit tests
|
||||
needs: paths-filter
|
||||
if: needs.paths-filter.outputs.hasChanges == 'true'
|
||||
defaults:
|
||||
run:
|
||||
working-directory: dotnet
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ ubuntu-latest ]
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
- name: Setup .NET 8.0
|
||||
uses: actions/setup-dotnet@v4
|
||||
with:
|
||||
dotnet-version: '8.0.x'
|
||||
- name: Install dev certs
|
||||
run: dotnet --version && dotnet dev-certs https --trust
|
||||
- name: Restore dependencies
|
||||
run: |
|
||||
# dotnet nuget add source --name dotnet-tool https://pkgs.dev.azure.com/dnceng/public/_packaging/dotnet-tools/nuget/v3/index.json --configfile NuGet.config
|
||||
dotnet restore -bl
|
||||
- name: Build
|
||||
run: dotnet build --no-restore --configuration Release -bl /p:SignAssembly=true
|
||||
- name: GRPC tests
|
||||
run: dotnet test --no-build -bl --configuration Release --filter "Category=GRPC"
|
||||
|
||||
integration-test:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
@ -224,6 +256,8 @@ jobs:
|
||||
with:
|
||||
dotnet-version: '8.0.x'
|
||||
global-json-file: dotnet/global.json
|
||||
- name: Install dev certs
|
||||
run: dotnet --version && dotnet dev-certs https --trust
|
||||
- name: Restore dependencies
|
||||
run: |
|
||||
dotnet restore -bl
|
||||
|
||||
1
.github/workflows/dotnet-release.yml
vendored
1
.github/workflows/dotnet-release.yml
vendored
@ -52,6 +52,7 @@ jobs:
|
||||
run: |
|
||||
echo "Build AutoGen"
|
||||
dotnet build --no-restore --configuration Release -bl /p:SignAssembly=true
|
||||
- run: sudo dotnet dev-certs https --trust --no-password
|
||||
- name: Unit Test
|
||||
run: dotnet test --no-build -bl --configuration Release
|
||||
env:
|
||||
|
||||
@ -118,6 +118,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Hello", "Hello", "{F42F9C8E
|
||||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Core.Grpc", "src\Microsoft.AutoGen\Core.Grpc\Microsoft.AutoGen.Core.Grpc.csproj", "{3D83C6DB-ACEA-48F3-959F-145CCD2EE135}"
|
||||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GettingStartedGrpc", "samples\GettingStartedGrpc\GettingStartedGrpc.csproj", "{C3740DF1-18B1-4607-81E4-302F0308C848}"
|
||||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Core.Grpc.Tests", "test\Microsoft.AutoGen.Core.Grpc.Tests\Microsoft.AutoGen.Core.Grpc.Tests.csproj", "{23A028D3-5EB1-4FA0-9CD1-A1340B830579}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
Debug|Any CPU = Debug|Any CPU
|
||||
@ -306,6 +310,14 @@ Global
|
||||
{AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{C3740DF1-18B1-4607-81E4-302F0308C848}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{C3740DF1-18B1-4607-81E4-302F0308C848}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{C3740DF1-18B1-4607-81E4-302F0308C848}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{C3740DF1-18B1-4607-81E4-302F0308C848}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
EndGlobalSection
|
||||
GlobalSection(SolutionProperties) = preSolution
|
||||
HideSolutionNode = FALSE
|
||||
@ -359,6 +371,8 @@ Global
|
||||
{3D83C6DB-ACEA-48F3-959F-145CCD2EE135} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
|
||||
{AAD593FE-A49B-425E-A9FE-A0022CD25E3D} = {F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1}
|
||||
{F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6}
|
||||
{C3740DF1-18B1-4607-81E4-302F0308C848} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6}
|
||||
{23A028D3-5EB1-4FA0-9CD1-A1340B830579} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
|
||||
EndGlobalSection
|
||||
GlobalSection(ExtensibilityGlobals) = postSolution
|
||||
SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B}
|
||||
|
||||
34
dotnet/samples/GettingStartedGrpc/Checker.cs
Normal file
34
dotnet/samples/GettingStartedGrpc/Checker.cs
Normal file
@ -0,0 +1,34 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Checker.cs
|
||||
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
using Microsoft.AutoGen.Core;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using TerminationF = System.Func<int, bool>;
|
||||
|
||||
namespace GettingStartedGrpcSample;
|
||||
|
||||
[TypeSubscription("default")]
|
||||
public class Checker(
|
||||
AgentId id,
|
||||
IAgentRuntime runtime,
|
||||
IHostApplicationLifetime hostApplicationLifetime,
|
||||
TerminationF runUntilFunc
|
||||
) :
|
||||
BaseAgent(id, runtime, "Modifier", null),
|
||||
IHandle<Events.CountUpdate>
|
||||
{
|
||||
public async ValueTask HandleAsync(Events.CountUpdate item, MessageContext messageContext)
|
||||
{
|
||||
if (!runUntilFunc(item.NewCount))
|
||||
{
|
||||
Console.WriteLine($"\nChecker:\n{item.NewCount} passed the check, continue.");
|
||||
await this.PublishMessageAsync(new Events.CountMessage { Content = item.NewCount }, new TopicId("default"));
|
||||
}
|
||||
else
|
||||
{
|
||||
Console.WriteLine($"\nChecker:\n{item.NewCount} failed the check, stopping.");
|
||||
hostApplicationLifetime.StopApplication();
|
||||
}
|
||||
}
|
||||
}
|
||||
26
dotnet/samples/GettingStartedGrpc/GettingStartedGrpc.csproj
Normal file
26
dotnet/samples/GettingStartedGrpc/GettingStartedGrpc.csproj
Normal file
@ -0,0 +1,26 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
<TargetFramework>net8.0</TargetFramework>
|
||||
<RootNamespace>getting_started</RootNamespace>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Contracts\Microsoft.AutoGen.Contracts.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Core\Microsoft.AutoGen.Core.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Core.Grpc\Microsoft.AutoGen.Core.Grpc.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
|
||||
<ItemGroup>
|
||||
<Protobuf Include="message.proto" GrpcServices="Client" Link="Protos\message.proto" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Grpc.Tools" PrivateAssets="All" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
29
dotnet/samples/GettingStartedGrpc/Modifier.cs
Normal file
29
dotnet/samples/GettingStartedGrpc/Modifier.cs
Normal file
@ -0,0 +1,29 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Modifier.cs
|
||||
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
using Microsoft.AutoGen.Core;
|
||||
|
||||
using ModifyF = System.Func<int, int>;
|
||||
|
||||
namespace GettingStartedGrpcSample;
|
||||
|
||||
[TypeSubscription("default")]
|
||||
public class Modifier(
|
||||
AgentId id,
|
||||
IAgentRuntime runtime,
|
||||
ModifyF modifyFunc
|
||||
) :
|
||||
BaseAgent(id, runtime, "Modifier", null),
|
||||
IHandle<Events.CountMessage>
|
||||
{
|
||||
|
||||
public async ValueTask HandleAsync(Events.CountMessage item, MessageContext messageContext)
|
||||
{
|
||||
int newValue = modifyFunc(item.Content);
|
||||
Console.WriteLine($"\nModifier:\nModified {item.Content} to {newValue}");
|
||||
|
||||
var updateMessage = new Events.CountUpdate { NewCount = newValue };
|
||||
await this.PublishMessageAsync(updateMessage, topic: new TopicId("default"));
|
||||
}
|
||||
}
|
||||
36
dotnet/samples/GettingStartedGrpc/Program.cs
Normal file
36
dotnet/samples/GettingStartedGrpc/Program.cs
Normal file
@ -0,0 +1,36 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
using GettingStartedGrpcSample;
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
using Microsoft.AutoGen.Core;
|
||||
using Microsoft.AutoGen.Core.Grpc;
|
||||
using Microsoft.Extensions.DependencyInjection.Extensions;
|
||||
using ModifyF = System.Func<int, int>;
|
||||
using TerminationF = System.Func<int, bool>;
|
||||
|
||||
ModifyF modifyFunc = (int x) => x - 1;
|
||||
TerminationF runUntilFunc = (int x) =>
|
||||
{
|
||||
return x <= 1;
|
||||
};
|
||||
|
||||
AgentsAppBuilder appBuilder = new AgentsAppBuilder();
|
||||
appBuilder.AddGrpcAgentWorker("http://localhost:50051");
|
||||
|
||||
appBuilder.Services.TryAddSingleton(modifyFunc);
|
||||
appBuilder.Services.TryAddSingleton(runUntilFunc);
|
||||
|
||||
appBuilder.AddAgent<Checker>("Checker");
|
||||
appBuilder.AddAgent<Modifier>("Modifier");
|
||||
|
||||
var app = await appBuilder.BuildAsync();
|
||||
await app.StartAsync();
|
||||
|
||||
// Send the initial count to the agents app, running on the `local` runtime, and pass through the registered services via the application `builder`
|
||||
await app.PublishMessageAsync(new GettingStartedGrpcSample.Events.CountMessage
|
||||
{
|
||||
Content = 10
|
||||
}, new TopicId("default"));
|
||||
|
||||
// Run until application shutdown
|
||||
await app.WaitForShutdownAsync();
|
||||
11
dotnet/samples/GettingStartedGrpc/message.proto
Normal file
11
dotnet/samples/GettingStartedGrpc/message.proto
Normal file
@ -0,0 +1,11 @@
|
||||
syntax = "proto3";
|
||||
|
||||
option csharp_namespace = "GettingStartedGrpcSample.Events";
|
||||
|
||||
message CountMessage {
|
||||
int32 content = 1;
|
||||
}
|
||||
|
||||
message CountUpdate {
|
||||
int32 new_count = 1;
|
||||
}
|
||||
@ -0,0 +1,76 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentsAppBuilderExtensions.cs
|
||||
|
||||
using System.Diagnostics;
|
||||
using Grpc.Core;
|
||||
using Grpc.Net.Client.Configuration;
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
using Microsoft.AutoGen.Protobuf;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.DependencyInjection.Extensions;
|
||||
using Microsoft.Extensions.Logging;
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
public static class AgentsAppBuilderExtensions
|
||||
{
|
||||
private const string _defaultAgentServiceAddress = "http://localhost:53071";
|
||||
|
||||
// TODO: How do we ensure AddGrpcAgentWorker and UseInProcessRuntime are mutually exclusive?
|
||||
public static AgentsAppBuilder AddGrpcAgentWorker(this AgentsAppBuilder builder, string? agentServiceAddress = null)
|
||||
{
|
||||
builder.Services.AddGrpcClient<AgentRpc.AgentRpcClient>(options =>
|
||||
{
|
||||
options.Address = new Uri(agentServiceAddress ?? builder.Configuration.GetValue("AGENT_HOST", _defaultAgentServiceAddress));
|
||||
options.ChannelOptionsActions.Add(channelOptions =>
|
||||
{
|
||||
var loggerFactory = new LoggerFactory();
|
||||
if (Debugger.IsAttached)
|
||||
{
|
||||
channelOptions.HttpHandler = new SocketsHttpHandler
|
||||
{
|
||||
EnableMultipleHttp2Connections = false,
|
||||
KeepAlivePingDelay = TimeSpan.FromSeconds(200),
|
||||
KeepAlivePingTimeout = TimeSpan.FromSeconds(100),
|
||||
KeepAlivePingPolicy = HttpKeepAlivePingPolicy.Always
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
channelOptions.HttpHandler = new SocketsHttpHandler
|
||||
{
|
||||
EnableMultipleHttp2Connections = true,
|
||||
KeepAlivePingDelay = TimeSpan.FromSeconds(20),
|
||||
KeepAlivePingTimeout = TimeSpan.FromSeconds(10),
|
||||
KeepAlivePingPolicy = HttpKeepAlivePingPolicy.WithActiveRequests
|
||||
};
|
||||
}
|
||||
|
||||
var methodConfig = new MethodConfig
|
||||
{
|
||||
Names = { MethodName.Default },
|
||||
RetryPolicy = new RetryPolicy
|
||||
{
|
||||
MaxAttempts = 5,
|
||||
InitialBackoff = TimeSpan.FromSeconds(1),
|
||||
MaxBackoff = TimeSpan.FromSeconds(5),
|
||||
BackoffMultiplier = 1.5,
|
||||
RetryableStatusCodes = { StatusCode.Unavailable }
|
||||
}
|
||||
};
|
||||
|
||||
channelOptions.ServiceConfig = new() { MethodConfigs = { methodConfig } };
|
||||
channelOptions.ThrowOperationCanceledOnCancellation = true;
|
||||
});
|
||||
});
|
||||
|
||||
builder.Services.TryAddSingleton(DistributedContextPropagator.Current);
|
||||
builder.Services.AddSingleton<IAgentRuntime, GrpcAgentRuntime>();
|
||||
builder.Services.AddHostedService<GrpcAgentRuntime>(services =>
|
||||
{
|
||||
return (services.GetRequiredService<IAgentRuntime>() as GrpcAgentRuntime)!;
|
||||
});
|
||||
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,43 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// CloudEventExtensions.cs
|
||||
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
internal static class CloudEventExtensions
|
||||
{
|
||||
// Convert an ISubscrptionDefinition to a Protobuf Subscription
|
||||
internal static CloudEvent CreateCloudEvent(Google.Protobuf.WellKnownTypes.Any payload, TopicId topic, string dataType, AgentId? sender, string messageId)
|
||||
{
|
||||
var attributes = new Dictionary<string, CloudEvent.Types.CloudEventAttributeValue>
|
||||
{
|
||||
{
|
||||
Constants.DATA_CONTENT_TYPE_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = Constants.DATA_CONTENT_TYPE_PROTOBUF_VALUE }
|
||||
},
|
||||
{
|
||||
Constants.DATA_SCHEMA_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = dataType }
|
||||
},
|
||||
{
|
||||
Constants.MESSAGE_KIND_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = Constants.MESSAGE_KIND_VALUE_PUBLISH }
|
||||
}
|
||||
};
|
||||
|
||||
if (sender != null)
|
||||
{
|
||||
var senderNonNull = (AgentId)sender;
|
||||
attributes.Add(Constants.AGENT_SENDER_TYPE_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = senderNonNull.Type });
|
||||
attributes.Add(Constants.AGENT_SENDER_KEY_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = senderNonNull.Key });
|
||||
}
|
||||
|
||||
return new CloudEvent
|
||||
{
|
||||
ProtoData = payload,
|
||||
Type = topic.Type,
|
||||
Source = topic.Source,
|
||||
Id = messageId,
|
||||
Attributes = { attributes }
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
21
dotnet/src/Microsoft.AutoGen/Core.Grpc/Constants.cs
Normal file
21
dotnet/src/Microsoft.AutoGen/Core.Grpc/Constants.cs
Normal file
@ -0,0 +1,21 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Constants.cs
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
public static class Constants
|
||||
{
|
||||
public const string DATA_CONTENT_TYPE_PROTOBUF_VALUE = "application/x-protobuf";
|
||||
public const string DATA_CONTENT_TYPE_JSON_VALUE = "application/json";
|
||||
public const string DATA_CONTENT_TYPE_TEXT_VALUE = "text/plain";
|
||||
|
||||
public const string DATA_CONTENT_TYPE_ATTR = "datacontenttype";
|
||||
public const string DATA_SCHEMA_ATTR = "dataschema";
|
||||
public const string AGENT_SENDER_TYPE_ATTR = "agagentsendertype";
|
||||
public const string AGENT_SENDER_KEY_ATTR = "agagentsenderkey";
|
||||
|
||||
public const string MESSAGE_KIND_ATTR = "agmsgkind";
|
||||
public const string MESSAGE_KIND_VALUE_PUBLISH = "publish";
|
||||
public const string MESSAGE_KIND_VALUE_RPC_REQUEST = "rpc_request";
|
||||
public const string MESSAGE_KIND_VALUE_RPC_RESPONSE = "rpc_response";
|
||||
}
|
||||
430
dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
Normal file
430
dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
Normal file
@ -0,0 +1,430 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GrpcAgentRuntime.cs
|
||||
|
||||
using System.Collections.Concurrent;
|
||||
using Grpc.Core;
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
using Microsoft.AutoGen.Protobuf;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
internal sealed class AgentsContainer(IAgentRuntime hostingRuntime)
|
||||
{
|
||||
private readonly IAgentRuntime hostingRuntime = hostingRuntime;
|
||||
|
||||
private Dictionary<Contracts.AgentId, IHostableAgent> agentInstances = new();
|
||||
public Dictionary<string, ISubscriptionDefinition> Subscriptions = new();
|
||||
private Dictionary<AgentType, Func<Contracts.AgentId, IAgentRuntime, ValueTask<IHostableAgent>>> agentFactories = new();
|
||||
|
||||
public async ValueTask<IHostableAgent> EnsureAgentAsync(Contracts.AgentId agentId)
|
||||
{
|
||||
if (!this.agentInstances.TryGetValue(agentId, out IHostableAgent? agent))
|
||||
{
|
||||
if (!this.agentFactories.TryGetValue(agentId.Type, out Func<Contracts.AgentId, IAgentRuntime, ValueTask<IHostableAgent>>? factoryFunc))
|
||||
{
|
||||
throw new Exception($"Agent with name {agentId.Type} not found.");
|
||||
}
|
||||
|
||||
agent = await factoryFunc(agentId, this.hostingRuntime);
|
||||
this.agentInstances.Add(agentId, agent);
|
||||
}
|
||||
|
||||
return this.agentInstances[agentId];
|
||||
}
|
||||
|
||||
public async ValueTask<Contracts.AgentId> GetAgentAsync(Contracts.AgentId agentId, bool lazy = true)
|
||||
{
|
||||
if (!lazy)
|
||||
{
|
||||
await this.EnsureAgentAsync(agentId);
|
||||
}
|
||||
|
||||
return agentId;
|
||||
}
|
||||
|
||||
public AgentType RegisterAgentFactory(AgentType type, Func<Contracts.AgentId, IAgentRuntime, ValueTask<IHostableAgent>> factoryFunc)
|
||||
{
|
||||
if (this.agentFactories.ContainsKey(type))
|
||||
{
|
||||
throw new Exception($"Agent factory with type {type} already exists.");
|
||||
}
|
||||
|
||||
this.agentFactories.Add(type, factoryFunc);
|
||||
return type;
|
||||
}
|
||||
|
||||
public void AddSubscription(ISubscriptionDefinition subscription)
|
||||
{
|
||||
if (this.Subscriptions.ContainsKey(subscription.Id))
|
||||
{
|
||||
throw new Exception($"Subscription with id {subscription.Id} already exists.");
|
||||
}
|
||||
|
||||
this.Subscriptions.Add(subscription.Id, subscription);
|
||||
}
|
||||
|
||||
public bool RemoveSubscriptionAsync(string subscriptionId)
|
||||
{
|
||||
if (!this.Subscriptions.ContainsKey(subscriptionId))
|
||||
{
|
||||
throw new Exception($"Subscription with id {subscriptionId} does not exist.");
|
||||
}
|
||||
|
||||
return this.Subscriptions.Remove(subscriptionId);
|
||||
}
|
||||
|
||||
public HashSet<AgentType> RegisteredAgentTypes => this.agentFactories.Keys.ToHashSet();
|
||||
public IEnumerable<IHostableAgent> LiveAgents => this.agentInstances.Values;
|
||||
}
|
||||
|
||||
public sealed class GrpcAgentRuntime : IHostedService, IAgentRuntime, IMessageSink<Message>, IDisposable
|
||||
{
|
||||
public GrpcAgentRuntime(AgentRpc.AgentRpcClient client,
|
||||
IHostApplicationLifetime hostApplicationLifetime,
|
||||
IServiceProvider serviceProvider,
|
||||
ILogger<GrpcAgentRuntime> logger)
|
||||
{
|
||||
this._client = client;
|
||||
this._logger = logger;
|
||||
this._shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping);
|
||||
|
||||
this._messageRouter = new GrpcMessageRouter(client, this, _clientId, logger, this._shutdownCts.Token);
|
||||
this._agentsContainer = new AgentsContainer(this);
|
||||
|
||||
this.ServiceProvider = serviceProvider;
|
||||
}
|
||||
|
||||
// Request ID -> ResultSink<...>
|
||||
private readonly ConcurrentDictionary<string, ResultSink<object?>> _pendingRequests = new();
|
||||
|
||||
private readonly AgentRpc.AgentRpcClient _client;
|
||||
private readonly GrpcMessageRouter _messageRouter;
|
||||
|
||||
private readonly ILogger<GrpcAgentRuntime> _logger;
|
||||
private readonly CancellationTokenSource _shutdownCts;
|
||||
|
||||
private readonly AgentsContainer _agentsContainer;
|
||||
|
||||
public IServiceProvider ServiceProvider { get; }
|
||||
|
||||
private Guid _clientId = Guid.NewGuid();
|
||||
private CallOptions CallOptions
|
||||
{
|
||||
get
|
||||
{
|
||||
var metadata = new Metadata
|
||||
{
|
||||
{ "client-id", this._clientId.ToString() }
|
||||
};
|
||||
return new CallOptions(headers: metadata);
|
||||
}
|
||||
}
|
||||
|
||||
public IProtoSerializationRegistry SerializationRegistry { get; } = new ProtobufSerializationRegistry();
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
this._shutdownCts.Cancel();
|
||||
this._messageRouter.Dispose();
|
||||
}
|
||||
|
||||
private async ValueTask HandleRequest(RpcRequest request, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (request is null)
|
||||
{
|
||||
throw new InvalidOperationException("Request is null.");
|
||||
}
|
||||
if (request.Payload is null)
|
||||
{
|
||||
throw new InvalidOperationException("Payload is null.");
|
||||
}
|
||||
if (request.Target is null)
|
||||
{
|
||||
throw new InvalidOperationException("Target is null.");
|
||||
}
|
||||
|
||||
var agentId = request.Target;
|
||||
var agent = await this._agentsContainer.EnsureAgentAsync(agentId.FromProtobuf());
|
||||
|
||||
// Convert payload back to object
|
||||
var payload = request.Payload;
|
||||
var message = payload.ToObject(SerializationRegistry);
|
||||
|
||||
var messageContext = new MessageContext(request.RequestId, cancellationToken)
|
||||
{
|
||||
Sender = request.Source?.FromProtobuf() ?? null,
|
||||
Topic = null,
|
||||
IsRpc = true
|
||||
};
|
||||
|
||||
var result = await agent.OnMessageAsync(message, messageContext);
|
||||
|
||||
if (result is not null)
|
||||
{
|
||||
var response = new RpcResponse
|
||||
{
|
||||
RequestId = request.RequestId,
|
||||
Payload = result.ToPayload(SerializationRegistry)
|
||||
};
|
||||
|
||||
var responseMessage = new Message
|
||||
{
|
||||
Response = response
|
||||
};
|
||||
|
||||
await this._messageRouter.RouteMessageAsync(responseMessage, cancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
private async ValueTask HandleResponse(RpcResponse request, CancellationToken _ = default)
|
||||
{
|
||||
if (request is null)
|
||||
{
|
||||
throw new InvalidOperationException("Request is null.");
|
||||
}
|
||||
if (request.Payload is null)
|
||||
{
|
||||
throw new InvalidOperationException("Payload is null.");
|
||||
}
|
||||
if (request.RequestId is null)
|
||||
{
|
||||
throw new InvalidOperationException("RequestId is null.");
|
||||
}
|
||||
|
||||
if (_pendingRequests.TryRemove(request.RequestId, out var resultSink))
|
||||
{
|
||||
var payload = request.Payload;
|
||||
var message = payload.ToObject(SerializationRegistry);
|
||||
resultSink.SetResult(message);
|
||||
}
|
||||
}
|
||||
|
||||
private async ValueTask HandlePublish(CloudEvent evt, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (evt is null)
|
||||
{
|
||||
throw new InvalidOperationException("CloudEvent is null.");
|
||||
}
|
||||
if (evt.ProtoData is null)
|
||||
{
|
||||
throw new InvalidOperationException("ProtoData is null.");
|
||||
}
|
||||
if (evt.Attributes is null)
|
||||
{
|
||||
throw new InvalidOperationException("Attributes is null.");
|
||||
}
|
||||
|
||||
var topic = new TopicId(evt.Type, evt.Source);
|
||||
Contracts.AgentId? sender = null;
|
||||
if (evt.Attributes.TryGetValue(Constants.AGENT_SENDER_TYPE_ATTR, out var typeValue) && evt.Attributes.TryGetValue(Constants.AGENT_SENDER_KEY_ATTR, out var keyValue))
|
||||
{
|
||||
sender = new Contracts.AgentId
|
||||
{
|
||||
Type = typeValue.CeString,
|
||||
Key = keyValue.CeString
|
||||
};
|
||||
}
|
||||
|
||||
var messageId = evt.Id;
|
||||
var typeName = evt.Attributes[Constants.DATA_SCHEMA_ATTR].CeString;
|
||||
var serializer = SerializationRegistry.GetSerializer(typeName) ?? throw new Exception();
|
||||
var message = serializer.Deserialize(evt.ProtoData);
|
||||
|
||||
var messageContext = new MessageContext(messageId, cancellationToken)
|
||||
{
|
||||
Sender = sender,
|
||||
Topic = topic,
|
||||
IsRpc = false
|
||||
};
|
||||
|
||||
// Iterate over subscriptions values to find receiving agents
|
||||
foreach (var subscription in this._agentsContainer.Subscriptions.Values)
|
||||
{
|
||||
if (subscription.Matches(topic))
|
||||
{
|
||||
var recipient = subscription.MapToAgent(topic);
|
||||
var agent = await this._agentsContainer.EnsureAgentAsync(recipient);
|
||||
await agent.OnMessageAsync(message, messageContext);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public ValueTask StartAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
return this._messageRouter.StartAsync(cancellationToken);
|
||||
}
|
||||
|
||||
Task IHostedService.StartAsync(CancellationToken cancellationToken) => this._messageRouter.StartAsync(cancellationToken).AsTask();
|
||||
|
||||
public Task StopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
return this._messageRouter.StopAsync();
|
||||
}
|
||||
|
||||
public async ValueTask<object?> SendMessageAsync(object message, Contracts.AgentId recepient, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (!SerializationRegistry.Exists(message.GetType()))
|
||||
{
|
||||
SerializationRegistry.RegisterSerializer(message.GetType());
|
||||
}
|
||||
|
||||
var payload = message.ToPayload(SerializationRegistry);
|
||||
var request = new RpcRequest
|
||||
{
|
||||
RequestId = Guid.NewGuid().ToString(),
|
||||
Source = sender?.ToProtobuf() ?? null,
|
||||
Target = recepient.ToProtobuf(),
|
||||
Payload = payload,
|
||||
};
|
||||
|
||||
Message msg = new()
|
||||
{
|
||||
Request = request
|
||||
};
|
||||
|
||||
// Create a future that will be completed when the response is received
|
||||
var resultSink = new ResultSink<object?>();
|
||||
this._pendingRequests.TryAdd(request.RequestId, resultSink);
|
||||
await this._messageRouter.RouteMessageAsync(msg, cancellationToken);
|
||||
|
||||
return await resultSink.Future;
|
||||
}
|
||||
|
||||
public async ValueTask PublishMessageAsync(object message, TopicId topic, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (!SerializationRegistry.Exists(message.GetType()))
|
||||
{
|
||||
SerializationRegistry.RegisterSerializer(message.GetType());
|
||||
}
|
||||
var protoAny = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message);
|
||||
var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message.GetType());
|
||||
|
||||
var cloudEvent = CloudEventExtensions.CreateCloudEvent(protoAny, topic, typeName, sender, messageId ?? Guid.NewGuid().ToString());
|
||||
|
||||
Message msg = new()
|
||||
{
|
||||
CloudEvent = cloudEvent
|
||||
};
|
||||
|
||||
await this._messageRouter.RouteMessageAsync(msg, cancellationToken);
|
||||
}
|
||||
|
||||
public ValueTask<Contracts.AgentId> GetAgentAsync(Contracts.AgentId agentId, bool lazy = true) => this._agentsContainer.GetAgentAsync(agentId, lazy);
|
||||
|
||||
public ValueTask<Contracts.AgentId> GetAgentAsync(AgentType agentType, string key = "default", bool lazy = true)
|
||||
=> this.GetAgentAsync(new Contracts.AgentId(agentType, key), lazy);
|
||||
|
||||
public ValueTask<Contracts.AgentId> GetAgentAsync(string agent, string key = "default", bool lazy = true)
|
||||
=> this.GetAgentAsync(new Contracts.AgentId(agent, key), lazy);
|
||||
|
||||
public async ValueTask<IDictionary<string, object>> SaveAgentStateAsync(Contracts.AgentId agentId)
|
||||
{
|
||||
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
|
||||
return await agent.SaveStateAsync();
|
||||
}
|
||||
|
||||
public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary<string, object> state)
|
||||
{
|
||||
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
|
||||
await agent.LoadStateAsync(state);
|
||||
}
|
||||
|
||||
public async ValueTask<AgentMetadata> GetAgentMetadataAsync(Contracts.AgentId agentId)
|
||||
{
|
||||
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
|
||||
return agent.Metadata;
|
||||
}
|
||||
|
||||
public async ValueTask AddSubscriptionAsync(ISubscriptionDefinition subscription)
|
||||
{
|
||||
this._agentsContainer.AddSubscription(subscription);
|
||||
|
||||
var _ = await this._client.AddSubscriptionAsync(new AddSubscriptionRequest
|
||||
{
|
||||
Subscription = subscription.ToProtobuf()
|
||||
}, this.CallOptions);
|
||||
}
|
||||
|
||||
public async ValueTask RemoveSubscriptionAsync(string subscriptionId)
|
||||
{
|
||||
this._agentsContainer.RemoveSubscriptionAsync(subscriptionId);
|
||||
|
||||
await this._client.RemoveSubscriptionAsync(new RemoveSubscriptionRequest
|
||||
{
|
||||
Id = subscriptionId
|
||||
}, this.CallOptions);
|
||||
}
|
||||
|
||||
public async ValueTask<AgentType> RegisterAgentFactoryAsync(AgentType type, Func<Contracts.AgentId, IAgentRuntime, ValueTask<IHostableAgent>> factoryFunc)
|
||||
{
|
||||
this._agentsContainer.RegisterAgentFactory(type, factoryFunc);
|
||||
|
||||
await this._client.RegisterAgentAsync(new RegisterAgentTypeRequest
|
||||
{
|
||||
Type = type,
|
||||
}, this.CallOptions);
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
public ValueTask<AgentProxy> TryGetAgentProxyAsync(Contracts.AgentId agentId)
|
||||
{
|
||||
// TODO: Do we want to support getting remote agent proxies?
|
||||
return ValueTask.FromResult(new AgentProxy(agentId, this));
|
||||
}
|
||||
|
||||
public async ValueTask<IDictionary<string, object>> SaveStateAsync()
|
||||
{
|
||||
Dictionary<string, object> state = new();
|
||||
foreach (var agent in this._agentsContainer.LiveAgents)
|
||||
{
|
||||
state[agent.Id.ToString()] = await agent.SaveStateAsync();
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
public async ValueTask LoadStateAsync(IDictionary<string, object> state)
|
||||
{
|
||||
HashSet<AgentType> registeredTypes = this._agentsContainer.RegisteredAgentTypes;
|
||||
|
||||
foreach (var agentIdStr in state.Keys)
|
||||
{
|
||||
Contracts.AgentId agentId = Contracts.AgentId.FromStr(agentIdStr);
|
||||
if (state[agentIdStr] is not IDictionary<string, object> agentStateDict)
|
||||
{
|
||||
throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary<string, object>)}: {state[agentIdStr].GetType()}");
|
||||
}
|
||||
|
||||
if (registeredTypes.Contains(agentId.Type))
|
||||
{
|
||||
IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId);
|
||||
await agent.LoadStateAsync(agentStateDict);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async ValueTask OnMessageAsync(Message message, CancellationToken cancellation = default)
|
||||
{
|
||||
switch (message.MessageCase)
|
||||
{
|
||||
case Message.MessageOneofCase.Request:
|
||||
var request = message.Request ?? throw new InvalidOperationException("Request is null.");
|
||||
await HandleRequest(request);
|
||||
break;
|
||||
case Message.MessageOneofCase.Response:
|
||||
var response = message.Response ?? throw new InvalidOperationException("Response is null.");
|
||||
await HandleResponse(response);
|
||||
break;
|
||||
case Message.MessageOneofCase.CloudEvent:
|
||||
var cloudEvent = message.CloudEvent ?? throw new InvalidOperationException("CloudEvent is null.");
|
||||
await HandlePublish(cloudEvent);
|
||||
break;
|
||||
default:
|
||||
throw new InvalidOperationException($"Unexpected message '{message}'.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
296
dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcMessageRouter.cs
Normal file
296
dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcMessageRouter.cs
Normal file
@ -0,0 +1,296 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GrpcMessageRouter.cs
|
||||
|
||||
using System.Threading.Channels;
|
||||
using Grpc.Core;
|
||||
using Microsoft.AutoGen.Protobuf;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
// TODO: Consider whether we want to just reuse IHandle
|
||||
internal interface IMessageSink<TMessage>
|
||||
{
|
||||
public ValueTask OnMessageAsync(TMessage message, CancellationToken cancellation = default);
|
||||
}
|
||||
|
||||
internal sealed class AutoRestartChannel : IDisposable
|
||||
{
|
||||
private readonly object _channelLock = new();
|
||||
private readonly AgentRpc.AgentRpcClient _client;
|
||||
private readonly Guid _clientId;
|
||||
private readonly ILogger<GrpcAgentRuntime> _logger;
|
||||
private readonly CancellationTokenSource _shutdownCts;
|
||||
private AsyncDuplexStreamingCall<Message, Message>? _channel;
|
||||
|
||||
public AutoRestartChannel(AgentRpc.AgentRpcClient client,
|
||||
Guid clientId,
|
||||
ILogger<GrpcAgentRuntime> logger,
|
||||
CancellationToken shutdownCancellation = default)
|
||||
{
|
||||
_client = client;
|
||||
_clientId = clientId;
|
||||
_logger = logger;
|
||||
_shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation);
|
||||
}
|
||||
|
||||
public void EnsureConnected()
|
||||
{
|
||||
_logger.LogInformation("Connecting to gRPC endpoint " + Environment.GetEnvironmentVariable("AGENT_HOST"));
|
||||
|
||||
if (this.RecreateChannel(null) == null)
|
||||
{
|
||||
throw new Exception("Failed to connect to gRPC endpoint.");
|
||||
};
|
||||
}
|
||||
|
||||
public AsyncDuplexStreamingCall<Message, Message> StreamingCall
|
||||
{
|
||||
get
|
||||
{
|
||||
if (_channel is { } channel)
|
||||
{
|
||||
return channel;
|
||||
}
|
||||
|
||||
lock (_channelLock)
|
||||
{
|
||||
if (_channel is not null)
|
||||
{
|
||||
return _channel;
|
||||
}
|
||||
|
||||
return RecreateChannel(null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public AsyncDuplexStreamingCall<Message, Message> RecreateChannel() => RecreateChannel(this._channel);
|
||||
|
||||
private AsyncDuplexStreamingCall<Message, Message> RecreateChannel(AsyncDuplexStreamingCall<Message, Message>? ownedChannel)
|
||||
{
|
||||
// Make sure we are only re-creating the channel if it does not exit or we are the owner.
|
||||
if (_channel is null || _channel == ownedChannel)
|
||||
{
|
||||
lock (_channelLock)
|
||||
{
|
||||
if (_channel is null || _channel == ownedChannel)
|
||||
{
|
||||
var metadata = new Metadata
|
||||
{
|
||||
{ "client-id", _clientId.ToString() }
|
||||
};
|
||||
_channel?.Dispose();
|
||||
_channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token, headers: metadata);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return _channel;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
IDisposable? channelDisposable = Interlocked.Exchange(ref this._channel, null);
|
||||
channelDisposable?.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class GrpcMessageRouter(AgentRpc.AgentRpcClient client,
|
||||
IMessageSink<Message> incomingMessageSink,
|
||||
Guid clientId,
|
||||
ILogger<GrpcAgentRuntime> logger,
|
||||
CancellationToken shutdownCancellation = default) : IDisposable
|
||||
{
|
||||
private static readonly BoundedChannelOptions DefaultChannelOptions = new BoundedChannelOptions(1024)
|
||||
{
|
||||
AllowSynchronousContinuations = true,
|
||||
SingleReader = true,
|
||||
SingleWriter = false,
|
||||
FullMode = BoundedChannelFullMode.Wait
|
||||
};
|
||||
|
||||
private readonly ILogger<GrpcAgentRuntime> _logger = logger;
|
||||
|
||||
private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation);
|
||||
|
||||
private readonly IMessageSink<Message> _incomingMessageSink = incomingMessageSink;
|
||||
private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel
|
||||
// TODO: Enable a way to configure the channel options
|
||||
= Channel.CreateBounded<(Message, TaskCompletionSource)>(DefaultChannelOptions);
|
||||
|
||||
private readonly AutoRestartChannel _incomingMessageChannel = new AutoRestartChannel(client, clientId, logger, shutdownCancellation);
|
||||
|
||||
private Task? _readTask;
|
||||
private Task? _writeTask;
|
||||
|
||||
private async Task RunReadPump()
|
||||
{
|
||||
var cachedChannel = _incomingMessageChannel.StreamingCall;
|
||||
while (!_shutdownCts.Token.IsCancellationRequested)
|
||||
{
|
||||
try
|
||||
{
|
||||
await foreach (var message in cachedChannel.ResponseStream.ReadAllAsync(_shutdownCts.Token))
|
||||
{
|
||||
// next if message is null
|
||||
if (message == null)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
await _incomingMessageSink.OnMessageAsync(message, _shutdownCts.Token);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Time to shut down.
|
||||
break;
|
||||
}
|
||||
catch (Exception ex) when (!_shutdownCts.IsCancellationRequested)
|
||||
{
|
||||
_logger.LogError(ex, "Error reading from channel.");
|
||||
cachedChannel = this._incomingMessageChannel.RecreateChannel();
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Shutdown requested.
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async Task RunWritePump()
|
||||
{
|
||||
var cachedChannel = this._incomingMessageChannel.StreamingCall;
|
||||
var outboundMessages = _outboundMessagesChannel.Reader;
|
||||
while (!_shutdownCts.IsCancellationRequested)
|
||||
{
|
||||
(Message Message, TaskCompletionSource WriteCompletionSource) item = default;
|
||||
try
|
||||
{
|
||||
await outboundMessages.WaitToReadAsync().ConfigureAwait(false);
|
||||
|
||||
// Read the next message if we don't already have an unsent message
|
||||
// waiting to be sent.
|
||||
if (!outboundMessages.TryRead(out item))
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
while (!_shutdownCts.IsCancellationRequested)
|
||||
{
|
||||
await cachedChannel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false);
|
||||
item.WriteCompletionSource.TrySetResult();
|
||||
break;
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// Time to shut down.
|
||||
item.WriteCompletionSource?.TrySetCanceled();
|
||||
break;
|
||||
}
|
||||
catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable)
|
||||
{
|
||||
// we could not connect to the endpoint - most likely we have the wrong port or failed ssl
|
||||
// we need to let the user know what port we tried to connect to and then do backoff and retry
|
||||
_logger.LogError(ex, "Error connecting to GRPC endpoint {Endpoint}.", Environment.GetEnvironmentVariable("AGENT_HOST"));
|
||||
break;
|
||||
}
|
||||
catch (RpcException ex) when (ex.StatusCode == StatusCode.OK)
|
||||
{
|
||||
_logger.LogError(ex, "Error writing to channel, continuing (Status OK). {ex}", cachedChannel.ToString());
|
||||
break;
|
||||
}
|
||||
catch (Exception ex) when (!_shutdownCts.IsCancellationRequested)
|
||||
{
|
||||
item.WriteCompletionSource?.TrySetException(ex);
|
||||
_logger.LogError(ex, $"Error writing to channel.{ex}");
|
||||
cachedChannel = this._incomingMessageChannel.RecreateChannel();
|
||||
continue;
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Shutdown requested.
|
||||
item.WriteCompletionSource?.TrySetCanceled();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
while (outboundMessages.TryRead(out var item))
|
||||
{
|
||||
item.WriteCompletionSource.TrySetCanceled();
|
||||
}
|
||||
}
|
||||
|
||||
public ValueTask RouteMessageAsync(Message message, CancellationToken cancellation = default)
|
||||
{
|
||||
var tcs = new TaskCompletionSource();
|
||||
return _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellation);
|
||||
}
|
||||
|
||||
public ValueTask StartAsync(CancellationToken cancellation)
|
||||
{
|
||||
// TODO: Should we error out on a noncancellable token?
|
||||
|
||||
this._incomingMessageChannel.EnsureConnected();
|
||||
var didSuppress = false;
|
||||
|
||||
// Make sure we do not mistakenly flow the ExecutionContext into the background pumping tasks.
|
||||
if (!ExecutionContext.IsFlowSuppressed())
|
||||
{
|
||||
didSuppress = true;
|
||||
ExecutionContext.SuppressFlow();
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
_readTask = Task.Run(RunReadPump, cancellation);
|
||||
_writeTask = Task.Run(RunWritePump, cancellation);
|
||||
|
||||
return ValueTask.CompletedTask;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
return ValueTask.FromException(ex);
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (didSuppress)
|
||||
{
|
||||
ExecutionContext.RestoreFlow();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No point in returning a ValueTask here, since we are awaiting the two tasks
|
||||
public async Task StopAsync()
|
||||
{
|
||||
_shutdownCts.Cancel();
|
||||
|
||||
_outboundMessagesChannel.Writer.TryComplete();
|
||||
|
||||
List<Task> pendingTasks = new();
|
||||
if (_readTask is { } readTask)
|
||||
{
|
||||
pendingTasks.Add(readTask);
|
||||
}
|
||||
|
||||
if (_writeTask is { } writeTask)
|
||||
{
|
||||
pendingTasks.Add(writeTask);
|
||||
}
|
||||
|
||||
await Task.WhenAll(pendingTasks).ConfigureAwait(false);
|
||||
|
||||
this._incomingMessageChannel.Dispose();
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_outboundMessagesChannel.Writer.TryComplete();
|
||||
this._incomingMessageChannel.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,23 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IAgentMessageSerializer.cs
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
/// <summary>
|
||||
/// Interface for serializing and deserializing agent messages.
|
||||
/// </summary>
|
||||
public interface IAgentMessageSerializer
|
||||
{
|
||||
/// <summary>
|
||||
/// Serialize an agent message.
|
||||
/// </summary>
|
||||
/// <param name="message">The message to serialize.</param>
|
||||
/// <returns>The serialized message.</returns>
|
||||
Google.Protobuf.WellKnownTypes.Any Serialize(object message);
|
||||
|
||||
/// <summary>
|
||||
/// Deserialize an agent message.
|
||||
/// </summary>
|
||||
/// <param name="message">The message to deserialize.</param>
|
||||
/// <returns>The deserialized message.</returns>
|
||||
object Deserialize(Google.Protobuf.WellKnownTypes.Any message);
|
||||
}
|
||||
@ -0,0 +1,102 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IAgentRuntimeExtensions.cs
|
||||
|
||||
using System.Diagnostics;
|
||||
using Google.Protobuf.Collections;
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
using Microsoft.AutoGen.Protobuf;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using static Microsoft.AutoGen.Contracts.CloudEvent.Types;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
public static class GrpcAgentRuntimeExtensions
|
||||
{
|
||||
public static (string?, string?) GetTraceIdAndState(GrpcAgentRuntime runtime, IDictionary<string, string> metadata)
|
||||
{
|
||||
var dcp = runtime.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
|
||||
dcp.ExtractTraceIdAndState(metadata,
|
||||
static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
|
||||
{
|
||||
var metadata = (IDictionary<string, string>)carrier!;
|
||||
fieldValues = null;
|
||||
metadata.TryGetValue(fieldName, out fieldValue);
|
||||
},
|
||||
out var traceParent,
|
||||
out var traceState);
|
||||
return (traceParent, traceState);
|
||||
}
|
||||
public static (string?, string?) GetTraceIdAndState(GrpcAgentRuntime worker, MapField<string, CloudEventAttributeValue> metadata)
|
||||
{
|
||||
var dcp = worker.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
|
||||
dcp.ExtractTraceIdAndState(metadata,
|
||||
static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
|
||||
{
|
||||
var metadata = (MapField<string, CloudEventAttributeValue>)carrier!;
|
||||
fieldValues = null;
|
||||
metadata.TryGetValue(fieldName, out var ceValue);
|
||||
fieldValue = ceValue?.CeString;
|
||||
},
|
||||
out var traceParent,
|
||||
out var traceState);
|
||||
return (traceParent, traceState);
|
||||
}
|
||||
public static void Update(GrpcAgentRuntime worker, RpcRequest request, Activity? activity = null)
|
||||
{
|
||||
var dcp = worker.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
|
||||
dcp.Inject(activity, request.Metadata, static (carrier, key, value) =>
|
||||
{
|
||||
var metadata = (IDictionary<string, string>)carrier!;
|
||||
if (metadata.TryGetValue(key, out _))
|
||||
{
|
||||
metadata[key] = value;
|
||||
}
|
||||
else
|
||||
{
|
||||
metadata.Add(key, value);
|
||||
}
|
||||
});
|
||||
}
|
||||
public static void Update(GrpcAgentRuntime worker, CloudEvent cloudEvent, Activity? activity = null)
|
||||
{
|
||||
var dcp = worker.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
|
||||
dcp.Inject(activity, cloudEvent.Attributes, static (carrier, key, value) =>
|
||||
{
|
||||
var mapField = (MapField<string, CloudEventAttributeValue>)carrier!;
|
||||
if (mapField.TryGetValue(key, out var ceValue))
|
||||
{
|
||||
mapField[key] = new CloudEventAttributeValue { CeString = value };
|
||||
}
|
||||
else
|
||||
{
|
||||
mapField.Add(key, new CloudEventAttributeValue { CeString = value });
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public static IDictionary<string, string> ExtractMetadata(GrpcAgentRuntime worker, IDictionary<string, string> metadata)
|
||||
{
|
||||
var dcp = worker.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
|
||||
var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
|
||||
{
|
||||
var metadata = (IDictionary<string, string>)carrier!;
|
||||
fieldValues = null;
|
||||
metadata.TryGetValue(fieldName, out fieldValue);
|
||||
});
|
||||
|
||||
return baggage as IDictionary<string, string> ?? new Dictionary<string, string>();
|
||||
}
|
||||
public static IDictionary<string, string> ExtractMetadata(GrpcAgentRuntime worker, MapField<string, CloudEventAttributeValue> metadata)
|
||||
{
|
||||
var dcp = worker.ServiceProvider.GetRequiredService<DistributedContextPropagator>();
|
||||
var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
|
||||
{
|
||||
var metadata = (MapField<string, CloudEventAttributeValue>)carrier!;
|
||||
fieldValues = null;
|
||||
metadata.TryGetValue(fieldName, out var ceValue);
|
||||
fieldValue = ceValue?.CeString;
|
||||
});
|
||||
|
||||
return baggage as IDictionary<string, string> ?? new Dictionary<string, string>();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,10 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// IProtobufMessageSerializer.cs
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
public interface IProtobufMessageSerializer
|
||||
{
|
||||
Google.Protobuf.WellKnownTypes.Any Serialize(object input);
|
||||
object Deserialize(Google.Protobuf.WellKnownTypes.Any input);
|
||||
}
|
||||
@ -0,0 +1,27 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ISerializationRegistry.cs
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
public interface IProtoSerializationRegistry
|
||||
{
|
||||
/// <summary>
|
||||
/// Registers a serializer for the specified type.
|
||||
/// </summary>
|
||||
/// <param name="type">The type to register.</param>
|
||||
void RegisterSerializer(System.Type type) => RegisterSerializer(type, new ProtobufMessageSerializer(type));
|
||||
|
||||
void RegisterSerializer(System.Type type, IProtobufMessageSerializer serializer);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the serializer for the specified type.
|
||||
/// </summary>
|
||||
/// <param name="type">The type to get the serializer for.</param>
|
||||
/// <returns>The serializer for the specified type.</returns>
|
||||
IProtobufMessageSerializer? GetSerializer(System.Type type) => GetSerializer(TypeNameResolver.ResolveTypeName(type));
|
||||
IProtobufMessageSerializer? GetSerializer(string typeName);
|
||||
|
||||
ITypeNameResolver TypeNameResolver { get; }
|
||||
|
||||
bool Exists(System.Type type);
|
||||
}
|
||||
@ -0,0 +1,9 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ITypeNameResolver.cs
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
public interface ITypeNameResolver
|
||||
{
|
||||
string ResolveTypeName(Type input);
|
||||
}
|
||||
@ -14,7 +14,6 @@
|
||||
<ItemGroup>
|
||||
<Protobuf Include="..\..\..\..\protos\agent_worker.proto" GrpcServices="Client;Server" Link="Protos\agent_worker.proto" />
|
||||
<Protobuf Include="..\..\..\..\protos\cloudevent.proto" GrpcServices="Client;Server" Link="Protos\cloudevent.proto" />
|
||||
<Protobuf Include="..\..\..\..\protos\agent_events.proto" GrpcServices="Client;Server" Link="Protos\agent_events.proto" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ProtobufConversionExtensions.cs
|
||||
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
using Microsoft.AutoGen.Protobuf;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
public static class ProtobufConversionExtensions
|
||||
{
|
||||
// Convert an ISubscrptionDefinition to a Protobuf Subscription
|
||||
public static Subscription? ToProtobuf(this ISubscriptionDefinition subscriptionDefinition)
|
||||
{
|
||||
// Check if is a TypeSubscription
|
||||
if (subscriptionDefinition is Contracts.TypeSubscription typeSubscription)
|
||||
{
|
||||
return new Subscription
|
||||
{
|
||||
Id = typeSubscription.Id,
|
||||
TypeSubscription = new Protobuf.TypeSubscription
|
||||
{
|
||||
TopicType = typeSubscription.TopicType,
|
||||
AgentType = typeSubscription.AgentType
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Check if is a TypePrefixSubscription
|
||||
if (subscriptionDefinition is Contracts.TypePrefixSubscription typePrefixSubscription)
|
||||
{
|
||||
return new Subscription
|
||||
{
|
||||
Id = typePrefixSubscription.Id,
|
||||
TypePrefixSubscription = new Protobuf.TypePrefixSubscription
|
||||
{
|
||||
TopicTypePrefix = typePrefixSubscription.TopicTypePrefix,
|
||||
AgentType = typePrefixSubscription.AgentType
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
// Convert AgentId from Protobuf to AgentId
|
||||
public static Contracts.AgentId FromProtobuf(this Protobuf.AgentId agentId)
|
||||
{
|
||||
return new Contracts.AgentId(agentId.Type, agentId.Key);
|
||||
}
|
||||
|
||||
// Convert AgentId from AgentId to Protobuf
|
||||
public static Protobuf.AgentId ToProtobuf(this Contracts.AgentId agentId)
|
||||
{
|
||||
return new Protobuf.AgentId
|
||||
{
|
||||
Type = agentId.Type,
|
||||
Key = agentId.Key
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,46 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ProtobufMessageSerializer.cs
|
||||
|
||||
using Google.Protobuf;
|
||||
using Google.Protobuf.WellKnownTypes;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
/// <summary>
|
||||
/// Interface for serializing and deserializing agent messages.
|
||||
/// </summary>
|
||||
public class ProtobufMessageSerializer : IProtobufMessageSerializer
|
||||
{
|
||||
private System.Type _concreteType;
|
||||
|
||||
public ProtobufMessageSerializer(System.Type concreteType)
|
||||
{
|
||||
_concreteType = concreteType;
|
||||
}
|
||||
|
||||
public object Deserialize(Any message)
|
||||
{
|
||||
// Check if the concrete type is a proto IMessage
|
||||
if (typeof(IMessage).IsAssignableFrom(_concreteType))
|
||||
{
|
||||
var nameOfMethod = nameof(Any.Unpack);
|
||||
var result = message.GetType().GetMethods().Where(m => m.Name == nameOfMethod && m.IsGenericMethod).First().MakeGenericMethod(_concreteType).Invoke(message, null);
|
||||
return result as IMessage ?? throw new ArgumentException("Failed to deserialize", nameof(message));
|
||||
}
|
||||
|
||||
// Raise an exception if the concrete type is not a proto IMessage
|
||||
throw new ArgumentException("Concrete type must be a proto IMessage", nameof(_concreteType));
|
||||
}
|
||||
|
||||
public Any Serialize(object message)
|
||||
{
|
||||
// Check if message is a proto IMessage
|
||||
if (message is IMessage protoMessage)
|
||||
{
|
||||
return Any.Pack(protoMessage);
|
||||
}
|
||||
|
||||
// Raise an exception if the message is not a proto IMessage
|
||||
throw new ArgumentException("Message must be a proto IMessage", nameof(message));
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,37 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ProtobufSerializationRegistry.cs
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
public class ProtobufSerializationRegistry : IProtoSerializationRegistry
|
||||
{
|
||||
private readonly Dictionary<string, IProtobufMessageSerializer> _serializers
|
||||
= new Dictionary<string, IProtobufMessageSerializer>();
|
||||
|
||||
public ITypeNameResolver TypeNameResolver => new ProtobufTypeNameResolver();
|
||||
|
||||
public bool Exists(Type type)
|
||||
{
|
||||
return _serializers.ContainsKey(TypeNameResolver.ResolveTypeName(type));
|
||||
}
|
||||
|
||||
public IProtobufMessageSerializer? GetSerializer(Type type)
|
||||
{
|
||||
return GetSerializer(TypeNameResolver.ResolveTypeName(type));
|
||||
}
|
||||
|
||||
public IProtobufMessageSerializer? GetSerializer(string typeName)
|
||||
{
|
||||
_serializers.TryGetValue(typeName, out var serializer);
|
||||
return serializer;
|
||||
}
|
||||
|
||||
public void RegisterSerializer(Type type, IProtobufMessageSerializer serializer)
|
||||
{
|
||||
if (_serializers.ContainsKey(TypeNameResolver.ResolveTypeName(type)))
|
||||
{
|
||||
throw new InvalidOperationException($"Serializer already registered for {type.FullName}");
|
||||
}
|
||||
_serializers[TypeNameResolver.ResolveTypeName(type)] = serializer;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,23 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// ProtobufTypeNameResolver.cs
|
||||
|
||||
using Google.Protobuf;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
public class ProtobufTypeNameResolver : ITypeNameResolver
|
||||
{
|
||||
public string ResolveTypeName(Type input)
|
||||
{
|
||||
if (typeof(IMessage).IsAssignableFrom(input))
|
||||
{
|
||||
// TODO: Consider changing this to avoid instantiation...
|
||||
var protoMessage = (IMessage?)Activator.CreateInstance(input) ?? throw new InvalidOperationException($"Failed to create instance of {input.FullName}");
|
||||
return protoMessage.Descriptor.FullName;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException("Input must be a protobuf message.");
|
||||
}
|
||||
}
|
||||
}
|
||||
42
dotnet/src/Microsoft.AutoGen/Core.Grpc/RpcExtensions.cs
Normal file
42
dotnet/src/Microsoft.AutoGen/Core.Grpc/RpcExtensions.cs
Normal file
@ -0,0 +1,42 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// RpcExtensions.cs
|
||||
|
||||
using Google.Protobuf;
|
||||
using Microsoft.AutoGen.Protobuf;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc;
|
||||
|
||||
internal static class RpcExtensions
|
||||
{
|
||||
|
||||
public static Payload ToPayload(this object message, IProtoSerializationRegistry serializationRegistry)
|
||||
{
|
||||
if (!serializationRegistry.Exists(message.GetType()))
|
||||
{
|
||||
serializationRegistry.RegisterSerializer(message.GetType());
|
||||
}
|
||||
var rpcMessage = (serializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message);
|
||||
|
||||
var typeName = serializationRegistry.TypeNameResolver.ResolveTypeName(message.GetType());
|
||||
const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf";
|
||||
|
||||
// Protobuf any to byte array
|
||||
Payload payload = new()
|
||||
{
|
||||
DataType = typeName,
|
||||
DataContentType = PAYLOAD_DATA_CONTENT_TYPE,
|
||||
Data = rpcMessage.ToByteString()
|
||||
};
|
||||
|
||||
return payload;
|
||||
}
|
||||
|
||||
public static object ToObject(this Payload payload, IProtoSerializationRegistry serializationRegistry)
|
||||
{
|
||||
var typeName = payload.DataType;
|
||||
var data = payload.Data;
|
||||
var serializer = serializationRegistry.GetSerializer(typeName) ?? throw new Exception();
|
||||
var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data);
|
||||
return serializer.Deserialize(any);
|
||||
}
|
||||
}
|
||||
@ -4,6 +4,7 @@
|
||||
using System.Diagnostics;
|
||||
using System.Reflection;
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
|
||||
@ -21,6 +22,7 @@ public class AgentsAppBuilder
|
||||
}
|
||||
|
||||
public IServiceCollection Services => this.builder.Services;
|
||||
public IConfiguration Configuration => this.builder.Configuration;
|
||||
|
||||
public void AddAgentsFromAssemblies()
|
||||
{
|
||||
|
||||
@ -1,263 +1,152 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// AgentGrpcTests.cs
|
||||
|
||||
using System.Collections.Concurrent;
|
||||
using System.Text.Json;
|
||||
using FluentAssertions;
|
||||
using Google.Protobuf.Reflection;
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
// using Microsoft.AutoGen.Core.Tests;
|
||||
using Microsoft.AutoGen.Core.Grpc.Tests.Protobuf;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Xunit;
|
||||
using static Microsoft.AutoGen.Core.Grpc.Tests.AgentGrpcTests;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc.Tests;
|
||||
|
||||
[Trait("Category", "UnitV2")]
|
||||
[Trait("Category", "GRPC")]
|
||||
public class AgentGrpcTests
|
||||
{
|
||||
/// <summary>
|
||||
/// Verify that if the agent is not initialized via AgentWorker, it should throw the correct exception.
|
||||
/// </summary>
|
||||
/// <returns>void</returns>
|
||||
[Fact]
|
||||
public async Task Agent_ShouldThrowException_WhenNotInitialized()
|
||||
public async Task AgentShouldNotReceiveMessagesWhenNotSubscribedTest()
|
||||
{
|
||||
using var runtime = new GrpcRuntime();
|
||||
var (_, agent) = runtime.Start(false); // Do not initialize
|
||||
var fixture = new GrpcAgentRuntimeFixture();
|
||||
var runtime = (GrpcAgentRuntime)await fixture.Start();
|
||||
|
||||
// Expect an exception when calling AddSubscriptionAsync because the agent is uninitialized
|
||||
await Assert.ThrowsAsync<UninitializedAgentWorker.AgentInitalizedIncorrectlyException>(
|
||||
async () => await agent.AddSubscriptionAsync("TestEvent")
|
||||
);
|
||||
}
|
||||
Logger<BaseAgent> logger = new(new LoggerFactory());
|
||||
TestProtobufAgent agent = null!;
|
||||
|
||||
/// <summary>
|
||||
/// validate that the agent is initialized correctly with implicit subs
|
||||
/// </summary>
|
||||
/// <returns>void</returns>
|
||||
[Fact]
|
||||
public async Task Agent_ShouldInitializeCorrectly()
|
||||
{
|
||||
using var runtime = new GrpcRuntime();
|
||||
var (worker, agent) = runtime.Start();
|
||||
Assert.Equal(nameof(GrpcAgentRuntime), worker.GetType().Name);
|
||||
await Task.Delay(5000);
|
||||
var subscriptions = await agent.GetSubscriptionsAsync();
|
||||
Assert.Equal(2, subscriptions.Count);
|
||||
}
|
||||
/// <summary>
|
||||
/// Test AddSubscriptionAsync method
|
||||
/// </summary>
|
||||
/// <returns>void</returns>
|
||||
[Fact]
|
||||
public async Task SubscribeAsync_UnsubscribeAsync_and_GetSubscriptionsTest()
|
||||
{
|
||||
using var runtime = new GrpcRuntime();
|
||||
var (_, agent) = runtime.Start();
|
||||
await agent.AddSubscriptionAsync("TestEvent");
|
||||
await Task.Delay(100);
|
||||
var subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true);
|
||||
var found = false;
|
||||
foreach (var subscription in subscriptions)
|
||||
await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) =>
|
||||
{
|
||||
if (subscription.TypeSubscription.TopicType == "TestEvent")
|
||||
{
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
Assert.True(found);
|
||||
await agent.RemoveSubscriptionAsync("TestEvent").ConfigureAwait(true);
|
||||
await Task.Delay(1000);
|
||||
subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true);
|
||||
found = false;
|
||||
foreach (var subscription in subscriptions)
|
||||
{
|
||||
if (subscription.TypeSubscription.TopicType == "TestEvent")
|
||||
{
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
Assert.False(found);
|
||||
}
|
||||
agent = new TestProtobufAgent(id, runtime, logger);
|
||||
return await ValueTask.FromResult(agent);
|
||||
});
|
||||
|
||||
/// <summary>
|
||||
/// Test StoreAsync and ReadAsync methods
|
||||
/// </summary>
|
||||
/// <returns>void</returns>
|
||||
[Fact]
|
||||
public async Task StoreAsync_and_ReadAsyncTest()
|
||||
{
|
||||
using var runtime = new GrpcRuntime();
|
||||
var (_, agent) = runtime.Start();
|
||||
Dictionary<string, string> state = new()
|
||||
{
|
||||
{ "testdata", "Active" }
|
||||
};
|
||||
await agent.StoreAsync(new AgentState
|
||||
{
|
||||
AgentId = agent.AgentId,
|
||||
TextData = JsonSerializer.Serialize(state)
|
||||
}).ConfigureAwait(true);
|
||||
var readState = await agent.ReadAsync<AgentState>(agent.AgentId).ConfigureAwait(true);
|
||||
var read = JsonSerializer.Deserialize<Dictionary<string, string>>(readState.TextData) ?? new Dictionary<string, string> { { "data", "No state data found" } };
|
||||
read.TryGetValue("testdata", out var value);
|
||||
Assert.Equal("Active", value);
|
||||
}
|
||||
// Ensure the agent is actually created
|
||||
AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false);
|
||||
|
||||
// Validate agent ID
|
||||
agentId.Should().Be(agent.Id, "Agent ID should match the registered agent");
|
||||
|
||||
/// <summary>
|
||||
/// Test PublishMessageAsync method and ReceiveMessage method
|
||||
/// </summary>
|
||||
/// <returns>void</returns>
|
||||
[Fact]
|
||||
public async Task PublishMessageAsync_and_ReceiveMessageTest()
|
||||
{
|
||||
using var runtime = new GrpcRuntime();
|
||||
var (_, agent) = runtime.Start();
|
||||
var topicType = "TestTopic";
|
||||
await agent.AddSubscriptionAsync(topicType).ConfigureAwait(true);
|
||||
var subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true);
|
||||
var found = false;
|
||||
foreach (var subscription in subscriptions)
|
||||
{
|
||||
if (subscription.TypeSubscription.TopicType == topicType)
|
||||
{
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
Assert.True(found);
|
||||
await agent.PublishMessageAsync(new TextMessage()
|
||||
{
|
||||
Source = topicType,
|
||||
TextMessage_ = "buffer"
|
||||
}, topicType).ConfigureAwait(true);
|
||||
await Task.Delay(100);
|
||||
Assert.True(TestAgent.ReceivedMessages.ContainsKey(topicType));
|
||||
runtime.Stop();
|
||||
|
||||
await runtime.PublishMessageAsync(new Protobuf.TextMessage { Source = topicType, Content = "test" }, new TopicId(topicType)).ConfigureAwait(true);
|
||||
|
||||
agent.ReceivedMessages.Any().Should().BeFalse("Agent should not receive messages when not subscribed.");
|
||||
fixture.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeCorrectHandler()
|
||||
public async Task AgentShouldReceiveMessagesWhenSubscribedTest()
|
||||
{
|
||||
var agent = new TestAgent(new AgentsMetadata(TypeRegistry.Empty, new Dictionary<string, Type>(), new Dictionary<Type, HashSet<string>>(), new Dictionary<Type, HashSet<string>>()), new Logger<Agent>(new LoggerFactory()));
|
||||
var fixture = new GrpcAgentRuntimeFixture();
|
||||
var runtime = (GrpcAgentRuntime)await fixture.Start();
|
||||
|
||||
await agent.HandleObjectAsync("hello world");
|
||||
await agent.HandleObjectAsync(42);
|
||||
Logger<BaseAgent> logger = new(new LoggerFactory());
|
||||
SubscribedProtobufAgent agent = null!;
|
||||
|
||||
agent.ReceivedItems.Should().HaveCount(2);
|
||||
agent.ReceivedItems[0].Should().Be("hello world");
|
||||
agent.ReceivedItems[1].Should().Be(42);
|
||||
await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) =>
|
||||
{
|
||||
agent = new SubscribedProtobufAgent(id, runtime, logger);
|
||||
return await ValueTask.FromResult(agent);
|
||||
});
|
||||
|
||||
// Ensure the agent is actually created
|
||||
AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false);
|
||||
|
||||
// Validate agent ID
|
||||
agentId.Should().Be(agent.Id, "Agent ID should match the registered agent");
|
||||
|
||||
await runtime.RegisterImplicitAgentSubscriptionsAsync<SubscribedProtobufAgent>("MyAgent");
|
||||
|
||||
var topicType = "TestTopic";
|
||||
|
||||
await runtime.PublishMessageAsync(new TextMessage { Source = topicType, Content = "test" }, new TopicId(topicType)).ConfigureAwait(true);
|
||||
|
||||
// Wait for the message to be processed
|
||||
await Task.Delay(100);
|
||||
|
||||
agent.ReceivedMessages.Any().Should().BeTrue("Agent should receive messages when subscribed.");
|
||||
fixture.Dispose();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The test agent is a simple agent that is used for testing purposes.
|
||||
/// </summary>
|
||||
public class TestAgent(
|
||||
[FromKeyedServices("AgentsMetadata")] AgentsMetadata eventTypes,
|
||||
Logger<Agent>? logger = null) : Agent(eventTypes, logger), IHandle<TextMessage>
|
||||
[Fact]
|
||||
public async Task SendMessageAsyncShouldReturnResponseTest()
|
||||
{
|
||||
public Task Handle(TextMessage item, CancellationToken cancellationToken = default)
|
||||
{
|
||||
ReceivedMessages[item.Source] = item.TextMessage_;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
public Task Handle(string item)
|
||||
{
|
||||
ReceivedItems.Add(item);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
public Task Handle(int item)
|
||||
{
|
||||
ReceivedItems.Add(item);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
public List<object> ReceivedItems { get; private set; } = [];
|
||||
// Arrange
|
||||
var fixture = new GrpcAgentRuntimeFixture();
|
||||
var runtime = (GrpcAgentRuntime)await fixture.Start();
|
||||
|
||||
/// <summary>
|
||||
/// Key: source
|
||||
/// Value: message
|
||||
/// </summary>
|
||||
public static ConcurrentDictionary<string, object> ReceivedMessages { get; private set; } = new();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// GrpcRuntimeFixture - provides a fixture for the agent runtime.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This fixture is used to provide a runtime for the agent tests.
|
||||
/// However, it is shared between tests. So operations from one test can affect another.
|
||||
/// </remarks>
|
||||
public sealed class GrpcRuntime : IDisposable
|
||||
{
|
||||
public IHost Client { get; private set; }
|
||||
public IHost? AppHost { get; private set; }
|
||||
|
||||
public GrpcRuntime()
|
||||
{
|
||||
Environment.SetEnvironmentVariable("ASPNETCORE_ENVIRONMENT", "Development");
|
||||
AppHost = Host.CreateDefaultBuilder().Build();
|
||||
Client = Host.CreateDefaultBuilder().Build();
|
||||
}
|
||||
|
||||
private static int GetAvailablePort()
|
||||
{
|
||||
using var listener = new System.Net.Sockets.TcpListener(System.Net.IPAddress.Loopback, 0);
|
||||
listener.Start();
|
||||
int port = ((System.Net.IPEndPoint)listener.LocalEndpoint).Port;
|
||||
listener.Stop();
|
||||
return port;
|
||||
}
|
||||
|
||||
private static async Task<IHost> StartClientAsync()
|
||||
{
|
||||
return await AgentsApp.StartAsync().ConfigureAwait(false);
|
||||
}
|
||||
private static async Task<IHost> StartAppHostAsync()
|
||||
{
|
||||
return await Microsoft.AutoGen.Runtime.Grpc.Host.StartAsync(local: false, useGrpc: true).ConfigureAwait(false);
|
||||
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Start - gets a new port and starts fresh instances
|
||||
/// </summary>
|
||||
public (IAgentRuntime, TestAgent) Start(bool initialize = true)
|
||||
{
|
||||
int port = GetAvailablePort(); // Get a new port per test run
|
||||
|
||||
// Update environment variables so each test runs independently
|
||||
Environment.SetEnvironmentVariable("ASPNETCORE_HTTPS_PORTS", port.ToString());
|
||||
Environment.SetEnvironmentVariable("AGENT_HOST", $"https://localhost:{port}");
|
||||
|
||||
AppHost = StartAppHostAsync().GetAwaiter().GetResult();
|
||||
Client = StartClientAsync().GetAwaiter().GetResult();
|
||||
|
||||
var agent = ActivatorUtilities.CreateInstance<TestAgent>(Client.Services);
|
||||
var worker = Client.Services.GetRequiredService<IAgentRuntime>();
|
||||
if (initialize)
|
||||
{
|
||||
Agent.Initialize(worker, agent);
|
||||
}
|
||||
|
||||
return (worker, agent);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Stop - stops the agent and ensures cleanup
|
||||
/// </summary>
|
||||
public void Stop()
|
||||
{
|
||||
Client?.StopAsync().GetAwaiter().GetResult();
|
||||
AppHost?.StopAsync().GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Dispose - Ensures cleanup after each test
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
Stop();
|
||||
Logger<BaseAgent> logger = new(new LoggerFactory());
|
||||
await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) => await ValueTask.FromResult(new TestProtobufAgent(id, runtime, logger)));
|
||||
var agentId = new AgentId("MyAgent", "default");
|
||||
var response = await runtime.SendMessageAsync(new RpcTextMessage { Source = "TestTopic", Content = "Request" }, agentId);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(response);
|
||||
Assert.IsType<RpcTextMessage>(response);
|
||||
if (response is RpcTextMessage responseString)
|
||||
{
|
||||
Assert.Equal("Request", responseString.Content);
|
||||
}
|
||||
fixture.Dispose();
|
||||
}
|
||||
|
||||
public class ReceiverAgent(AgentId id,
|
||||
IAgentRuntime runtime) : BaseAgent(id, runtime, "Receiver Agent", null),
|
||||
IHandle<TextMessage>
|
||||
{
|
||||
public ValueTask HandleAsync(TextMessage item, MessageContext messageContext)
|
||||
{
|
||||
ReceivedItems.Add(item.Content);
|
||||
return ValueTask.CompletedTask;
|
||||
}
|
||||
|
||||
public List<string> ReceivedItems { get; private set; } = [];
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SubscribeAsyncRemoveSubscriptionAsyncAndGetSubscriptionsTest()
|
||||
{
|
||||
var fixture = new GrpcAgentRuntimeFixture();
|
||||
var runtime = (GrpcAgentRuntime)await fixture.Start();
|
||||
ReceiverAgent? agent = null;
|
||||
await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) =>
|
||||
{
|
||||
agent = new ReceiverAgent(id, runtime);
|
||||
return await ValueTask.FromResult(agent);
|
||||
});
|
||||
|
||||
Assert.Null(agent);
|
||||
await runtime.GetAgentAsync("MyAgent", lazy: false);
|
||||
Assert.NotNull(agent);
|
||||
Assert.True(agent.ReceivedItems.Count == 0);
|
||||
|
||||
var topicTypeName = "TestTopic";
|
||||
await runtime.PublishMessageAsync(new TextMessage { Source = "topic", Content = "test" }, new TopicId(topicTypeName));
|
||||
await Task.Delay(100);
|
||||
|
||||
Assert.True(agent.ReceivedItems.Count == 0);
|
||||
|
||||
var subscription = new TypeSubscription(topicTypeName, "MyAgent");
|
||||
await runtime.AddSubscriptionAsync(subscription);
|
||||
|
||||
await runtime.PublishMessageAsync(new TextMessage { Source = "topic", Content = "test" }, new TopicId(topicTypeName));
|
||||
await Task.Delay(100);
|
||||
|
||||
Assert.True(agent.ReceivedItems.Count == 1);
|
||||
Assert.Equal("test", agent.ReceivedItems[0]);
|
||||
|
||||
await runtime.RemoveSubscriptionAsync(subscription.Id);
|
||||
await runtime.PublishMessageAsync(new TextMessage { Source = "topic", Content = "test" }, new TopicId(topicTypeName));
|
||||
await Task.Delay(100);
|
||||
|
||||
Assert.True(agent.ReceivedItems.Count == 1);
|
||||
fixture.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,83 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GrpcAgentRuntimeFixture.cs
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
// using Microsoft.AutoGen.Core.Tests;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc.Tests;
|
||||
/// <summary>
|
||||
/// Fixture for setting up the gRPC agent runtime for testing.
|
||||
/// </summary>
|
||||
public sealed class GrpcAgentRuntimeFixture : IDisposable
|
||||
{
|
||||
/// the gRPC agent runtime.
|
||||
public AgentsApp? Client { get; private set; }
|
||||
/// mock server for testing.
|
||||
public WebApplication? Server { get; private set; }
|
||||
|
||||
public GrpcAgentRuntimeFixture()
|
||||
{
|
||||
}
|
||||
/// <summary>
|
||||
/// Start - gets a new port and starts fresh instances
|
||||
/// </summary>
|
||||
public async Task<IAgentRuntime> Start(bool initialize = true)
|
||||
{
|
||||
int port = GetAvailablePort(); // Get a new port per test run
|
||||
|
||||
// Update environment variables so each test runs independently
|
||||
Environment.SetEnvironmentVariable("ASPNETCORE_HTTPS_PORTS", port.ToString());
|
||||
Environment.SetEnvironmentVariable("AGENT_HOST", $"https://localhost:{port}");
|
||||
Environment.SetEnvironmentVariable("ASPNETCORE_ENVIRONMENT", "Development");
|
||||
Server = ServerBuilder().Result;
|
||||
await Server.StartAsync().ConfigureAwait(true);
|
||||
Client = ClientBuilder().Result;
|
||||
await Client.StartAsync().ConfigureAwait(true);
|
||||
|
||||
var worker = Client.Services.GetRequiredService<IAgentRuntime>();
|
||||
|
||||
return (worker);
|
||||
}
|
||||
private static async Task<AgentsApp> ClientBuilder()
|
||||
{
|
||||
var appBuilder = new AgentsAppBuilder();
|
||||
appBuilder.AddGrpcAgentWorker();
|
||||
appBuilder.AddAgent<TestProtobufAgent>("TestAgent");
|
||||
return await appBuilder.BuildAsync();
|
||||
}
|
||||
private static async Task<WebApplication> ServerBuilder()
|
||||
{
|
||||
var builder = WebApplication.CreateBuilder();
|
||||
builder.Services.AddGrpc();
|
||||
var app = builder.Build();
|
||||
app.MapGrpcService<GrpcAgentServiceFixture>();
|
||||
return app;
|
||||
}
|
||||
private static int GetAvailablePort()
|
||||
{
|
||||
using var listener = new System.Net.Sockets.TcpListener(System.Net.IPAddress.Loopback, 0);
|
||||
listener.Start();
|
||||
int port = ((System.Net.IPEndPoint)listener.LocalEndpoint).Port;
|
||||
listener.Stop();
|
||||
return port;
|
||||
}
|
||||
/// <summary>
|
||||
/// Stop - stops the agent and ensures cleanup
|
||||
/// </summary>
|
||||
public void Stop()
|
||||
{
|
||||
(Client as IHost)?.StopAsync(TimeSpan.FromSeconds(30)).GetAwaiter().GetResult();
|
||||
Server?.StopAsync().GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Dispose - Ensures cleanup after each test
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
Stop();
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,34 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// GrpcAgentServiceFixture.cs
|
||||
using Grpc.Core;
|
||||
using Microsoft.AutoGen.Protobuf;
|
||||
namespace Microsoft.AutoGen.Core.Grpc.Tests;
|
||||
|
||||
/// <summary>
|
||||
/// This fixture is largely just a loopback as we are testing the client side logic of the GrpcAgentRuntime in isolation from the rest of the system.
|
||||
/// </summary>
|
||||
public sealed class GrpcAgentServiceFixture() : AgentRpc.AgentRpcBase
|
||||
{
|
||||
public override async Task OpenChannel(IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
|
||||
{
|
||||
try
|
||||
{
|
||||
var workerProcess = new TestGrpcWorkerConnection(requestStream, responseStream, context);
|
||||
await workerProcess.Connect().ConfigureAwait(true);
|
||||
}
|
||||
catch
|
||||
{
|
||||
if (context.CancellationToken.IsCancellationRequested)
|
||||
{
|
||||
return;
|
||||
}
|
||||
throw;
|
||||
}
|
||||
}
|
||||
public override async Task<GetStateResponse> GetState(AgentId request, ServerCallContext context) => new GetStateResponse { AgentState = new AgentState { AgentId = request } };
|
||||
public override async Task<SaveStateResponse> SaveState(AgentState request, ServerCallContext context) => new SaveStateResponse { };
|
||||
public override async Task<AddSubscriptionResponse> AddSubscription(AddSubscriptionRequest request, ServerCallContext context) => new AddSubscriptionResponse { };
|
||||
public override async Task<RemoveSubscriptionResponse> RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) => new RemoveSubscriptionResponse { };
|
||||
public override async Task<GetSubscriptionsResponse> GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) => new GetSubscriptionsResponse { };
|
||||
public override async Task<RegisterAgentTypeResponse> RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context) => new RegisterAgentTypeResponse { };
|
||||
}
|
||||
@ -10,8 +10,17 @@
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Core\Microsoft.AutoGen.Core.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.AutoGen\Core.Grpc\Microsoft.AutoGen.Core.Grpc.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.AutoGen\AgentHost\Microsoft.AutoGen.AgentHost.csproj" />
|
||||
<PackageReference Include="Microsoft.Extensions.Hosting" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Protobuf Include="./messages.proto" GrpcServices="Client;Server" Link="Protos\messages.proto" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Grpc.AspNetCore" />
|
||||
<PackageReference Include="Grpc.Net.ClientFactory" />
|
||||
<PackageReference Include="Grpc.Tools" PrivateAssets="All" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
||||
@ -0,0 +1,134 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// TestGrpcWorkerConnection.cs
|
||||
|
||||
using System.Threading.Channels;
|
||||
using Grpc.Core;
|
||||
using Microsoft.AutoGen.Protobuf;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc.Tests;
|
||||
|
||||
internal sealed class TestGrpcWorkerConnection : IAsyncDisposable
|
||||
{
|
||||
private static long s_nextConnectionId;
|
||||
private Task _readTask = Task.CompletedTask;
|
||||
private Task _writeTask = Task.CompletedTask;
|
||||
private readonly string _connectionId = Interlocked.Increment(ref s_nextConnectionId).ToString();
|
||||
private readonly object _lock = new();
|
||||
private readonly HashSet<string> _supportedTypes = [];
|
||||
private readonly CancellationTokenSource _shutdownCancellationToken = new();
|
||||
public Task Completion { get; private set; } = Task.CompletedTask;
|
||||
public IAsyncStreamReader<Message> RequestStream { get; }
|
||||
public IServerStreamWriter<Message> ResponseStream { get; }
|
||||
public ServerCallContext ServerCallContext { get; }
|
||||
private readonly Channel<Message> _outboundMessages;
|
||||
public TestGrpcWorkerConnection(IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
|
||||
{
|
||||
RequestStream = requestStream;
|
||||
ResponseStream = responseStream;
|
||||
ServerCallContext = context;
|
||||
_outboundMessages = Channel.CreateUnbounded<Message>(new UnboundedChannelOptions { AllowSynchronousContinuations = true, SingleReader = true, SingleWriter = false });
|
||||
}
|
||||
public Task Connect()
|
||||
{
|
||||
var didSuppress = false;
|
||||
if (!ExecutionContext.IsFlowSuppressed())
|
||||
{
|
||||
didSuppress = true;
|
||||
ExecutionContext.SuppressFlow();
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
_readTask = Task.Run(RunReadPump);
|
||||
_writeTask = Task.Run(RunWritePump);
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (didSuppress)
|
||||
{
|
||||
ExecutionContext.RestoreFlow();
|
||||
}
|
||||
}
|
||||
|
||||
return Completion = Task.WhenAll(_readTask, _writeTask);
|
||||
}
|
||||
public void AddSupportedType(string type)
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
_supportedTypes.Add(type);
|
||||
}
|
||||
}
|
||||
public HashSet<string> GetSupportedTypes()
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
return new HashSet<string>(_supportedTypes);
|
||||
}
|
||||
}
|
||||
public async Task SendMessage(Message message)
|
||||
{
|
||||
await _outboundMessages.Writer.WriteAsync(message).ConfigureAwait(false);
|
||||
}
|
||||
public async Task RunReadPump()
|
||||
{
|
||||
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
|
||||
try
|
||||
{
|
||||
await foreach (var message in RequestStream.ReadAllAsync(_shutdownCancellationToken.Token))
|
||||
{
|
||||
//_gateway.OnReceivedMessageAsync(this, message, _shutdownCancellationToken.Token).Ignore();
|
||||
switch (message.MessageCase)
|
||||
{
|
||||
case Message.MessageOneofCase.Request:
|
||||
await SendMessage(new Message { Request = message.Request }).ConfigureAwait(false);
|
||||
break;
|
||||
case Message.MessageOneofCase.Response:
|
||||
await SendMessage(new Message { Response = message.Response }).ConfigureAwait(false);
|
||||
break;
|
||||
case Message.MessageOneofCase.CloudEvent:
|
||||
await SendMessage(new Message { CloudEvent = message.CloudEvent }).ConfigureAwait(false);
|
||||
break;
|
||||
default:
|
||||
// if it wasn't recognized return bad request
|
||||
throw new RpcException(new Status(StatusCode.InvalidArgument, $"Unknown message type for message '{message}'"));
|
||||
};
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
}
|
||||
finally
|
||||
{
|
||||
_shutdownCancellationToken.Cancel();
|
||||
//_gateway.OnRemoveWorkerProcess(this);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task RunWritePump()
|
||||
{
|
||||
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
|
||||
try
|
||||
{
|
||||
await foreach (var message in _outboundMessages.Reader.ReadAllAsync(_shutdownCancellationToken.Token))
|
||||
{
|
||||
await ResponseStream.WriteAsync(message);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
}
|
||||
finally
|
||||
{
|
||||
_shutdownCancellationToken.Cancel();
|
||||
}
|
||||
}
|
||||
|
||||
public async ValueTask DisposeAsync()
|
||||
{
|
||||
_shutdownCancellationToken.Cancel();
|
||||
await Completion.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
|
||||
}
|
||||
|
||||
public override string ToString() => $"Connection-{_connectionId}";
|
||||
}
|
||||
@ -0,0 +1,50 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// TestProtobufAgent.cs
|
||||
|
||||
using Microsoft.AutoGen.Contracts;
|
||||
using Microsoft.AutoGen.Core.Grpc.Tests.Protobuf;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Microsoft.AutoGen.Core.Grpc.Tests;
|
||||
|
||||
/// <summary>
|
||||
/// The test agent is a simple agent that is used for testing purposes.
|
||||
/// </summary>
|
||||
public class TestProtobufAgent(AgentId id,
|
||||
IAgentRuntime runtime,
|
||||
Logger<BaseAgent>? logger = null) : BaseAgent(id, runtime, "Test Agent", logger),
|
||||
IHandle<TextMessage>,
|
||||
IHandle<RpcTextMessage, RpcTextMessage>
|
||||
|
||||
{
|
||||
public ValueTask HandleAsync(TextMessage item, MessageContext messageContext)
|
||||
{
|
||||
ReceivedMessages[item.Source] = item.Content;
|
||||
return ValueTask.CompletedTask;
|
||||
}
|
||||
|
||||
public ValueTask<RpcTextMessage> HandleAsync(RpcTextMessage item, MessageContext messageContext)
|
||||
{
|
||||
ReceivedMessages[item.Source] = item.Content;
|
||||
return ValueTask.FromResult(new RpcTextMessage { Source = item.Source, Content = item.Content });
|
||||
}
|
||||
|
||||
public List<object> ReceivedItems { get; private set; } = [];
|
||||
|
||||
/// <summary>
|
||||
/// Key: source
|
||||
/// Value: message
|
||||
/// </summary>
|
||||
private readonly Dictionary<string, object> _receivedMessages = new();
|
||||
public Dictionary<string, object> ReceivedMessages => _receivedMessages;
|
||||
}
|
||||
|
||||
[TypeSubscription("TestTopic")]
|
||||
public class SubscribedProtobufAgent : TestProtobufAgent
|
||||
{
|
||||
public SubscribedProtobufAgent(AgentId id,
|
||||
IAgentRuntime runtime,
|
||||
Logger<BaseAgent>? logger = null) : base(id, runtime, logger)
|
||||
{
|
||||
}
|
||||
}
|
||||
13
dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/messages.proto
Normal file
13
dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/messages.proto
Normal file
@ -0,0 +1,13 @@
|
||||
syntax = "proto3";
|
||||
|
||||
option csharp_namespace = "Microsoft.AutoGen.Core.Grpc.Tests.Protobuf";
|
||||
|
||||
message TextMessage {
|
||||
string content = 1;
|
||||
string source = 2;
|
||||
}
|
||||
|
||||
message RpcTextMessage {
|
||||
string content = 1;
|
||||
string source = 2;
|
||||
}
|
||||
@ -109,7 +109,7 @@ public class AgentTests()
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SubscribeAsyncRemoveSubscriptionAsyncAndGetSubscriptionsTest()
|
||||
public async Task SubscribeAsyncRemoveSubscriptionAsyncTest()
|
||||
{
|
||||
var runtime = new InProcessRuntime();
|
||||
await runtime.StartAsync();
|
||||
|
||||
@ -1,43 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package agents;
|
||||
|
||||
option csharp_namespace = "Microsoft.AutoGen.Contracts";
|
||||
message TextMessage {
|
||||
string textMessage = 1;
|
||||
string source = 2;
|
||||
}
|
||||
message Input {
|
||||
string message = 1;
|
||||
}
|
||||
message InputProcessed {
|
||||
string route = 1;
|
||||
}
|
||||
message Output {
|
||||
string message = 1;
|
||||
}
|
||||
message OutputWritten {
|
||||
string route = 1;
|
||||
}
|
||||
message IOError {
|
||||
string message = 1;
|
||||
}
|
||||
message NewMessageReceived {
|
||||
string message = 1;
|
||||
}
|
||||
message ResponseGenerated {
|
||||
string response = 1;
|
||||
}
|
||||
message GoodBye {
|
||||
string message = 1;
|
||||
}
|
||||
message MessageStored {
|
||||
string message = 1;
|
||||
}
|
||||
message ConversationClosed {
|
||||
string user_id = 1;
|
||||
string user_message = 2;
|
||||
}
|
||||
message Shutdown {
|
||||
string message = 1;
|
||||
}
|
||||
@ -1,8 +0,0 @@
|
||||
syntax = "proto3";
|
||||
package agents;
|
||||
|
||||
option csharp_namespace = "Microsoft.AutoGen.Contracts";
|
||||
|
||||
message AgentState {
|
||||
string message = 1;
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user