* fix bug and add tests

* update
This commit is contained in:
Xiaoyun Zhang 2024-05-20 20:19:07 -07:00 committed by GitHub
parent 31d2d37d88
commit 3e6f073373
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 145 additions and 16 deletions

View File

@ -92,5 +92,15 @@ public partial class Example03_Agent_FunctionCall
calculateTax.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
calculateTax.GetToolCalls().Should().HaveCount(1);
calculateTax.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));
// parallel function calls
var calculateTaxes = await agent.SendAsync("calculate tax: 100, 0.1; calculate tax: 200, 0.2");
calculateTaxes.GetContent().Should().Be("tax is 10\ntax is 40"); // "tax is 10\n tax is 40
calculateTaxes.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
calculateTaxes.GetToolCalls().Should().HaveCount(2);
calculateTaxes.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));
// send aggregate message back to llm to get the final result
var finalResult = await agent.SendAsync(calculateTaxes);
}
}

View File

@ -169,7 +169,7 @@ public static class MessageExtension
TextMessage textMessage => textMessage.Content,
Message msg => msg.Content,
ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls.Count == 1 ? toolCallResultMessage.ToolCalls.First().Result : null,
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => aggregateMessage.Message2.ToolCalls.Count == 1 ? aggregateMessage.Message2.ToolCalls.First().Result : null,
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => string.Join("\n", aggregateMessage.Message2.ToolCalls.Where(x => x.Result is not null).Select(x => x.Result)),
_ => null,
};
}

View File

@ -26,6 +26,8 @@ public class ToolCall
public string FunctionArguments { get; set; }
public string? ToolCallId { get; set; }
public string? Result { get; set; }
public override string ToString()

View File

@ -128,13 +128,13 @@ public class FunctionCallMiddleware : IStreamingMiddleware
if (this.functionMap?.TryGetValue(functionName, out var func) is true)
{
var result = await func(functionArguments);
toolCallResult.Add(new ToolCall(functionName, functionArguments, result));
toolCallResult.Add(new ToolCall(functionName, functionArguments, result) { ToolCallId = toolCall.ToolCallId });
}
else if (this.functionMap is not null)
{
var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}";
toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage));
toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage) { ToolCallId = toolCall.ToolCallId });
}
else
{
@ -156,7 +156,7 @@ public class FunctionCallMiddleware : IStreamingMiddleware
if (this.functionMap?.TryGetValue(fName, out var func) is true)
{
var result = await func(fArgs);
toolCallResult.Add(new ToolCall(fName, fArgs, result));
toolCallResult.Add(new ToolCall(fName, fArgs, result) { ToolCallId = toolCall.ToolCallId });
}
}

View File

@ -152,7 +152,7 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
.Where(tc => tc is ChatCompletionsFunctionToolCall)
.Select(tc => (ChatCompletionsFunctionToolCall)tc);
var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments));
var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id });
return new ToolCallMessage(toolCalls, from);
}
@ -322,7 +322,7 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent");
}
var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments));
var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From };
foreach (var tc in toolCall)
{
@ -336,7 +336,7 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
{
return message.ToolCalls
.Where(tc => tc.Result is not null)
.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName));
.Select((tc, i) => new ChatRequestToolMessage(tc.Result, tc.ToolCallId ?? $"{tc.FunctionName}_{i}"));
}
private IEnumerable<ChatRequestMessage> ProcessMessage(IAgent agent, Message message)

View File

@ -145,7 +145,7 @@
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_0"
}
],
"FunctionCallName": null,
@ -159,7 +159,7 @@
{
"Role": "tool",
"Content": "result",
"ToolCallId": "test"
"ToolCallId": "test_0"
}
]
},
@ -169,12 +169,12 @@
{
"Role": "tool",
"Content": "test",
"ToolCallId": "result"
"ToolCallId": "result_0"
},
{
"Role": "tool",
"Content": "test",
"ToolCallId": "result"
"ToolCallId": "result_1"
}
]
},
@ -190,13 +190,13 @@
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_0"
},
{
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_1"
}
],
"FunctionCallName": null,
@ -216,7 +216,7 @@
"Type": "Function",
"Name": "test",
"Arguments": "test",
"Id": "test"
"Id": "test_0"
}
],
"FunctionCallName": null,
@ -225,7 +225,7 @@
{
"Role": "tool",
"Content": "result",
"ToolCallId": "test"
"ToolCallId": "test_0"
}
]
}

View File

@ -293,6 +293,7 @@ public class OpenAIMessageTests
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First();
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be("test_0");
functionToolCall.Arguments.Should().Be("test");
return await innerAgent.GenerateReplyAsync(msgs);
})
@ -303,6 +304,41 @@ public class OpenAIMessageTests
await agent.GenerateReplyAsync([message]);
}
[Fact]
public async Task ItProcessParallelToolCallMessageAsync()
{
var middleware = new OpenAIChatRequestMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
var innerMessage = msgs.Last();
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().BeNullOrEmpty();
chatRequestMessage.Name.Should().Be("assistant");
chatRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
{
chatRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.ElementAt(i);
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be($"test_{i}");
functionToolCall.Arguments.Should().Be("test");
}
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
// user message
var toolCalls = new[]
{
new ToolCall("test", "test"),
new ToolCall("test", "test"),
};
IMessage message = new ToolCallMessage(toolCalls, "assistant");
await agent.GenerateReplyAsync([message]);
}
[Fact]
public async Task ItThrowExceptionWhenProcessingToolCallMessageFromUserAndStrictModeIsTrueAsync()
{
@ -326,7 +362,7 @@ public class OpenAIMessageTests
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be("test");
chatRequestMessage.ToolCallId.Should().Be("test_0");
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
@ -336,6 +372,37 @@ public class OpenAIMessageTests
await agent.GenerateReplyAsync([message]);
}
[Fact]
public async Task ItProcessParallelToolCallResultMessageAsync()
{
var middleware = new OpenAIChatRequestMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
msgs.Count().Should().Be(2);
for (int i = 0; i < msgs.Count(); i++)
{
var innerMessage = msgs.ElementAt(i);
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be($"test_{i}");
}
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
// user message
var toolCalls = new[]
{
new ToolCall("test", "test", "result"),
new ToolCall("test", "test", "result"),
};
IMessage message = new ToolCallResultMessage(toolCalls, "user");
await agent.GenerateReplyAsync([message]);
}
[Fact]
public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync()
{
@ -372,6 +439,7 @@ public class OpenAIMessageTests
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
chatRequestMessage.Content.Should().Be("result");
chatRequestMessage.ToolCallId.Should().Be("test_0");
var toolCallMessage = msgs.First();
toolCallMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
@ -381,6 +449,7 @@ public class OpenAIMessageTests
toolCallRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First();
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be("test_0");
functionToolCall.Arguments.Should().Be("test");
return await innerAgent.GenerateReplyAsync(msgs);
})
@ -393,6 +462,54 @@ public class OpenAIMessageTests
await agent.GenerateReplyAsync([aggregateMessage]);
}
[Fact]
public async Task ItProcessParallelFunctionCallMiddlewareMessageFromAssistantAsync()
{
var middleware = new OpenAIChatRequestMessageConnector();
var agent = new EchoAgent("assistant")
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
{
msgs.Count().Should().Be(3);
var toolCallMessage = msgs.First();
toolCallMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)toolCallMessage!).Content;
toolCallRequestMessage.Content.Should().BeNullOrEmpty();
toolCallRequestMessage.ToolCalls.Count().Should().Be(2);
for (int i = 0; i < toolCallRequestMessage.ToolCalls.Count(); i++)
{
toolCallRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType<ChatCompletionsFunctionToolCall>();
var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.ElementAt(i);
functionToolCall.Name.Should().Be("test");
functionToolCall.Id.Should().Be($"test_{i}");
functionToolCall.Arguments.Should().Be("test");
}
for (int i = 1; i < msgs.Count(); i++)
{
var toolCallResultMessage = msgs.ElementAt(i);
toolCallResultMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
var toolCallResultRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)toolCallResultMessage!).Content;
toolCallResultRequestMessage.Content.Should().Be("result");
toolCallResultRequestMessage.ToolCallId.Should().Be($"test_{i - 1}");
}
return await innerAgent.GenerateReplyAsync(msgs);
})
.RegisterMiddleware(middleware);
// user message
var toolCalls = new[]
{
new ToolCall("test", "test", "result"),
new ToolCall("test", "test", "result"),
};
var toolCallMessage = new ToolCallMessage(toolCalls, "assistant");
var toolCallResultMessage = new ToolCallResultMessage(toolCalls, "assistant");
var aggregateMessage = new AggregateMessage<ToolCallMessage, ToolCallResultMessage>(toolCallMessage, toolCallResultMessage, "assistant");
await agent.GenerateReplyAsync([aggregateMessage]);
}
[Fact]
public async Task ItConvertChatResponseMessageToTextMessageAsync()
{