| 
									
										
										
										
											2024-10-02 11:42:27 -07:00
										 |  |  | // Copyright (c) Microsoft Corporation. All rights reserved. | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  | // MiddlewareTest.cs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | using System; | 
					
						
							|  |  |  | using System.Collections.Generic; | 
					
						
							|  |  |  | using System.Linq; | 
					
						
							|  |  |  | using System.Text.Json; | 
					
						
							|  |  |  | using System.Threading.Tasks; | 
					
						
							|  |  |  | using FluentAssertions; | 
					
						
							| 
									
										
										
										
											2024-11-03 09:18:32 -08:00
										 |  |  | using Microsoft.Extensions.AI; | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  | using Xunit; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace AutoGen.Tests; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-03 11:49:08 -05:00
										 |  |  | [Trait("Category", "UnitV1")] | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  | public partial class MiddlewareTest | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |     [Function] | 
					
						
							|  |  |  |     public async Task<string> Echo(string message) | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         return $"[FUNC] {message}"; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     [Fact] | 
					
						
							|  |  |  |     public async Task HumanInputMiddlewareTestAsync() | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         var agent = new EchoAgent("echo"); | 
					
						
							|  |  |  |         var neverAskUserInputMW = new HumanInputMiddleware(mode: HumanInputMode.NEVER); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         var neverInputAgent = agent.RegisterMiddleware(neverAskUserInputMW); | 
					
						
							|  |  |  |         var reply = await neverInputAgent.SendAsync("hello"); | 
					
						
							|  |  |  |         reply.GetContent()!.Should().Be("hello"); | 
					
						
							|  |  |  |         reply.From.Should().Be("echo"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         var alwaysAskUserInputMW = new HumanInputMiddleware( | 
					
						
							|  |  |  |             mode: HumanInputMode.ALWAYS, | 
					
						
							|  |  |  |             getInput: () => "input"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         var alwaysInputAgent = agent.RegisterMiddleware(alwaysAskUserInputMW); | 
					
						
							|  |  |  |         reply = await alwaysInputAgent.SendAsync("hello"); | 
					
						
							|  |  |  |         reply.GetContent()!.Should().Be("input"); | 
					
						
							|  |  |  |         reply.From.Should().Be("echo"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // test auto mode | 
					
						
							|  |  |  |         // if the reply from echo is not terminate message, return the original reply | 
					
						
							|  |  |  |         var autoAskUserInputMW = new HumanInputMiddleware( | 
					
						
							|  |  |  |             mode: HumanInputMode.AUTO, | 
					
						
							|  |  |  |             isTermination: async (messages, ct) => messages.Last()?.GetContent() == "terminate", | 
					
						
							|  |  |  |             getInput: () => "input", | 
					
						
							|  |  |  |             exitKeyword: "exit"); | 
					
						
							|  |  |  |         var autoInputAgent = agent.RegisterMiddleware(autoAskUserInputMW); | 
					
						
							|  |  |  |         reply = await autoInputAgent.SendAsync("hello"); | 
					
						
							|  |  |  |         reply.GetContent()!.Should().Be("hello"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // if the reply from echo is terminate message, asking user for input | 
					
						
							|  |  |  |         reply = await autoInputAgent.SendAsync("terminate"); | 
					
						
							|  |  |  |         reply.GetContent()!.Should().Be("input"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // if the reply from echo is terminate message, and user input is exit, return the TERMINATE message | 
					
						
							|  |  |  |         autoAskUserInputMW = new HumanInputMiddleware( | 
					
						
							|  |  |  |             mode: HumanInputMode.AUTO, | 
					
						
							|  |  |  |             isTermination: async (messages, ct) => messages.Last().GetContent() == "terminate", | 
					
						
							|  |  |  |             getInput: () => "exit", | 
					
						
							|  |  |  |             exitKeyword: "exit"); | 
					
						
							|  |  |  |         autoInputAgent = agent.RegisterMiddleware(autoAskUserInputMW); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         reply = await autoInputAgent.SendAsync("terminate"); | 
					
						
							|  |  |  |         reply.IsGroupChatTerminateMessage().Should().BeTrue(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     [Fact] | 
					
						
							|  |  |  |     public async Task FunctionCallMiddlewareTestAsync() | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         var agent = new EchoAgent("echo"); | 
					
						
							| 
									
										
										
										
											2025-02-12 16:02:37 -08:00
										 |  |  |         var args = new AutoGen.Tests.MiddlewareTest.EchoSchema { message = "hello" }; // make the format check happy on linux | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  |         var argsJson = JsonSerializer.Serialize(args) ?? throw new InvalidOperationException("Failed to serialize args"); | 
					
						
							| 
									
										
										
										
											2024-11-03 09:18:32 -08:00
										 |  |  |         var functionCall = new ToolCall("Echo", argsJson); | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  |         var functionCallAgent = agent.RegisterMiddleware(async (messages, options, agent, ct) => | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             if (options?.Functions is null) | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 return await agent.GenerateReplyAsync(messages, options, ct); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-27 14:37:47 -07:00
										 |  |  |             return new ToolCallMessage(functionCall.FunctionName, functionCall.FunctionArguments, from: agent.Name); | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  |         }); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // test 1 | 
					
						
							|  |  |  |         // middleware should invoke function call if the message is a function call message | 
					
						
							|  |  |  |         var mw = new FunctionCallMiddleware( | 
					
						
							| 
									
										
										
										
											2024-11-03 09:18:32 -08:00
										 |  |  |             functionMap: new Dictionary<string, Func<string, Task<string>>> { { "Echo", EchoWrapper } }); | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |         var testAgent = agent.RegisterMiddleware(mw); | 
					
						
							| 
									
										
										
										
											2024-08-27 14:37:47 -07:00
										 |  |  |         var functionCallMessage = new ToolCallMessage(functionCall.FunctionName, functionCall.FunctionArguments, from: "user"); | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  |         var reply = await testAgent.SendAsync(functionCallMessage); | 
					
						
							|  |  |  |         reply.Should().BeOfType<ToolCallResultMessage>(); | 
					
						
							|  |  |  |         reply.GetContent()!.Should().Be("[FUNC] hello"); | 
					
						
							|  |  |  |         reply.From.Should().Be("echo"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // test 2 | 
					
						
							| 
									
										
										
										
											2024-11-03 09:18:32 -08:00
										 |  |  |         // middleware should work with AIFunction from M.E.A.I | 
					
						
							|  |  |  |         var getWeatherTool = AIFunctionFactory.Create(this.Echo); | 
					
						
							|  |  |  |         mw = new FunctionCallMiddleware([getWeatherTool]); | 
					
						
							|  |  |  |         testAgent = agent.RegisterMiddleware(mw); | 
					
						
							|  |  |  |         reply = await testAgent.SendAsync(functionCallMessage); | 
					
						
							|  |  |  |         reply.GetContent()!.Should().Be("[FUNC] hello"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // test 3 | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  |         // middleware should invoke function call if agent reply is a function call message | 
					
						
							|  |  |  |         mw = new FunctionCallMiddleware( | 
					
						
							|  |  |  |             functions: [this.EchoFunctionContract], | 
					
						
							| 
									
										
										
										
											2024-11-03 09:18:32 -08:00
										 |  |  |             functionMap: new Dictionary<string, Func<string, Task<string>>> { { "Echo", EchoWrapper } }); | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  |         testAgent = functionCallAgent.RegisterMiddleware(mw); | 
					
						
							|  |  |  |         reply = await testAgent.SendAsync("hello"); | 
					
						
							|  |  |  |         reply.GetContent()!.Should().Be("[FUNC] hello"); | 
					
						
							|  |  |  |         reply.From.Should().Be("echo"); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-03 09:18:32 -08:00
										 |  |  |         // test 4 | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  |         // middleware should return original reply if the reply from agent is not a function call message | 
					
						
							|  |  |  |         mw = new FunctionCallMiddleware( | 
					
						
							| 
									
										
										
										
											2024-11-03 09:18:32 -08:00
										 |  |  |             functionMap: new Dictionary<string, Func<string, Task<string>>> { { "Echo", EchoWrapper } }); | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  |         testAgent = agent.RegisterMiddleware(mw); | 
					
						
							|  |  |  |         reply = await testAgent.SendAsync("hello"); | 
					
						
							|  |  |  |         reply.GetContent()!.Should().Be("hello"); | 
					
						
							|  |  |  |         reply.From.Should().Be("echo"); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-03 09:18:32 -08:00
										 |  |  |         // test 5 | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  |         // middleware should return an error message if the function name is not available when invoking the function from previous agent reply | 
					
						
							|  |  |  |         mw = new FunctionCallMiddleware( | 
					
						
							| 
									
										
										
										
											2024-11-03 09:18:32 -08:00
										 |  |  |             functionMap: new Dictionary<string, Func<string, Task<string>>> { { "Echo2", EchoWrapper } }); | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  |         testAgent = agent.RegisterMiddleware(mw); | 
					
						
							|  |  |  |         reply = await testAgent.SendAsync(functionCallMessage); | 
					
						
							| 
									
										
										
										
											2024-11-03 09:18:32 -08:00
										 |  |  |         reply.GetContent()!.Should().Be("Function Echo is not available. Available functions are: Echo2"); | 
					
						
							| 
									
										
										
										
											2024-04-26 09:21:46 -07:00
										 |  |  |     } | 
					
						
							|  |  |  | } |