autogen/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
Stephen Toub fffa61f639
Update to stable Microsoft.Extensions.AI release (#6552)
<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

Moves to the stable 9.5.0 release instead of a preview (for the core
Microsoft.Extensions.AI.Abstractions and Microsoft.Extensions.AI
packages).

## Related issue number

<!-- For example: "Closes #1234" -->

## Checks

- [x] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.
2025-05-19 16:23:56 -04:00

208 lines
9.2 KiB
C#

// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionCallMiddleware.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
namespace AutoGen.Core;
/// <summary>
/// The middleware that process function call message that both send to an agent or reply from an agent.
/// <para>If the last message is <see cref="ToolCallMessage"/> and the tool calls is available in this middleware's function map,
/// the tools from the last message will be invoked and a <see cref="ToolCallResultMessage"/> will be returned. In this situation,
/// the inner agent will be short-cut and won't be invoked.</para>
/// <para>Otherwise, the message will be sent to the inner agent. In this situation</para>
/// <para>if the reply from the inner agent is <see cref="ToolCallMessage"/>,
/// and the tool calls is available in this middleware's function map, the tools from the reply will be invoked,
/// and a <see cref="ToolCallAggregateMessage"/> will be returned.
/// </para>
/// <para>If the reply from the inner agent is <see cref="ToolCallMessage"/> but the tool calls is not available in this middleware's function map,
/// or the reply from the inner agent is not <see cref="ToolCallMessage"/>, the original reply from the inner agent will be returned.</para>
/// <para>
/// When used as a streaming middleware, if the streaming reply from the inner agent is <see cref="ToolCallMessageUpdate"/> or <see cref="TextMessageUpdate"/>,
/// This middleware will update the message accordingly and invoke the function if the tool call is available in this middleware's function map.
/// If the streaming reply from the inner agent is other types of message, the most recent message will be used to invoke the function.
/// </para>
/// </summary>
public class FunctionCallMiddleware : IStreamingMiddleware
{
private readonly IEnumerable<FunctionContract>? functions;
private readonly IDictionary<string, Func<string, Task<string>>>? functionMap;
public FunctionCallMiddleware(
IEnumerable<FunctionContract>? functions = null,
IDictionary<string, Func<string, Task<string>>>? functionMap = null,
string? name = null)
{
this.Name = name ?? nameof(FunctionCallMiddleware);
this.functions = functions;
this.functionMap = functionMap;
}
/// <summary>
/// Create a new instance of <see cref="FunctionCallMiddleware"/> with a list of <see cref="AIFunction"/>.
/// </summary>
/// <param name="functions">function list</param>
/// <param name="name">optional middleware name. If not provided, the class name <see cref="FunctionCallMiddleware"/> will be used.</param>
public FunctionCallMiddleware(IEnumerable<AIFunction> functions, string? name = null)
{
this.Name = name ?? nameof(FunctionCallMiddleware);
this.functions = functions.Select(f => (FunctionContract)f).ToArray();
this.functionMap = functions.Select(f => (f.Name, this.AIToolInvokeWrapper(f.InvokeAsync))).ToDictionary(f => f.Name, f => f.Item2);
}
public string? Name { get; }
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
var lastMessage = context.Messages.Last();
if (lastMessage is ToolCallMessage toolCallMessage)
{
return await this.InvokeToolCallMessagesBeforeInvokingAgentAsync(toolCallMessage, agent);
}
// combine functions
var options = new GenerateReplyOptions(context.Options ?? new GenerateReplyOptions());
var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions;
options.Functions = combinedFunctions?.ToArray();
var reply = await agent.GenerateReplyAsync(context.Messages, options, cancellationToken);
// if the reply is a function call message plus the function's name is available in function map, invoke the function and return the result instead of sending to the agent.
if (reply is ToolCallMessage toolCallMsg)
{
return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent);
}
// for all other messages, just return the reply from the agent.
return reply;
}
public async IAsyncEnumerable<IMessage> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var lastMessage = context.Messages.Last();
if (lastMessage is ToolCallMessage toolCallMessage)
{
yield return await this.InvokeToolCallMessagesBeforeInvokingAgentAsync(toolCallMessage, agent);
}
// combine functions
var options = new GenerateReplyOptions(context.Options ?? new GenerateReplyOptions());
var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions;
options.Functions = combinedFunctions?.ToArray();
IMessage? mergedFunctionCallMessage = default;
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken))
{
if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null)
{
if (mergedFunctionCallMessage is null)
{
mergedFunctionCallMessage = new ToolCallMessage(toolCallMessageUpdate);
}
else if (mergedFunctionCallMessage is ToolCallMessage toolCall)
{
toolCall.Update(toolCallMessageUpdate);
}
else
{
throw new InvalidOperationException("The first message is ToolCallMessage, but the update message is not ToolCallMessageUpdate");
}
}
else if (message is ToolCallMessage toolCallMessage1)
{
mergedFunctionCallMessage = toolCallMessage1;
}
else
{
yield return message;
}
}
if (mergedFunctionCallMessage is ToolCallMessage toolCallMsg)
{
yield return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent);
}
}
private async Task<ToolCallResultMessage> InvokeToolCallMessagesBeforeInvokingAgentAsync(ToolCallMessage toolCallMessage, IAgent agent)
{
var toolCallResult = new List<ToolCall>();
var toolCalls = toolCallMessage.ToolCalls;
foreach (var toolCall in toolCalls)
{
var functionName = toolCall.FunctionName;
var functionArguments = toolCall.FunctionArguments;
if (this.functionMap?.TryGetValue(functionName, out var func) is true)
{
var result = await func(functionArguments);
toolCallResult.Add(new ToolCall(functionName, functionArguments, result) { ToolCallId = toolCall.ToolCallId });
}
else if (this.functionMap is not null)
{
var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}";
toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage) { ToolCallId = toolCall.ToolCallId });
}
else
{
throw new InvalidOperationException("FunctionMap is not available");
}
}
return new ToolCallResultMessage(toolCallResult, from: agent.Name);
}
private async Task<IMessage> InvokeToolCallMessagesAfterInvokingAgentAsync(ToolCallMessage toolCallMsg, IAgent agent)
{
var toolCallsReply = toolCallMsg.ToolCalls;
var toolCallResult = new List<ToolCall>();
foreach (var toolCall in toolCallsReply)
{
var fName = toolCall.FunctionName;
var fArgs = toolCall.FunctionArguments;
if (this.functionMap?.TryGetValue(fName, out var func) is true)
{
var result = await func(fArgs);
toolCallResult.Add(new ToolCall(fName, fArgs, result) { ToolCallId = toolCall.ToolCallId });
}
}
if (toolCallResult.Count > 0)
{
var toolCallResultMessage = new ToolCallResultMessage(toolCallResult, from: agent.Name);
return new ToolCallAggregateMessage(toolCallMsg, toolCallResultMessage, from: agent.Name);
}
else
{
return toolCallMsg;
}
}
private Func<string, Task<string>> AIToolInvokeWrapper(Func<AIFunctionArguments?, CancellationToken, ValueTask<object?>> lambda)
{
return async (string args) =>
{
var arguments = JsonSerializer.Deserialize<Dictionary<string, object?>>(args);
var result = await lambda(new(arguments), CancellationToken.None);
return result switch
{
string s => s,
JsonElement e => e.ToString(),
_ => JsonSerializer.Serialize(result),
};
};
}
}