mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-10-24 14:29:06 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			131 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			C#
		
	
	
	
	
	
			
		
		
	
	
			131 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			C#
		
	
	
	
	
	
| // Copyright (c) Microsoft Corporation. All rights reserved.
 | |
| // KernelFunctionMiddlewareTests.cs
 | |
| 
 | |
| using System.ClientModel;
 | |
| using AutoGen.Core;
 | |
| using AutoGen.OpenAI;
 | |
| using AutoGen.OpenAI.Extension;
 | |
| using AutoGen.Tests;
 | |
| using Azure.AI.OpenAI;
 | |
| using FluentAssertions;
 | |
| using Microsoft.SemanticKernel;
 | |
| using Xunit;
 | |
| 
 | |
| namespace AutoGen.SemanticKernel.Tests;
 | |
| 
 | |
| [Trait("Category", "UnitV1")]
 | |
| public class KernelFunctionMiddlewareTests
 | |
| {
 | |
|     [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
 | |
|     public async Task ItRegisterKernelFunctionMiddlewareFromTestPluginTests()
 | |
|     {
 | |
|         var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
 | |
|         var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
 | |
|         var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
 | |
|         var openaiClient = new AzureOpenAIClient(
 | |
|             endpoint: new Uri(endpoint),
 | |
|             credential: new ApiKeyCredential(key));
 | |
| 
 | |
|         var kernel = new Kernel();
 | |
|         var plugin = kernel.ImportPluginFromType<TestPlugin>();
 | |
|         var kernelFunctionMiddleware = new KernelPluginMiddleware(kernel, plugin);
 | |
| 
 | |
|         var agent = new OpenAIChatAgent(openaiClient.GetChatClient(deployName), "assistant")
 | |
|             .RegisterMessageConnector()
 | |
|             .RegisterMiddleware(kernelFunctionMiddleware);
 | |
| 
 | |
|         var reply = await agent.SendAsync("what's the status of the light?");
 | |
|         reply.GetContent().Should().Be("off");
 | |
|         reply.Should().BeOfType<ToolCallAggregateMessage>();
 | |
|         if (reply is ToolCallAggregateMessage aggregateMessage)
 | |
|         {
 | |
|             var toolCallMessage = aggregateMessage.Message1;
 | |
|             toolCallMessage.ToolCalls.Should().HaveCount(1);
 | |
|             toolCallMessage.ToolCalls[0].FunctionName.Should().Be("GetState");
 | |
| 
 | |
|             var toolCallResultMessage = aggregateMessage.Message2;
 | |
|             toolCallResultMessage.ToolCalls.Should().HaveCount(1);
 | |
|             toolCallResultMessage.ToolCalls[0].Result.Should().Be("off");
 | |
|         }
 | |
| 
 | |
|         reply = await agent.SendAsync("change the status of the light to on");
 | |
|         reply.GetContent().Should().Be("The status of the light is now on");
 | |
|         reply.Should().BeOfType<ToolCallAggregateMessage>();
 | |
|         if (reply is ToolCallAggregateMessage aggregateMessage1)
 | |
|         {
 | |
|             var toolCallMessage = aggregateMessage1.Message1;
 | |
|             toolCallMessage.ToolCalls.Should().HaveCount(1);
 | |
|             toolCallMessage.ToolCalls[0].FunctionName.Should().Be("ChangeState");
 | |
| 
 | |
|             var toolCallResultMessage = aggregateMessage1.Message2;
 | |
|             toolCallResultMessage.ToolCalls.Should().HaveCount(1);
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
 | |
|     public async Task ItRegisterKernelFunctionMiddlewareFromMethodTests()
 | |
|     {
 | |
|         var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
 | |
|         var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
 | |
|         var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
 | |
|         var openaiClient = new AzureOpenAIClient(
 | |
|             endpoint: new Uri(endpoint),
 | |
|             credential: new ApiKeyCredential(key));
 | |
| 
 | |
|         var kernel = new Kernel();
 | |
|         var getWeatherMethod = kernel.CreateFunctionFromMethod((string location) => $"The weather in {location} is sunny.", functionName: "GetWeather", description: "Get the weather for a location.");
 | |
|         var createPersonObjectMethod = kernel.CreateFunctionFromMethod((string name, string email, int age) => new Person(name, email, age), functionName: "CreatePersonObject", description: "Creates a person object.");
 | |
|         var plugin = kernel.ImportPluginFromFunctions("plugin", [getWeatherMethod, createPersonObjectMethod]);
 | |
|         var kernelFunctionMiddleware = new KernelPluginMiddleware(kernel, plugin);
 | |
| 
 | |
|         var agent = new OpenAIChatAgent(chatClient: openaiClient.GetChatClient(deployName), "assistant")
 | |
|             .RegisterMessageConnector()
 | |
|             .RegisterMiddleware(kernelFunctionMiddleware);
 | |
| 
 | |
|         var reply = await agent.SendAsync("what's the weather in Seattle?");
 | |
|         reply.GetContent().Should().Be("The weather in Seattle is sunny.");
 | |
|         reply.Should().BeOfType<ToolCallAggregateMessage>();
 | |
|         if (reply is ToolCallAggregateMessage getWeatherMessage)
 | |
|         {
 | |
|             var toolCallMessage = getWeatherMessage.Message1;
 | |
|             toolCallMessage.ToolCalls.Should().HaveCount(1);
 | |
|             toolCallMessage.ToolCalls[0].FunctionName.Should().Be("GetWeather");
 | |
| 
 | |
|             var toolCallResultMessage = getWeatherMessage.Message2;
 | |
|             toolCallResultMessage.ToolCalls.Should().HaveCount(1);
 | |
|         }
 | |
| 
 | |
|         reply = await agent.SendAsync("Create a person object with name: John, email: 12345@gmail.com, age: 30");
 | |
|         reply.GetContent().Should().Be("Name: John, Email: 12345@gmail.com, Age: 30");
 | |
|         reply.Should().BeOfType<ToolCallAggregateMessage>();
 | |
|         if (reply is ToolCallAggregateMessage createPersonObjectMessage)
 | |
|         {
 | |
|             var toolCallMessage = createPersonObjectMessage.Message1;
 | |
|             toolCallMessage.ToolCalls.Should().HaveCount(1);
 | |
|             toolCallMessage.ToolCalls[0].FunctionName.Should().Be("CreatePersonObject");
 | |
| 
 | |
|             var toolCallResultMessage = createPersonObjectMessage.Message2;
 | |
|             toolCallResultMessage.ToolCalls.Should().HaveCount(1);
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| public class Person
 | |
| {
 | |
|     public Person(string name, string email, int age)
 | |
|     {
 | |
|         this.Name = name;
 | |
|         this.Email = email;
 | |
|         this.Age = age;
 | |
|     }
 | |
| 
 | |
|     public string Name { get; set; }
 | |
|     public string Email { get; set; }
 | |
|     public int Age { get; set; }
 | |
| 
 | |
|     public override string ToString()
 | |
|     {
 | |
|         return $"Name: {this.Name}, Email: {this.Email}, Age: {this.Age}";
 | |
|     }
 | |
| }
 | 
