2024-10-02 11:42:27 -07:00
// Copyright (c) Microsoft Corporation. All rights reserved.
2024-05-09 15:36:20 -07:00
// KernelFunctionMiddlewareTests.cs
using AutoGen.Core ;
2024-08-27 14:37:47 -07:00
using AutoGen.OpenAI ;
using AutoGen.OpenAI.Extension ;
2024-05-09 15:36:20 -07:00
using AutoGen.Tests ;
2024-08-27 14:37:47 -07:00
using Azure ;
2024-05-09 15:36:20 -07:00
using Azure.AI.OpenAI ;
using FluentAssertions ;
using Microsoft.SemanticKernel ;
namespace AutoGen.SemanticKernel.Tests ;
public class KernelFunctionMiddlewareTests
{
2024-06-14 06:53:12 -07:00
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
2024-05-09 15:36:20 -07:00
public async Task ItRegisterKernelFunctionMiddlewareFromTestPluginTests ( )
{
var endpoint = Environment . GetEnvironmentVariable ( "AZURE_OPENAI_ENDPOINT" ) ? ? throw new Exception ( "Please set AZURE_OPENAI_ENDPOINT environment variable." ) ;
var key = Environment . GetEnvironmentVariable ( "AZURE_OPENAI_API_KEY" ) ? ? throw new Exception ( "Please set AZURE_OPENAI_API_KEY environment variable." ) ;
2024-06-14 06:53:12 -07:00
var deployName = Environment . GetEnvironmentVariable ( "AZURE_OPENAI_DEPLOY_NAME" ) ? ? throw new Exception ( "Please set AZURE_OPENAI_DEPLOY_NAME environment variable." ) ;
2024-08-27 14:37:47 -07:00
var openaiClient = new AzureOpenAIClient (
endpoint : new Uri ( endpoint ) ,
credential : new AzureKeyCredential ( key ) ) ;
2024-05-09 15:36:20 -07:00
var kernel = new Kernel ( ) ;
var plugin = kernel . ImportPluginFromType < TestPlugin > ( ) ;
var kernelFunctionMiddleware = new KernelPluginMiddleware ( kernel , plugin ) ;
2024-08-27 14:37:47 -07:00
var agent = new OpenAIChatAgent ( openaiClient . GetChatClient ( deployName ) , "assistant" )
2024-05-09 15:36:20 -07:00
. RegisterMessageConnector ( )
. RegisterMiddleware ( kernelFunctionMiddleware ) ;
var reply = await agent . SendAsync ( "what's the status of the light?" ) ;
reply . GetContent ( ) . Should ( ) . Be ( "off" ) ;
2024-05-20 22:48:19 -07:00
reply . Should ( ) . BeOfType < ToolCallAggregateMessage > ( ) ;
if ( reply is ToolCallAggregateMessage aggregateMessage )
2024-05-09 15:36:20 -07:00
{
var toolCallMessage = aggregateMessage . Message1 ;
toolCallMessage . ToolCalls . Should ( ) . HaveCount ( 1 ) ;
toolCallMessage . ToolCalls [ 0 ] . FunctionName . Should ( ) . Be ( "GetState" ) ;
var toolCallResultMessage = aggregateMessage . Message2 ;
toolCallResultMessage . ToolCalls . Should ( ) . HaveCount ( 1 ) ;
toolCallResultMessage . ToolCalls [ 0 ] . Result . Should ( ) . Be ( "off" ) ;
}
reply = await agent . SendAsync ( "change the status of the light to on" ) ;
reply . GetContent ( ) . Should ( ) . Be ( "The status of the light is now on" ) ;
2024-05-20 22:48:19 -07:00
reply . Should ( ) . BeOfType < ToolCallAggregateMessage > ( ) ;
if ( reply is ToolCallAggregateMessage aggregateMessage1 )
2024-05-09 15:36:20 -07:00
{
var toolCallMessage = aggregateMessage1 . Message1 ;
toolCallMessage . ToolCalls . Should ( ) . HaveCount ( 1 ) ;
toolCallMessage . ToolCalls [ 0 ] . FunctionName . Should ( ) . Be ( "ChangeState" ) ;
var toolCallResultMessage = aggregateMessage1 . Message2 ;
toolCallResultMessage . ToolCalls . Should ( ) . HaveCount ( 1 ) ;
}
}
2024-06-14 06:53:12 -07:00
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
2024-05-09 15:36:20 -07:00
public async Task ItRegisterKernelFunctionMiddlewareFromMethodTests ( )
{
var endpoint = Environment . GetEnvironmentVariable ( "AZURE_OPENAI_ENDPOINT" ) ? ? throw new Exception ( "Please set AZURE_OPENAI_ENDPOINT environment variable." ) ;
var key = Environment . GetEnvironmentVariable ( "AZURE_OPENAI_API_KEY" ) ? ? throw new Exception ( "Please set AZURE_OPENAI_API_KEY environment variable." ) ;
2024-06-14 06:53:12 -07:00
var deployName = Environment . GetEnvironmentVariable ( "AZURE_OPENAI_DEPLOY_NAME" ) ? ? throw new Exception ( "Please set AZURE_OPENAI_DEPLOY_NAME environment variable." ) ;
2024-08-27 14:37:47 -07:00
var openaiClient = new AzureOpenAIClient (
endpoint : new Uri ( endpoint ) ,
credential : new AzureKeyCredential ( key ) ) ;
2024-05-09 15:36:20 -07:00
var kernel = new Kernel ( ) ;
var getWeatherMethod = kernel . CreateFunctionFromMethod ( ( string location ) = > $"The weather in {location} is sunny." , functionName : "GetWeather" , description : "Get the weather for a location." ) ;
var createPersonObjectMethod = kernel . CreateFunctionFromMethod ( ( string name , string email , int age ) = > new Person ( name , email , age ) , functionName : "CreatePersonObject" , description : "Creates a person object." ) ;
var plugin = kernel . ImportPluginFromFunctions ( "plugin" , [ getWeatherMethod , createPersonObjectMethod ] ) ;
var kernelFunctionMiddleware = new KernelPluginMiddleware ( kernel , plugin ) ;
2024-08-27 14:37:47 -07:00
var agent = new OpenAIChatAgent ( chatClient : openaiClient . GetChatClient ( deployName ) , "assistant" )
2024-05-09 15:36:20 -07:00
. RegisterMessageConnector ( )
. RegisterMiddleware ( kernelFunctionMiddleware ) ;
var reply = await agent . SendAsync ( "what's the weather in Seattle?" ) ;
reply . GetContent ( ) . Should ( ) . Be ( "The weather in Seattle is sunny." ) ;
2024-05-20 22:48:19 -07:00
reply . Should ( ) . BeOfType < ToolCallAggregateMessage > ( ) ;
if ( reply is ToolCallAggregateMessage getWeatherMessage )
2024-05-09 15:36:20 -07:00
{
var toolCallMessage = getWeatherMessage . Message1 ;
toolCallMessage . ToolCalls . Should ( ) . HaveCount ( 1 ) ;
toolCallMessage . ToolCalls [ 0 ] . FunctionName . Should ( ) . Be ( "GetWeather" ) ;
var toolCallResultMessage = getWeatherMessage . Message2 ;
toolCallResultMessage . ToolCalls . Should ( ) . HaveCount ( 1 ) ;
}
reply = await agent . SendAsync ( "Create a person object with name: John, email: 12345@gmail.com, age: 30" ) ;
reply . GetContent ( ) . Should ( ) . Be ( "Name: John, Email: 12345@gmail.com, Age: 30" ) ;
2024-05-20 22:48:19 -07:00
reply . Should ( ) . BeOfType < ToolCallAggregateMessage > ( ) ;
if ( reply is ToolCallAggregateMessage createPersonObjectMessage )
2024-05-09 15:36:20 -07:00
{
var toolCallMessage = createPersonObjectMessage . Message1 ;
toolCallMessage . ToolCalls . Should ( ) . HaveCount ( 1 ) ;
toolCallMessage . ToolCalls [ 0 ] . FunctionName . Should ( ) . Be ( "CreatePersonObject" ) ;
var toolCallResultMessage = createPersonObjectMessage . Message2 ;
toolCallResultMessage . ToolCalls . Should ( ) . HaveCount ( 1 ) ;
}
}
}
public class Person
{
public Person ( string name , string email , int age )
{
this . Name = name ;
this . Email = email ;
this . Age = age ;
}
public string Name { get ; set ; }
public string Email { get ; set ; }
public int Age { get ; set ; }
public override string ToString ( )
{
return $"Name: {this.Name}, Email: {this.Email}, Age: {this.Age}" ;
}
}