// Copyright (c) Microsoft Corporation. All rights reserved. // FunctionCallMiddleware.cs using System; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; namespace AutoGen.Core; /// /// The middleware that process function call message that both send to an agent or reply from an agent. /// If the last message is and the tool calls is available in this middleware's function map, /// the tools from the last message will be invoked and a will be returned. In this situation, /// the inner agent will be short-cut and won't be invoked. /// Otherwise, the message will be sent to the inner agent. In this situation /// if the reply from the inner agent is , /// and the tool calls is available in this middleware's function map, the tools from the reply will be invoked, /// and a will be returned. /// /// If the reply from the inner agent is but the tool calls is not available in this middleware's function map, /// or the reply from the inner agent is not , the original reply from the inner agent will be returned. /// /// When used as a streaming middleware, if the streaming reply from the inner agent is or , /// This middleware will update the message accordingly and invoke the function if the tool call is available in this middleware's function map. /// If the streaming reply from the inner agent is other types of message, the most recent message will be used to invoke the function. /// /// public class FunctionCallMiddleware : IStreamingMiddleware { private readonly IEnumerable? functions; private readonly IDictionary>>? functionMap; public FunctionCallMiddleware( IEnumerable? functions = null, IDictionary>>? functionMap = null, string? name = null) { this.Name = name ?? nameof(FunctionCallMiddleware); this.functions = functions; this.functionMap = functionMap; } public string? Name { get; } public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { var lastMessage = context.Messages.Last(); if (lastMessage is ToolCallMessage toolCallMessage) { return await this.InvokeToolCallMessagesBeforeInvokingAgentAsync(toolCallMessage, agent); } // combine functions var options = new GenerateReplyOptions(context.Options ?? new GenerateReplyOptions()); var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions; options.Functions = combinedFunctions?.ToArray(); var reply = await agent.GenerateReplyAsync(context.Messages, options, cancellationToken); // if the reply is a function call message plus the function's name is available in function map, invoke the function and return the result instead of sending to the agent. if (reply is ToolCallMessage toolCallMsg) { return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent); } // for all other messages, just return the reply from the agent. return reply; } public async IAsyncEnumerable InvokeAsync( MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var lastMessage = context.Messages.Last(); if (lastMessage is ToolCallMessage toolCallMessage) { yield return await this.InvokeToolCallMessagesBeforeInvokingAgentAsync(toolCallMessage, agent); } // combine functions var options = new GenerateReplyOptions(context.Options ?? new GenerateReplyOptions()); var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions; options.Functions = combinedFunctions?.ToArray(); IStreamingMessage? initMessage = default; await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken)) { if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null) { if (initMessage is null) { initMessage = new ToolCallMessage(toolCallMessageUpdate); } else if (initMessage is ToolCallMessage toolCall) { toolCall.Update(toolCallMessageUpdate); } else { throw new InvalidOperationException("The first message is ToolCallMessage, but the update message is not ToolCallMessageUpdate"); } } else { yield return message; } } if (initMessage is ToolCallMessage toolCallMsg) { yield return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent); } } private async Task InvokeToolCallMessagesBeforeInvokingAgentAsync(ToolCallMessage toolCallMessage, IAgent agent) { var toolCallResult = new List(); var toolCalls = toolCallMessage.ToolCalls; foreach (var toolCall in toolCalls) { var functionName = toolCall.FunctionName; var functionArguments = toolCall.FunctionArguments; if (this.functionMap?.TryGetValue(functionName, out var func) is true) { var result = await func(functionArguments); toolCallResult.Add(new ToolCall(functionName, functionArguments, result) { ToolCallId = toolCall.ToolCallId }); } else if (this.functionMap is not null) { var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}"; toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage) { ToolCallId = toolCall.ToolCallId }); } else { throw new InvalidOperationException("FunctionMap is not available"); } } return new ToolCallResultMessage(toolCallResult, from: agent.Name); } private async Task InvokeToolCallMessagesAfterInvokingAgentAsync(ToolCallMessage toolCallMsg, IAgent agent) { var toolCallsReply = toolCallMsg.ToolCalls; var toolCallResult = new List(); foreach (var toolCall in toolCallsReply) { var fName = toolCall.FunctionName; var fArgs = toolCall.FunctionArguments; if (this.functionMap?.TryGetValue(fName, out var func) is true) { var result = await func(fArgs); toolCallResult.Add(new ToolCall(fName, fArgs, result) { ToolCallId = toolCall.ToolCallId }); } } if (toolCallResult.Count() > 0) { var toolCallResultMessage = new ToolCallResultMessage(toolCallResult, from: agent.Name); return new ToolCallAggregateMessage(toolCallMsg, toolCallResultMessage, from: agent.Name); } else { return toolCallMsg; } } }