stop setting name field when assistant message contains tool call (#3481)

This commit is contained in:
Xiaoyun Zhang 2024-09-05 13:54:30 -07:00 committed by GitHub
parent 40cfe07a95
commit a44b86f26e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 170 additions and 17 deletions

View File

@ -335,7 +335,10 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments)); var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var textContent = message.GetContent() ?? string.Empty; var textContent = message.GetContent() ?? string.Empty;
var chatRequestMessage = new ChatRequestAssistantMessage(textContent) { Name = message.From };
// don't include the name field when it's tool call message.
// fix https://github.com/microsoft/autogen/issues/3437
var chatRequestMessage = new ChatRequestAssistantMessage(textContent);
foreach (var tc in toolCall) foreach (var tc in toolCall)
{ {
chatRequestMessage.ToolCalls.Add(tc); chatRequestMessage.ToolCalls.Add(tc);

View File

@ -322,7 +322,10 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
var toolCallParts = message.ToolCalls.Select((tc, i) => ChatToolCall.CreateFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments)); var toolCallParts = message.ToolCalls.Select((tc, i) => ChatToolCall.CreateFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var textContent = message.GetContent() ?? null; var textContent = message.GetContent() ?? null;
var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent) { ParticipantName = message.From };
// Don't set participant name for assistant when it is tool call
// fix https://github.com/microsoft/autogen/issues/3437
var chatRequestMessage = new AssistantChatMessage(toolCallParts, textContent);
return [chatRequestMessage]; return [chatRequestMessage];
} }

View File

@ -139,7 +139,7 @@
{ {
"Role": "assistant", "Role": "assistant",
"Content": [], "Content": [],
"Name": "assistant", "Name": null,
"TooCall": [ "TooCall": [
{ {
"Type": "Function", "Type": "Function",
@ -184,7 +184,7 @@
{ {
"Role": "assistant", "Role": "assistant",
"Content": [], "Content": [],
"Name": "assistant", "Name": null,
"TooCall": [ "TooCall": [
{ {
"Type": "Function", "Type": "Function",
@ -210,7 +210,7 @@
{ {
"Role": "assistant", "Role": "assistant",
"Content": [], "Content": [],
"Name": "assistant", "Name": null,
"TooCall": [ "TooCall": [
{ {
"Type": "Function", "Type": "Function",

View File

@ -27,6 +27,12 @@ public partial class OpenAIChatAgentTest
return $"The weather in {location} is sunny."; return $"The weather in {location} is sunny.";
} }
[Function]
public async Task<string> CalculateTaxAsync(string location, double income)
{
return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task BasicConversationTestAsync() public async Task BasicConversationTestAsync()
{ {
@ -246,6 +252,65 @@ public partial class OpenAIChatAgentTest
respond.GetContent()?.Should().NotBeNullOrEmpty(); respond.GetContent()?.Should().NotBeNullOrEmpty();
} }
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task ItProduceValidContentAfterFunctionCall()
{
// https://github.com/microsoft/autogen/issues/3437
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var options = new ChatCompletionOptions()
{
Temperature = 0.7f,
MaxTokens = 1,
};
var agentName = "assistant";
var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}");
var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny.");
var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName);
var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName);
var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName);
var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}");
var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000.");
var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName);
var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName);
var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName);
var chatHistory = new List<IMessage>()
{
new TextMessage(Role.User, "What's the weather in Seattle", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in Paris", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in New York", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in London", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"),
};
var agent = new OpenAIChatAgent(
chatClient: openaiClient.GetChatClient(deployName),
name: "assistant",
options: options)
.RegisterMessageConnector();
var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions
{
MaxToken = 1024,
Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract],
});
}
private OpenAIClient CreateOpenAIClientFromAzureOpenAI() private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
{ {
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");

View File

@ -276,7 +276,10 @@ public class OpenAIMessageTests
var innerMessage = msgs.Last(); var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>(); innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>();
var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content; var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content;
chatRequestMessage.ParticipantName.Should().Be("assistant"); // when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(1); chatRequestMessage.ToolCalls.Count().Should().Be(1);
chatRequestMessage.Content.First().Text.Should().Be("textContent"); chatRequestMessage.Content.First().Text.Should().Be("textContent");
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatToolCall>(); chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatToolCall>();
@ -307,7 +310,10 @@ public class OpenAIMessageTests
innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>(); innerMessage!.Should().BeOfType<MessageEnvelope<ChatMessage>>();
var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content; var chatRequestMessage = (AssistantChatMessage)((MessageEnvelope<ChatMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty(); chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.ParticipantName.Should().Be("assistant"); // when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.ParticipantName.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(2); chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++) for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{ {

View File

@ -81,7 +81,7 @@
{ {
"Role": "assistant", "Role": "assistant",
"Content": "", "Content": "",
"Name": "assistant", "Name": null,
"TooCall": [ "TooCall": [
{ {
"Type": "Function", "Type": "Function",
@ -126,7 +126,7 @@
{ {
"Role": "assistant", "Role": "assistant",
"Content": "", "Content": "",
"Name": "assistant", "Name": null,
"TooCall": [ "TooCall": [
{ {
"Type": "Function", "Type": "Function",
@ -152,7 +152,7 @@
{ {
"Role": "assistant", "Role": "assistant",
"Content": "", "Content": "",
"Name": "assistant", "Name": null,
"TooCall": [ "TooCall": [
{ {
"Type": "Function", "Type": "Function",

View File

@ -4,6 +4,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text.Json;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using AutoGen.OpenAI.V1.Extension; using AutoGen.OpenAI.V1.Extension;
@ -45,7 +46,11 @@ namespace AutoGen.OpenAI.V1.Tests
_output.WriteLine($"agent name: {agent.Name}"); _output.WriteLine($"agent name: {agent.Name}");
foreach (var message in messages) foreach (var message in messages)
{ {
_output.WriteLine(message.FormatMessage()); if (message is IMessage<object> envelope)
{
var json = JsonSerializer.Serialize(envelope.Content, new JsonSerializerOptions { WriteIndented = true });
_output.WriteLine(json);
}
} }
throw; throw;
@ -149,9 +154,9 @@ You create math question and ask student to answer it.
Then you check if the answer is correct. Then you check if the answer is correct.
If the answer is wrong, you ask student to fix it", If the answer is wrong, you ask student to fix it",
modelName: model) modelName: model)
.RegisterMessageConnector() .RegisterMiddleware(Print)
.RegisterStreamingMiddleware(functionCallMiddleware) .RegisterMiddleware(new OpenAIChatRequestMessageConnector())
.RegisterMiddleware(Print); .RegisterMiddleware(functionCallMiddleware);
return teacher; return teacher;
} }

View File

@ -22,7 +22,13 @@ public partial class OpenAIChatAgentTest
[Function] [Function]
public async Task<string> GetWeatherAsync(string location) public async Task<string> GetWeatherAsync(string location)
{ {
return $"The weather in {location} is sunny."; return $"[GetWeather] The weather in {location} is sunny.";
}
[Function]
public async Task<string> CalculateTaxAsync(string location, double income)
{
return $"[CalculateTax] The tax in {location} for income {income} is 1000.";
} }
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
@ -270,6 +276,64 @@ public partial class OpenAIChatAgentTest
action.Should().ThrowExactly<ArgumentException>().WithMessage("Messages should not be provided in options"); action.Should().ThrowExactly<ArgumentException>().WithMessage("Messages should not be provided in options");
} }
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task ItProduceValidContentAfterFunctionCall()
{
// https://github.com/microsoft/autogen/issues/3437
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = CreateOpenAIClientFromAzureOpenAI();
var options = new ChatCompletionsOptions(deployName, [])
{
Temperature = 0.7f,
MaxTokens = 1,
};
var agentName = "assistant";
var getWeatherToolCall = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}");
var getWeatherToolCallResult = new ToolCall(this.GetWeatherAsyncFunctionContract.Name, "{\"location\":\"Seattle\"}", "The weather in Seattle is sunny.");
var getWeatherToolCallMessage = new ToolCallMessage([getWeatherToolCall], from: agentName);
var getWeatherToolCallResultMessage = new ToolCallResultMessage([getWeatherToolCallResult], from: agentName);
var getWeatherAggregateMessage = new ToolCallAggregateMessage(getWeatherToolCallMessage, getWeatherToolCallResultMessage, from: agentName);
var calculateTaxToolCall = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}");
var calculateTaxToolCallResult = new ToolCall(this.CalculateTaxAsyncFunctionContract.Name, "{\"location\":\"Seattle\",\"income\":1000}", "The tax in Seattle for income 1000 is 1000.");
var calculateTaxToolCallMessage = new ToolCallMessage([calculateTaxToolCall], from: agentName);
var calculateTaxToolCallResultMessage = new ToolCallResultMessage([calculateTaxToolCallResult], from: agentName);
var calculateTaxAggregateMessage = new ToolCallAggregateMessage(calculateTaxToolCallMessage, calculateTaxToolCallResultMessage, from: agentName);
var chatHistory = new List<IMessage>()
{
new TextMessage(Role.User, "What's the weather in Seattle", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Seattle is sunny, now check the tax in seattle", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in Paris", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in Paris is sunny, now check the tax in Paris", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in New York", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in New York is sunny, now check the tax in New York", from: "admin"),
calculateTaxAggregateMessage,
new TextMessage(Role.User, "what's the weather in London", from: "user"),
getWeatherAggregateMessage,
new TextMessage(Role.User, "The weather in London is sunny, now check the tax in London", from: "admin"),
};
var agent = new OpenAIChatAgent(
openAIClient: openaiClient,
name: "assistant",
options: options)
.RegisterMessageConnector();
var res = await agent.GenerateReplyAsync(chatHistory, new GenerateReplyOptions
{
MaxToken = 1024,
Functions = [this.GetWeatherAsyncFunctionContract, this.CalculateTaxAsyncFunctionContract],
});
}
private OpenAIClient CreateOpenAIClientFromAzureOpenAI() private OpenAIClient CreateOpenAIClientFromAzureOpenAI()
{ {
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");

View File

@ -278,7 +278,10 @@ public class OpenAIMessageTests
var innerMessage = msgs.Last(); var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>(); innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content; var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Name.Should().Be("assistant"); // when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.Name.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(1); chatRequestMessage.ToolCalls.Count().Should().Be(1);
chatRequestMessage.Content.Should().Be("textContent"); chatRequestMessage.Content.Should().Be("textContent");
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>(); chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
@ -309,7 +312,11 @@ public class OpenAIMessageTests
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>(); innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content; var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty(); chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.Name.Should().Be("assistant");
// when the message is a tool call message
// the name field should not be set
// please visit OpenAIChatRequestMessageConnector class for more information
chatRequestMessage.Name.Should().BeNullOrEmpty();
chatRequestMessage.ToolCalls.Count().Should().Be(2); chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++) for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{ {