2024-10-02 11:42:27 -07:00
// Copyright (c) Microsoft Corporation. All rights reserved.
2024-04-26 09:21:46 -07:00
// FunctionCallMiddleware.cs
using System ;
using System.Collections.Generic ;
using System.Linq ;
using System.Runtime.CompilerServices ;
2024-11-03 09:18:32 -08:00
using System.Text.Json ;
2024-04-26 09:21:46 -07:00
using System.Threading ;
using System.Threading.Tasks ;
2024-11-03 09:18:32 -08:00
using Microsoft.Extensions.AI ;
2024-04-26 09:21:46 -07:00
namespace AutoGen.Core ;
/// <summary>
/// The middleware that process function call message that both send to an agent or reply from an agent.
/// <para>If the last message is <see cref="ToolCallMessage"/> and the tool calls is available in this middleware's function map,
/// the tools from the last message will be invoked and a <see cref="ToolCallResultMessage"/> will be returned. In this situation,
/// the inner agent will be short-cut and won't be invoked.</para>
/// <para>Otherwise, the message will be sent to the inner agent. In this situation</para>
/// <para>if the reply from the inner agent is <see cref="ToolCallMessage"/>,
/// and the tool calls is available in this middleware's function map, the tools from the reply will be invoked,
2024-05-20 22:48:19 -07:00
/// and a <see cref="ToolCallAggregateMessage"/> will be returned.
2024-04-26 09:21:46 -07:00
/// </para>
/// <para>If the reply from the inner agent is <see cref="ToolCallMessage"/> but the tool calls is not available in this middleware's function map,
/// or the reply from the inner agent is not <see cref="ToolCallMessage"/>, the original reply from the inner agent will be returned.</para>
/// <para>
/// When used as a streaming middleware, if the streaming reply from the inner agent is <see cref="ToolCallMessageUpdate"/> or <see cref="TextMessageUpdate"/>,
/// This middleware will update the message accordingly and invoke the function if the tool call is available in this middleware's function map.
/// If the streaming reply from the inner agent is other types of message, the most recent message will be used to invoke the function.
/// </para>
/// </summary>
2024-05-05 07:51:00 -07:00
public class FunctionCallMiddleware : IStreamingMiddleware
2024-04-26 09:21:46 -07:00
{
private readonly IEnumerable < FunctionContract > ? functions ;
private readonly IDictionary < string , Func < string , Task < string > > > ? functionMap ;
public FunctionCallMiddleware (
IEnumerable < FunctionContract > ? functions = null ,
IDictionary < string , Func < string , Task < string > > > ? functionMap = null ,
string? name = null )
{
this . Name = name ? ? nameof ( FunctionCallMiddleware ) ;
this . functions = functions ;
this . functionMap = functionMap ;
}
2024-11-03 09:18:32 -08:00
/// <summary>
/// Create a new instance of <see cref="FunctionCallMiddleware"/> with a list of <see cref="AIFunction"/>.
/// </summary>
/// <param name="functions">function list</param>
/// <param name="name">optional middleware name. If not provided, the class name <see cref="FunctionCallMiddleware"/> will be used.</param>
public FunctionCallMiddleware ( IEnumerable < AIFunction > functions , string? name = null )
{
this . Name = name ? ? nameof ( FunctionCallMiddleware ) ;
2025-04-07 14:51:56 -07:00
this . functions = functions . Select ( f = > ( FunctionContract ) f ) . ToArray ( ) ;
2024-11-03 09:18:32 -08:00
2025-04-07 14:51:56 -07:00
this . functionMap = functions . Select ( f = > ( f . Name , this . AIToolInvokeWrapper ( f . InvokeAsync ) ) ) . ToDictionary ( f = > f . Name , f = > f . Item2 ) ;
2024-11-03 09:18:32 -08:00
}
2024-04-26 09:21:46 -07:00
public string? Name { get ; }
public async Task < IMessage > InvokeAsync ( MiddlewareContext context , IAgent agent , CancellationToken cancellationToken = default )
{
var lastMessage = context . Messages . Last ( ) ;
if ( lastMessage is ToolCallMessage toolCallMessage )
{
return await this . InvokeToolCallMessagesBeforeInvokingAgentAsync ( toolCallMessage , agent ) ;
}
// combine functions
var options = new GenerateReplyOptions ( context . Options ? ? new GenerateReplyOptions ( ) ) ;
var combinedFunctions = this . functions ? . Concat ( options . Functions ? ? [ ] ) ? ? options . Functions ;
options . Functions = combinedFunctions ? . ToArray ( ) ;
var reply = await agent . GenerateReplyAsync ( context . Messages , options , cancellationToken ) ;
// if the reply is a function call message plus the function's name is available in function map, invoke the function and return the result instead of sending to the agent.
if ( reply is ToolCallMessage toolCallMsg )
{
return await this . InvokeToolCallMessagesAfterInvokingAgentAsync ( toolCallMsg , agent ) ;
}
// for all other messages, just return the reply from the agent.
return reply ;
}
2024-07-01 09:52:57 -07:00
public async IAsyncEnumerable < IMessage > InvokeAsync (
2024-04-26 09:21:46 -07:00
MiddlewareContext context ,
IStreamingAgent agent ,
2024-05-05 07:51:00 -07:00
[EnumeratorCancellation] CancellationToken cancellationToken = default )
2024-04-26 09:21:46 -07:00
{
var lastMessage = context . Messages . Last ( ) ;
if ( lastMessage is ToolCallMessage toolCallMessage )
{
yield return await this . InvokeToolCallMessagesBeforeInvokingAgentAsync ( toolCallMessage , agent ) ;
}
// combine functions
var options = new GenerateReplyOptions ( context . Options ? ? new GenerateReplyOptions ( ) ) ;
var combinedFunctions = this . functions ? . Concat ( options . Functions ? ? [ ] ) ? ? options . Functions ;
options . Functions = combinedFunctions ? . ToArray ( ) ;
2024-07-01 09:52:57 -07:00
IMessage ? mergedFunctionCallMessage = default ;
2024-05-05 07:51:00 -07:00
await foreach ( var message in agent . GenerateStreamingReplyAsync ( context . Messages , options , cancellationToken ) )
2024-04-26 09:21:46 -07:00
{
if ( message is ToolCallMessageUpdate toolCallMessageUpdate & & this . functionMap ! = null )
{
2024-07-01 09:52:57 -07:00
if ( mergedFunctionCallMessage is null )
2024-04-26 09:21:46 -07:00
{
2024-07-01 09:52:57 -07:00
mergedFunctionCallMessage = new ToolCallMessage ( toolCallMessageUpdate ) ;
2024-04-26 09:21:46 -07:00
}
2024-07-01 09:52:57 -07:00
else if ( mergedFunctionCallMessage is ToolCallMessage toolCall )
2024-04-26 09:21:46 -07:00
{
toolCall . Update ( toolCallMessageUpdate ) ;
}
else
{
throw new InvalidOperationException ( "The first message is ToolCallMessage, but the update message is not ToolCallMessageUpdate" ) ;
}
}
2024-07-01 09:52:57 -07:00
else if ( message is ToolCallMessage toolCallMessage1 )
{
mergedFunctionCallMessage = toolCallMessage1 ;
}
2024-04-26 09:21:46 -07:00
else
{
yield return message ;
}
}
2024-07-01 09:52:57 -07:00
if ( mergedFunctionCallMessage is ToolCallMessage toolCallMsg )
2024-04-26 09:21:46 -07:00
{
yield return await this . InvokeToolCallMessagesAfterInvokingAgentAsync ( toolCallMsg , agent ) ;
}
}
private async Task < ToolCallResultMessage > InvokeToolCallMessagesBeforeInvokingAgentAsync ( ToolCallMessage toolCallMessage , IAgent agent )
{
var toolCallResult = new List < ToolCall > ( ) ;
var toolCalls = toolCallMessage . ToolCalls ;
foreach ( var toolCall in toolCalls )
{
var functionName = toolCall . FunctionName ;
var functionArguments = toolCall . FunctionArguments ;
if ( this . functionMap ? . TryGetValue ( functionName , out var func ) is true )
{
var result = await func ( functionArguments ) ;
2024-05-20 20:19:07 -07:00
toolCallResult . Add ( new ToolCall ( functionName , functionArguments , result ) { ToolCallId = toolCall . ToolCallId } ) ;
2024-04-26 09:21:46 -07:00
}
else if ( this . functionMap is not null )
{
var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(" , ", this.functionMap.Select(f => f.Key))}" ;
2024-05-20 20:19:07 -07:00
toolCallResult . Add ( new ToolCall ( functionName , functionArguments , errorMessage ) { ToolCallId = toolCall . ToolCallId } ) ;
2024-04-26 09:21:46 -07:00
}
else
{
throw new InvalidOperationException ( "FunctionMap is not available" ) ;
}
}
return new ToolCallResultMessage ( toolCallResult , from : agent . Name ) ;
}
private async Task < IMessage > InvokeToolCallMessagesAfterInvokingAgentAsync ( ToolCallMessage toolCallMsg , IAgent agent )
{
var toolCallsReply = toolCallMsg . ToolCalls ;
var toolCallResult = new List < ToolCall > ( ) ;
foreach ( var toolCall in toolCallsReply )
{
var fName = toolCall . FunctionName ;
var fArgs = toolCall . FunctionArguments ;
if ( this . functionMap ? . TryGetValue ( fName , out var func ) is true )
{
var result = await func ( fArgs ) ;
2024-05-20 20:19:07 -07:00
toolCallResult . Add ( new ToolCall ( fName , fArgs , result ) { ToolCallId = toolCall . ToolCallId } ) ;
2024-04-26 09:21:46 -07:00
}
}
2024-09-30 16:32:48 -07:00
if ( toolCallResult . Count > 0 )
2024-04-26 09:21:46 -07:00
{
var toolCallResultMessage = new ToolCallResultMessage ( toolCallResult , from : agent . Name ) ;
2024-05-20 22:48:19 -07:00
return new ToolCallAggregateMessage ( toolCallMsg , toolCallResultMessage , from : agent . Name ) ;
2024-04-26 09:21:46 -07:00
}
else
{
return toolCallMsg ;
}
}
2024-11-03 09:18:32 -08:00
2025-05-19 13:23:56 -07:00
private Func < string , Task < string > > AIToolInvokeWrapper ( Func < AIFunctionArguments ? , CancellationToken , ValueTask < object? > > lambda )
2024-11-03 09:18:32 -08:00
{
return async ( string args ) = >
{
var arguments = JsonSerializer . Deserialize < Dictionary < string , object? > > ( args ) ;
2025-05-19 13:23:56 -07:00
var result = await lambda ( new ( arguments ) , CancellationToken . None ) ;
2024-11-03 09:18:32 -08:00
return result switch
{
string s = > s ,
JsonElement e = > e . ToString ( ) ,
_ = > JsonSerializer . Serialize ( result ) ,
} ;
} ;
}
2024-04-26 09:21:46 -07:00
}