mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-24 13:39:24 +00:00
parent
31d2d37d88
commit
3e6f073373
@ -92,5 +92,15 @@ public partial class Example03_Agent_FunctionCall
|
||||
calculateTax.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
|
||||
calculateTax.GetToolCalls().Should().HaveCount(1);
|
||||
calculateTax.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));
|
||||
|
||||
// parallel function calls
|
||||
var calculateTaxes = await agent.SendAsync("calculate tax: 100, 0.1; calculate tax: 200, 0.2");
|
||||
calculateTaxes.GetContent().Should().Be("tax is 10\ntax is 40"); // "tax is 10\n tax is 40
|
||||
calculateTaxes.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
|
||||
calculateTaxes.GetToolCalls().Should().HaveCount(2);
|
||||
calculateTaxes.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));
|
||||
|
||||
// send aggregate message back to llm to get the final result
|
||||
var finalResult = await agent.SendAsync(calculateTaxes);
|
||||
}
|
||||
}
|
||||
|
||||
@ -169,7 +169,7 @@ public static class MessageExtension
|
||||
TextMessage textMessage => textMessage.Content,
|
||||
Message msg => msg.Content,
|
||||
ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls.Count == 1 ? toolCallResultMessage.ToolCalls.First().Result : null,
|
||||
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => aggregateMessage.Message2.ToolCalls.Count == 1 ? aggregateMessage.Message2.ToolCalls.First().Result : null,
|
||||
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => string.Join("\n", aggregateMessage.Message2.ToolCalls.Where(x => x.Result is not null).Select(x => x.Result)),
|
||||
_ => null,
|
||||
};
|
||||
}
|
||||
|
||||
@ -26,6 +26,8 @@ public class ToolCall
|
||||
|
||||
public string FunctionArguments { get; set; }
|
||||
|
||||
public string? ToolCallId { get; set; }
|
||||
|
||||
public string? Result { get; set; }
|
||||
|
||||
public override string ToString()
|
||||
|
||||
@ -128,13 +128,13 @@ public class FunctionCallMiddleware : IStreamingMiddleware
|
||||
if (this.functionMap?.TryGetValue(functionName, out var func) is true)
|
||||
{
|
||||
var result = await func(functionArguments);
|
||||
toolCallResult.Add(new ToolCall(functionName, functionArguments, result));
|
||||
toolCallResult.Add(new ToolCall(functionName, functionArguments, result) { ToolCallId = toolCall.ToolCallId });
|
||||
}
|
||||
else if (this.functionMap is not null)
|
||||
{
|
||||
var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}";
|
||||
|
||||
toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage));
|
||||
toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage) { ToolCallId = toolCall.ToolCallId });
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -156,7 +156,7 @@ public class FunctionCallMiddleware : IStreamingMiddleware
|
||||
if (this.functionMap?.TryGetValue(fName, out var func) is true)
|
||||
{
|
||||
var result = await func(fArgs);
|
||||
toolCallResult.Add(new ToolCall(fName, fArgs, result));
|
||||
toolCallResult.Add(new ToolCall(fName, fArgs, result) { ToolCallId = toolCall.ToolCallId });
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -152,7 +152,7 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
|
||||
.Where(tc => tc is ChatCompletionsFunctionToolCall)
|
||||
.Select(tc => (ChatCompletionsFunctionToolCall)tc);
|
||||
|
||||
var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments));
|
||||
var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments) { ToolCallId = tc.Id });
|
||||
|
||||
return new ToolCallMessage(toolCalls, from);
|
||||
}
|
||||
@ -322,7 +322,7 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
|
||||
throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent");
|
||||
}
|
||||
|
||||
var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments));
|
||||
var toolCall = message.ToolCalls.Select((tc, i) => new ChatCompletionsFunctionToolCall(tc.ToolCallId ?? $"{tc.FunctionName}_{i}", tc.FunctionName, tc.FunctionArguments));
|
||||
var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty) { Name = message.From };
|
||||
foreach (var tc in toolCall)
|
||||
{
|
||||
@ -336,7 +336,7 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
|
||||
{
|
||||
return message.ToolCalls
|
||||
.Where(tc => tc.Result is not null)
|
||||
.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName));
|
||||
.Select((tc, i) => new ChatRequestToolMessage(tc.Result, tc.ToolCallId ?? $"{tc.FunctionName}_{i}"));
|
||||
}
|
||||
|
||||
private IEnumerable<ChatRequestMessage> ProcessMessage(IAgent agent, Message message)
|
||||
|
||||
@ -145,7 +145,7 @@
|
||||
"Type": "Function",
|
||||
"Name": "test",
|
||||
"Arguments": "test",
|
||||
"Id": "test"
|
||||
"Id": "test_0"
|
||||
}
|
||||
],
|
||||
"FunctionCallName": null,
|
||||
@ -159,7 +159,7 @@
|
||||
{
|
||||
"Role": "tool",
|
||||
"Content": "result",
|
||||
"ToolCallId": "test"
|
||||
"ToolCallId": "test_0"
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -169,12 +169,12 @@
|
||||
{
|
||||
"Role": "tool",
|
||||
"Content": "test",
|
||||
"ToolCallId": "result"
|
||||
"ToolCallId": "result_0"
|
||||
},
|
||||
{
|
||||
"Role": "tool",
|
||||
"Content": "test",
|
||||
"ToolCallId": "result"
|
||||
"ToolCallId": "result_1"
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -190,13 +190,13 @@
|
||||
"Type": "Function",
|
||||
"Name": "test",
|
||||
"Arguments": "test",
|
||||
"Id": "test"
|
||||
"Id": "test_0"
|
||||
},
|
||||
{
|
||||
"Type": "Function",
|
||||
"Name": "test",
|
||||
"Arguments": "test",
|
||||
"Id": "test"
|
||||
"Id": "test_1"
|
||||
}
|
||||
],
|
||||
"FunctionCallName": null,
|
||||
@ -216,7 +216,7 @@
|
||||
"Type": "Function",
|
||||
"Name": "test",
|
||||
"Arguments": "test",
|
||||
"Id": "test"
|
||||
"Id": "test_0"
|
||||
}
|
||||
],
|
||||
"FunctionCallName": null,
|
||||
@ -225,7 +225,7 @@
|
||||
{
|
||||
"Role": "tool",
|
||||
"Content": "result",
|
||||
"ToolCallId": "test"
|
||||
"ToolCallId": "test_0"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -293,6 +293,7 @@ public class OpenAIMessageTests
|
||||
chatRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
|
||||
var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.First();
|
||||
functionToolCall.Name.Should().Be("test");
|
||||
functionToolCall.Id.Should().Be("test_0");
|
||||
functionToolCall.Arguments.Should().Be("test");
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
@ -303,6 +304,41 @@ public class OpenAIMessageTests
|
||||
await agent.GenerateReplyAsync([message]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessParallelToolCallMessageAsync()
|
||||
{
|
||||
var middleware = new OpenAIChatRequestMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
|
||||
{
|
||||
var innerMessage = msgs.Last();
|
||||
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
|
||||
var chatRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
|
||||
chatRequestMessage.Content.Should().BeNullOrEmpty();
|
||||
chatRequestMessage.Name.Should().Be("assistant");
|
||||
chatRequestMessage.ToolCalls.Count().Should().Be(2);
|
||||
for (int i = 0; i < chatRequestMessage.ToolCalls.Count(); i++)
|
||||
{
|
||||
chatRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType<ChatCompletionsFunctionToolCall>();
|
||||
var functionToolCall = (ChatCompletionsFunctionToolCall)chatRequestMessage.ToolCalls.ElementAt(i);
|
||||
functionToolCall.Name.Should().Be("test");
|
||||
functionToolCall.Id.Should().Be($"test_{i}");
|
||||
functionToolCall.Arguments.Should().Be("test");
|
||||
}
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(middleware);
|
||||
|
||||
// user message
|
||||
var toolCalls = new[]
|
||||
{
|
||||
new ToolCall("test", "test"),
|
||||
new ToolCall("test", "test"),
|
||||
};
|
||||
IMessage message = new ToolCallMessage(toolCalls, "assistant");
|
||||
await agent.GenerateReplyAsync([message]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItThrowExceptionWhenProcessingToolCallMessageFromUserAndStrictModeIsTrueAsync()
|
||||
{
|
||||
@ -326,7 +362,7 @@ public class OpenAIMessageTests
|
||||
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
|
||||
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
|
||||
chatRequestMessage.Content.Should().Be("result");
|
||||
chatRequestMessage.ToolCallId.Should().Be("test");
|
||||
chatRequestMessage.ToolCallId.Should().Be("test_0");
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(middleware);
|
||||
@ -336,6 +372,37 @@ public class OpenAIMessageTests
|
||||
await agent.GenerateReplyAsync([message]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessParallelToolCallResultMessageAsync()
|
||||
{
|
||||
var middleware = new OpenAIChatRequestMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
|
||||
{
|
||||
msgs.Count().Should().Be(2);
|
||||
|
||||
for (int i = 0; i < msgs.Count(); i++)
|
||||
{
|
||||
var innerMessage = msgs.ElementAt(i);
|
||||
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
|
||||
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
|
||||
chatRequestMessage.Content.Should().Be("result");
|
||||
chatRequestMessage.ToolCallId.Should().Be($"test_{i}");
|
||||
}
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(middleware);
|
||||
|
||||
// user message
|
||||
var toolCalls = new[]
|
||||
{
|
||||
new ToolCall("test", "test", "result"),
|
||||
new ToolCall("test", "test", "result"),
|
||||
};
|
||||
IMessage message = new ToolCallResultMessage(toolCalls, "user");
|
||||
await agent.GenerateReplyAsync([message]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessFunctionCallMiddlewareMessageFromUserAsync()
|
||||
{
|
||||
@ -372,6 +439,7 @@ public class OpenAIMessageTests
|
||||
innerMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
|
||||
var chatRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)innerMessage!).Content;
|
||||
chatRequestMessage.Content.Should().Be("result");
|
||||
chatRequestMessage.ToolCallId.Should().Be("test_0");
|
||||
|
||||
var toolCallMessage = msgs.First();
|
||||
toolCallMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
|
||||
@ -381,6 +449,7 @@ public class OpenAIMessageTests
|
||||
toolCallRequestMessage.ToolCalls.First().Should().BeOfType<ChatCompletionsFunctionToolCall>();
|
||||
var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.First();
|
||||
functionToolCall.Name.Should().Be("test");
|
||||
functionToolCall.Id.Should().Be("test_0");
|
||||
functionToolCall.Arguments.Should().Be("test");
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
@ -393,6 +462,54 @@ public class OpenAIMessageTests
|
||||
await agent.GenerateReplyAsync([aggregateMessage]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItProcessParallelFunctionCallMiddlewareMessageFromAssistantAsync()
|
||||
{
|
||||
var middleware = new OpenAIChatRequestMessageConnector();
|
||||
var agent = new EchoAgent("assistant")
|
||||
.RegisterMiddleware(async (msgs, _, innerAgent, _) =>
|
||||
{
|
||||
msgs.Count().Should().Be(3);
|
||||
var toolCallMessage = msgs.First();
|
||||
toolCallMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
|
||||
var toolCallRequestMessage = (ChatRequestAssistantMessage)((MessageEnvelope<ChatRequestMessage>)toolCallMessage!).Content;
|
||||
toolCallRequestMessage.Content.Should().BeNullOrEmpty();
|
||||
toolCallRequestMessage.ToolCalls.Count().Should().Be(2);
|
||||
|
||||
for (int i = 0; i < toolCallRequestMessage.ToolCalls.Count(); i++)
|
||||
{
|
||||
toolCallRequestMessage.ToolCalls.ElementAt(i).Should().BeOfType<ChatCompletionsFunctionToolCall>();
|
||||
var functionToolCall = (ChatCompletionsFunctionToolCall)toolCallRequestMessage.ToolCalls.ElementAt(i);
|
||||
functionToolCall.Name.Should().Be("test");
|
||||
functionToolCall.Id.Should().Be($"test_{i}");
|
||||
functionToolCall.Arguments.Should().Be("test");
|
||||
}
|
||||
|
||||
for (int i = 1; i < msgs.Count(); i++)
|
||||
{
|
||||
var toolCallResultMessage = msgs.ElementAt(i);
|
||||
toolCallResultMessage!.Should().BeOfType<MessageEnvelope<ChatRequestMessage>>();
|
||||
var toolCallResultRequestMessage = (ChatRequestToolMessage)((MessageEnvelope<ChatRequestMessage>)toolCallResultMessage!).Content;
|
||||
toolCallResultRequestMessage.Content.Should().Be("result");
|
||||
toolCallResultRequestMessage.ToolCallId.Should().Be($"test_{i - 1}");
|
||||
}
|
||||
|
||||
return await innerAgent.GenerateReplyAsync(msgs);
|
||||
})
|
||||
.RegisterMiddleware(middleware);
|
||||
|
||||
// user message
|
||||
var toolCalls = new[]
|
||||
{
|
||||
new ToolCall("test", "test", "result"),
|
||||
new ToolCall("test", "test", "result"),
|
||||
};
|
||||
var toolCallMessage = new ToolCallMessage(toolCalls, "assistant");
|
||||
var toolCallResultMessage = new ToolCallResultMessage(toolCalls, "assistant");
|
||||
var aggregateMessage = new AggregateMessage<ToolCallMessage, ToolCallResultMessage>(toolCallMessage, toolCallResultMessage, "assistant");
|
||||
await agent.GenerateReplyAsync([aggregateMessage]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ItConvertChatResponseMessageToTextMessageAsync()
|
||||
{
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user