mirror of
https://github.com/microsoft/autogen.git
synced 2025-06-26 22:30:10 +00:00

<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed.
135 lines
4.7 KiB
C#
135 lines
4.7 KiB
C#
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// TestGrpcWorkerConnection.cs
|
|
|
|
using System.Threading.Channels;
|
|
using Grpc.Core;
|
|
using Microsoft.AutoGen.Protobuf;
|
|
|
|
namespace Microsoft.AutoGen.Core.Grpc.Tests;
|
|
|
|
internal sealed class TestGrpcWorkerConnection : IAsyncDisposable
|
|
{
|
|
private static long s_nextConnectionId;
|
|
private Task _readTask = Task.CompletedTask;
|
|
private Task _writeTask = Task.CompletedTask;
|
|
private readonly string _connectionId = Interlocked.Increment(ref s_nextConnectionId).ToString();
|
|
private readonly object _lock = new();
|
|
private readonly HashSet<string> _supportedTypes = [];
|
|
private readonly CancellationTokenSource _shutdownCancellationToken = new();
|
|
public Task Completion { get; private set; } = Task.CompletedTask;
|
|
public IAsyncStreamReader<Message> RequestStream { get; }
|
|
public IServerStreamWriter<Message> ResponseStream { get; }
|
|
public ServerCallContext ServerCallContext { get; }
|
|
private readonly Channel<Message> _outboundMessages;
|
|
public TestGrpcWorkerConnection(IAsyncStreamReader<Message> requestStream, IServerStreamWriter<Message> responseStream, ServerCallContext context)
|
|
{
|
|
RequestStream = requestStream;
|
|
ResponseStream = responseStream;
|
|
ServerCallContext = context;
|
|
_outboundMessages = Channel.CreateUnbounded<Message>(new UnboundedChannelOptions { AllowSynchronousContinuations = true, SingleReader = true, SingleWriter = false });
|
|
}
|
|
public Task Connect()
|
|
{
|
|
var didSuppress = false;
|
|
if (!ExecutionContext.IsFlowSuppressed())
|
|
{
|
|
didSuppress = true;
|
|
ExecutionContext.SuppressFlow();
|
|
}
|
|
|
|
try
|
|
{
|
|
_readTask = Task.Run(RunReadPump);
|
|
_writeTask = Task.Run(RunWritePump);
|
|
}
|
|
finally
|
|
{
|
|
if (didSuppress)
|
|
{
|
|
ExecutionContext.RestoreFlow();
|
|
}
|
|
}
|
|
|
|
return Completion = Task.WhenAll(_readTask, _writeTask);
|
|
}
|
|
public void AddSupportedType(string type)
|
|
{
|
|
lock (_lock)
|
|
{
|
|
_supportedTypes.Add(type);
|
|
}
|
|
}
|
|
public HashSet<string> GetSupportedTypes()
|
|
{
|
|
lock (_lock)
|
|
{
|
|
return new HashSet<string>(_supportedTypes);
|
|
}
|
|
}
|
|
public async Task SendMessage(Message message)
|
|
{
|
|
await _outboundMessages.Writer.WriteAsync(message).ConfigureAwait(false);
|
|
}
|
|
public async Task RunReadPump()
|
|
{
|
|
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
|
|
try
|
|
{
|
|
await foreach (var message in RequestStream.ReadAllAsync(_shutdownCancellationToken.Token))
|
|
{
|
|
//_gateway.OnReceivedMessageAsync(this, message, _shutdownCancellationToken.Token).Ignore();
|
|
switch (message.MessageCase)
|
|
{
|
|
case Message.MessageOneofCase.Request:
|
|
await SendMessage(new Message { Request = message.Request }).ConfigureAwait(false);
|
|
break;
|
|
case Message.MessageOneofCase.Response:
|
|
await SendMessage(new Message { Response = message.Response }).ConfigureAwait(false);
|
|
break;
|
|
case Message.MessageOneofCase.CloudEvent:
|
|
await SendMessage(new Message { CloudEvent = message.CloudEvent }).ConfigureAwait(false);
|
|
break;
|
|
default:
|
|
// if it wasn't recognized return bad request
|
|
throw new RpcException(new Status(StatusCode.InvalidArgument, $"Unknown message type for message '{message}'"));
|
|
}
|
|
}
|
|
}
|
|
catch (OperationCanceledException)
|
|
{
|
|
}
|
|
finally
|
|
{
|
|
_shutdownCancellationToken.Cancel();
|
|
//_gateway.OnRemoveWorkerProcess(this);
|
|
}
|
|
}
|
|
|
|
public async Task RunWritePump()
|
|
{
|
|
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
|
|
try
|
|
{
|
|
await foreach (var message in _outboundMessages.Reader.ReadAllAsync(_shutdownCancellationToken.Token))
|
|
{
|
|
await ResponseStream.WriteAsync(message);
|
|
}
|
|
}
|
|
catch (OperationCanceledException)
|
|
{
|
|
}
|
|
finally
|
|
{
|
|
_shutdownCancellationToken.Cancel();
|
|
}
|
|
}
|
|
|
|
public async ValueTask DisposeAsync()
|
|
{
|
|
_shutdownCancellationToken.Cancel();
|
|
await Completion.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
|
|
}
|
|
|
|
public override string ToString() => $"Connection-{_connectionId}";
|
|
}
|