mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-08 09:31:51 +00:00

<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? Moves to the stable 9.5.0 release instead of a preview (for the core Microsoft.Extensions.AI.Abstractions and Microsoft.Extensions.AI packages). ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed.
208 lines
9.2 KiB
C#
208 lines
9.2 KiB
C#
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// FunctionCallMiddleware.cs
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
using System.Runtime.CompilerServices;
|
|
using System.Text.Json;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
using Microsoft.Extensions.AI;
|
|
|
|
namespace AutoGen.Core;
|
|
|
|
/// <summary>
|
|
/// The middleware that process function call message that both send to an agent or reply from an agent.
|
|
/// <para>If the last message is <see cref="ToolCallMessage"/> and the tool calls is available in this middleware's function map,
|
|
/// the tools from the last message will be invoked and a <see cref="ToolCallResultMessage"/> will be returned. In this situation,
|
|
/// the inner agent will be short-cut and won't be invoked.</para>
|
|
/// <para>Otherwise, the message will be sent to the inner agent. In this situation</para>
|
|
/// <para>if the reply from the inner agent is <see cref="ToolCallMessage"/>,
|
|
/// and the tool calls is available in this middleware's function map, the tools from the reply will be invoked,
|
|
/// and a <see cref="ToolCallAggregateMessage"/> will be returned.
|
|
/// </para>
|
|
/// <para>If the reply from the inner agent is <see cref="ToolCallMessage"/> but the tool calls is not available in this middleware's function map,
|
|
/// or the reply from the inner agent is not <see cref="ToolCallMessage"/>, the original reply from the inner agent will be returned.</para>
|
|
/// <para>
|
|
/// When used as a streaming middleware, if the streaming reply from the inner agent is <see cref="ToolCallMessageUpdate"/> or <see cref="TextMessageUpdate"/>,
|
|
/// This middleware will update the message accordingly and invoke the function if the tool call is available in this middleware's function map.
|
|
/// If the streaming reply from the inner agent is other types of message, the most recent message will be used to invoke the function.
|
|
/// </para>
|
|
/// </summary>
|
|
public class FunctionCallMiddleware : IStreamingMiddleware
|
|
{
|
|
private readonly IEnumerable<FunctionContract>? functions;
|
|
private readonly IDictionary<string, Func<string, Task<string>>>? functionMap;
|
|
|
|
public FunctionCallMiddleware(
|
|
IEnumerable<FunctionContract>? functions = null,
|
|
IDictionary<string, Func<string, Task<string>>>? functionMap = null,
|
|
string? name = null)
|
|
{
|
|
this.Name = name ?? nameof(FunctionCallMiddleware);
|
|
this.functions = functions;
|
|
this.functionMap = functionMap;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Create a new instance of <see cref="FunctionCallMiddleware"/> with a list of <see cref="AIFunction"/>.
|
|
/// </summary>
|
|
/// <param name="functions">function list</param>
|
|
/// <param name="name">optional middleware name. If not provided, the class name <see cref="FunctionCallMiddleware"/> will be used.</param>
|
|
public FunctionCallMiddleware(IEnumerable<AIFunction> functions, string? name = null)
|
|
{
|
|
this.Name = name ?? nameof(FunctionCallMiddleware);
|
|
this.functions = functions.Select(f => (FunctionContract)f).ToArray();
|
|
|
|
this.functionMap = functions.Select(f => (f.Name, this.AIToolInvokeWrapper(f.InvokeAsync))).ToDictionary(f => f.Name, f => f.Item2);
|
|
}
|
|
|
|
public string? Name { get; }
|
|
|
|
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
|
|
{
|
|
var lastMessage = context.Messages.Last();
|
|
if (lastMessage is ToolCallMessage toolCallMessage)
|
|
{
|
|
return await this.InvokeToolCallMessagesBeforeInvokingAgentAsync(toolCallMessage, agent);
|
|
}
|
|
|
|
// combine functions
|
|
var options = new GenerateReplyOptions(context.Options ?? new GenerateReplyOptions());
|
|
var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions;
|
|
options.Functions = combinedFunctions?.ToArray();
|
|
|
|
var reply = await agent.GenerateReplyAsync(context.Messages, options, cancellationToken);
|
|
|
|
// if the reply is a function call message plus the function's name is available in function map, invoke the function and return the result instead of sending to the agent.
|
|
if (reply is ToolCallMessage toolCallMsg)
|
|
{
|
|
return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent);
|
|
}
|
|
|
|
// for all other messages, just return the reply from the agent.
|
|
return reply;
|
|
}
|
|
|
|
public async IAsyncEnumerable<IMessage> InvokeAsync(
|
|
MiddlewareContext context,
|
|
IStreamingAgent agent,
|
|
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
|
{
|
|
var lastMessage = context.Messages.Last();
|
|
if (lastMessage is ToolCallMessage toolCallMessage)
|
|
{
|
|
yield return await this.InvokeToolCallMessagesBeforeInvokingAgentAsync(toolCallMessage, agent);
|
|
}
|
|
|
|
// combine functions
|
|
var options = new GenerateReplyOptions(context.Options ?? new GenerateReplyOptions());
|
|
var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions;
|
|
options.Functions = combinedFunctions?.ToArray();
|
|
|
|
IMessage? mergedFunctionCallMessage = default;
|
|
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken))
|
|
{
|
|
if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null)
|
|
{
|
|
if (mergedFunctionCallMessage is null)
|
|
{
|
|
mergedFunctionCallMessage = new ToolCallMessage(toolCallMessageUpdate);
|
|
}
|
|
else if (mergedFunctionCallMessage is ToolCallMessage toolCall)
|
|
{
|
|
toolCall.Update(toolCallMessageUpdate);
|
|
}
|
|
else
|
|
{
|
|
throw new InvalidOperationException("The first message is ToolCallMessage, but the update message is not ToolCallMessageUpdate");
|
|
}
|
|
}
|
|
else if (message is ToolCallMessage toolCallMessage1)
|
|
{
|
|
mergedFunctionCallMessage = toolCallMessage1;
|
|
}
|
|
else
|
|
{
|
|
yield return message;
|
|
}
|
|
}
|
|
|
|
if (mergedFunctionCallMessage is ToolCallMessage toolCallMsg)
|
|
{
|
|
yield return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent);
|
|
}
|
|
}
|
|
|
|
private async Task<ToolCallResultMessage> InvokeToolCallMessagesBeforeInvokingAgentAsync(ToolCallMessage toolCallMessage, IAgent agent)
|
|
{
|
|
var toolCallResult = new List<ToolCall>();
|
|
var toolCalls = toolCallMessage.ToolCalls;
|
|
foreach (var toolCall in toolCalls)
|
|
{
|
|
var functionName = toolCall.FunctionName;
|
|
var functionArguments = toolCall.FunctionArguments;
|
|
if (this.functionMap?.TryGetValue(functionName, out var func) is true)
|
|
{
|
|
var result = await func(functionArguments);
|
|
toolCallResult.Add(new ToolCall(functionName, functionArguments, result) { ToolCallId = toolCall.ToolCallId });
|
|
}
|
|
else if (this.functionMap is not null)
|
|
{
|
|
var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}";
|
|
|
|
toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage) { ToolCallId = toolCall.ToolCallId });
|
|
}
|
|
else
|
|
{
|
|
throw new InvalidOperationException("FunctionMap is not available");
|
|
}
|
|
}
|
|
|
|
return new ToolCallResultMessage(toolCallResult, from: agent.Name);
|
|
}
|
|
|
|
private async Task<IMessage> InvokeToolCallMessagesAfterInvokingAgentAsync(ToolCallMessage toolCallMsg, IAgent agent)
|
|
{
|
|
var toolCallsReply = toolCallMsg.ToolCalls;
|
|
var toolCallResult = new List<ToolCall>();
|
|
foreach (var toolCall in toolCallsReply)
|
|
{
|
|
var fName = toolCall.FunctionName;
|
|
var fArgs = toolCall.FunctionArguments;
|
|
if (this.functionMap?.TryGetValue(fName, out var func) is true)
|
|
{
|
|
var result = await func(fArgs);
|
|
toolCallResult.Add(new ToolCall(fName, fArgs, result) { ToolCallId = toolCall.ToolCallId });
|
|
}
|
|
}
|
|
|
|
if (toolCallResult.Count > 0)
|
|
{
|
|
var toolCallResultMessage = new ToolCallResultMessage(toolCallResult, from: agent.Name);
|
|
return new ToolCallAggregateMessage(toolCallMsg, toolCallResultMessage, from: agent.Name);
|
|
}
|
|
else
|
|
{
|
|
return toolCallMsg;
|
|
}
|
|
}
|
|
|
|
private Func<string, Task<string>> AIToolInvokeWrapper(Func<AIFunctionArguments?, CancellationToken, ValueTask<object?>> lambda)
|
|
{
|
|
return async (string args) =>
|
|
{
|
|
var arguments = JsonSerializer.Deserialize<Dictionary<string, object?>>(args);
|
|
var result = await lambda(new(arguments), CancellationToken.None);
|
|
|
|
return result switch
|
|
{
|
|
string s => s,
|
|
JsonElement e => e.ToString(),
|
|
_ => JsonSerializer.Serialize(result),
|
|
};
|
|
};
|
|
}
|
|
}
|