Move handoff to base in agentchat (#4509)

This commit is contained in:
Eric Zhu 2024-12-03 14:34:55 -08:00 committed by GitHub
parent 5235bbc0d6
commit 50e84b945e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 85 additions and 60 deletions

View File

@ -1,4 +1,4 @@
from ._assistant_agent import AssistantAgent, Handoff
from ._assistant_agent import AssistantAgent, Handoff # type: ignore
from ._base_chat_agent import BaseChatAgent
from ._code_executor_agent import CodeExecutorAgent
from ._coding_assistant_agent import CodingAssistantAgent

View File

@ -1,6 +1,7 @@
import asyncio
import json
import logging
import warnings
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence
from autogen_core.base import CancellationToken
@ -15,9 +16,10 @@ from autogen_core.components.models import (
UserMessage,
)
from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, Field, model_validator
from typing_extensions import deprecated
from .. import EVENT_LOGGER_NAME
from ..base import Handoff as HandoffBase
from ..base import Response
from ..messages import (
AgentMessage,
@ -33,51 +35,16 @@ from ._base_chat_agent import BaseChatAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class Handoff(BaseModel):
"""Handoff configuration for :class:`AssistantAgent`."""
@deprecated("Moved to autogen_agentchat.base.Handoff. Will remove in 0.4.0.", stacklevel=2)
class Handoff(HandoffBase):
"""[DEPRECATED] Handoff configuration. Moved to :class:`autogen_agentchat.base.Handoff`. Will remove in 0.4.0."""
target: str
"""The name of the target agent to handoff to."""
description: str = Field(default=None)
"""The description of the handoff such as the condition under which it should happen and the target agent's ability.
If not provided, it is generated from the target agent's name."""
name: str = Field(default=None)
"""The name of this handoff configuration. If not provided, it is generated from the target agent's name."""
message: str = Field(default=None)
"""The message to the target agent.
If not provided, it is generated from the target agent's name."""
@model_validator(mode="before")
@classmethod
def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("description") is None:
values["description"] = f"Handoff to {values['target']}."
if values.get("name") is None:
values["name"] = f"transfer_to_{values['target']}".lower()
else:
name = values["name"]
if not isinstance(name, str):
raise ValueError(f"Handoff name must be a string: {values['name']}")
# Check if name is a valid identifier.
if not name.isidentifier():
raise ValueError(f"Handoff name must be a valid identifier: {values['name']}")
if values.get("message") is None:
values["message"] = (
f"Transferred to {values['target']}, adopting the role of {values['target']} immediately."
)
return values
@property
def handoff_tool(self) -> Tool:
"""Create a handoff tool from this handoff configuration."""
def _handoff_tool() -> str:
return self.message
return FunctionTool(_handoff_tool, name=self.name, description=self.description)
def model_post_init(self, __context: Any) -> None:
warnings.warn(
"Handoff was moved to autogen_agentchat.base.Handoff. Importing from this will be removed in 0.4.0.",
DeprecationWarning,
stacklevel=2,
)
class AssistantAgent(BaseChatAgent):
@ -87,7 +54,7 @@ class AssistantAgent(BaseChatAgent):
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client to use for inference.
tools (List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
handoffs (List[Handoff | str] | None, optional): The handoff configurations for the agent,
handoffs (List[HandoffBase | str] | None, optional): The handoff configurations for the agent,
allowing it to transfer to other agents by responding with a :class:`HandoffMessage`.
The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`.
If a handoff is a string, it should represent the target agent's name.
@ -204,7 +171,7 @@ class AssistantAgent(BaseChatAgent):
model_client: ChatCompletionClient,
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[Handoff | str] | None = None,
handoffs: List[HandoffBase | str] | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
@ -236,14 +203,14 @@ class AssistantAgent(BaseChatAgent):
raise ValueError(f"Tool names must be unique: {tool_names}")
# Handoff tools.
self._handoff_tools: List[Tool] = []
self._handoffs: Dict[str, Handoff] = {}
self._handoffs: Dict[str, HandoffBase] = {}
if handoffs is not None:
if model_client.capabilities["function_calling"] is False:
raise ValueError("The model does not support function calling, which is needed for handoffs.")
for handoff in handoffs:
if isinstance(handoff, str):
handoff = Handoff(target=handoff)
if isinstance(handoff, Handoff):
handoff = HandoffBase(target=handoff)
if isinstance(handoff, HandoffBase):
self._handoff_tools.append(handoff.handoff_tool)
self._handoffs[handoff.name] = handoff
else:
@ -312,7 +279,7 @@ class AssistantAgent(BaseChatAgent):
yield tool_call_result_msg
# Detect handoff requests.
handoffs: List[Handoff] = []
handoffs: List[HandoffBase] = []
for call in result.content:
if call.name in self._handoffs:
handoffs.append(self._handoffs[call.name])

View File

@ -1,4 +1,5 @@
from ._chat_agent import ChatAgent, Response
from ._handoff import Handoff
from ._task import TaskResult, TaskRunner
from ._team import Team
from ._termination import TerminatedException, TerminationCondition
@ -11,4 +12,5 @@ __all__ = [
"TerminationCondition",
"TaskResult",
"TaskRunner",
"Handoff",
]

View File

@ -0,0 +1,56 @@
import logging
from typing import Any, Dict
from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, Field, model_validator
from .. import EVENT_LOGGER_NAME
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class Handoff(BaseModel):
"""Handoff configuration."""
target: str
"""The name of the target agent to handoff to."""
description: str = Field(default=None)
"""The description of the handoff such as the condition under which it should happen and the target agent's ability.
If not provided, it is generated from the target agent's name."""
name: str = Field(default=None)
"""The name of this handoff configuration. If not provided, it is generated from the target agent's name."""
message: str = Field(default=None)
"""The message to the target agent.
If not provided, it is generated from the target agent's name."""
@model_validator(mode="before")
@classmethod
def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("description") is None:
values["description"] = f"Handoff to {values['target']}."
if values.get("name") is None:
values["name"] = f"transfer_to_{values['target']}".lower()
else:
name = values["name"]
if not isinstance(name, str):
raise ValueError(f"Handoff name must be a string: {values['name']}")
# Check if name is a valid identifier.
if not name.isidentifier():
raise ValueError(f"Handoff name must be a valid identifier: {values['name']}")
if values.get("message") is None:
values["message"] = (
f"Transferred to {values['target']}, adopting the role of {values['target']} immediately."
)
return values
@property
def handoff_tool(self) -> Tool:
"""Create a handoff tool from this handoff configuration."""
def _handoff_tool() -> str:
return self.message
return FunctionTool(_handoff_tool, name=self.name, description=self.description)

View File

@ -5,8 +5,8 @@ from typing import Any, AsyncGenerator, List
import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent, Handoff
from autogen_agentchat.base import TaskResult
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import Handoff, TaskResult
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import (
HandoffMessage,

View File

@ -10,9 +10,8 @@ from autogen_agentchat.agents import (
AssistantAgent,
BaseChatAgent,
CodeExecutorAgent,
Handoff,
)
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.base import Handoff, Response, TaskResult
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import (
AgentMessage,

View File

@ -251,7 +251,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.6"
}
},
"nbformat": 4,

View File

@ -39,7 +39,7 @@
"\n",
"For {py:class}`~autogen_agentchat.agents.AssistantAgent`, you can set the\n",
"`handoffs` argument to specify which agents it can hand off to. You can\n",
"use {py:class}`~autogen_agentchat.agents.Handoff` to customize the message\n",
"use {py:class}`~autogen_agentchat.base.Handoff` to customize the message\n",
"content and handoff behavior.\n",
"\n",
"The overall process can be summarized as follows:\n",

View File

@ -655,11 +655,12 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autogen_agentchat.agents import AssistantAgent, Handoff\n",
"from autogen_agentchat.agents import AssistantAgent\n",
"from autogen_agentchat.base import Handoff\n",
"from autogen_agentchat.task import HandoffTermination, TextMentionTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_ext.models import OpenAIChatCompletionClient\n",