mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 15:09:41 +00:00
Custom selector function for SelectorGroupChat (#4026)
* Custom selector function for SelectorGroupChat * Update documentation
This commit is contained in:
parent
369ffb511b
commit
173acc6638
@ -131,14 +131,19 @@ class AssistantAgent(BaseChatAgent):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.task import MaxMessageTermination
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent(name="assistant", model_client=model_client)
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent(name="assistant", model_client=model_client)
|
||||
|
||||
await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2))
|
||||
result await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2))
|
||||
print(result)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
The following example demonstrates how to create an assistant agent with
|
||||
@ -146,6 +151,7 @@ class AssistantAgent(BaseChatAgent):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.task import MaxMessageTermination
|
||||
@ -155,15 +161,18 @@ class AssistantAgent(BaseChatAgent):
|
||||
return "The current time is 12:00 PM."
|
||||
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
|
||||
|
||||
stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3))
|
||||
stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3))
|
||||
|
||||
async for message in stream:
|
||||
print(message)
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -22,19 +22,25 @@ class TerminationCondition(ABC):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from autogen_agentchat.teams import MaxTurnsTermination, TextMentionTermination
|
||||
|
||||
# Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned.
|
||||
cond1 = MaxTurnsTermination(10) | TextMentionTermination("TERMINATE")
|
||||
|
||||
# Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned.
|
||||
cond2 = MaxTurnsTermination(10) & TextMentionTermination("TERMINATE")
|
||||
async def main() -> None:
|
||||
# Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned.
|
||||
cond1 = MaxTurnsTermination(10) | TextMentionTermination("TERMINATE")
|
||||
|
||||
...
|
||||
# Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned.
|
||||
cond2 = MaxTurnsTermination(10) & TextMentionTermination("TERMINATE")
|
||||
|
||||
# Reset the termination condition.
|
||||
await cond1.reset()
|
||||
await cond2.reset()
|
||||
# ...
|
||||
|
||||
# Reset the termination condition.
|
||||
await cond1.reset()
|
||||
await cond2.reset()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
@property
|
||||
|
||||
@ -61,46 +61,55 @@ class RoundRobinGroupChat(BaseGroupChat):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
from autogen_agentchat.task import StopMessageTermination
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
async def get_weather(location: str) -> str:
|
||||
return f"The weather in {location} is sunny."
|
||||
|
||||
assistant = AssistantAgent(
|
||||
"Assistant",
|
||||
model_client=model_client,
|
||||
tools=[get_weather],
|
||||
)
|
||||
team = RoundRobinGroupChat([assistant])
|
||||
stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination())
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
async def get_weather(location: str) -> str:
|
||||
return f"The weather in {location} is sunny."
|
||||
|
||||
|
||||
assistant = AssistantAgent(
|
||||
"Assistant",
|
||||
model_client=model_client,
|
||||
tools=[get_weather],
|
||||
)
|
||||
team = RoundRobinGroupChat([assistant])
|
||||
stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination())
|
||||
async for message in stream:
|
||||
print(message)
|
||||
asyncio.run(main())
|
||||
|
||||
A team with multiple participants:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
from autogen_agentchat.task import StopMessageTermination
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
team = RoundRobinGroupChat([agent1, agent2])
|
||||
stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination())
|
||||
async for message in stream:
|
||||
print(message)
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("Assistant1", model_client=model_client)
|
||||
agent2 = AssistantAgent("Assistant2", model_client=model_client)
|
||||
team = RoundRobinGroupChat([agent1, agent2])
|
||||
stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination())
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(self, participants: List[ChatAgent]):
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, Dict, List
|
||||
from typing import Callable, Dict, List, Sequence
|
||||
|
||||
from autogen_core.components.models import ChatCompletionClient, SystemMessage
|
||||
|
||||
from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import MultiModalMessage, StopMessage, TextMessage
|
||||
from ...messages import ChatMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from .._events import (
|
||||
GroupChatPublishEvent,
|
||||
GroupChatSelectSpeakerEvent,
|
||||
@ -20,7 +20,7 @@ event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
"""A group chat manager that selects the next speaker using a ChatCompletion
|
||||
model."""
|
||||
model and a custom selector function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -32,6 +32,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
model_client: ChatCompletionClient,
|
||||
selector_prompt: str,
|
||||
allow_repeated_speaker: bool,
|
||||
selector_func: Callable[[Sequence[ChatMessage]], str | None] | None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
parent_topic_type,
|
||||
@ -44,12 +45,24 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
self._selector_prompt = selector_prompt
|
||||
self._previous_speaker: str | None = None
|
||||
self._allow_repeated_speaker = allow_repeated_speaker
|
||||
self._selector_func = selector_func
|
||||
|
||||
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client.
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client,
|
||||
with the selector function as override if it returns a speaker name.
|
||||
|
||||
A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
|
||||
"""
|
||||
|
||||
# Use the selector function if provided.
|
||||
if self._selector_func is not None:
|
||||
speaker = self._selector_func([msg.agent_message for msg in thread])
|
||||
if speaker is not None:
|
||||
# Skip the model based selection.
|
||||
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=speaker, source=self.id))
|
||||
return speaker
|
||||
|
||||
# Construct the history of the conversation.
|
||||
history_messages: List[str] = []
|
||||
for event in thread:
|
||||
msg = event.agent_message
|
||||
@ -160,6 +173,10 @@ class SelectorGroupChat(BaseGroupChat):
|
||||
Must contain '{roles}', '{participants}', and '{history}' to be filled in.
|
||||
allow_repeated_speaker (bool, optional): Whether to allow the same speaker to be selected
|
||||
consecutively. Defaults to False.
|
||||
selector_func (Callable[[Sequence[ChatMessage]], str | None], optional): A custom selector
|
||||
function that takes the conversation history and returns the name of the next speaker.
|
||||
If provided, this function will be used to override the model to select the next speaker.
|
||||
If the function returns None, the model will be used to select the next speaker.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
|
||||
@ -175,43 +192,97 @@ class SelectorGroupChat(BaseGroupChat):
|
||||
from autogen_agentchat.teams import SelectorGroupChat
|
||||
from autogen_agentchat.task import StopMessageTermination
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
async def lookup_hotel(location: str) -> str:
|
||||
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
|
||||
|
||||
async def lookup_flight(origin: str, destination: str) -> str:
|
||||
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
|
||||
|
||||
async def book_trip() -> str:
|
||||
return "Your trip is booked!"
|
||||
|
||||
travel_advisor = AssistantAgent(
|
||||
"Travel_Advisor",
|
||||
model_client,
|
||||
tools=[book_trip],
|
||||
description="Helps with travel planning.",
|
||||
)
|
||||
hotel_agent = AssistantAgent(
|
||||
"Hotel_Agent",
|
||||
model_client,
|
||||
tools=[lookup_hotel],
|
||||
description="Helps with hotel booking.",
|
||||
)
|
||||
flight_agent = AssistantAgent(
|
||||
"Flight_Agent",
|
||||
model_client,
|
||||
tools=[lookup_flight],
|
||||
description="Helps with flight booking.",
|
||||
)
|
||||
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client)
|
||||
stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
async def lookup_hotel(location: str) -> str:
|
||||
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
A team with a custom selector function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import SelectorGroupChat
|
||||
from autogen_agentchat.task import TextMentionTermination
|
||||
|
||||
|
||||
async def lookup_flight(origin: str, destination: str) -> str:
|
||||
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
def check_caculation(x: int, y: int, answer: int) -> str:
|
||||
if x + y == answer:
|
||||
return "Correct!"
|
||||
else:
|
||||
return "Incorrect!"
|
||||
|
||||
agent1 = AssistantAgent(
|
||||
"Agent1",
|
||||
model_client,
|
||||
description="For calculation",
|
||||
system_message="Calculate the sum of two numbers",
|
||||
)
|
||||
agent2 = AssistantAgent(
|
||||
"Agent2",
|
||||
model_client,
|
||||
tools=[check_caculation],
|
||||
description="For checking calculation",
|
||||
system_message="Check the answer and respond with 'Correct!' or 'Incorrect!'",
|
||||
)
|
||||
|
||||
def selector_func(messages):
|
||||
if len(messages) == 1 or messages[-1].content == "Incorrect!":
|
||||
return "Agent1"
|
||||
if messages[-1].source == "Agent1":
|
||||
return "Agent2"
|
||||
return None
|
||||
|
||||
team = SelectorGroupChat([agent1, agent2], model_client=model_client, selector_func=selector_func)
|
||||
|
||||
stream = team.run_stream("What is 1 + 1?", termination_condition=TextMentionTermination("Correct!"))
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
async def book_trip() -> str:
|
||||
return "Your trip is booked!"
|
||||
import asyncio
|
||||
|
||||
|
||||
travel_advisor = AssistantAgent(
|
||||
"Travel_Advisor",
|
||||
model_client,
|
||||
tools=[book_trip],
|
||||
description="Helps with travel planning.",
|
||||
)
|
||||
hotel_agent = AssistantAgent(
|
||||
"Hotel_Agent",
|
||||
model_client,
|
||||
tools=[lookup_hotel],
|
||||
description="Helps with hotel booking.",
|
||||
)
|
||||
flight_agent = AssistantAgent(
|
||||
"Flight_Agent",
|
||||
model_client,
|
||||
tools=[lookup_flight],
|
||||
description="Helps with flight booking.",
|
||||
)
|
||||
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client)
|
||||
stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
|
||||
async for message in stream:
|
||||
print(message)
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -219,7 +290,6 @@ class SelectorGroupChat(BaseGroupChat):
|
||||
participants: List[ChatAgent],
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
selector_prompt: str = """You are in a role play game. The following roles are available:
|
||||
{roles}.
|
||||
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
@ -229,6 +299,7 @@ Read the following conversation. Then select the next role from {participants} t
|
||||
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
""",
|
||||
allow_repeated_speaker: bool = False,
|
||||
selector_func: Callable[[Sequence[ChatMessage]], str | None] | None = None,
|
||||
):
|
||||
super().__init__(participants, group_chat_manager_class=SelectorGroupChatManager)
|
||||
# Validate the participants.
|
||||
@ -244,6 +315,7 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
self._selector_prompt = selector_prompt
|
||||
self._model_client = model_client
|
||||
self._allow_repeated_speaker = allow_repeated_speaker
|
||||
self._selector_func = selector_func
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
@ -262,4 +334,5 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
self._model_client,
|
||||
self._selector_prompt,
|
||||
self._allow_repeated_speaker,
|
||||
self._selector_func,
|
||||
)
|
||||
|
||||
@ -61,28 +61,34 @@ class Swarm(BaseGroupChat):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import Swarm
|
||||
from autogen_agentchat.task import MaxMessageTermination
|
||||
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent(
|
||||
"Alice",
|
||||
model_client=model_client,
|
||||
handoffs=["Bob"],
|
||||
system_message="You are Alice and you only answer questions about yourself.",
|
||||
)
|
||||
agent2 = AssistantAgent(
|
||||
"Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January."
|
||||
)
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
team = Swarm([agent1, agent2])
|
||||
agent1 = AssistantAgent(
|
||||
"Alice",
|
||||
model_client=model_client,
|
||||
handoffs=["Bob"],
|
||||
system_message="You are Alice and you only answer questions about yourself.",
|
||||
)
|
||||
agent2 = AssistantAgent(
|
||||
"Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January."
|
||||
)
|
||||
|
||||
stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
|
||||
async for message in stream:
|
||||
print(message)
|
||||
team = Swarm([agent1, agent2])
|
||||
|
||||
stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(self, participants: List[ChatAgent]):
|
||||
|
||||
@ -493,6 +493,54 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
|
||||
index += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_custom_selector(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent3", role="assistant"))
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
agent1 = _EchoAgent("agent1", description="echo agent 1")
|
||||
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
||||
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
||||
agent4 = _EchoAgent("agent4", description="echo agent 4")
|
||||
|
||||
def _select_agent(messages: Sequence[ChatMessage]) -> str | None:
|
||||
if len(messages) == 0:
|
||||
return "agent1"
|
||||
elif messages[-1].source == "agent1":
|
||||
return "agent2"
|
||||
elif messages[-1].source == "agent2":
|
||||
return None
|
||||
elif messages[-1].source == "agent3":
|
||||
return "agent4"
|
||||
else:
|
||||
return "agent1"
|
||||
|
||||
team = SelectorGroupChat(
|
||||
participants=[agent1, agent2, agent3, agent4],
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
selector_func=_select_agent,
|
||||
)
|
||||
result = await team.run("task", termination_condition=MaxMessageTermination(6))
|
||||
assert len(result.messages) == 6
|
||||
assert result.messages[1].source == "agent1"
|
||||
assert result.messages[2].source == "agent2"
|
||||
assert result.messages[3].source == "agent3"
|
||||
assert result.messages[4].source == "agent4"
|
||||
assert result.messages[5].source == "agent1"
|
||||
|
||||
|
||||
class _HandOffAgent(BaseChatAgent):
|
||||
def __init__(self, name: str, description: str, next_agent: str) -> None:
|
||||
super().__init__(name, description)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user