mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-15 12:11:30 +00:00
Ability to generate handoff message from AssistantAgent (#3968)
* Ability to generate handoff message from AssistantAgent * Fix mypy * Validation --------- Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
parent
14846a3e84
commit
eb4b1f856e
@ -1,4 +1,4 @@
|
|||||||
from ._assistant_agent import AssistantAgent
|
from ._assistant_agent import AssistantAgent, Handoff
|
||||||
from ._base_chat_agent import BaseChatAgent
|
from ._base_chat_agent import BaseChatAgent
|
||||||
from ._code_executor_agent import CodeExecutorAgent
|
from ._code_executor_agent import CodeExecutorAgent
|
||||||
from ._coding_assistant_agent import CodingAssistantAgent
|
from ._coding_assistant_agent import CodingAssistantAgent
|
||||||
@ -7,6 +7,7 @@ from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseChatAgent",
|
"BaseChatAgent",
|
||||||
"AssistantAgent",
|
"AssistantAgent",
|
||||||
|
"Handoff",
|
||||||
"CodeExecutorAgent",
|
"CodeExecutorAgent",
|
||||||
"CodingAssistantAgent",
|
"CodingAssistantAgent",
|
||||||
"ToolUseAssistantAgent",
|
"ToolUseAssistantAgent",
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Awaitable, Callable, List, Sequence
|
from typing import Any, Awaitable, Callable, Dict, List, Sequence
|
||||||
|
|
||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
from autogen_core.components import FunctionCall
|
from autogen_core.components import FunctionCall
|
||||||
@ -15,11 +15,12 @@ from autogen_core.components.models import (
|
|||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from autogen_core.components.tools import FunctionTool, Tool
|
from autogen_core.components.tools import FunctionTool, Tool
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
from .. import EVENT_LOGGER_NAME
|
from .. import EVENT_LOGGER_NAME
|
||||||
from ..messages import (
|
from ..messages import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
HandoffMessage,
|
||||||
StopMessage,
|
StopMessage,
|
||||||
TextMessage,
|
TextMessage,
|
||||||
)
|
)
|
||||||
@ -31,6 +32,9 @@ event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
|||||||
class ToolCallEvent(BaseModel):
|
class ToolCallEvent(BaseModel):
|
||||||
"""A tool call event."""
|
"""A tool call event."""
|
||||||
|
|
||||||
|
source: str
|
||||||
|
"""The source of the event."""
|
||||||
|
|
||||||
tool_calls: List[FunctionCall]
|
tool_calls: List[FunctionCall]
|
||||||
"""The tool call message."""
|
"""The tool call message."""
|
||||||
|
|
||||||
@ -40,12 +44,58 @@ class ToolCallEvent(BaseModel):
|
|||||||
class ToolCallResultEvent(BaseModel):
|
class ToolCallResultEvent(BaseModel):
|
||||||
"""A tool call result event."""
|
"""A tool call result event."""
|
||||||
|
|
||||||
|
source: str
|
||||||
|
"""The source of the event."""
|
||||||
|
|
||||||
tool_call_results: List[FunctionExecutionResult]
|
tool_call_results: List[FunctionExecutionResult]
|
||||||
"""The tool call result message."""
|
"""The tool call result message."""
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Handoff(BaseModel):
|
||||||
|
"""Handoff configuration for :class:`AssistantAgent`."""
|
||||||
|
|
||||||
|
target: str
|
||||||
|
"""The name of the target agent to handoff to."""
|
||||||
|
|
||||||
|
description: str = Field(default=None)
|
||||||
|
"""The description of the handoff such as the condition under which it should happen and the target agent's ability.
|
||||||
|
If not provided, it is generated from the target agent's name."""
|
||||||
|
|
||||||
|
name: str = Field(default=None)
|
||||||
|
"""The name of this handoff configuration. If not provided, it is generated from the target agent's name."""
|
||||||
|
|
||||||
|
message: str = Field(default=None)
|
||||||
|
"""The message to the target agent.
|
||||||
|
If not provided, it is generated from the target agent's name."""
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
if values.get("description") is None:
|
||||||
|
values["description"] = f"Handoff to {values['target']}."
|
||||||
|
if values.get("name") is None:
|
||||||
|
values["name"] = f"transfer_to_{values['target']}".lower()
|
||||||
|
else:
|
||||||
|
name = values["name"]
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise ValueError(f"Handoff name must be a string: {values['name']}")
|
||||||
|
# Check if name is a valid identifier.
|
||||||
|
if not name.isidentifier():
|
||||||
|
raise ValueError(f"Handoff name must be a valid identifier: {values['name']}")
|
||||||
|
if values.get("message") is None:
|
||||||
|
values["message"] = (
|
||||||
|
f"Transferred to {values['target']}, adopting the role of {values['target']} immediately."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def handoff_tool(self) -> Tool:
|
||||||
|
"""Create a handoff tool from this handoff configuration."""
|
||||||
|
return FunctionTool(lambda: self.message, name=self.name, description=self.description)
|
||||||
|
|
||||||
|
|
||||||
class AssistantAgent(BaseChatAgent):
|
class AssistantAgent(BaseChatAgent):
|
||||||
"""An agent that provides assistance with tool use.
|
"""An agent that provides assistance with tool use.
|
||||||
|
|
||||||
@ -55,8 +105,52 @@ class AssistantAgent(BaseChatAgent):
|
|||||||
name (str): The name of the agent.
|
name (str): The name of the agent.
|
||||||
model_client (ChatCompletionClient): The model client to use for inference.
|
model_client (ChatCompletionClient): The model client to use for inference.
|
||||||
tools (List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
|
tools (List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
|
||||||
|
handoffs (List[Handoff | str] | None, optional): The handoff configurations for the agent, allowing it to transfer to other agents by responding with a HandoffMessage.
|
||||||
|
If a handoff is a string, it should represent the target agent's name.
|
||||||
description (str, optional): The description of the agent.
|
description (str, optional): The description of the agent.
|
||||||
system_message (str, optional): The system message for the model.
|
system_message (str, optional): The system message for the model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If tool names are not unique.
|
||||||
|
ValueError: If handoff names are not unique.
|
||||||
|
ValueError: If handoff names are not unique from tool names.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
The following example demonstrates how to create an assistant agent with
|
||||||
|
a model client and generate a response to a simple task.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from autogen_ext.models import OpenAIChatCompletionClient
|
||||||
|
from autogen_agentchat.agents import AssistantAgent
|
||||||
|
from autogen_agentchat.task import MaxMessageTermination
|
||||||
|
|
||||||
|
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||||
|
agent = AssistantAgent(name="assistant", model_client=model_client)
|
||||||
|
|
||||||
|
await agent.run("What is the capital of France?", termination_condition=MaxMessageTermination(2))
|
||||||
|
|
||||||
|
|
||||||
|
The following example demonstrates how to create an assistant agent with
|
||||||
|
a model client and a tool, and generate a response to a simple task using the tool.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from autogen_ext.models import OpenAIChatCompletionClient
|
||||||
|
from autogen_agentchat.agents import AssistantAgent
|
||||||
|
from autogen_agentchat.task import MaxMessageTermination
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_time() -> str:
|
||||||
|
return "The current time is 12:00 PM."
|
||||||
|
|
||||||
|
|
||||||
|
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||||
|
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])
|
||||||
|
|
||||||
|
await agent.run("What is the current time?", termination_condition=MaxMessageTermination(3))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -65,6 +159,7 @@ class AssistantAgent(BaseChatAgent):
|
|||||||
model_client: ChatCompletionClient,
|
model_client: ChatCompletionClient,
|
||||||
*,
|
*,
|
||||||
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
|
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
|
||||||
|
handoffs: List[Handoff | str] | None = None,
|
||||||
description: str = "An agent that provides assistance with ability to use tools.",
|
description: str = "An agent that provides assistance with ability to use tools.",
|
||||||
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
|
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
|
||||||
):
|
):
|
||||||
@ -84,33 +179,71 @@ class AssistantAgent(BaseChatAgent):
|
|||||||
self._tools.append(FunctionTool(tool, description=description))
|
self._tools.append(FunctionTool(tool, description=description))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||||
|
# Check if tool names are unique.
|
||||||
|
tool_names = [tool.name for tool in self._tools]
|
||||||
|
if len(tool_names) != len(set(tool_names)):
|
||||||
|
raise ValueError(f"Tool names must be unique: {tool_names}")
|
||||||
|
# Handoff tools.
|
||||||
|
self._handoff_tools: List[Tool] = []
|
||||||
|
self._handoffs: Dict[str, Handoff] = {}
|
||||||
|
if handoffs is not None:
|
||||||
|
for handoff in handoffs:
|
||||||
|
if isinstance(handoff, str):
|
||||||
|
handoff = Handoff(target=handoff)
|
||||||
|
if isinstance(handoff, Handoff):
|
||||||
|
self._handoff_tools.append(handoff.handoff_tool)
|
||||||
|
self._handoffs[handoff.name] = handoff
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
|
||||||
|
# Check if handoff tool names are unique.
|
||||||
|
handoff_tool_names = [tool.name for tool in self._handoff_tools]
|
||||||
|
if len(handoff_tool_names) != len(set(handoff_tool_names)):
|
||||||
|
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
|
||||||
|
# Check if handoff tool names not in tool names.
|
||||||
|
if any(name in tool_names for name in handoff_tool_names):
|
||||||
|
raise ValueError(
|
||||||
|
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
|
||||||
|
)
|
||||||
self._model_context: List[LLMMessage] = []
|
self._model_context: List[LLMMessage] = []
|
||||||
|
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||||
# Add messages to the model context.
|
# Add messages to the model context.
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
# TODO: add special handling for handoff messages
|
|
||||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||||
|
|
||||||
# Generate an inference result based on the current model context.
|
# Generate an inference result based on the current model context.
|
||||||
llm_messages = self._system_messages + self._model_context
|
llm_messages = self._system_messages + self._model_context
|
||||||
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)
|
result = await self._model_client.create(
|
||||||
|
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
|
||||||
|
)
|
||||||
|
|
||||||
# Add the response to the model context.
|
# Add the response to the model context.
|
||||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||||
|
|
||||||
# Run tool calls until the model produces a string response.
|
# Run tool calls until the model produces a string response.
|
||||||
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
||||||
event_logger.debug(ToolCallEvent(tool_calls=result.content))
|
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
|
||||||
# Execute the tool calls.
|
# Execute the tool calls.
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
||||||
)
|
)
|
||||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
|
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
|
||||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||||
|
|
||||||
|
# Detect handoff requests.
|
||||||
|
handoffs: List[Handoff] = []
|
||||||
|
for call in result.content:
|
||||||
|
if call.name in self._handoffs:
|
||||||
|
handoffs.append(self._handoffs[call.name])
|
||||||
|
if len(handoffs) > 0:
|
||||||
|
if len(handoffs) > 1:
|
||||||
|
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
|
||||||
|
# Respond with a handoff message.
|
||||||
|
return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name)
|
||||||
|
|
||||||
# Generate an inference result based on the current model context.
|
# Generate an inference result based on the current model context.
|
||||||
result = await self._model_client.create(
|
result = await self._model_client.create(
|
||||||
self._model_context, tools=self._tools, cancellation_token=cancellation_token
|
self._model_context, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
|
||||||
)
|
)
|
||||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||||
|
|
||||||
@ -127,9 +260,9 @@ class AssistantAgent(BaseChatAgent):
|
|||||||
) -> FunctionExecutionResult:
|
) -> FunctionExecutionResult:
|
||||||
"""Execute a tool call and return the result."""
|
"""Execute a tool call and return the result."""
|
||||||
try:
|
try:
|
||||||
if not self._tools:
|
if not self._tools + self._handoff_tools:
|
||||||
raise ValueError("No tools are available.")
|
raise ValueError("No tools are available.")
|
||||||
tool = next((t for t in self._tools if t.name == tool_call.name), None)
|
tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None)
|
||||||
if tool is None:
|
if tool is None:
|
||||||
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
||||||
arguments = json.loads(tool_call.arguments)
|
arguments = json.loads(tool_call.arguments)
|
||||||
|
@ -35,8 +35,11 @@ class StopMessage(BaseMessage):
|
|||||||
class HandoffMessage(BaseMessage):
|
class HandoffMessage(BaseMessage):
|
||||||
"""A message requesting handoff of a conversation to another agent."""
|
"""A message requesting handoff of a conversation to another agent."""
|
||||||
|
|
||||||
|
target: str
|
||||||
|
"""The name of the target agent to handoff to."""
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
"""The agent name to handoff the conversation to."""
|
"""The handoff message to the target agent."""
|
||||||
|
|
||||||
|
|
||||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
|
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
|
||||||
|
@ -37,7 +37,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||||||
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
|
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
|
||||||
"""Select a speaker from the participants based on handoff message."""
|
"""Select a speaker from the participants based on handoff message."""
|
||||||
if len(thread) > 0 and isinstance(thread[-1].agent_message, HandoffMessage):
|
if len(thread) > 0 and isinstance(thread[-1].agent_message, HandoffMessage):
|
||||||
self._current_speaker = thread[-1].agent_message.content
|
self._current_speaker = thread[-1].agent_message.target
|
||||||
if self._current_speaker not in self._participant_topic_types:
|
if self._current_speaker not in self._participant_topic_types:
|
||||||
raise ValueError("The selected speaker in the handoff message is not a participant.")
|
raise ValueError("The selected speaker in the handoff message is not a participant.")
|
||||||
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=self._current_speaker, source=self.id))
|
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=self._current_speaker, source=self.id))
|
||||||
@ -47,7 +47,40 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||||||
|
|
||||||
|
|
||||||
class Swarm(BaseGroupChat):
|
class Swarm(BaseGroupChat):
|
||||||
"""(Experimental) A group chat that selects the next speaker based on handoff message only."""
|
"""A group chat team that selects the next speaker based on handoff message only.
|
||||||
|
|
||||||
|
The first participant in the list of participants is the initial speaker.
|
||||||
|
The next speaker is selected based on the :class:`~autogen_agentchat.messages.HandoffMessage` message
|
||||||
|
sent by the current speaker. If no handoff message is sent, the current speaker
|
||||||
|
continues to be the speaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
participants (List[ChatAgent]): The agents participating in the group chat. The first agent in the list is the initial speaker.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from autogen_ext.models import OpenAIChatCompletionClient
|
||||||
|
from autogen_agentchat.agents import AssistantAgent
|
||||||
|
from autogen_agentchat.teams import Swarm
|
||||||
|
from autogen_agentchat.task import MaxMessageTermination
|
||||||
|
|
||||||
|
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||||
|
|
||||||
|
agent1 = AssistantAgent(
|
||||||
|
"Alice",
|
||||||
|
model_client=model_client,
|
||||||
|
handoffs=["Bob"],
|
||||||
|
system_message="You are Alice and you only answer questions about yourself.",
|
||||||
|
)
|
||||||
|
agent2 = AssistantAgent(
|
||||||
|
"Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January."
|
||||||
|
)
|
||||||
|
|
||||||
|
team = Swarm([agent1, agent2])
|
||||||
|
await team.run("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, participants: List[ChatAgent]):
|
def __init__(self, participants: List[ChatAgent]):
|
||||||
super().__init__(participants, group_chat_manager_class=SwarmGroupChatManager)
|
super().__init__(participants, group_chat_manager_class=SwarmGroupChatManager)
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from typing import Any, AsyncGenerator, List
|
from typing import Any, AsyncGenerator, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from autogen_agentchat.agents import AssistantAgent
|
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||||
from autogen_agentchat.messages import StopMessage, TextMessage
|
from autogen_agentchat.agents import AssistantAgent, Handoff
|
||||||
|
from autogen_agentchat.logging import FileLogHandler
|
||||||
|
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
|
||||||
|
from autogen_core.base import CancellationToken
|
||||||
from autogen_core.components.tools import FunctionTool
|
from autogen_core.components.tools import FunctionTool
|
||||||
from autogen_ext.models import OpenAIChatCompletionClient
|
from autogen_ext.models import OpenAIChatCompletionClient
|
||||||
from openai.resources.chat.completions import AsyncCompletions
|
from openai.resources.chat.completions import AsyncCompletions
|
||||||
@ -14,6 +18,10 @@ from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
|||||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
|
||||||
from openai.types.completion_usage import CompletionUsage
|
from openai.types.completion_usage import CompletionUsage
|
||||||
|
|
||||||
|
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.addHandler(FileLogHandler("test_assistant_agent.log"))
|
||||||
|
|
||||||
|
|
||||||
class _MockChatCompletion:
|
class _MockChatCompletion:
|
||||||
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
||||||
@ -107,3 +115,51 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert isinstance(result.messages[0], TextMessage)
|
assert isinstance(result.messages[0], TextMessage)
|
||||||
assert isinstance(result.messages[1], TextMessage)
|
assert isinstance(result.messages[1], TextMessage)
|
||||||
assert isinstance(result.messages[2], StopMessage)
|
assert isinstance(result.messages[2], StopMessage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
handoff = Handoff(target="agent2")
|
||||||
|
model = "gpt-4o-2024-05-13"
|
||||||
|
chat_completions = [
|
||||||
|
ChatCompletion(
|
||||||
|
id="id1",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id="1",
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name=handoff.name,
|
||||||
|
arguments=json.dumps({}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
role="assistant",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=0,
|
||||||
|
model=model,
|
||||||
|
object="chat.completion",
|
||||||
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mock = _MockChatCompletion(chat_completions)
|
||||||
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||||
|
tool_use_agent = AssistantAgent(
|
||||||
|
"tool_use_agent",
|
||||||
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||||
|
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||||
|
handoffs=[handoff],
|
||||||
|
)
|
||||||
|
response = await tool_use_agent.on_messages(
|
||||||
|
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
|
||||||
|
)
|
||||||
|
assert isinstance(response, HandoffMessage)
|
||||||
|
assert response.target == "agent2"
|
||||||
|
@ -10,6 +10,7 @@ from autogen_agentchat.agents import (
|
|||||||
AssistantAgent,
|
AssistantAgent,
|
||||||
BaseChatAgent,
|
BaseChatAgent,
|
||||||
CodeExecutorAgent,
|
CodeExecutorAgent,
|
||||||
|
Handoff,
|
||||||
)
|
)
|
||||||
from autogen_agentchat.logging import FileLogHandler
|
from autogen_agentchat.logging import FileLogHandler
|
||||||
from autogen_agentchat.messages import (
|
from autogen_agentchat.messages import (
|
||||||
@ -415,11 +416,11 @@ class _HandOffAgent(BaseChatAgent):
|
|||||||
self._next_agent = next_agent
|
self._next_agent = next_agent
|
||||||
|
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||||
return HandoffMessage(content=self._next_agent, source=self.name)
|
return HandoffMessage(content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_swarm() -> None:
|
async def test_swarm_handoff() -> None:
|
||||||
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
||||||
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||||
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||||
@ -428,8 +429,81 @@ async def test_swarm() -> None:
|
|||||||
result = await team.run("task", termination_condition=MaxMessageTermination(6))
|
result = await team.run("task", termination_condition=MaxMessageTermination(6))
|
||||||
assert len(result.messages) == 6
|
assert len(result.messages) == 6
|
||||||
assert result.messages[0].content == "task"
|
assert result.messages[0].content == "task"
|
||||||
assert result.messages[1].content == "third_agent"
|
assert result.messages[1].content == "Transferred to third_agent."
|
||||||
assert result.messages[2].content == "first_agent"
|
assert result.messages[2].content == "Transferred to first_agent."
|
||||||
assert result.messages[3].content == "second_agent"
|
assert result.messages[3].content == "Transferred to second_agent."
|
||||||
assert result.messages[4].content == "third_agent"
|
assert result.messages[4].content == "Transferred to third_agent."
|
||||||
assert result.messages[5].content == "first_agent"
|
assert result.messages[5].content == "Transferred to first_agent."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
model = "gpt-4o-2024-05-13"
|
||||||
|
chat_completions = [
|
||||||
|
ChatCompletion(
|
||||||
|
id="id1",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id="1",
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="handoff_to_agent2",
|
||||||
|
arguments=json.dumps({}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
role="assistant",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=0,
|
||||||
|
model=model,
|
||||||
|
object="chat.completion",
|
||||||
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||||
|
),
|
||||||
|
ChatCompletion(
|
||||||
|
id="id2",
|
||||||
|
choices=[
|
||||||
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
||||||
|
],
|
||||||
|
created=0,
|
||||||
|
model=model,
|
||||||
|
object="chat.completion",
|
||||||
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||||
|
),
|
||||||
|
ChatCompletion(
|
||||||
|
id="id2",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=0,
|
||||||
|
model=model,
|
||||||
|
object="chat.completion",
|
||||||
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mock = _MockChatCompletion(chat_completions)
|
||||||
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||||
|
|
||||||
|
agnet1 = AssistantAgent(
|
||||||
|
"agent1",
|
||||||
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||||
|
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
|
||||||
|
)
|
||||||
|
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
|
||||||
|
team = Swarm([agnet1, agent2])
|
||||||
|
result = await team.run("task", termination_condition=StopMessageTermination())
|
||||||
|
assert len(result.messages) == 5
|
||||||
|
assert result.messages[0].content == "task"
|
||||||
|
assert result.messages[1].content == "handoff to agent2"
|
||||||
|
assert result.messages[2].content == "Transferred to agent1."
|
||||||
|
assert result.messages[3].content == "Hello"
|
||||||
|
assert result.messages[4].content == "TERMINATE"
|
||||||
|
@ -32,7 +32,8 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from autogen_ext.models import OpenAIChatCompletionClient, UserMessage\n",
|
"from autogen_core.components.models import UserMessage\n",
|
||||||
|
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Create an OpenAI model client.\n",
|
"# Create an OpenAI model client.\n",
|
||||||
"model_client = OpenAIChatCompletionClient(\n",
|
"model_client = OpenAIChatCompletionClient(\n",
|
||||||
@ -500,7 +501,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "autogen_core",
|
"display_name": ".venv",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -514,7 +515,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.9"
|
"version": "3.11.5"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user