mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-15 20:21:10 +00:00
ensure that cancellation token is passed in InvokeWithActivityAsync (#4329)
* ensure that cancellation token is passed in InvokeWithActivityAsync * add comments and baggange is not nullable * store ncrunch settings * shange signature to have nullable activity at the end of Update * correct spelling case * primary contructor * add docs and make async interface accept cancellation tokens * address code ql error
This commit is contained in:
parent
01dc56b244
commit
d186a41ed1
8
dotnet/AutoGen.v3.ncrunchsolution
Normal file
8
dotnet/AutoGen.v3.ncrunchsolution
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<SolutionConfiguration>
|
||||||
|
<Settings>
|
||||||
|
<AllowParallelTestExecution>True</AllowParallelTestExecution>
|
||||||
|
<EnableRDI>True</EnableRDI>
|
||||||
|
<RdiConfigured>True</RdiConfigured>
|
||||||
|
<SolutionConfigured>True</SolutionConfigured>
|
||||||
|
</Settings>
|
||||||
|
</SolutionConfiguration>
|
@ -15,8 +15,8 @@ public interface IAgentRuntime
|
|||||||
ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default);
|
ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default);
|
||||||
ValueTask SendMessageAsync(Message message, CancellationToken cancellationToken = default);
|
ValueTask SendMessageAsync(Message message, CancellationToken cancellationToken = default);
|
||||||
ValueTask PublishEventAsync(CloudEvent @event, CancellationToken cancellationToken = default);
|
ValueTask PublishEventAsync(CloudEvent @event, CancellationToken cancellationToken = default);
|
||||||
void Update(Activity? activity, RpcRequest request);
|
void Update(RpcRequest request, Activity? activity);
|
||||||
void Update(Activity? activity, CloudEvent cloudEvent);
|
void Update(CloudEvent cloudEvent, Activity? activity);
|
||||||
(string?, string?) GetTraceIDandState(IDictionary<string, string> metadata);
|
(string?, string?) GetTraceIdAndState(IDictionary<string, string> metadata);
|
||||||
IDictionary<string, string> ExtractMetadata(IDictionary<string, string> metadata);
|
IDictionary<string, string> ExtractMetadata(IDictionary<string, string> metadata);
|
||||||
}
|
}
|
||||||
|
@ -3,8 +3,24 @@
|
|||||||
|
|
||||||
namespace Microsoft.AutoGen.Abstractions;
|
namespace Microsoft.AutoGen.Abstractions;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Interface for managing the state of an agent.
|
||||||
|
/// </summary>
|
||||||
public interface IAgentState
|
public interface IAgentState
|
||||||
{
|
{
|
||||||
ValueTask<AgentState> ReadStateAsync();
|
/// <summary>
|
||||||
ValueTask<string> WriteStateAsync(AgentState state, string eTag);
|
/// Reads the current state of the agent asynchronously.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="cancellationToken">A token to cancel the operation.</param>
|
||||||
|
/// <returns>A task that represents the asynchronous read operation. The task result contains the current state of the agent.</returns>
|
||||||
|
ValueTask<AgentState> ReadStateAsync(CancellationToken cancellationToken = default);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Writes the specified state of the agent asynchronously.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="state">The state to write.</param>
|
||||||
|
/// <param name="eTag">The ETag for concurrency control.</param>
|
||||||
|
/// <param name="cancellationToken">A token to cancel the operation.</param>
|
||||||
|
/// <returns>A task that represents the asynchronous write operation. The task result contains the ETag of the written state.</returns>
|
||||||
|
ValueTask<string> WriteStateAsync(AgentState state, string eTag, CancellationToken cancellationToken = default);
|
||||||
}
|
}
|
||||||
|
@ -93,7 +93,7 @@ public abstract class AgentBase : IAgentBase, IHandle
|
|||||||
{
|
{
|
||||||
var activity = this.ExtractActivity(msg.CloudEvent.Type, msg.CloudEvent.Metadata);
|
var activity = this.ExtractActivity(msg.CloudEvent.Type, msg.CloudEvent.Metadata);
|
||||||
await this.InvokeWithActivityAsync(
|
await this.InvokeWithActivityAsync(
|
||||||
static ((AgentBase Agent, CloudEvent Item) state) => state.Agent.CallHandler(state.Item),
|
static ((AgentBase Agent, CloudEvent Item) state, CancellationToken _) => state.Agent.CallHandler(state.Item),
|
||||||
(this, msg.CloudEvent),
|
(this, msg.CloudEvent),
|
||||||
activity,
|
activity,
|
||||||
msg.CloudEvent.Type, cancellationToken).ConfigureAwait(false);
|
msg.CloudEvent.Type, cancellationToken).ConfigureAwait(false);
|
||||||
@ -103,7 +103,7 @@ public abstract class AgentBase : IAgentBase, IHandle
|
|||||||
{
|
{
|
||||||
var activity = this.ExtractActivity(msg.Request.Method, msg.Request.Metadata);
|
var activity = this.ExtractActivity(msg.Request.Method, msg.Request.Metadata);
|
||||||
await this.InvokeWithActivityAsync(
|
await this.InvokeWithActivityAsync(
|
||||||
static ((AgentBase Agent, RpcRequest Request) state) => state.Agent.OnRequestCoreAsync(state.Request),
|
static ((AgentBase Agent, RpcRequest Request) state, CancellationToken ct) => state.Agent.OnRequestCoreAsync(state.Request, ct),
|
||||||
(this, msg.Request),
|
(this, msg.Request),
|
||||||
activity,
|
activity,
|
||||||
msg.Request.Method, cancellationToken).ConfigureAwait(false);
|
msg.Request.Method, cancellationToken).ConfigureAwait(false);
|
||||||
@ -142,8 +142,8 @@ public abstract class AgentBase : IAgentBase, IHandle
|
|||||||
}
|
}
|
||||||
public async Task<T> ReadAsync<T>(AgentId agentId, CancellationToken cancellationToken = default) where T : IMessage, new()
|
public async Task<T> ReadAsync<T>(AgentId agentId, CancellationToken cancellationToken = default) where T : IMessage, new()
|
||||||
{
|
{
|
||||||
var agentstate = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
|
var agentState = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
|
||||||
return agentstate.FromAgentState<T>();
|
return agentState.FromAgentState<T>();
|
||||||
}
|
}
|
||||||
private void OnResponseCore(RpcResponse response)
|
private void OnResponseCore(RpcResponse response)
|
||||||
{
|
{
|
||||||
@ -195,9 +195,9 @@ public abstract class AgentBase : IAgentBase, IHandle
|
|||||||
activity?.SetTag("peer.service", target.ToString());
|
activity?.SetTag("peer.service", target.ToString());
|
||||||
|
|
||||||
var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
|
var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||||
_context.Update(activity, request);
|
_context.Update(request, activity);
|
||||||
await this.InvokeWithActivityAsync(
|
await this.InvokeWithActivityAsync(
|
||||||
static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResponse>) state) =>
|
static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResponse>) state, CancellationToken ct) =>
|
||||||
{
|
{
|
||||||
var (self, request, completion) = state;
|
var (self, request, completion) = state;
|
||||||
|
|
||||||
@ -206,7 +206,7 @@ public abstract class AgentBase : IAgentBase, IHandle
|
|||||||
self._pendingRequests[request.RequestId] = completion;
|
self._pendingRequests[request.RequestId] = completion;
|
||||||
}
|
}
|
||||||
|
|
||||||
await state.Agent._context.SendRequestAsync(state.Agent, state.Request).ConfigureAwait(false);
|
await state.Agent._context.SendRequestAsync(state.Agent, state.Request, ct).ConfigureAwait(false);
|
||||||
|
|
||||||
await completion.Task.ConfigureAwait(false);
|
await completion.Task.ConfigureAwait(false);
|
||||||
},
|
},
|
||||||
@ -231,11 +231,11 @@ public abstract class AgentBase : IAgentBase, IHandle
|
|||||||
activity?.SetTag("peer.service", $"{item.Type}/{item.Source}");
|
activity?.SetTag("peer.service", $"{item.Type}/{item.Source}");
|
||||||
|
|
||||||
// TODO: fix activity
|
// TODO: fix activity
|
||||||
_context.Update(activity, item);
|
_context.Update(item, activity);
|
||||||
await this.InvokeWithActivityAsync(
|
await this.InvokeWithActivityAsync(
|
||||||
static async ((AgentBase Agent, CloudEvent Event) state) =>
|
static async ((AgentBase Agent, CloudEvent Event) state, CancellationToken ct) =>
|
||||||
{
|
{
|
||||||
await state.Agent._context.PublishEventAsync(state.Event).ConfigureAwait(false);
|
await state.Agent._context.PublishEventAsync(state.Event, ct).ConfigureAwait(false);
|
||||||
},
|
},
|
||||||
(this, item),
|
(this, item),
|
||||||
activity,
|
activity,
|
||||||
|
@ -5,15 +5,25 @@ using System.Diagnostics;
|
|||||||
|
|
||||||
namespace Microsoft.AutoGen.Agents;
|
namespace Microsoft.AutoGen.Agents;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Provides extension methods for the <see cref="AgentBase"/> class.
|
||||||
|
/// </summary>
|
||||||
public static class AgentBaseExtensions
|
public static class AgentBaseExtensions
|
||||||
{
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Extracts an <see cref="Activity"/> from the given agent and metadata.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="agent">The agent from which to extract the activity.</param>
|
||||||
|
/// <param name="activityName">The name of the activity.</param>
|
||||||
|
/// <param name="metadata">The metadata containing trace information.</param>
|
||||||
|
/// <returns>The extracted <see cref="Activity"/> or null if extraction fails.</returns>
|
||||||
public static Activity? ExtractActivity(this AgentBase agent, string activityName, IDictionary<string, string> metadata)
|
public static Activity? ExtractActivity(this AgentBase agent, string activityName, IDictionary<string, string> metadata)
|
||||||
{
|
{
|
||||||
Activity? activity;
|
Activity? activity;
|
||||||
(var traceParent, var traceState) = agent.Context.GetTraceIDandState(metadata);
|
var (traceParent, traceState) = agent.Context.GetTraceIdAndState(metadata);
|
||||||
if (!string.IsNullOrEmpty(traceParent))
|
if (!string.IsNullOrEmpty(traceParent))
|
||||||
{
|
{
|
||||||
if (ActivityContext.TryParse(traceParent, traceState, isRemote: true, out ActivityContext parentContext))
|
if (ActivityContext.TryParse(traceParent, traceState, isRemote: true, out var parentContext))
|
||||||
{
|
{
|
||||||
// traceParent is a W3CId
|
// traceParent is a W3CId
|
||||||
activity = AgentBase.s_source.CreateActivity(activityName, ActivityKind.Server, parentContext);
|
activity = AgentBase.s_source.CreateActivity(activityName, ActivityKind.Server, parentContext);
|
||||||
@ -33,15 +43,12 @@ public static class AgentBaseExtensions
|
|||||||
|
|
||||||
var baggage = agent.Context.ExtractMetadata(metadata);
|
var baggage = agent.Context.ExtractMetadata(metadata);
|
||||||
|
|
||||||
if (baggage is not null)
|
|
||||||
{
|
|
||||||
foreach (var baggageItem in baggage)
|
foreach (var baggageItem in baggage)
|
||||||
{
|
{
|
||||||
activity.AddBaggage(baggageItem.Key, baggageItem.Value);
|
activity.AddBaggage(baggageItem.Key, baggageItem.Value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
activity = AgentBase.s_source.CreateActivity(activityName, ActivityKind.Server);
|
activity = AgentBase.s_source.CreateActivity(activityName, ActivityKind.Server);
|
||||||
@ -49,7 +56,19 @@ public static class AgentBaseExtensions
|
|||||||
|
|
||||||
return activity;
|
return activity;
|
||||||
}
|
}
|
||||||
public static async Task InvokeWithActivityAsync<TState>(this AgentBase agent, Func<TState, Task> func, TState state, Activity? activity, string methodName, CancellationToken cancellationToken = default)
|
|
||||||
|
/// <summary>
|
||||||
|
/// Invokes a function asynchronously within the context of an <see cref="Activity"/>.
|
||||||
|
/// </summary>
|
||||||
|
/// <typeparam name="TState">The type of the state parameter.</typeparam>
|
||||||
|
/// <param name="agent">The agent invoking the function.</param>
|
||||||
|
/// <param name="func">The function to invoke.</param>
|
||||||
|
/// <param name="state">The state parameter to pass to the function.</param>
|
||||||
|
/// <param name="activity">The activity within which to invoke the function.</param>
|
||||||
|
/// <param name="methodName">The name of the method being invoked.</param>
|
||||||
|
/// <param name="cancellationToken">A token to monitor for cancellation requests.</param>
|
||||||
|
/// <returns>A task representing the asynchronous operation.</returns>
|
||||||
|
public static async Task InvokeWithActivityAsync<TState>(this AgentBase agent, Func<TState, CancellationToken, Task> func, TState state, Activity? activity, string methodName, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
if (activity is not null && activity.StartTimeUtc == default)
|
if (activity is not null && activity.StartTimeUtc == default)
|
||||||
{
|
{
|
||||||
@ -63,7 +82,7 @@ public static class AgentBaseExtensions
|
|||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
await func(state).ConfigureAwait(false);
|
await func(state, cancellationToken).ConfigureAwait(false);
|
||||||
if (activity is not null && activity.IsAllDataRequested)
|
if (activity is not null && activity.IsAllDataRequested)
|
||||||
{
|
{
|
||||||
activity.SetStatus(ActivityStatusCode.Ok);
|
activity.SetStatus(ActivityStatusCode.Ok);
|
||||||
|
@ -15,7 +15,7 @@ internal sealed class AgentRuntime(AgentId agentId, IAgentWorker worker, ILogger
|
|||||||
public ILogger<AgentBase> Logger { get; } = logger;
|
public ILogger<AgentBase> Logger { get; } = logger;
|
||||||
public IAgentBase? AgentInstance { get; set; }
|
public IAgentBase? AgentInstance { get; set; }
|
||||||
private DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator;
|
private DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator;
|
||||||
public (string?, string?) GetTraceIDandState(IDictionary<string, string> metadata)
|
public (string?, string?) GetTraceIdAndState(IDictionary<string, string> metadata)
|
||||||
{
|
{
|
||||||
DistributedContextPropagator.ExtractTraceIdAndState(metadata,
|
DistributedContextPropagator.ExtractTraceIdAndState(metadata,
|
||||||
static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
|
static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
|
||||||
@ -28,11 +28,11 @@ internal sealed class AgentRuntime(AgentId agentId, IAgentWorker worker, ILogger
|
|||||||
out var traceState);
|
out var traceState);
|
||||||
return (traceParent, traceState);
|
return (traceParent, traceState);
|
||||||
}
|
}
|
||||||
public void Update(Activity? activity, RpcRequest request)
|
public void Update(RpcRequest request, Activity? activity = null)
|
||||||
{
|
{
|
||||||
DistributedContextPropagator.Inject(activity, request.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
|
DistributedContextPropagator.Inject(activity, request.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
|
||||||
}
|
}
|
||||||
public void Update(Activity? activity, CloudEvent cloudEvent)
|
public void Update(CloudEvent cloudEvent, Activity? activity = null)
|
||||||
{
|
{
|
||||||
DistributedContextPropagator.Inject(activity, cloudEvent.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
|
DistributedContextPropagator.Inject(activity, cloudEvent.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
|
||||||
}
|
}
|
||||||
|
@ -5,16 +5,14 @@ using Google.Protobuf;
|
|||||||
using Microsoft.AutoGen.Abstractions;
|
using Microsoft.AutoGen.Abstractions;
|
||||||
using Microsoft.Extensions.AI;
|
using Microsoft.Extensions.AI;
|
||||||
namespace Microsoft.AutoGen.Agents;
|
namespace Microsoft.AutoGen.Agents;
|
||||||
public abstract class InferenceAgent<T> : AgentBase where T : IMessage, new()
|
public abstract class InferenceAgent<T>(
|
||||||
{
|
|
||||||
protected IChatClient ChatClient { get; }
|
|
||||||
public InferenceAgent(
|
|
||||||
IAgentRuntime context,
|
IAgentRuntime context,
|
||||||
EventTypes typeRegistry, IChatClient client
|
EventTypes typeRegistry,
|
||||||
) : base(context, typeRegistry)
|
IChatClient client)
|
||||||
{
|
: AgentBase(context, typeRegistry)
|
||||||
ChatClient = client;
|
where T : IMessage, new()
|
||||||
}
|
{
|
||||||
|
protected IChatClient ChatClient { get; } = client;
|
||||||
|
|
||||||
private Task<ChatCompletion> CompleteAsync(
|
private Task<ChatCompletion> CompleteAsync(
|
||||||
IList<ChatMessage> chatMessages,
|
IList<ChatMessage> chatMessages,
|
||||||
|
@ -7,7 +7,8 @@ namespace Microsoft.AutoGen.Agents;
|
|||||||
|
|
||||||
internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore")] IPersistentState<AgentState> state) : Grain, IAgentState
|
internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore")] IPersistentState<AgentState> state) : Grain, IAgentState
|
||||||
{
|
{
|
||||||
public async ValueTask<string> WriteStateAsync(AgentState newState, string eTag)
|
/// <inheritdoc />
|
||||||
|
public async ValueTask<string> WriteStateAsync(AgentState newState, string eTag, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
// etags for optimistic concurrency control
|
// etags for optimistic concurrency control
|
||||||
// if the Etag is null, its a new state
|
// if the Etag is null, its a new state
|
||||||
@ -27,7 +28,8 @@ internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore
|
|||||||
return state.Etag;
|
return state.Etag;
|
||||||
}
|
}
|
||||||
|
|
||||||
public ValueTask<AgentState> ReadStateAsync()
|
/// <inheritdoc />
|
||||||
|
public ValueTask<AgentState> ReadStateAsync(CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
return ValueTask.FromResult(state.State);
|
return ValueTask.FromResult(state.State);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user