From eb4b1f856e5df5d25a84f1bbb2ac1d461a5dba17 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Tue, 29 Oct 2024 08:04:14 -0700 Subject: [PATCH] Ability to generate handoff message from AssistantAgent (#3968) * Ability to generate handoff message from AssistantAgent * Fix mypy * Validation --------- Co-authored-by: Victor Dibia --- .../src/autogen_agentchat/agents/__init__.py | 3 +- .../agents/_assistant_agent.py | 151 ++++++++++++++++-- .../src/autogen_agentchat/messages.py | 5 +- .../teams/_group_chat/_swarm_group_chat.py | 37 ++++- .../tests/test_assistant_agent.py | 60 ++++++- .../tests/test_group_chat.py | 88 +++++++++- .../framework/model-clients.ipynb | 7 +- 7 files changed, 326 insertions(+), 25 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py index 2f3258860..7eb35962b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py @@ -1,4 +1,4 @@ -from ._assistant_agent import AssistantAgent +from ._assistant_agent import AssistantAgent, Handoff from ._base_chat_agent import BaseChatAgent from ._code_executor_agent import CodeExecutorAgent from ._coding_assistant_agent import CodingAssistantAgent @@ -7,6 +7,7 @@ from ._tool_use_assistant_agent import ToolUseAssistantAgent __all__ = [ "BaseChatAgent", "AssistantAgent", + "Handoff", "CodeExecutorAgent", "CodingAssistantAgent", "ToolUseAssistantAgent", diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 6523b6d42..11e243afb 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -1,7 +1,7 @@ import asyncio import json import logging -from typing import Any, Awaitable, Callable, List, Sequence +from typing import Any, Awaitable, Callable, Dict, List, Sequence from autogen_core.base import CancellationToken from autogen_core.components import FunctionCall @@ -15,11 +15,12 @@ from autogen_core.components.models import ( UserMessage, ) from autogen_core.components.tools import FunctionTool, Tool -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, model_validator from .. import EVENT_LOGGER_NAME from ..messages import ( ChatMessage, + HandoffMessage, StopMessage, TextMessage, ) @@ -31,6 +32,9 @@ event_logger = logging.getLogger(EVENT_LOGGER_NAME) class ToolCallEvent(BaseModel): """A tool call event.""" + source: str + """The source of the event.""" + tool_calls: List[FunctionCall] """The tool call message.""" @@ -40,12 +44,58 @@ class ToolCallEvent(BaseModel): class ToolCallResultEvent(BaseModel): """A tool call result event.""" + source: str + """The source of the event.""" + tool_call_results: List[FunctionExecutionResult] """The tool call result message.""" model_config = ConfigDict(arbitrary_types_allowed=True) +class Handoff(BaseModel): + """Handoff configuration for :class:`AssistantAgent`.""" + + target: str + """The name of the target agent to handoff to.""" + + description: str = Field(default=None) + """The description of the handoff such as the condition under which it should happen and the target agent's ability. + If not provided, it is generated from the target agent's name.""" + + name: str = Field(default=None) + """The name of this handoff configuration. If not provided, it is generated from the target agent's name.""" + + message: str = Field(default=None) + """The message to the target agent. + If not provided, it is generated from the target agent's name.""" + + @model_validator(mode="before") + @classmethod + def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("description") is None: + values["description"] = f"Handoff to {values['target']}." + if values.get("name") is None: + values["name"] = f"transfer_to_{values['target']}".lower() + else: + name = values["name"] + if not isinstance(name, str): + raise ValueError(f"Handoff name must be a string: {values['name']}") + # Check if name is a valid identifier. + if not name.isidentifier(): + raise ValueError(f"Handoff name must be a valid identifier: {values['name']}") + if values.get("message") is None: + values["message"] = ( + f"Transferred to {values['target']}, adopting the role of {values['target']} immediately." + ) + return values + + @property + def handoff_tool(self) -> Tool: + """Create a handoff tool from this handoff configuration.""" + return FunctionTool(lambda: self.message, name=self.name, description=self.description) + + class AssistantAgent(BaseChatAgent): """An agent that provides assistance with tool use. @@ -55,8 +105,52 @@ class AssistantAgent(BaseChatAgent): name (str): The name of the agent. model_client (ChatCompletionClient): The model client to use for inference. tools (List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent. + handoffs (List[Handoff | str] | None, optional): The handoff configurations for the agent, allowing it to transfer to other agents by responding with a HandoffMessage. + If a handoff is a string, it should represent the target agent's name. description (str, optional): The description of the agent. system_message (str, optional): The system message for the model. + + Raises: + ValueError: If tool names are not unique. + ValueError: If handoff names are not unique. + ValueError: If handoff names are not unique from tool names. + + Examples: + + The following example demonstrates how to create an assistant agent with + a model client and generate a response to a simple task. + + .. code-block:: python + + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.task import MaxMessageTermination + + model_client = OpenAIChatCompletionClient(model="gpt-4o") + agent = AssistantAgent(name="assistant", model_client=model_client) + + await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2)) + + + The following example demonstrates how to create an assistant agent with + a model client and a tool, and generate a response to a simple task using the tool. + + .. code-block:: python + + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.task import MaxMessageTermination + + + async def get_current_time() -> str: + return "The current time is 12:00 PM." + + + model_client = OpenAIChatCompletionClient(model="gpt-4o") + agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time]) + + await agent.run("What is the current time?", termination_condition=MaxMessageTermination(3)) + """ def __init__( @@ -65,6 +159,7 @@ class AssistantAgent(BaseChatAgent): model_client: ChatCompletionClient, *, tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, + handoffs: List[Handoff | str] | None = None, description: str = "An agent that provides assistance with ability to use tools.", system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.", ): @@ -84,33 +179,71 @@ class AssistantAgent(BaseChatAgent): self._tools.append(FunctionTool(tool, description=description)) else: raise ValueError(f"Unsupported tool type: {type(tool)}") + # Check if tool names are unique. + tool_names = [tool.name for tool in self._tools] + if len(tool_names) != len(set(tool_names)): + raise ValueError(f"Tool names must be unique: {tool_names}") + # Handoff tools. + self._handoff_tools: List[Tool] = [] + self._handoffs: Dict[str, Handoff] = {} + if handoffs is not None: + for handoff in handoffs: + if isinstance(handoff, str): + handoff = Handoff(target=handoff) + if isinstance(handoff, Handoff): + self._handoff_tools.append(handoff.handoff_tool) + self._handoffs[handoff.name] = handoff + else: + raise ValueError(f"Unsupported handoff type: {type(handoff)}") + # Check if handoff tool names are unique. + handoff_tool_names = [tool.name for tool in self._handoff_tools] + if len(handoff_tool_names) != len(set(handoff_tool_names)): + raise ValueError(f"Handoff names must be unique: {handoff_tool_names}") + # Check if handoff tool names not in tool names. + if any(name in tool_names for name in handoff_tool_names): + raise ValueError( + f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}" + ) self._model_context: List[LLMMessage] = [] async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: # Add messages to the model context. for msg in messages: - # TODO: add special handling for handoff messages self._model_context.append(UserMessage(content=msg.content, source=msg.source)) # Generate an inference result based on the current model context. llm_messages = self._system_messages + self._model_context - result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token) + result = await self._model_client.create( + llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token + ) # Add the response to the model context. self._model_context.append(AssistantMessage(content=result.content, source=self.name)) # Run tool calls until the model produces a string response. while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content): - event_logger.debug(ToolCallEvent(tool_calls=result.content)) + event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name)) # Execute the tool calls. results = await asyncio.gather( *[self._execute_tool_call(call, cancellation_token) for call in result.content] ) - event_logger.debug(ToolCallResultEvent(tool_call_results=results)) + event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name)) self._model_context.append(FunctionExecutionResultMessage(content=results)) + + # Detect handoff requests. + handoffs: List[Handoff] = [] + for call in result.content: + if call.name in self._handoffs: + handoffs.append(self._handoffs[call.name]) + if len(handoffs) > 0: + if len(handoffs) > 1: + raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}") + # Respond with a handoff message. + return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name) + # Generate an inference result based on the current model context. result = await self._model_client.create( - self._model_context, tools=self._tools, cancellation_token=cancellation_token + self._model_context, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token ) self._model_context.append(AssistantMessage(content=result.content, source=self.name)) @@ -127,9 +260,9 @@ class AssistantAgent(BaseChatAgent): ) -> FunctionExecutionResult: """Execute a tool call and return the result.""" try: - if not self._tools: + if not self._tools + self._handoff_tools: raise ValueError("No tools are available.") - tool = next((t for t in self._tools if t.name == tool_call.name), None) + tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None) if tool is None: raise ValueError(f"The tool '{tool_call.name}' is not available.") arguments = json.loads(tool_call.arguments) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 99bd0c888..505ec3cb8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -35,8 +35,11 @@ class StopMessage(BaseMessage): class HandoffMessage(BaseMessage): """A message requesting handoff of a conversation to another agent.""" + target: str + """The name of the target agent to handoff to.""" + content: str - """The agent name to handoff the conversation to.""" + """The handoff message to the target agent.""" ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py index 4f2d08afc..7c24ac4c1 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -37,7 +37,7 @@ class SwarmGroupChatManager(BaseGroupChatManager): async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str: """Select a speaker from the participants based on handoff message.""" if len(thread) > 0 and isinstance(thread[-1].agent_message, HandoffMessage): - self._current_speaker = thread[-1].agent_message.content + self._current_speaker = thread[-1].agent_message.target if self._current_speaker not in self._participant_topic_types: raise ValueError("The selected speaker in the handoff message is not a participant.") event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=self._current_speaker, source=self.id)) @@ -47,7 +47,40 @@ class SwarmGroupChatManager(BaseGroupChatManager): class Swarm(BaseGroupChat): - """(Experimental) A group chat that selects the next speaker based on handoff message only.""" + """A group chat team that selects the next speaker based on handoff message only. + + The first participant in the list of participants is the initial speaker. + The next speaker is selected based on the :class:`~autogen_agentchat.messages.HandoffMessage` message + sent by the current speaker. If no handoff message is sent, the current speaker + continues to be the speaker. + + Args: + participants (List[ChatAgent]): The agents participating in the group chat. The first agent in the list is the initial speaker. + + Examples: + + .. code-block:: python + + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.teams import Swarm + from autogen_agentchat.task import MaxMessageTermination + + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent( + "Alice", + model_client=model_client, + handoffs=["Bob"], + system_message="You are Alice and you only answer questions about yourself.", + ) + agent2 = AssistantAgent( + "Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January." + ) + + team = Swarm([agent1, agent2]) + await team.run("What is bob's birthday?", termination_condition=MaxMessageTermination(3)) + """ def __init__(self, participants: List[ChatAgent]): super().__init__(participants, group_chat_manager_class=SwarmGroupChatManager) diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 9a243a5a2..bff941b90 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -1,10 +1,14 @@ import asyncio import json +import logging from typing import Any, AsyncGenerator, List import pytest -from autogen_agentchat.agents import AssistantAgent -from autogen_agentchat.messages import StopMessage, TextMessage +from autogen_agentchat import EVENT_LOGGER_NAME +from autogen_agentchat.agents import AssistantAgent, Handoff +from autogen_agentchat.logging import FileLogHandler +from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage +from autogen_core.base import CancellationToken from autogen_core.components.tools import FunctionTool from autogen_ext.models import OpenAIChatCompletionClient from openai.resources.chat.completions import AsyncCompletions @@ -14,6 +18,10 @@ from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function from openai.types.completion_usage import CompletionUsage +logger = logging.getLogger(EVENT_LOGGER_NAME) +logger.setLevel(logging.DEBUG) +logger.addHandler(FileLogHandler("test_assistant_agent.log")) + class _MockChatCompletion: def __init__(self, chat_completions: List[ChatCompletion]) -> None: @@ -107,3 +115,51 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: assert isinstance(result.messages[0], TextMessage) assert isinstance(result.messages[1], TextMessage) assert isinstance(result.messages[2], StopMessage) + + +@pytest.mark.asyncio +async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: + handoff = Handoff(target="agent2") + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function( + name=handoff.name, + arguments=json.dumps({}), + ), + ) + ], + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + tool_use_agent = AssistantAgent( + "tool_use_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + handoffs=[handoff], + ) + response = await tool_use_agent.on_messages( + [TextMessage(content="task", source="user")], cancellation_token=CancellationToken() + ) + assert isinstance(response, HandoffMessage) + assert response.target == "agent2" diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 3f3c8a3b8..d209de3bd 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -10,6 +10,7 @@ from autogen_agentchat.agents import ( AssistantAgent, BaseChatAgent, CodeExecutorAgent, + Handoff, ) from autogen_agentchat.logging import FileLogHandler from autogen_agentchat.messages import ( @@ -415,11 +416,11 @@ class _HandOffAgent(BaseChatAgent): self._next_agent = next_agent async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: - return HandoffMessage(content=self._next_agent, source=self.name) + return HandoffMessage(content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name) @pytest.mark.asyncio -async def test_swarm() -> None: +async def test_swarm_handoff() -> None: first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent") second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent") third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent") @@ -428,8 +429,81 @@ async def test_swarm() -> None: result = await team.run("task", termination_condition=MaxMessageTermination(6)) assert len(result.messages) == 6 assert result.messages[0].content == "task" - assert result.messages[1].content == "third_agent" - assert result.messages[2].content == "first_agent" - assert result.messages[3].content == "second_agent" - assert result.messages[4].content == "third_agent" - assert result.messages[5].content == "first_agent" + assert result.messages[1].content == "Transferred to third_agent." + assert result.messages[2].content == "Transferred to first_agent." + assert result.messages[3].content == "Transferred to second_agent." + assert result.messages[4].content == "Transferred to third_agent." + assert result.messages[5].content == "Transferred to first_agent." + + +@pytest.mark.asyncio +async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function( + name="handoff_to_agent2", + arguments=json.dumps({}), + ), + ) + ], + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ), + ChatCompletion( + id="id2", + choices=[ + Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant")) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ), + ChatCompletion( + id="id2", + choices=[ + Choice( + finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant") + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + + agnet1 = AssistantAgent( + "agent1", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")], + ) + agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1") + team = Swarm([agnet1, agent2]) + result = await team.run("task", termination_condition=StopMessageTermination()) + assert len(result.messages) == 5 + assert result.messages[0].content == "task" + assert result.messages[1].content == "handoff to agent2" + assert result.messages[2].content == "Transferred to agent1." + assert result.messages[3].content == "Hello" + assert result.messages[4].content == "TERMINATE" diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb index 1e5f5c293..2a7f00710 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb @@ -32,7 +32,8 @@ "metadata": {}, "outputs": [], "source": [ - "from autogen_ext.models import OpenAIChatCompletionClient, UserMessage\n", + "from autogen_core.components.models import UserMessage\n", + "from autogen_ext.models import OpenAIChatCompletionClient\n", "\n", "# Create an OpenAI model client.\n", "model_client = OpenAIChatCompletionClient(\n", @@ -500,7 +501,7 @@ ], "metadata": { "kernelspec": { - "display_name": "autogen_core", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -514,7 +515,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.11.5" } }, "nbformat": 4,