From cf1365763ca5c7ea39ffceab6c06fe1663c32b0d Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Thu, 13 Mar 2025 12:41:13 -0400 Subject: [PATCH] feat: Implement AgentChat.NET Termination Conditions (#5839) Closes #5801 --- .../AgentChat/Abstractions/Messages.cs | 10 + .../AgentChat/Abstractions/Termination.cs | 94 ++-- .../AgentChat/Abstractions/Usage.cs | 10 + .../Terminations/ExternalTermination.cs | 53 ++ .../Terminations/FunctionCallTermination.cs | 51 ++ .../Terminations/HandoffTermination.cs | 53 ++ .../Terminations/MaxMessageTermination.cs | 52 ++ .../Terminations/SourceMatchTermination.cs | 50 ++ .../Terminations/TextMentionTermination.cs | 76 +++ .../Terminations/TextMessageTermination.cs | 63 +++ .../Terminations/TimeoutTermination.cs | 59 +++ .../Terminations/TokenUsageTermination.cs | 73 +++ .../AgentChatSmokeTest.cs | 1 + .../LifecycleObjectTests.cs | 1 + .../RunContextStackTests.cs | 1 + .../TerminationConditionTests.cs | 476 ++++++++++++++++++ 16 files changed, 1078 insertions(+), 45 deletions(-) create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Usage.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/ExternalTermination.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/FunctionCallTermination.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/HandoffTermination.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/MaxMessageTermination.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/SourceMatchTermination.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TextMentionTermination.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TextMessageTermination.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TimeoutTermination.cs create mode 100644 dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TokenUsageTermination.cs create mode 100644 dotnet/test/Microsoft.AutoGen.AgentChat.Tests/TerminationConditionTests.cs diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Messages.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Messages.cs index fcc1871bb..97e1d5b76 100644 --- a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Messages.cs +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Messages.cs @@ -25,6 +25,11 @@ public abstract class AgentMessage /// public required string Source { get; set; } + /// + /// The usage incurred when producing this message. + /// + public RequestUsage? ModelUsage { get; set; } + // IMPORTANT NOTE: Unlike the ITypeMarshal implementation in ProtobufTypeMarshal, // the .ToWire() call on this is intended to be used for directly converting a concrete message type to its leaf representation. // In the context of Protobuf these may not be the same due to discriminated union types being real types, as opposed to @@ -495,6 +500,11 @@ public class FunctionExecutionResult /// public required string Id { get; set; } + /// + /// The name of the function that was called. + /// + public required string Name { get; set; } + /// /// The result of calling the function. /// diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Termination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Termination.cs index 352b25298..21027cc36 100644 --- a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Termination.cs +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Termination.cs @@ -3,6 +3,29 @@ namespace Microsoft.AutoGen.AgentChat.Abstractions; +public static class TerminationConditionExtensions +{ + /// + /// Combine this termination condition with another using a logical OR. + /// + /// Another termination condition. + /// The combined termination condition, with appropriate short-circuiting. + public static ITerminationCondition Or(this ITerminationCondition this_, ITerminationCondition other) + { + return new CombinerCondition(CombinerCondition.Or, this_, other); + } + + /// + /// Combine this termination condition with another using a logical AND. + /// + /// Another termination condition. + /// The combined termination condition, with appropriate short-circuiting. + public static ITerminationCondition And(this ITerminationCondition this_, ITerminationCondition other) + { + return new CombinerCondition(CombinerCondition.And, this_, other); + } +} + /// /// A stateful condition that determines when a conversation should be terminated. /// @@ -12,7 +35,8 @@ namespace Microsoft.AutoGen.AgentChat.Abstractions; /// /// Once a termination condition has been reached, it must be before it can be used again. /// -/// Termination conditions can be combined using the and methods. +/// Termination conditions can be combined using the and +/// methods. /// public interface ITerminationCondition { @@ -38,23 +62,37 @@ public interface ITerminationCondition public void Reset(); /// - /// Combine this termination condition with another using a logical OR. + /// Combine two termination conditions with another using an associative, short-circuiting OR. /// - /// Another termination condition. - /// The combined termination condition, with appropriate short-circuiting. - public ITerminationCondition Or(ITerminationCondition other) + /// + /// The left-hand side termination condition. If this condition is already a disjunction, the RHS condition is added to the list of clauses. + /// + /// + /// The right-hand side termination condition. If the LHS condition is already a disjunction, this condition is added to the list of clauses. + /// + /// + /// The combined termination condition, with appropriate short-circuiting. + /// + public static ITerminationCondition operator |(ITerminationCondition left, ITerminationCondition right) { - return new CombinerCondition(CombinerCondition.Or, this, other); + return left.Or(right); } /// - /// Combine this termination condition with another using a logical AND. + /// Combine two termination conditions with another using an associative, short-circuiting AND. /// - /// Another termination condition. - /// The combined termination condition, with appropriate short-circuiting. - public ITerminationCondition And(ITerminationCondition other) + /// + /// The left-hand side termination condition. If this condition is already a conjunction, the RHS condition is added to the list of clauses. + /// + /// + /// The right-hand side termination condition. If the LHS condition is already a conjunction, this condition is added to the list of clauses. + /// + /// + /// The combined termination condition, with appropriate short-circuiting. + /// + public static ITerminationCondition operator &(ITerminationCondition left, ITerminationCondition right) { - return new CombinerCondition(CombinerCondition.And, this, other); + return left.And(right); } } @@ -167,38 +205,4 @@ internal sealed class CombinerCondition : ITerminationCondition return null; } - - /// - /// - /// If this condition is already a disjunction, the new condition is added to the list of clauses. - /// - ITerminationCondition ITerminationCondition.Or(ITerminationCondition other) - { - if (this.conjunction == Or) - { - this.clauses.Add(other); - return this; - } - else - { - return new CombinerCondition(Or, this, new CombinerCondition(Or, other)); - } - } - - /// - /// - /// If this condition is already a conjunction, the new condition is added to the list of clauses. - /// - ITerminationCondition ITerminationCondition.And(ITerminationCondition other) - { - if (this.conjunction == And) - { - this.clauses.Add(other); - return this; - } - else - { - return new CombinerCondition(And, this, new CombinerCondition(And, other)); - } - } } diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Usage.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Usage.cs new file mode 100644 index 000000000..78d04221d --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Abstractions/Usage.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Usage.cs + +namespace Microsoft.AutoGen.AgentChat.Abstractions; + +public struct RequestUsage +{ + public int PromptTokens { get; set; } + public int CompletionTokens { get; set; } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/ExternalTermination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/ExternalTermination.cs new file mode 100644 index 000000000..9d2155ce0 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/ExternalTermination.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ExternalTermination.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.Terminations; + +/// +/// A that is externally controlled by calling the method. +/// +public sealed class ExternalTermination : ITerminationCondition +{ + public ExternalTermination() + { + this.TerminationQueued = false; + this.IsTerminated = false; + } + + public bool TerminationQueued { get; private set; } + public bool IsTerminated { get; private set; } + + /// + /// Set the termination condition to terminated. + /// + public void Set() + { + this.TerminationQueued = true; + } + + public ValueTask CheckAndUpdateAsync(IList messages) + { + if (this.IsTerminated) + { + throw new TerminatedException(); + } + + if (this.TerminationQueued) + { + this.IsTerminated = true; + string message = "External termination requested."; + StopMessage result = new() { Content = message, Source = nameof(ExternalTermination) }; + return ValueTask.FromResult(result); + } + + return ValueTask.FromResult(null); + } + + public void Reset() + { + this.TerminationQueued = false; + this.IsTerminated = false; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/FunctionCallTermination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/FunctionCallTermination.cs new file mode 100644 index 000000000..20dd93769 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/FunctionCallTermination.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// FunctionCallTermination.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.Terminations; + +/// +/// Terminate the conversation if a with a specific name is received. +/// +public sealed class FunctionCallTermination : ITerminationCondition +{ + /// + /// Initializes a new instance of the class. + /// + /// The name of the function to look for in the messages. + public FunctionCallTermination(string functionName) + { + this.FunctionName = functionName; + this.IsTerminated = false; + } + + public string FunctionName { get; } + public bool IsTerminated { get; private set; } + + public ValueTask CheckAndUpdateAsync(IList messages) + { + if (this.IsTerminated) + { + throw new TerminatedException(); + } + + foreach (AgentMessage item in messages) + { + if (item is ToolCallExecutionEvent toolCallEvent && toolCallEvent.Content.Any(execution => execution.Name == this.FunctionName)) + { + this.IsTerminated = true; + string message = $"Function '{this.FunctionName}' was executed."; + StopMessage result = new() { Content = message, Source = nameof(FunctionCallTermination) }; + return ValueTask.FromResult(result); + } + } + + return ValueTask.FromResult(null); + } + + public void Reset() + { + this.IsTerminated = false; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/HandoffTermination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/HandoffTermination.cs new file mode 100644 index 000000000..229466963 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/HandoffTermination.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// HandoffTermination.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.Terminations; + +/// +/// Terminate the conversation if a with the given +/// is received. +/// +public sealed class HandoffTermination : ITerminationCondition +{ + /// + /// Initializes a new instance of the class. + /// + /// The target of the handoff message. + public HandoffTermination(string target) + { + this.Target = target; + this.IsTerminated = false; + } + + public string Target { get; } + public bool IsTerminated { get; private set; } + + public ValueTask CheckAndUpdateAsync(IList messages) + { + if (this.IsTerminated) + { + throw new TerminatedException(); + } + + foreach (AgentMessage item in messages) + { + if (item is HandoffMessage handoffMessage && handoffMessage.Target == this.Target) + { + this.IsTerminated = true; + + string message = $"Handoff to {handoffMessage.Target} from {handoffMessage.Source} detected."; + StopMessage result = new() { Content = message, Source = nameof(HandoffTermination) }; + return ValueTask.FromResult(result); + } + } + + return ValueTask.FromResult(null); + } + + public void Reset() + { + this.IsTerminated = false; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/MaxMessageTermination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/MaxMessageTermination.cs new file mode 100644 index 000000000..fa2275f46 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/MaxMessageTermination.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// MaxMessageTermination.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.Terminations; + +/// +/// Terminate the conversation after a maximum number of messages have been exchanged. +/// +public sealed class MaxMessageTermination : ITerminationCondition +{ + /// + /// Initializes a new instance of the class. + /// + /// The maximum number of messages allowed in the conversation. + public MaxMessageTermination(int maxMessages, bool includeAgentEvent = false) + { + this.MaxMessages = maxMessages; + this.MessageCount = 0; + this.IncludeAgentEvent = includeAgentEvent; + } + + public int MaxMessages { get; } + public int MessageCount { get; private set; } + public bool IncludeAgentEvent { get; } + + public bool IsTerminated => this.MessageCount >= this.MaxMessages; + + public ValueTask CheckAndUpdateAsync(IList messages) + { + if (this.IsTerminated) + { + throw new TerminatedException(); + } + + this.MessageCount += messages.Where(m => this.IncludeAgentEvent || m is not AgentEvent).Count(); + + if (this.IsTerminated) + { + StopMessage result = new() { Content = "Max message count reached", Source = nameof(MaxMessageTermination) }; + return ValueTask.FromResult(result); + } + + return ValueTask.FromResult(null); + } + + public void Reset() + { + this.MessageCount = 0; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/SourceMatchTermination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/SourceMatchTermination.cs new file mode 100644 index 000000000..e938e0e78 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/SourceMatchTermination.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SourceMatchTermination.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.Terminations; + +/// +/// Terminate the conversation after a specific source responds. +/// +public sealed class SourceMatchTermination : ITerminationCondition +{ + /// + /// Initializes a new instance of the class. + /// + /// List of source names to terminate the conversation. + public SourceMatchTermination(params IEnumerable sources) + { + this.Sources = [.. sources]; + } + + public HashSet Sources { get; } + public bool IsTerminated { get; private set; } + + public ValueTask CheckAndUpdateAsync(IList messages) + { + if (this.IsTerminated) + { + throw new TerminatedException(); + } + + foreach (AgentMessage item in messages) + { + if (this.Sources.Contains(item.Source)) + { + this.IsTerminated = true; + string message = $"'{item.Source}' answered."; + StopMessage result = new() { Content = message, Source = nameof(SourceMatchTermination) }; + return ValueTask.FromResult(result); + } + } + + return ValueTask.FromResult(null); + } + + public void Reset() + { + this.IsTerminated = false; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TextMentionTermination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TextMentionTermination.cs new file mode 100644 index 000000000..95684a2bb --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TextMentionTermination.cs @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TextMentionTermination.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; +using Microsoft.Extensions.AI; + +namespace Microsoft.AutoGen.AgentChat.Terminations; + +/// +/// Terminate the conversation if a specific text is mentioned. +/// +public sealed class TextMentionTermination : ITerminationCondition +{ + /// + /// Initializes a new instance of the class. + /// + /// The text to look for in the messages. + /// Check only the messages of the specified agents for the text to look for. + public TextMentionTermination(string targetText, IEnumerable? sources = null) + { + this.TargetText = targetText; + this.Sources = sources != null ? [.. sources] : null; + this.IsTerminated = false; + } + + public string TargetText { get; } + public HashSet? Sources { get; } + + public bool IsTerminated { get; private set; } + + private bool CheckMultiModalData(MultiModalData data) + { + return data.ContentType == MultiModalData.Type.String && + ((TextContent)data.AIContent).Text.Contains(this.TargetText); + } + + public ValueTask CheckAndUpdateAsync(IList messages) + { + if (this.IsTerminated) + { + throw new TerminatedException(); + } + + foreach (AgentMessage item in messages) + { + if (this.Sources != null && !this.Sources.Contains(item.Source)) + { + continue; + } + + bool hasMentions = item switch + { + TextMessage textMessage => textMessage.Content.Contains(this.TargetText), + MultiModalMessage multiModalMessage => multiModalMessage.Content.Any(CheckMultiModalData), + StopMessage stopMessage => stopMessage.Content.Contains(this.TargetText), + ToolCallSummaryMessage toolCallSummaryMessage => toolCallSummaryMessage.Content.Contains(this.TargetText), + + _ => false + }; + + if (hasMentions) + { + this.IsTerminated = true; + StopMessage result = new() { Content = "Text mention received", Source = nameof(TextMentionTermination) }; + return ValueTask.FromResult(result); + } + } + + return ValueTask.FromResult(null); + } + + public void Reset() + { + this.IsTerminated = false; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TextMessageTermination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TextMessageTermination.cs new file mode 100644 index 000000000..392bd281f --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TextMessageTermination.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TextMessageTermination.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.Terminations; + +/// +/// Terminate the conversation if a is received. +/// +/// This termination condition checks for TextMessage instances in the message sequence. When a TextMessage is found, +/// it terminates the conversation if either: +/// +/// +/// No source was specified (terminates on any ) +/// The message source matches the specified source +/// +/// +/// +public sealed class TextMessageTermination : ITerminationCondition +{ + /// + /// Initializes a new instance of the class. + /// + /// + /// The source name to match against incoming messages. If null, matches any source. + /// Defaults to null. + /// + public TextMessageTermination(string? source = null) + { + this.Source = source; + this.IsTerminated = false; + } + + public string? Source { get; } + public bool IsTerminated { get; private set; } + + public ValueTask CheckAndUpdateAsync(IList messages) + { + if (this.IsTerminated) + { + throw new TerminatedException(); + } + + foreach (AgentMessage item in messages) + { + if (item is TextMessage textMessage && (this.Source == null || textMessage.Source == this.Source)) + { + this.IsTerminated = true; + string message = $"Text message received from '{textMessage.Source}'."; + StopMessage result = new() { Content = message, Source = nameof(TextMessageTermination) }; + return ValueTask.FromResult(result); + } + } + + return ValueTask.FromResult(null); + } + + public void Reset() + { + this.IsTerminated = false; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TimeoutTermination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TimeoutTermination.cs new file mode 100644 index 000000000..015709d2e --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TimeoutTermination.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TimeoutTermination.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.Terminations; + +/// +/// Terminate the conversation after the specified duration has passed. +/// +public sealed class TimeoutTermination : ITerminationCondition +{ + /// + /// Initializes a new instance of the class. + /// + /// The maximum duration before terminating the conversation. + public TimeoutTermination(TimeSpan timeout) + { + this.Timeout = timeout; + this.StartTime = DateTime.UtcNow; + } + + /// + /// Initializes a new instance of the class. + /// + /// The maximum duration in seconds before terminating the conversation. + public TimeoutTermination(float seconds) : this(TimeSpan.FromSeconds(seconds)) + { + } + + public TimeSpan Timeout { get; } + public DateTime StartTime { get; private set; } + + public bool IsTerminated { get; private set; } + + public ValueTask CheckAndUpdateAsync(IList messages) + { + if (this.IsTerminated) + { + throw new TerminatedException(); + } + + if (DateTime.UtcNow - this.StartTime >= this.Timeout) + { + this.IsTerminated = true; + string message = $"Timeout of {this.Timeout.TotalSeconds} seconds reached."; + StopMessage result = new() { Content = message, Source = nameof(TimeoutTermination) }; + return ValueTask.FromResult(result); + } + + return ValueTask.FromResult(null); + } + + public void Reset() + { + this.IsTerminated = false; + this.StartTime = DateTime.UtcNow; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TokenUsageTermination.cs b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TokenUsageTermination.cs new file mode 100644 index 000000000..a20ca4da2 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/AgentChat/Terminations/TokenUsageTermination.cs @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TokenUsageTermination.cs + +using Microsoft.AutoGen.AgentChat.Abstractions; + +namespace Microsoft.AutoGen.AgentChat.Terminations; + +/// +/// Terminate the conversation if the token usage limit is reached. +/// +public sealed class TokenUsageTermination : ITerminationCondition +{ + /// + /// Initializes a new instance of the class. + /// + /// The maximum total number of tokens allowed in the conversation. + /// The maximum number of prompt tokens allowed in the conversation. + /// The maximum number of completion tokens allowed in the conversation. + public TokenUsageTermination(int? maxTotalTokens = null, int? maxPromptTokens = null, int? maxCompletionTokens = null) + { + this.MaxTotalTokens = maxTotalTokens; + this.MaxPromptTokens = maxPromptTokens; + this.MaxCompletionTokens = maxCompletionTokens; + + this.PromptTokenCount = 0; + this.CompletionTokenCount = 0; + } + + public int? MaxTotalTokens { get; } + public int? MaxPromptTokens { get; } + public int? MaxCompletionTokens { get; } + + public int TotalTokenCount => this.PromptTokenCount + this.CompletionTokenCount; + public int PromptTokenCount { get; private set; } + public int CompletionTokenCount { get; private set; } + + public bool IsTerminated => + (this.MaxTotalTokens != null && this.TotalTokenCount >= this.MaxTotalTokens) || + (this.MaxPromptTokens != null && this.PromptTokenCount >= this.MaxPromptTokens) || + (this.MaxCompletionTokens != null && this.CompletionTokenCount >= this.MaxCompletionTokens); + + public ValueTask CheckAndUpdateAsync(IList messages) + { + if (this.IsTerminated) + { + throw new TerminatedException(); + } + + foreach (AgentMessage item in messages) + { + if (item.ModelUsage is RequestUsage usage) + { + this.PromptTokenCount += usage.PromptTokens; + this.CompletionTokenCount += usage.CompletionTokens; + } + } + + if (this.IsTerminated) + { + string message = $"Token usage limit reached, total token count: {this.TotalTokenCount}, prompt token count: {this.PromptTokenCount}, completion token count: {this.CompletionTokenCount}."; + StopMessage result = new() { Content = message, Source = nameof(TokenUsageTermination) }; + return ValueTask.FromResult(result); + } + + return ValueTask.FromResult(null); + } + + public void Reset() + { + this.PromptTokenCount = 0; + this.CompletionTokenCount = 0; + } +} diff --git a/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/AgentChatSmokeTest.cs b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/AgentChatSmokeTest.cs index 97305b28b..1c9262e43 100644 --- a/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/AgentChatSmokeTest.cs +++ b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/AgentChatSmokeTest.cs @@ -12,6 +12,7 @@ using Xunit; namespace Microsoft.AutoGen.AgentChat.Tests; +[Trait("Category", "UnitV2")] public class AgentChatSmokeTest { public class SpeakMessageAgent : ChatAgentBase diff --git a/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/LifecycleObjectTests.cs b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/LifecycleObjectTests.cs index 4e8f4cdc5..5daa0de29 100644 --- a/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/LifecycleObjectTests.cs +++ b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/LifecycleObjectTests.cs @@ -72,6 +72,7 @@ internal sealed class LifecycleObjectFixture : LifecycleObject } } +[Trait("Category", "UnitV2")] public class LifecycleObjectTests { /* diff --git a/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/RunContextStackTests.cs b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/RunContextStackTests.cs index fde9ecdca..89e8b4d66 100644 --- a/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/RunContextStackTests.cs +++ b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/RunContextStackTests.cs @@ -8,6 +8,7 @@ using Xunit; namespace Microsoft.AutoGen.AgentChat.Tests; +[Trait("Category", "UnitV2")] public class RunContextStackTests { public static IRunContextLayer CreateLayer(Action>? setupAction = null) diff --git a/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/TerminationConditionTests.cs b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/TerminationConditionTests.cs new file mode 100644 index 000000000..db54e456a --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.AgentChat.Tests/TerminationConditionTests.cs @@ -0,0 +1,476 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TerminationConditionTests.cs + +using FluentAssertions; +using Microsoft.AutoGen.AgentChat.Abstractions; +using Microsoft.AutoGen.AgentChat.Terminations; +using Microsoft.Extensions.AI; +using Xunit; + +namespace Microsoft.AutoGen.AgentChat.Tests; + +[Trait("Category", "UnitV2")] +public static class TerminationExtensions +{ + public static async Task InvokeExpectingNullAsync(this TTermination termination, IList messages, bool reset = true) + where TTermination : ITerminationCondition + { + (await termination.CheckAndUpdateAsync(messages)).Should().BeNull(); + termination.IsTerminated.Should().BeFalse(); + + if (reset) + { + termination.Reset(); + } + } + + private static readonly HashSet AnonymousTerminationConditions = ["CombinerCondition", nameof(ITerminationCondition)]; + public static async Task InvokeExpectingStopAsync(this TTermination termination, IList messages, bool reset = true) + where TTermination : ITerminationCondition + { + StopMessage? stopMessage = await termination.CheckAndUpdateAsync(messages); + stopMessage.Should().NotBeNull(); + + string name = typeof(TTermination).Name; + if (!AnonymousTerminationConditions.Contains(name)) + { + stopMessage!.Source.Should().Be(typeof(TTermination).Name); + } + + termination.IsTerminated.Should().BeTrue(); + + if (reset) + { + termination.Reset(); + } + } + + public static async Task InvokeExpectingFailureAsync(this TTermination termination, IList messages, bool reset = true) + where TTermination : ITerminationCondition + { + Func failureAction = () => termination.CheckAndUpdateAsync(messages).AsTask(); + await failureAction.Should().ThrowAsync(); + termination.IsTerminated.Should().BeTrue(); + + if (reset) + { + termination.Reset(); + } + } +} + +public class TerminationConditionTests +{ + [Fact] + public async Task Test_HandoffTermination() + { + HandoffTermination termination = new("target"); + termination.IsTerminated.Should().BeFalse(); + + TextMessage textMessage = new() { Content = "Hello", Source = "user" }; + HandoffMessage targetHandoffMessage = new() { Target = "target", Source = "user", Context = "Hello" }; + HandoffMessage otherHandoffMessage = new() { Target = "another", Source = "user", Context = "Hello" }; + + await termination.InvokeExpectingNullAsync([]); + await termination.InvokeExpectingNullAsync([textMessage]); + await termination.InvokeExpectingStopAsync([targetHandoffMessage]); + await termination.InvokeExpectingNullAsync([otherHandoffMessage]); + await termination.InvokeExpectingStopAsync([textMessage, targetHandoffMessage], reset: false); + + await termination.InvokeExpectingFailureAsync([], reset: false); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + } + + [Fact] + public async Task StopMessageTermination() + { + StopMessageTermination termination = new(); + termination.IsTerminated.Should().BeFalse(); + + TextMessage textMessage = new() { Content = "Hello", Source = "user" }; + TextMessage otherMessage = new() { Content = "World", Source = "aser" }; + StopMessage stopMessage = new() { Content = "Stop", Source = "user" }; + + await termination.InvokeExpectingNullAsync([]); + await termination.InvokeExpectingNullAsync([textMessage]); + await termination.InvokeExpectingStopAsync([stopMessage]); + await termination.InvokeExpectingNullAsync([textMessage, otherMessage]); + await termination.InvokeExpectingStopAsync([textMessage, stopMessage], reset: false); + + await termination.InvokeExpectingFailureAsync([], reset: false); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + } + + [Fact] + public async Task Test_TextMesssageTermination() + { + TextMessageTermination termination = new(); + termination.IsTerminated.Should().BeFalse(); + + TextMessage userMessage = new() { Content = "Hello", Source = "user" }; + TextMessage agentMessage = new() { Content = "World", Source = "agent" }; + StopMessage stopMessage = new() { Content = "Stop", Source = "user" }; + + await termination.InvokeExpectingNullAsync([]); + await termination.InvokeExpectingStopAsync([userMessage]); + await termination.InvokeExpectingStopAsync([agentMessage]); + await termination.InvokeExpectingNullAsync([stopMessage]); + + termination = new("user"); + + await termination.InvokeExpectingNullAsync([agentMessage]); + await termination.InvokeExpectingNullAsync([stopMessage]); + await termination.InvokeExpectingStopAsync([userMessage], reset: false); + + await termination.InvokeExpectingFailureAsync([], reset: false); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + } + + [Fact] + public async Task MaxMessageTermination() + { + MaxMessageTermination termination = new(2); + termination.IsTerminated.Should().BeFalse(); + + TextMessage textMessage = new() { Content = "Hello", Source = "user" }; + TextMessage otherMessage = new() { Content = "World", Source = "agent" }; + UserInputRequestedEvent uiRequest = new() { Source = "agent", RequestId = "1" }; + + await termination.InvokeExpectingNullAsync([]); + await termination.InvokeExpectingNullAsync([textMessage]); + await termination.InvokeExpectingStopAsync([textMessage, otherMessage]); + await termination.InvokeExpectingNullAsync([textMessage, uiRequest]); + + termination = new(2, includeAgentEvent: true); + + await termination.InvokeExpectingStopAsync([textMessage, uiRequest], reset: false); + + await termination.InvokeExpectingFailureAsync([], reset: false); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + } + + [Fact] + public async Task Test_TextMentionTermination() + { + TextMentionTermination termination = new("stop"); + termination.IsTerminated.Should().BeFalse(); + + TextMessage textMessage = new() { Content = "Hello", Source = "user" }; + TextMessage userStopMessage = new() { Content = "stop", Source = "user" }; + TextMessage agentStopMessage = new() { Content = "stop", Source = "agent" }; + + await termination.InvokeExpectingNullAsync([]); + await termination.InvokeExpectingNullAsync([textMessage]); + await termination.InvokeExpectingStopAsync([userStopMessage]); + + termination = new("stop", sources: ["agent"]); + + await termination.InvokeExpectingNullAsync([textMessage]); + await termination.InvokeExpectingNullAsync([userStopMessage]); + await termination.InvokeExpectingStopAsync([agentStopMessage], reset: false); + + await termination.InvokeExpectingFailureAsync([], reset: false); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + } + + [Fact] + public async Task Text_TokenUsageTermination() + { + TokenUsageTermination termination = new(10); + termination.IsTerminated.Should().BeFalse(); + + RequestUsage usage_10_10 = new() { CompletionTokens = 10, PromptTokens = 10 }; + RequestUsage usage_01_01 = new() { CompletionTokens = 1, PromptTokens = 1 }; + RequestUsage usage_05_00 = new() { CompletionTokens = 5, PromptTokens = 0 }; + RequestUsage usage_00_05 = new() { CompletionTokens = 0, PromptTokens = 5 }; + + await termination.InvokeExpectingNullAsync([]); + + await termination.InvokeExpectingStopAsync([ + new TextMessage { Content = "Hello", Source = "user", ModelUsage = usage_10_10 }, + ]); + + await termination.InvokeExpectingNullAsync([ + new TextMessage { Content = "Hello", Source = "user", ModelUsage = usage_01_01 }, + new TextMessage { Content = "World", Source = "agent", ModelUsage = usage_01_01 }, + ]); + + await termination.InvokeExpectingStopAsync([ + new TextMessage { Content = "Hello", Source = "user", ModelUsage = usage_05_00 }, + new TextMessage { Content = "World", Source = "agent", ModelUsage = usage_00_05 }, + ], reset: false); + + await termination.InvokeExpectingFailureAsync([], reset: false); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + } + + public class AgentTextEvent : AgentEvent + { + public required string Content { get; set; } + + public override Extensions.AI.ChatMessage ToCompletionClientMessage(ChatRole role) + { + return new Extensions.AI.ChatMessage(ChatRole.Assistant, this.Content); + } + } + + [Fact] + public async Task Text_Termination_AndCombinator() + { + ITerminationCondition lhsClause = new MaxMessageTermination(2); + ITerminationCondition rhsClause = new TextMentionTermination("stop"); + + ITerminationCondition termination = lhsClause & rhsClause; + termination.IsTerminated.Should().BeFalse(); + + TextMessage userMessage = new() { Content = "Hello", Source = "user" }; + AgentTextEvent agentMessage = new() { Content = "World", Source = "agent" }; + + TextMessage userStopMessage = new() { Content = "stop", Source = "user" }; + + await termination.InvokeExpectingNullAsync([]); + + await termination.InvokeExpectingNullAsync([userMessage]); + + await termination.InvokeExpectingNullAsync([userMessage, agentMessage], reset: false); + lhsClause.IsTerminated.Should().BeFalse(); + rhsClause.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingStopAsync([userStopMessage], reset: false); + + lhsClause.IsTerminated.Should().BeTrue(); + rhsClause.IsTerminated.Should().BeTrue(); + termination.IsTerminated.Should().BeTrue(); + + await termination.InvokeExpectingFailureAsync([], reset: false); + + lhsClause.IsTerminated.Should().BeTrue(); + rhsClause.IsTerminated.Should().BeTrue(); + termination.IsTerminated.Should().BeTrue(); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingNullAsync([userMessage, agentMessage], reset: false); + lhsClause.IsTerminated.Should().BeFalse(); + rhsClause.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingNullAsync([userMessage], reset: false); + + lhsClause.IsTerminated.Should().BeTrue(); + rhsClause.IsTerminated.Should().BeFalse(); + termination.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingNullAsync([userMessage], reset: false); + + lhsClause.IsTerminated.Should().BeTrue(); + rhsClause.IsTerminated.Should().BeFalse(); + termination.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingStopAsync([userStopMessage], reset: false); + + lhsClause.IsTerminated.Should().BeTrue(); + rhsClause.IsTerminated.Should().BeTrue(); + termination.IsTerminated.Should().BeTrue(); + + await termination.InvokeExpectingFailureAsync([], reset: false); + + lhsClause.IsTerminated.Should().BeTrue(); + rhsClause.IsTerminated.Should().BeTrue(); + termination.IsTerminated.Should().BeTrue(); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingNullAsync([agentMessage, userStopMessage], reset: false); + + lhsClause.IsTerminated.Should().BeFalse(); + rhsClause.IsTerminated.Should().BeTrue(); + termination.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingStopAsync([userMessage], reset: false); + lhsClause.IsTerminated.Should().BeTrue(); + rhsClause.IsTerminated.Should().BeTrue(); + termination.IsTerminated.Should().BeTrue(); + + await termination.InvokeExpectingFailureAsync([], reset: false); + lhsClause.IsTerminated.Should().BeTrue(); + rhsClause.IsTerminated.Should().BeTrue(); + termination.IsTerminated.Should().BeTrue(); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + } + + [Fact] + public async Task Test_Termination_OrCombiner() + { + ITerminationCondition lhsClause = new MaxMessageTermination(3); + ITerminationCondition rhsClause = new TextMentionTermination("stop"); + + ITerminationCondition termination = lhsClause | rhsClause; + termination.IsTerminated.Should().BeFalse(); + + TextMessage userMessage = new() { Content = "Hello", Source = "user" }; + AgentTextEvent agentMessage = new() { Content = "World", Source = "agent" }; + TextMessage userStopMessage = new() { Content = "stop", Source = "user" }; + + await termination.InvokeExpectingNullAsync([]); + await termination.InvokeExpectingNullAsync([userMessage]); + await termination.InvokeExpectingNullAsync([userMessage, agentMessage]); + + await termination.InvokeExpectingNullAsync([userMessage, agentMessage, userMessage], reset: false); + lhsClause.IsTerminated.Should().BeFalse(); + rhsClause.IsTerminated.Should().BeFalse(); + termination.IsTerminated.Should().BeFalse(); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingStopAsync([userMessage, agentMessage, userStopMessage], reset: false); + lhsClause.IsTerminated.Should().BeFalse(); + rhsClause.IsTerminated.Should().BeTrue(); + termination.IsTerminated.Should().BeTrue(); + + await termination.InvokeExpectingFailureAsync([], reset: false); + lhsClause.IsTerminated.Should().BeFalse(); + rhsClause.IsTerminated.Should().BeTrue(); + termination.IsTerminated.Should().BeTrue(); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingStopAsync([userMessage, userMessage, userMessage], reset: false); + lhsClause.IsTerminated.Should().BeTrue(); + rhsClause.IsTerminated.Should().BeFalse(); + termination.IsTerminated.Should().BeTrue(); + + await termination.InvokeExpectingFailureAsync([], reset: false); + lhsClause.IsTerminated.Should().BeTrue(); + rhsClause.IsTerminated.Should().BeFalse(); + termination.IsTerminated.Should().BeTrue(); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + } + + [Fact] + public async Task Test_TimeoutTermination() + { + TextMessage userMessage = new() { Content = "Hello", Source = "user" }; + + TimeoutTermination termination = new(0.15f); + termination.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingNullAsync([]); + + await Task.Delay(TimeSpan.FromSeconds(0.20f)); + + await termination.InvokeExpectingStopAsync([], reset: false); + + await termination.InvokeExpectingFailureAsync([], reset: false); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingNullAsync([userMessage]); + + await Task.Delay(TimeSpan.FromSeconds(0.20f)); + + await termination.InvokeExpectingStopAsync([], reset: false); + } + + [Fact] + public async Task Test_ExternalTermination() + { + ExternalTermination termination = new(); + termination.IsTerminated.Should().BeFalse(); + + TextMessage userMessage = new() { Content = "Hello", Source = "user" }; + + await termination.InvokeExpectingNullAsync([]); + await termination.InvokeExpectingNullAsync([userMessage]); + + termination.Set(); + termination.IsTerminated.Should().BeFalse(); // We only terminate on the next check + + await termination.InvokeExpectingStopAsync([], reset: false); + await termination.InvokeExpectingFailureAsync([], reset: false); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + + await termination.InvokeExpectingNullAsync([userMessage]); + } + + private ToolCallRequestEvent CreateFunctionRequest(string functionName, string id = "1", string arguments = "") + { + ToolCallRequestEvent result = new ToolCallRequestEvent + { + Source = "agent" + }; + + result.Content.Add( + new FunctionCall + { + Id = id, + Name = functionName, + Arguments = arguments, + }); + + return result; + } + + private ToolCallExecutionEvent CreateFunctionResponse(string functionName, string id = "1", string content = "") + { + ToolCallExecutionEvent result = new ToolCallExecutionEvent + { + Source = "agent" + }; + + result.Content.Add( + new FunctionExecutionResult + { + Id = id, + Name = functionName, + Content = content, + }); + + return result; + } + + [Fact] + public async Task Test_FunctionCallTermination() + { + FunctionCallTermination termination = new("test_function"); + termination.IsTerminated.Should().BeFalse(); + + TextMessage userMessage = new() { Content = "Hello", Source = "user" }; + ToolCallRequestEvent toolCallRequest = CreateFunctionRequest("test_function"); + ToolCallExecutionEvent testExecution = CreateFunctionResponse("test_function"); + ToolCallExecutionEvent otherExecution = CreateFunctionResponse("other_function"); + + await termination.InvokeExpectingNullAsync([]); + await termination.InvokeExpectingNullAsync([userMessage]); + await termination.InvokeExpectingNullAsync([toolCallRequest]); + await termination.InvokeExpectingNullAsync([otherExecution]); + await termination.InvokeExpectingStopAsync([testExecution], reset: false); + + await termination.InvokeExpectingFailureAsync([], reset: false); + + termination.Reset(); + termination.IsTerminated.Should().BeFalse(); + } +}