[.Net] fix #2695 and #2884 (#3069)

* add round robin orchestrator

* add constructor for orchestrators

* add tests

* revert change

* return single orchestrator

* address comment
This commit is contained in:
Xiaoyun Zhang 2024-07-10 15:12:42 -07:00 committed by GitHub
parent f55a98f32b
commit 4e95630fa9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 902 additions and 107 deletions

View File

@ -31,6 +31,7 @@
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.console" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitVersion)" />
<PackageReference Include="Moq" Version="4.20.70" />
</ItemGroup>
<ItemGroup Condition="'$(IsTestProject)' == 'true'">

View File

@ -67,7 +67,8 @@ public class AnthropicClientAgent : IStreamingAgent
Stream = shouldStream,
Temperature = (decimal?)options?.Temperature ?? _temperature,
Tools = _tools?.ToList(),
ToolChoice = _toolChoice ?? ToolChoice.Auto
ToolChoice = _toolChoice ?? (_tools is { Length: > 0 } ? ToolChoice.Auto : null),
StopSequences = options?.StopSequence?.ToArray(),
};
chatCompletionRequest.Messages = BuildMessages(messages);
@ -95,6 +96,22 @@ public class AnthropicClientAgent : IStreamingAgent
}
}
return chatMessages;
// merge messages with the same role
// fixing #2884
var mergedMessages = chatMessages.Aggregate(new List<ChatMessage>(), (acc, message) =>
{
if (acc.Count > 0 && acc.Last().Role == message.Role)
{
acc.Last().Content.AddRange(message.Content);
}
else
{
acc.Add(message);
}
return acc;
});
return mergedMessages;
}
}

View File

@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// AnthropicClient.cs
using System;
@ -90,13 +90,7 @@ public sealed class AnthropicClient : IDisposable
{
var res = await JsonSerializer.DeserializeAsync<ChatCompletionResponse>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)),
cancellationToken: cancellationToken);
if (res == null)
{
throw new Exception("Failed to deserialize response");
}
cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response");
if (res.Delta?.Type == "input_json_delta" && !string.IsNullOrEmpty(res.Delta.PartialJson) &&
currentEvent.ContentBlock != null)
{

View File

@ -7,10 +7,14 @@ using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public interface IAgent
public interface IAgentMetaInformation
{
public string Name { get; }
}
public interface IAgent : IAgentMetaInformation
{
/// <summary>
/// Generate reply
/// </summary>

View File

@ -100,8 +100,7 @@ public static class GroupChatExtension
var msg = @$"From {x.From}:
{x.GetContent()}
<eof_msg>
round #
{i}";
round # {i}";
return new TextMessage(Role.User, content: msg);
});

View File

@ -15,6 +15,7 @@ public class GroupChat : IGroupChat
private List<IAgent> agents = new List<IAgent>();
private IEnumerable<IMessage> initializeMessages = new List<IMessage>();
private Graph? workflow = null;
private readonly IOrchestrator orchestrator;
public IEnumerable<IMessage>? Messages { get; private set; }
@ -36,6 +37,37 @@ public class GroupChat : IGroupChat
this.initializeMessages = initializeMessages ?? new List<IMessage>();
this.workflow = workflow;
if (admin is not null)
{
this.orchestrator = new RolePlayOrchestrator(admin, workflow);
}
else if (workflow is not null)
{
this.orchestrator = new WorkflowOrchestrator(workflow);
}
else
{
this.orchestrator = new RoundRobinOrchestrator();
}
this.Validation();
}
/// <summary>
/// Create a group chat which uses the <paramref name="orchestrator"/> to decide the next speaker(s).
/// </summary>
/// <param name="members"></param>
/// <param name="orchestrator"></param>
/// <param name="initializeMessages"></param>
public GroupChat(
IEnumerable<IAgent> members,
IOrchestrator orchestrator,
IEnumerable<IMessage>? initializeMessages = null)
{
this.agents = members.ToList();
this.initializeMessages = initializeMessages ?? new List<IMessage>();
this.orchestrator = orchestrator;
this.Validation();
}
@ -64,12 +96,6 @@ public class GroupChat : IGroupChat
throw new Exception("All agents in the workflow must be in the group chat.");
}
}
// must provide one of admin or workflow
if (this.admin == null && this.workflow == null)
{
throw new Exception("Must provide one of admin or workflow.");
}
}
/// <summary>
@ -81,6 +107,7 @@ public class GroupChat : IGroupChat
/// <param name="currentSpeaker">current speaker</param>
/// <param name="conversationHistory">conversation history</param>
/// <returns>next speaker.</returns>
[Obsolete("Please use RolePlayOrchestrator or WorkflowOrchestrator")]
public async Task<IAgent> SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumerable<IMessage> conversationHistory)
{
var agentNames = this.agents.Select(x => x.Name).ToList();
@ -140,37 +167,40 @@ From {agentNames.First()}:
}
public async Task<IEnumerable<IMessage>> CallAsync(
IEnumerable<IMessage>? conversationWithName = null,
IEnumerable<IMessage>? chatHistory = null,
int maxRound = 10,
CancellationToken ct = default)
{
var conversationHistory = new List<IMessage>();
if (conversationWithName != null)
conversationHistory.AddRange(this.initializeMessages);
if (chatHistory != null)
{
conversationHistory.AddRange(conversationWithName);
conversationHistory.AddRange(chatHistory);
}
var roundLeft = maxRound;
var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
while (roundLeft > 0)
{
null => this.agents.First(),
_ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
};
var round = 0;
while (round < maxRound)
{
var currentSpeaker = await this.SelectNextSpeakerAsync(lastSpeaker, conversationHistory);
var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
conversationHistory.Add(result);
// if message is terminate message, then terminate the conversation
if (result?.IsGroupChatTerminateMessage() ?? false)
var orchestratorContext = new OrchestrationContext
{
Candidates = this.agents,
ChatHistory = conversationHistory,
};
var nextSpeaker = await this.orchestrator.GetNextSpeakerAsync(orchestratorContext, ct);
if (nextSpeaker == null)
{
break;
}
lastSpeaker = currentSpeaker;
round++;
var result = await nextSpeaker.GenerateReplyAsync(conversationHistory, cancellationToken: ct);
conversationHistory.Add(result);
if (result.IsGroupChatTerminateMessage())
{
return conversationHistory;
}
roundLeft--;
}
return conversationHistory;

View File

@ -3,9 +3,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
@ -25,76 +22,12 @@ public class SequentialGroupChat : RoundRobinGroupChat
/// <summary>
/// A group chat that allows agents to talk in a round-robin manner.
/// </summary>
public class RoundRobinGroupChat : IGroupChat
public class RoundRobinGroupChat : GroupChat
{
private readonly List<IAgent> agents = new List<IAgent>();
private readonly List<IMessage> initializeMessages = new List<IMessage>();
public RoundRobinGroupChat(
IEnumerable<IAgent> agents,
List<IMessage>? initializeMessages = null)
: base(agents, initializeMessages: initializeMessages)
{
this.agents.AddRange(agents);
this.initializeMessages = initializeMessages ?? new List<IMessage>();
}
/// <inheritdoc />
public void AddInitializeMessage(IMessage message)
{
this.SendIntroduction(message);
}
public async Task<IEnumerable<IMessage>> CallAsync(
IEnumerable<IMessage>? conversationWithName = null,
int maxRound = 10,
CancellationToken ct = default)
{
var conversationHistory = new List<IMessage>();
if (conversationWithName != null)
{
conversationHistory.AddRange(conversationWithName);
}
var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
{
null => this.agents.First(),
_ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
};
var round = 0;
while (round < maxRound)
{
var currentSpeaker = this.SelectNextSpeaker(lastSpeaker);
var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
conversationHistory.Add(result);
// if message is terminate message, then terminate the conversation
if (result?.IsGroupChatTerminateMessage() ?? false)
{
break;
}
lastSpeaker = currentSpeaker;
round++;
}
return conversationHistory;
}
public void SendIntroduction(IMessage message)
{
this.initializeMessages.Add(message);
}
private IAgent SelectNextSpeaker(IAgent currentSpeaker)
{
var index = this.agents.IndexOf(currentSpeaker);
if (index == -1)
{
throw new ArgumentException("The agent is not in the group chat", nameof(currentSpeaker));
}
var nextIndex = (index + 1) % this.agents.Count;
return this.agents[nextIndex];
}
}

View File

@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IOrchestrator.cs
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public class OrchestrationContext
{
public IEnumerable<IAgent> Candidates { get; set; } = Array.Empty<IAgent>();
public IEnumerable<IMessage> ChatHistory { get; set; } = Array.Empty<IMessage>();
}
public interface IOrchestrator
{
/// <summary>
/// Return the next agent as the next speaker. return null if no agent is selected.
/// </summary>
/// <param name="context">orchestration context, such as candidate agents and chat history.</param>
/// <param name="cancellationToken">cancellation token</param>
public Task<IAgent?> GetNextSpeakerAsync(
OrchestrationContext context,
CancellationToken cancellationToken = default);
}

View File

@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// RolePlayOrchestrator.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public class RolePlayOrchestrator : IOrchestrator
{
private readonly IAgent admin;
private readonly Graph? workflow = null;
public RolePlayOrchestrator(IAgent admin, Graph? workflow = null)
{
this.admin = admin;
this.workflow = workflow;
}
public async Task<IAgent?> GetNextSpeakerAsync(
OrchestrationContext context,
CancellationToken cancellationToken = default)
{
var candidates = context.Candidates.ToList();
if (candidates.Count == 0)
{
return null;
}
if (candidates.Count == 1)
{
return candidates.First();
}
// if there's a workflow
// and the next available agent from the workflow is in the group chat
// then return the next agent from the workflow
if (this.workflow != null)
{
var lastMessage = context.ChatHistory.LastOrDefault();
if (lastMessage == null)
{
return null;
}
var currentSpeaker = candidates.First(candidates => candidates.Name == lastMessage.From);
var nextAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, context.ChatHistory);
nextAgents = nextAgents.Where(nextAgent => candidates.Any(candidate => candidate.Name == nextAgent.Name));
candidates = nextAgents.ToList();
if (!candidates.Any())
{
return null;
}
if (candidates is { Count: 1 })
{
return candidates.First();
}
}
// In this case, since there are more than one available agents from the workflow for the next speaker
// the admin will be invoked to decide the next speaker
var agentNames = candidates.Select(candidate => candidate.Name);
var rolePlayMessage = new TextMessage(Role.User,
content: $@"You are in a role play game. Carefully read the conversation history and carry on the conversation.
The available roles are:
{string.Join(",", agentNames)}
Each message will start with 'From name:', e.g:
From {agentNames.First()}:
//your message//.");
var chatHistoryWithName = this.ProcessConversationsForRolePlay(context.ChatHistory);
var messages = new IMessage[] { rolePlayMessage }.Concat(chatHistoryWithName);
var response = await this.admin.GenerateReplyAsync(
messages: messages,
options: new GenerateReplyOptions
{
Temperature = 0,
MaxToken = 128,
StopSequence = [":"],
Functions = null,
},
cancellationToken: cancellationToken);
var name = response.GetContent() ?? throw new Exception("No name is returned.");
// remove From
name = name!.Substring(5);
var candidate = candidates.FirstOrDefault(x => x.Name!.ToLower() == name.ToLower());
if (candidate != null)
{
return candidate;
}
var errorMessage = $"The response from admin is {name}, which is either not in the candidates list or not in the correct format.";
throw new Exception(errorMessage);
}
private IEnumerable<IMessage> ProcessConversationsForRolePlay(IEnumerable<IMessage> messages)
{
return messages.Select((x, i) =>
{
var msg = @$"From {x.From}:
{x.GetContent()}
<eof_msg>
round # {i}";
return new TextMessage(Role.User, content: msg);
});
}
}

View File

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// RoundRobinOrchestrator.cs
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
/// <summary>
/// Return the next agent in a round-robin fashion.
/// <para>
/// If the last message is from one of the candidates, the next agent will be the next candidate in the list.
/// </para>
/// <para>
/// Otherwise, no agent will be selected. In this case, the orchestrator will return an empty list.
/// </para>
/// <para>
/// This orchestrator always return a single agent.
/// </para>
/// </summary>
public class RoundRobinOrchestrator : IOrchestrator
{
public async Task<IAgent?> GetNextSpeakerAsync(
OrchestrationContext context,
CancellationToken cancellationToken = default)
{
var lastMessage = context.ChatHistory.LastOrDefault();
if (lastMessage == null)
{
return null;
}
var candidates = context.Candidates.ToList();
var lastAgentIndex = candidates.FindIndex(a => a.Name == lastMessage.From);
if (lastAgentIndex == -1)
{
return null;
}
var nextAgentIndex = (lastAgentIndex + 1) % candidates.Count;
return candidates[nextAgentIndex];
}
}

View File

@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// WorkflowOrchestrator.cs
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public class WorkflowOrchestrator : IOrchestrator
{
private readonly Graph workflow;
public WorkflowOrchestrator(Graph workflow)
{
this.workflow = workflow;
}
public async Task<IAgent?> GetNextSpeakerAsync(
OrchestrationContext context,
CancellationToken cancellationToken = default)
{
var lastMessage = context.ChatHistory.LastOrDefault();
if (lastMessage == null)
{
return null;
}
var candidates = context.Candidates.ToList();
var currentSpeaker = candidates.FirstOrDefault(candidates => candidates.Name == lastMessage.From);
if (currentSpeaker == null)
{
return null;
}
var nextAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, context.ChatHistory);
nextAgents = nextAgents.Where(nextAgent => candidates.Any(candidate => candidate.Name == nextAgent.Name));
candidates = nextAgents.ToList();
if (!candidates.Any())
{
return null;
}
if (candidates is { Count: 1 })
{
return candidates.First();
}
else
{
throw new System.Exception("There are more than one available agents from the workflow for the next speaker.");
}
}
}

View File

@ -97,6 +97,7 @@ public class MistralClientAgent : IStreamingAgent
var chatHistory = BuildChatHistory(messages);
var chatRequest = new ChatCompletionRequest(model: _model, messages: chatHistory.ToList(), temperature: options?.Temperature, randomSeed: _randomSeed)
{
Stop = options?.StopSequence,
MaxTokens = options?.MaxToken,
ResponseFormat = _jsonOutput ? new ResponseFormat() { ResponseFormatType = "json_object" } : null,
};

View File

@ -105,6 +105,9 @@ public class ChatCompletionRequest
[JsonPropertyName("random_seed")]
public int? RandomSeed { get; set; }
[JsonPropertyName("stop")]
public string[]? Stop { get; set; }
[JsonPropertyName("tools")]
public List<FunctionTool>? Tools { get; set; }

View File

@ -32,6 +32,30 @@ public class AnthropicClientAgentTest
reply.From.Should().Be(agent.Name);
}
[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentMergeMessageWithSameRoleTests()
{
// this test is added to fix issue #2884
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku,
systemMessage: "You are a helpful AI assistant that convert user message to upper case")
.RegisterMessageConnector();
var uppCaseMessage = new TextMessage(Role.User, "abcdefg");
var anotherUserMessage = new TextMessage(Role.User, "hijklmn");
var assistantMessage = new TextMessage(Role.Assistant, "opqrst");
var anotherAssistantMessage = new TextMessage(Role.Assistant, "uvwxyz");
var yetAnotherUserMessage = new TextMessage(Role.User, "123456");
// just make sure it doesn't throw exception
var reply = await agent.SendAsync(chatHistory: [uppCaseMessage, anotherUserMessage, assistantMessage, anotherAssistantMessage, yetAnotherUserMessage]);
reply.GetContent().Should().NotBeNull();
}
[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task AnthropicAgentTestProcessImageAsync()
{

View File

@ -9,6 +9,7 @@
<ItemGroup>
<ProjectReference Include="..\..\sample\AutoGen.BasicSamples\AutoGen.BasicSample.csproj" />
<ProjectReference Include="..\..\src\AutoGen.Anthropic\AutoGen.Anthropic.csproj" />
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />
</ItemGroup>

View File

@ -0,0 +1,331 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// RolePlayOrchestratorTests.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Anthropic;
using AutoGen.Anthropic.Extensions;
using AutoGen.Anthropic.Utils;
using AutoGen.Gemini;
using AutoGen.Mistral;
using AutoGen.Mistral.Extension;
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
using Azure.AI.OpenAI;
using FluentAssertions;
using Moq;
using Xunit;
namespace AutoGen.Tests;
public class RolePlayOrchestratorTests
{
[Fact]
public async Task ItReturnNextSpeakerTestAsync()
{
var admin = Mock.Of<IAgent>();
Mock.Get(admin).Setup(x => x.Name).Returns("Admin");
Mock.Get(admin).Setup(x => x.GenerateReplyAsync(
It.IsAny<IEnumerable<IMessage>>(),
It.IsAny<GenerateReplyOptions>(),
It.IsAny<CancellationToken>()))
.Callback<IEnumerable<IMessage>, GenerateReplyOptions, CancellationToken>((messages, option, _) =>
{
// verify prompt
var rolePlayPrompt = messages.First().GetContent();
rolePlayPrompt.Should().Contain("You are in a role play game. Carefully read the conversation history and carry on the conversation");
rolePlayPrompt.Should().Contain("The available roles are:");
rolePlayPrompt.Should().Contain("Alice,Bob");
rolePlayPrompt.Should().Contain("From Alice:");
option.StopSequence.Should().BeEquivalentTo([":"]);
option.Temperature.Should().Be(0);
option.MaxToken.Should().Be(128);
option.Functions.Should().BeNull();
})
.ReturnsAsync(new TextMessage(Role.Assistant, "From Alice"));
var alice = new EchoAgent("Alice");
var bob = new EchoAgent("Bob");
var orchestrator = new RolePlayOrchestrator(admin);
var context = new OrchestrationContext
{
Candidates = [alice, bob],
ChatHistory = [],
};
var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker.Should().Be(alice);
}
[Fact]
public async Task ItReturnNullWhenNoCandidateIsAvailableAsync()
{
var admin = Mock.Of<IAgent>();
var orchestrator = new RolePlayOrchestrator(admin);
var context = new OrchestrationContext
{
Candidates = [],
ChatHistory = [],
};
var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker.Should().BeNull();
}
[Fact]
public async Task ItReturnCandidateWhenOnlyOneCandidateIsAvailableAsync()
{
var admin = Mock.Of<IAgent>();
var alice = new EchoAgent("Alice");
var orchestrator = new RolePlayOrchestrator(admin);
var context = new OrchestrationContext
{
Candidates = [alice],
ChatHistory = [],
};
var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker.Should().Be(alice);
}
[Fact]
public async Task ItThrowExceptionWhenAdminFailsToFollowPromptAsync()
{
var admin = Mock.Of<IAgent>();
Mock.Get(admin).Setup(x => x.Name).Returns("Admin");
Mock.Get(admin).Setup(x => x.GenerateReplyAsync(
It.IsAny<IEnumerable<IMessage>>(),
It.IsAny<GenerateReplyOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new TextMessage(Role.Assistant, "I don't know")); // admin fails to follow the prompt and returns an invalid message
var alice = new EchoAgent("Alice");
var bob = new EchoAgent("Bob");
var orchestrator = new RolePlayOrchestrator(admin);
var context = new OrchestrationContext
{
Candidates = [alice, bob],
ChatHistory = [],
};
var action = async () => await orchestrator.GetNextSpeakerAsync(context);
await action.Should().ThrowAsync<Exception>()
.WithMessage("The response from admin is 't know, which is either not in the candidates list or not in the correct format.");
}
[Fact]
public async Task ItSelectNextSpeakerFromWorkflowIfProvided()
{
var workflow = new Graph();
var alice = new EchoAgent("Alice");
var bob = new EchoAgent("Bob");
var charlie = new EchoAgent("Charlie");
workflow.AddTransition(Transition.Create(alice, bob));
workflow.AddTransition(Transition.Create(bob, charlie));
workflow.AddTransition(Transition.Create(charlie, alice));
var admin = Mock.Of<IAgent>();
var orchestrator = new RolePlayOrchestrator(admin, workflow);
var context = new OrchestrationContext
{
Candidates = [alice, bob, charlie],
ChatHistory =
[
new TextMessage(Role.User, "Hello, Bob", from: "Alice"),
],
};
var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker.Should().Be(bob);
}
[Fact]
public async Task ItReturnNullIfNoAvailableAgentFromWorkflowAsync()
{
var workflow = new Graph();
var alice = new EchoAgent("Alice");
var bob = new EchoAgent("Bob");
workflow.AddTransition(Transition.Create(alice, bob));
var admin = Mock.Of<IAgent>();
var orchestrator = new RolePlayOrchestrator(admin, workflow);
var context = new OrchestrationContext
{
Candidates = [alice, bob],
ChatHistory =
[
new TextMessage(Role.User, "Hello, Alice", from: "Bob"),
],
};
var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker.Should().BeNull();
}
[Fact]
public async Task ItUseCandidatesFromWorflowAsync()
{
var workflow = new Graph();
var alice = new EchoAgent("Alice");
var bob = new EchoAgent("Bob");
var charlie = new EchoAgent("Charlie");
workflow.AddTransition(Transition.Create(alice, bob));
workflow.AddTransition(Transition.Create(alice, charlie));
var admin = Mock.Of<IAgent>();
Mock.Get(admin).Setup(x => x.GenerateReplyAsync(
It.IsAny<IEnumerable<IMessage>>(),
It.IsAny<GenerateReplyOptions>(),
It.IsAny<CancellationToken>()))
.Callback<IEnumerable<IMessage>, GenerateReplyOptions, CancellationToken>((messages, option, _) =>
{
messages.First().IsSystemMessage().Should().BeTrue();
// verify prompt
var rolePlayPrompt = messages.First().GetContent();
rolePlayPrompt.Should().Contain("Bob,Charlie");
rolePlayPrompt.Should().Contain("From Bob:");
option.StopSequence.Should().BeEquivalentTo([":"]);
option.Temperature.Should().Be(0);
option.MaxToken.Should().Be(128);
option.Functions.Should().BeEmpty();
})
.ReturnsAsync(new TextMessage(Role.Assistant, "From Bob"));
var orchestrator = new RolePlayOrchestrator(admin, workflow);
var context = new OrchestrationContext
{
Candidates = [alice, bob],
ChatHistory =
[
new TextMessage(Role.User, "Hello, Bob", from: "Alice"),
],
};
var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker.Should().Be(bob);
}
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task GPT_3_5_CoderReviewerRunnerTestAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key));
var openAIChatAgent = new OpenAIChatAgent(
openAIClient: openaiClient,
name: "assistant",
modelName: deployName)
.RegisterMessageConnector();
await CoderReviewerRunnerTestAsync(openAIChatAgent);
}
[ApiKeyFact("GOOGLE_GEMINI_API_KEY")]
public async Task GoogleGemini_1_5_flash_001_CoderReviewerRunnerTestAsync()
{
var apiKey = Environment.GetEnvironmentVariable("GOOGLE_GEMINI_API_KEY") ?? throw new InvalidOperationException("GOOGLE_GEMINI_API_KEY is not set");
var geminiAgent = new GeminiChatAgent(
name: "gemini",
model: "gemini-1.5-flash-001",
apiKey: apiKey)
.RegisterMessageConnector();
await CoderReviewerRunnerTestAsync(geminiAgent);
}
[ApiKeyFact("ANTHROPIC_API_KEY")]
public async Task Claude3_Haiku_CoderReviewerRunnerTestAsync()
{
var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ?? throw new Exception("Please set ANTHROPIC_API_KEY environment variable.");
var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey);
var agent = new AnthropicClientAgent(
client,
name: "AnthropicAgent",
AnthropicConstants.Claude3Haiku,
systemMessage: "You are a helpful AI assistant that convert user message to upper case")
.RegisterMessageConnector();
await CoderReviewerRunnerTestAsync(agent);
}
[ApiKeyFact("MISTRAL_API_KEY")]
public async Task Mistra_7b_CoderReviewerRunnerTestAsync()
{
var apiKey = Environment.GetEnvironmentVariable("MISTRAL_API_KEY") ?? throw new InvalidOperationException("MISTRAL_API_KEY is not set.");
var client = new MistralClient(apiKey: apiKey);
var agent = new MistralClientAgent(
client: client,
name: "MistralClientAgent",
model: "open-mistral-7b")
.RegisterMessageConnector();
await CoderReviewerRunnerTestAsync(agent);
}
/// <summary>
/// This test is to mimic the conversation among coder, reviewer and runner.
/// The coder will write the code, the reviewer will review the code, and the runner will run the code.
/// </summary>
/// <param name="admin"></param>
/// <returns></returns>
public async Task CoderReviewerRunnerTestAsync(IAgent admin)
{
var coder = new EchoAgent("Coder");
var reviewer = new EchoAgent("Reviewer");
var runner = new EchoAgent("Runner");
var user = new EchoAgent("User");
var initializeMessage = new List<IMessage>
{
new TextMessage(Role.User, "Hello, I am user, I will provide the coding task, please write the code first, then review and run it", from: "User"),
new TextMessage(Role.User, "Hello, I am coder, I will write the code", from: "Coder"),
new TextMessage(Role.User, "Hello, I am reviewer, I will review the code", from: "Reviewer"),
new TextMessage(Role.User, "Hello, I am runner, I will run the code", from: "Runner"),
new TextMessage(Role.User, "how to print 'hello world' using C#", from: user.Name),
};
var chatHistory = new List<IMessage>()
{
new TextMessage(Role.User, """
```csharp
Console.WriteLine("Hello World");
```
""", from: coder.Name),
new TextMessage(Role.User, "The code looks good", from: reviewer.Name),
new TextMessage(Role.User, "The code runs successfully, the output is 'Hello World'", from: runner.Name),
};
var orchestrator = new RolePlayOrchestrator(admin);
foreach (var message in chatHistory)
{
var context = new OrchestrationContext
{
Candidates = [coder, reviewer, runner, user],
ChatHistory = initializeMessage,
};
var speaker = await orchestrator.GetNextSpeakerAsync(context);
speaker!.Name.Should().Be(message.From);
initializeMessage.Add(message);
}
// the last next speaker should be the user
var lastSpeaker = await orchestrator.GetNextSpeakerAsync(new OrchestrationContext
{
Candidates = [coder, reviewer, runner, user],
ChatHistory = initializeMessage,
});
lastSpeaker!.Name.Should().Be(user.Name);
}
}

View File

@ -0,0 +1,103 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// RoundRobinOrchestratorTests.cs
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using FluentAssertions;
using Xunit;
namespace AutoGen.Tests;
public class RoundRobinOrchestratorTests
{
[Fact]
public async Task ItReturnNextAgentAsync()
{
var orchestrator = new RoundRobinOrchestrator();
var context = new OrchestrationContext
{
Candidates = new List<IAgent>
{
new EchoAgent("Alice"),
new EchoAgent("Bob"),
new EchoAgent("Charlie"),
},
};
var messages = new List<IMessage>
{
new TextMessage(Role.User, "Hello, Alice", from: "Alice"),
new TextMessage(Role.User, "Hello, Bob", from: "Bob"),
new TextMessage(Role.User, "Hello, Charlie", from: "Charlie"),
};
var expected = new List<string> { "Bob", "Charlie", "Alice" };
var zip = messages.Zip(expected);
foreach (var (msg, expect) in zip)
{
context.ChatHistory = [msg];
var nextSpeaker = await orchestrator.GetNextSpeakerAsync(context);
Assert.Equal(expect, nextSpeaker!.Name);
}
}
[Fact]
public async Task ItReturnNullIfNoCandidates()
{
var orchestrator = new RoundRobinOrchestrator();
var context = new OrchestrationContext
{
Candidates = new List<IAgent>(),
ChatHistory = new List<IMessage>
{
new TextMessage(Role.User, "Hello, Alice", from: "Alice"),
},
};
var result = await orchestrator.GetNextSpeakerAsync(context);
Assert.Null(result);
}
[Fact]
public async Task ItReturnNullIfLastMessageIsNotFromCandidates()
{
var orchestrator = new RoundRobinOrchestrator();
var context = new OrchestrationContext
{
Candidates = new List<IAgent>
{
new EchoAgent("Alice"),
new EchoAgent("Bob"),
new EchoAgent("Charlie"),
},
ChatHistory = new List<IMessage>
{
new TextMessage(Role.User, "Hello, David", from: "David"),
},
};
var result = await orchestrator.GetNextSpeakerAsync(context);
result.Should().BeNull();
}
[Fact]
public async Task ItReturnEmptyListIfNoChatHistory()
{
var orchestrator = new RoundRobinOrchestrator();
var context = new OrchestrationContext
{
Candidates = new List<IAgent>
{
new EchoAgent("Alice"),
new EchoAgent("Bob"),
new EchoAgent("Charlie"),
},
};
var result = await orchestrator.GetNextSpeakerAsync(context);
result.Should().BeNull();
}
}

View File

@ -0,0 +1,112 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// WorkflowOrchestratorTests.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using FluentAssertions;
using Xunit;
namespace AutoGen.Tests;
public class WorkflowOrchestratorTests
{
[Fact]
public async Task ItReturnNextAgentAsync()
{
var workflow = new Graph();
var alice = new EchoAgent("Alice");
var bob = new EchoAgent("Bob");
var charlie = new EchoAgent("Charlie");
workflow.AddTransition(Transition.Create(alice, bob));
workflow.AddTransition(Transition.Create(bob, charlie));
workflow.AddTransition(Transition.Create(charlie, alice));
var orchestrator = new WorkflowOrchestrator(workflow);
var context = new OrchestrationContext
{
Candidates = [alice, bob, charlie]
};
var messages = new List<IMessage>
{
new TextMessage(Role.User, "Hello, Alice", from: "Alice"),
new TextMessage(Role.User, "Hello, Bob", from: "Bob"),
new TextMessage(Role.User, "Hello, Charlie", from: "Charlie"),
};
var expected = new List<string> { "Bob", "Charlie", "Alice" };
var zip = messages.Zip(expected);
foreach (var (msg, expect) in zip)
{
context.ChatHistory = [msg];
var result = await orchestrator.GetNextSpeakerAsync(context);
Assert.Equal(expect, result!.Name);
}
}
[Fact]
public async Task ItReturnNullIfNoCandidates()
{
var workflow = new Graph();
var orchestrator = new WorkflowOrchestrator(workflow);
var context = new OrchestrationContext
{
Candidates = new List<IAgent>(),
ChatHistory = new List<IMessage>
{
new TextMessage(Role.User, "Hello, Alice", from: "Alice"),
},
};
var nextAgent = await orchestrator.GetNextSpeakerAsync(context);
nextAgent.Should().BeNull();
}
[Fact]
public async Task ItReturnNullIfNoAgentIsAvailableFromWorkflowAsync()
{
var workflow = new Graph();
var alice = new EchoAgent("Alice");
var bob = new EchoAgent("Bob");
workflow.AddTransition(Transition.Create(alice, bob));
var orchestrator = new WorkflowOrchestrator(workflow);
var context = new OrchestrationContext
{
Candidates = [alice, bob],
ChatHistory = new List<IMessage>
{
new TextMessage(Role.User, "Hello, Bob", from: "Bob"),
},
};
var nextSpeaker = await orchestrator.GetNextSpeakerAsync(context);
nextSpeaker.Should().BeNull();
}
[Fact]
public async Task ItThrowExceptionWhenMoreThanOneAvailableAgentsFromWorkflowAsync()
{
var workflow = new Graph();
var alice = new EchoAgent("Alice");
var bob = new EchoAgent("Bob");
var charlie = new EchoAgent("Charlie");
workflow.AddTransition(Transition.Create(alice, bob));
workflow.AddTransition(Transition.Create(alice, charlie));
var orchestrator = new WorkflowOrchestrator(workflow);
var context = new OrchestrationContext
{
Candidates = [alice, bob, charlie],
ChatHistory = new List<IMessage>
{
new TextMessage(Role.User, "Hello, Bob", from: "Alice"),
},
};
var action = async () => await orchestrator.GetNextSpeakerAsync(context);
await action.Should().ThrowExactlyAsync<Exception>().WithMessage("There are more than one available agents from the workflow for the next speaker.");
}
}