Custom selector function for SelectorGroupChat (#4026)

* Custom selector function for SelectorGroupChat

* Update documentation
This commit is contained in:
Eric Zhu 2024-11-01 09:08:29 -07:00 committed by GitHub
parent 369ffb511b
commit 173acc6638
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 237 additions and 86 deletions

View File

@ -131,14 +131,19 @@ class AssistantAgent(BaseChatAgent):
.. code-block:: python
import asyncio
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.task import MaxMessageTermination
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client)
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client)
await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2))
result await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2))
print(result)
asyncio.run(main())
The following example demonstrates how to create an assistant agent with
@ -146,6 +151,7 @@ class AssistantAgent(BaseChatAgent):
.. code-block:: python
import asyncio
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.task import MaxMessageTermination
@ -155,15 +161,18 @@ class AssistantAgent(BaseChatAgent):
return "The current time is 12:00 PM."
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3))
stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3))
async for message in stream:
print(message)
async for message in stream:
print(message)
asyncio.run(main())
"""
def __init__(

View File

@ -22,19 +22,25 @@ class TerminationCondition(ABC):
.. code-block:: python
import asyncio
from autogen_agentchat.teams import MaxTurnsTermination, TextMentionTermination
# Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned.
cond1 = MaxTurnsTermination(10) | TextMentionTermination("TERMINATE")
# Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned.
cond2 = MaxTurnsTermination(10) & TextMentionTermination("TERMINATE")
async def main() -> None:
# Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned.
cond1 = MaxTurnsTermination(10) | TextMentionTermination("TERMINATE")
...
# Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned.
cond2 = MaxTurnsTermination(10) & TextMentionTermination("TERMINATE")
# Reset the termination condition.
await cond1.reset()
await cond2.reset()
# ...
# Reset the termination condition.
await cond1.reset()
await cond2.reset()
asyncio.run(main())
"""
@property

View File

@ -61,46 +61,55 @@ class RoundRobinGroupChat(BaseGroupChat):
.. code-block:: python
import asyncio
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.task import StopMessageTermination
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def get_weather(location: str) -> str:
return f"The weather in {location} is sunny."
assistant = AssistantAgent(
"Assistant",
model_client=model_client,
tools=[get_weather],
)
team = RoundRobinGroupChat([assistant])
stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
async def get_weather(location: str) -> str:
return f"The weather in {location} is sunny."
assistant = AssistantAgent(
"Assistant",
model_client=model_client,
tools=[get_weather],
)
team = RoundRobinGroupChat([assistant])
stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
asyncio.run(main())
A team with multiple participants:
.. code-block:: python
import asyncio
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.task import StopMessageTermination
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent("Assistant1", model_client=model_client)
agent2 = AssistantAgent("Assistant2", model_client=model_client)
team = RoundRobinGroupChat([agent1, agent2])
stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent("Assistant1", model_client=model_client)
agent2 = AssistantAgent("Assistant2", model_client=model_client)
team = RoundRobinGroupChat([agent1, agent2])
stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
asyncio.run(main())
"""
def __init__(self, participants: List[ChatAgent]):

View File

@ -1,12 +1,12 @@
import logging
import re
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Sequence
from autogen_core.components.models import ChatCompletionClient, SystemMessage
from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from ...base import ChatAgent, TerminationCondition
from ...messages import MultiModalMessage, StopMessage, TextMessage
from ...messages import ChatMessage, MultiModalMessage, StopMessage, TextMessage
from .._events import (
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
@ -20,7 +20,7 @@ event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class SelectorGroupChatManager(BaseGroupChatManager):
"""A group chat manager that selects the next speaker using a ChatCompletion
model."""
model and a custom selector function."""
def __init__(
self,
@ -32,6 +32,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
model_client: ChatCompletionClient,
selector_prompt: str,
allow_repeated_speaker: bool,
selector_func: Callable[[Sequence[ChatMessage]], str | None] | None,
) -> None:
super().__init__(
parent_topic_type,
@ -44,12 +45,24 @@ class SelectorGroupChatManager(BaseGroupChatManager):
self._selector_prompt = selector_prompt
self._previous_speaker: str | None = None
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
"""Selects the next speaker in a group chat using a ChatCompletion client.
"""Selects the next speaker in a group chat using a ChatCompletion client,
with the selector function as override if it returns a speaker name.
A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
"""
# Use the selector function if provided.
if self._selector_func is not None:
speaker = self._selector_func([msg.agent_message for msg in thread])
if speaker is not None:
# Skip the model based selection.
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=speaker, source=self.id))
return speaker
# Construct the history of the conversation.
history_messages: List[str] = []
for event in thread:
msg = event.agent_message
@ -160,6 +173,10 @@ class SelectorGroupChat(BaseGroupChat):
Must contain '{roles}', '{participants}', and '{history}' to be filled in.
allow_repeated_speaker (bool, optional): Whether to allow the same speaker to be selected
consecutively. Defaults to False.
selector_func (Callable[[Sequence[ChatMessage]], str | None], optional): A custom selector
function that takes the conversation history and returns the name of the next speaker.
If provided, this function will be used to override the model to select the next speaker.
If the function returns None, the model will be used to select the next speaker.
Raises:
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
@ -175,43 +192,97 @@ class SelectorGroupChat(BaseGroupChat):
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.task import StopMessageTermination
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def lookup_hotel(location: str) -> str:
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
async def lookup_flight(origin: str, destination: str) -> str:
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
async def book_trip() -> str:
return "Your trip is booked!"
travel_advisor = AssistantAgent(
"Travel_Advisor",
model_client,
tools=[book_trip],
description="Helps with travel planning.",
)
hotel_agent = AssistantAgent(
"Hotel_Agent",
model_client,
tools=[lookup_hotel],
description="Helps with hotel booking.",
)
flight_agent = AssistantAgent(
"Flight_Agent",
model_client,
tools=[lookup_flight],
description="Helps with flight booking.",
)
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client)
stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
async def lookup_hotel(location: str) -> str:
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
import asyncio
asyncio.run(main())
A team with a custom selector function:
.. code-block:: python
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.task import TextMentionTermination
async def lookup_flight(origin: str, destination: str) -> str:
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
def check_caculation(x: int, y: int, answer: int) -> str:
if x + y == answer:
return "Correct!"
else:
return "Incorrect!"
agent1 = AssistantAgent(
"Agent1",
model_client,
description="For calculation",
system_message="Calculate the sum of two numbers",
)
agent2 = AssistantAgent(
"Agent2",
model_client,
tools=[check_caculation],
description="For checking calculation",
system_message="Check the answer and respond with 'Correct!' or 'Incorrect!'",
)
def selector_func(messages):
if len(messages) == 1 or messages[-1].content == "Incorrect!":
return "Agent1"
if messages[-1].source == "Agent1":
return "Agent2"
return None
team = SelectorGroupChat([agent1, agent2], model_client=model_client, selector_func=selector_func)
stream = team.run_stream("What is 1 + 1?", termination_condition=TextMentionTermination("Correct!"))
async for message in stream:
print(message)
async def book_trip() -> str:
return "Your trip is booked!"
import asyncio
travel_advisor = AssistantAgent(
"Travel_Advisor",
model_client,
tools=[book_trip],
description="Helps with travel planning.",
)
hotel_agent = AssistantAgent(
"Hotel_Agent",
model_client,
tools=[lookup_hotel],
description="Helps with hotel booking.",
)
flight_agent = AssistantAgent(
"Flight_Agent",
model_client,
tools=[lookup_flight],
description="Helps with flight booking.",
)
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client)
stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
async for message in stream:
print(message)
asyncio.run(main())
"""
def __init__(
@ -219,7 +290,6 @@ class SelectorGroupChat(BaseGroupChat):
participants: List[ChatAgent],
model_client: ChatCompletionClient,
*,
termination_condition: TerminationCondition | None = None,
selector_prompt: str = """You are in a role play game. The following roles are available:
{roles}.
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
@ -229,6 +299,7 @@ Read the following conversation. Then select the next role from {participants} t
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
""",
allow_repeated_speaker: bool = False,
selector_func: Callable[[Sequence[ChatMessage]], str | None] | None = None,
):
super().__init__(participants, group_chat_manager_class=SelectorGroupChatManager)
# Validate the participants.
@ -244,6 +315,7 @@ Read the above conversation. Then select the next role from {participants} to pl
self._selector_prompt = selector_prompt
self._model_client = model_client
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func
def _create_group_chat_manager_factory(
self,
@ -262,4 +334,5 @@ Read the above conversation. Then select the next role from {participants} to pl
self._model_client,
self._selector_prompt,
self._allow_repeated_speaker,
self._selector_func,
)

View File

@ -61,28 +61,34 @@ class Swarm(BaseGroupChat):
.. code-block:: python
import asyncio
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import Swarm
from autogen_agentchat.task import MaxMessageTermination
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent(
"Alice",
model_client=model_client,
handoffs=["Bob"],
system_message="You are Alice and you only answer questions about yourself.",
)
agent2 = AssistantAgent(
"Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January."
)
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
team = Swarm([agent1, agent2])
agent1 = AssistantAgent(
"Alice",
model_client=model_client,
handoffs=["Bob"],
system_message="You are Alice and you only answer questions about yourself.",
)
agent2 = AssistantAgent(
"Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January."
)
stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
async for message in stream:
print(message)
team = Swarm([agent1, agent2])
stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
async for message in stream:
print(message)
asyncio.run(main())
"""
def __init__(self, participants: List[ChatAgent]):

View File

@ -493,6 +493,54 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
index += 1
@pytest.mark.asyncio
async def test_selector_group_chat_custom_selector(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent3", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent1 = _EchoAgent("agent1", description="echo agent 1")
agent2 = _EchoAgent("agent2", description="echo agent 2")
agent3 = _EchoAgent("agent3", description="echo agent 3")
agent4 = _EchoAgent("agent4", description="echo agent 4")
def _select_agent(messages: Sequence[ChatMessage]) -> str | None:
if len(messages) == 0:
return "agent1"
elif messages[-1].source == "agent1":
return "agent2"
elif messages[-1].source == "agent2":
return None
elif messages[-1].source == "agent3":
return "agent4"
else:
return "agent1"
team = SelectorGroupChat(
participants=[agent1, agent2, agent3, agent4],
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
selector_func=_select_agent,
)
result = await team.run("task", termination_condition=MaxMessageTermination(6))
assert len(result.messages) == 6
assert result.messages[1].source == "agent1"
assert result.messages[2].source == "agent2"
assert result.messages[3].source == "agent3"
assert result.messages[4].source == "agent4"
assert result.messages[5].source == "agent1"
class _HandOffAgent(BaseChatAgent):
def __init__(self, name: str, description: str, next_agent: str) -> None:
super().__init__(name, description)