mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-10 06:44:20 +00:00
Formalize ChatAgent response as a dataclass with inner messages (#3990)
This commit is contained in:
parent
e63fd17ed5
commit
3d51ab76ae
@ -18,12 +18,16 @@ from autogen_core.components.tools import FunctionTool, Tool
|
|||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
from .. import EVENT_LOGGER_NAME
|
from .. import EVENT_LOGGER_NAME
|
||||||
|
from ..base import Response
|
||||||
from ..messages import (
|
from ..messages import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
HandoffMessage,
|
HandoffMessage,
|
||||||
|
InnerMessage,
|
||||||
ResetMessage,
|
ResetMessage,
|
||||||
StopMessage,
|
StopMessage,
|
||||||
TextMessage,
|
TextMessage,
|
||||||
|
ToolCallMessage,
|
||||||
|
ToolCallResultMessages,
|
||||||
)
|
)
|
||||||
from ._base_chat_agent import BaseChatAgent
|
from ._base_chat_agent import BaseChatAgent
|
||||||
|
|
||||||
@ -214,7 +218,7 @@ class AssistantAgent(BaseChatAgent):
|
|||||||
return [TextMessage, HandoffMessage, StopMessage]
|
return [TextMessage, HandoffMessage, StopMessage]
|
||||||
return [TextMessage, StopMessage]
|
return [TextMessage, StopMessage]
|
||||||
|
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||||
# Add messages to the model context.
|
# Add messages to the model context.
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ResetMessage):
|
if isinstance(msg, ResetMessage):
|
||||||
@ -222,6 +226,9 @@ class AssistantAgent(BaseChatAgent):
|
|||||||
else:
|
else:
|
||||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||||
|
|
||||||
|
# Inner messages.
|
||||||
|
inner_messages: List[InnerMessage] = []
|
||||||
|
|
||||||
# Generate an inference result based on the current model context.
|
# Generate an inference result based on the current model context.
|
||||||
llm_messages = self._system_messages + self._model_context
|
llm_messages = self._system_messages + self._model_context
|
||||||
result = await self._model_client.create(
|
result = await self._model_client.create(
|
||||||
@ -234,12 +241,16 @@ class AssistantAgent(BaseChatAgent):
|
|||||||
# Run tool calls until the model produces a string response.
|
# Run tool calls until the model produces a string response.
|
||||||
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
||||||
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
|
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
|
||||||
|
# Add the tool call message to the output.
|
||||||
|
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
|
||||||
|
|
||||||
# Execute the tool calls.
|
# Execute the tool calls.
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
||||||
)
|
)
|
||||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
|
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
|
||||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||||
|
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
|
||||||
|
|
||||||
# Detect handoff requests.
|
# Detect handoff requests.
|
||||||
handoffs: List[Handoff] = []
|
handoffs: List[Handoff] = []
|
||||||
@ -249,8 +260,13 @@ class AssistantAgent(BaseChatAgent):
|
|||||||
if len(handoffs) > 0:
|
if len(handoffs) > 0:
|
||||||
if len(handoffs) > 1:
|
if len(handoffs) > 1:
|
||||||
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
|
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
|
||||||
# Respond with a handoff message.
|
# Return the output messages to signal the handoff.
|
||||||
return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name)
|
return Response(
|
||||||
|
chat_message=HandoffMessage(
|
||||||
|
content=handoffs[0].message, target=handoffs[0].target, source=self.name
|
||||||
|
),
|
||||||
|
inner_messages=inner_messages,
|
||||||
|
)
|
||||||
|
|
||||||
# Generate an inference result based on the current model context.
|
# Generate an inference result based on the current model context.
|
||||||
result = await self._model_client.create(
|
result = await self._model_client.create(
|
||||||
@ -262,9 +278,13 @@ class AssistantAgent(BaseChatAgent):
|
|||||||
# Detect stop request.
|
# Detect stop request.
|
||||||
request_stop = "terminate" in result.content.strip().lower()
|
request_stop = "terminate" in result.content.strip().lower()
|
||||||
if request_stop:
|
if request_stop:
|
||||||
return StopMessage(content=result.content, source=self.name)
|
return Response(
|
||||||
|
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||||
|
)
|
||||||
|
|
||||||
return TextMessage(content=result.content, source=self.name)
|
return Response(
|
||||||
|
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
|
||||||
|
)
|
||||||
|
|
||||||
async def _execute_tool_call(
|
async def _execute_tool_call(
|
||||||
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
||||||
|
|||||||
@ -3,9 +3,8 @@ from typing import List, Sequence
|
|||||||
|
|
||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
|
|
||||||
from ..base import ChatAgent, TaskResult, TerminationCondition
|
from ..base import ChatAgent, Response, TaskResult, TerminationCondition
|
||||||
from ..messages import ChatMessage
|
from ..messages import ChatMessage, InnerMessage, TextMessage
|
||||||
from ..teams import RoundRobinGroupChat
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChatAgent(ChatAgent, ABC):
|
class BaseChatAgent(ChatAgent, ABC):
|
||||||
@ -37,8 +36,8 @@ class BaseChatAgent(ChatAgent, ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||||
"""Handle incoming messages and return a response message."""
|
"""Handles incoming messages and returns a response."""
|
||||||
...
|
...
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
@ -49,10 +48,12 @@ class BaseChatAgent(ChatAgent, ABC):
|
|||||||
termination_condition: TerminationCondition | None = None,
|
termination_condition: TerminationCondition | None = None,
|
||||||
) -> TaskResult:
|
) -> TaskResult:
|
||||||
"""Run the agent with the given task and return the result."""
|
"""Run the agent with the given task and return the result."""
|
||||||
group_chat = RoundRobinGroupChat(participants=[self])
|
if cancellation_token is None:
|
||||||
result = await group_chat.run(
|
cancellation_token = CancellationToken()
|
||||||
task=task,
|
first_message = TextMessage(content=task, source="user")
|
||||||
cancellation_token=cancellation_token,
|
response = await self.on_messages([first_message], cancellation_token)
|
||||||
termination_condition=termination_condition,
|
messages: List[InnerMessage | ChatMessage] = [first_message]
|
||||||
)
|
if response.inner_messages is not None:
|
||||||
return result
|
messages += response.inner_messages
|
||||||
|
messages.append(response.chat_message)
|
||||||
|
return TaskResult(messages=messages)
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import List, Sequence
|
|||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
|
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
|
||||||
|
|
||||||
|
from ..base import Response
|
||||||
from ..messages import ChatMessage, TextMessage
|
from ..messages import ChatMessage, TextMessage
|
||||||
from ._base_chat_agent import BaseChatAgent
|
from ._base_chat_agent import BaseChatAgent
|
||||||
|
|
||||||
@ -25,7 +26,7 @@ class CodeExecutorAgent(BaseChatAgent):
|
|||||||
"""The types of messages that the code executor agent produces."""
|
"""The types of messages that the code executor agent produces."""
|
||||||
return [TextMessage]
|
return [TextMessage]
|
||||||
|
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||||
# Extract code blocks from the messages.
|
# Extract code blocks from the messages.
|
||||||
code_blocks: List[CodeBlock] = []
|
code_blocks: List[CodeBlock] = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
@ -34,6 +35,6 @@ class CodeExecutorAgent(BaseChatAgent):
|
|||||||
if code_blocks:
|
if code_blocks:
|
||||||
# Execute the code blocks.
|
# Execute the code blocks.
|
||||||
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
|
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
|
||||||
return TextMessage(content=result.output, source=self.name)
|
return Response(chat_message=TextMessage(content=result.output, source=self.name))
|
||||||
else:
|
else:
|
||||||
return TextMessage(content="No code blocks found in the thread.", source=self.name)
|
return Response(chat_message=TextMessage(content="No code blocks found in the thread.", source=self.name))
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from ._chat_agent import ChatAgent
|
from ._chat_agent import ChatAgent, Response
|
||||||
from ._task import TaskResult, TaskRunner
|
from ._task import TaskResult, TaskRunner
|
||||||
from ._team import Team
|
from ._team import Team
|
||||||
from ._termination import TerminatedException, TerminationCondition
|
from ._termination import TerminatedException, TerminationCondition
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ChatAgent",
|
"ChatAgent",
|
||||||
|
"Response",
|
||||||
"Team",
|
"Team",
|
||||||
"TerminatedException",
|
"TerminatedException",
|
||||||
"TerminationCondition",
|
"TerminationCondition",
|
||||||
|
|||||||
@ -1,12 +1,24 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
from typing import List, Protocol, Sequence, runtime_checkable
|
from typing import List, Protocol, Sequence, runtime_checkable
|
||||||
|
|
||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
|
|
||||||
from ..messages import ChatMessage
|
from ..messages import ChatMessage, InnerMessage
|
||||||
from ._task import TaskResult, TaskRunner
|
from ._task import TaskResult, TaskRunner
|
||||||
from ._termination import TerminationCondition
|
from ._termination import TerminationCondition
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
|
class Response:
|
||||||
|
"""A response from calling :meth:`ChatAgent.on_messages`."""
|
||||||
|
|
||||||
|
chat_message: ChatMessage
|
||||||
|
"""A chat message produced by the agent as the response."""
|
||||||
|
|
||||||
|
inner_messages: List[InnerMessage] | None = None
|
||||||
|
"""Inner messages produced by the agent."""
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class ChatAgent(TaskRunner, Protocol):
|
class ChatAgent(TaskRunner, Protocol):
|
||||||
"""Protocol for a chat agent."""
|
"""Protocol for a chat agent."""
|
||||||
@ -29,8 +41,8 @@ class ChatAgent(TaskRunner, Protocol):
|
|||||||
"""The types of messages that the agent produces."""
|
"""The types of messages that the agent produces."""
|
||||||
...
|
...
|
||||||
|
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||||
"""Handle incoming messages and return a response message."""
|
"""Handles incoming messages and returns a response."""
|
||||||
...
|
...
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Protocol, Sequence
|
|||||||
|
|
||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
|
|
||||||
from ..messages import ChatMessage
|
from ..messages import ChatMessage, InnerMessage
|
||||||
from ._termination import TerminationCondition
|
from ._termination import TerminationCondition
|
||||||
|
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ from ._termination import TerminationCondition
|
|||||||
class TaskResult:
|
class TaskResult:
|
||||||
"""Result of running a task."""
|
"""Result of running a task."""
|
||||||
|
|
||||||
messages: Sequence[ChatMessage]
|
messages: Sequence[InnerMessage | ChatMessage]
|
||||||
"""Messages produced by the task."""
|
"""Messages produced by the task."""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from autogen_core.components import Image
|
from autogen_core.components import FunctionCall, Image
|
||||||
|
from autogen_core.components.models import FunctionExecutionResult
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@ -49,8 +50,26 @@ class ResetMessage(BaseMessage):
|
|||||||
"""The content for the reset message."""
|
"""The content for the reset message."""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallMessage(BaseMessage):
|
||||||
|
"""A message signaling the use of tools."""
|
||||||
|
|
||||||
|
content: List[FunctionCall]
|
||||||
|
"""The tool calls."""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallResultMessages(BaseMessage):
|
||||||
|
"""A message signaling the results of tool calls."""
|
||||||
|
|
||||||
|
content: List[FunctionExecutionResult]
|
||||||
|
"""The tool call results."""
|
||||||
|
|
||||||
|
|
||||||
|
InnerMessage = ToolCallMessage | ToolCallResultMessages
|
||||||
|
"""Messages for intra-agent monologues."""
|
||||||
|
|
||||||
|
|
||||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ResetMessage
|
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ResetMessage
|
||||||
"""A message used by agents in a team."""
|
"""Messages for agent-to-agent communication."""
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -60,5 +79,7 @@ __all__ = [
|
|||||||
"StopMessage",
|
"StopMessage",
|
||||||
"HandoffMessage",
|
"HandoffMessage",
|
||||||
"ResetMessage",
|
"ResetMessage",
|
||||||
|
"ToolCallMessage",
|
||||||
|
"ToolCallResultMessages",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from autogen_core.base import (
|
|||||||
from autogen_core.components import ClosureAgent, TypeSubscription
|
from autogen_core.components import ClosureAgent, TypeSubscription
|
||||||
|
|
||||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||||
from ...messages import ChatMessage, TextMessage
|
from ...messages import ChatMessage, InnerMessage, TextMessage
|
||||||
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
|
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
|
||||||
from ._base_group_chat_manager import BaseGroupChatManager
|
from ._base_group_chat_manager import BaseGroupChatManager
|
||||||
from ._chat_agent_container import ChatAgentContainer
|
from ._chat_agent_container import ChatAgentContainer
|
||||||
@ -56,12 +56,13 @@ class BaseGroupChat(Team, ABC):
|
|||||||
def _create_participant_factory(
|
def _create_participant_factory(
|
||||||
self,
|
self,
|
||||||
parent_topic_type: str,
|
parent_topic_type: str,
|
||||||
|
output_topic_type: str,
|
||||||
agent: ChatAgent,
|
agent: ChatAgent,
|
||||||
) -> Callable[[], ChatAgentContainer]:
|
) -> Callable[[], ChatAgentContainer]:
|
||||||
def _factory() -> ChatAgentContainer:
|
def _factory() -> ChatAgentContainer:
|
||||||
id = AgentInstantiationContext.current_agent_id()
|
id = AgentInstantiationContext.current_agent_id()
|
||||||
assert id == AgentId(type=agent.name, key=self._team_id)
|
assert id == AgentId(type=agent.name, key=self._team_id)
|
||||||
container = ChatAgentContainer(parent_topic_type, agent)
|
container = ChatAgentContainer(parent_topic_type, output_topic_type, agent)
|
||||||
assert container.id == id
|
assert container.id == id
|
||||||
return container
|
return container
|
||||||
|
|
||||||
@ -85,6 +86,7 @@ class BaseGroupChat(Team, ABC):
|
|||||||
group_chat_manager_topic_type = group_chat_manager_agent_type.type
|
group_chat_manager_topic_type = group_chat_manager_agent_type.type
|
||||||
group_topic_type = "round_robin_group_topic"
|
group_topic_type = "round_robin_group_topic"
|
||||||
team_topic_type = "team_topic"
|
team_topic_type = "team_topic"
|
||||||
|
output_topic_type = "output_topic"
|
||||||
|
|
||||||
# Register participants.
|
# Register participants.
|
||||||
participant_topic_types: List[str] = []
|
participant_topic_types: List[str] = []
|
||||||
@ -97,7 +99,7 @@ class BaseGroupChat(Team, ABC):
|
|||||||
await ChatAgentContainer.register(
|
await ChatAgentContainer.register(
|
||||||
runtime,
|
runtime,
|
||||||
type=agent_type,
|
type=agent_type,
|
||||||
factory=self._create_participant_factory(group_topic_type, participant),
|
factory=self._create_participant_factory(group_topic_type, output_topic_type, participant),
|
||||||
)
|
)
|
||||||
# Add subscriptions for the participant.
|
# Add subscriptions for the participant.
|
||||||
await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type))
|
await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type))
|
||||||
@ -129,22 +131,22 @@ class BaseGroupChat(Team, ABC):
|
|||||||
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
|
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
|
||||||
)
|
)
|
||||||
|
|
||||||
group_chat_messages: List[ChatMessage] = []
|
output_messages: List[InnerMessage | ChatMessage] = []
|
||||||
|
|
||||||
async def collect_group_chat_messages(
|
async def collect_output_messages(
|
||||||
_runtime: AgentRuntime,
|
_runtime: AgentRuntime,
|
||||||
id: AgentId,
|
id: AgentId,
|
||||||
message: GroupChatPublishEvent,
|
message: InnerMessage | ChatMessage,
|
||||||
ctx: MessageContext,
|
ctx: MessageContext,
|
||||||
) -> None:
|
) -> None:
|
||||||
group_chat_messages.append(message.agent_message)
|
output_messages.append(message)
|
||||||
|
|
||||||
await ClosureAgent.register(
|
await ClosureAgent.register(
|
||||||
runtime,
|
runtime,
|
||||||
type="collect_group_chat_messages",
|
type="collect_output_messages",
|
||||||
closure=collect_group_chat_messages,
|
closure=collect_output_messages,
|
||||||
subscriptions=lambda: [
|
subscriptions=lambda: [
|
||||||
TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages"),
|
TypeSubscription(topic_type=output_topic_type, agent_type="collect_output_messages"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -154,8 +156,10 @@ class BaseGroupChat(Team, ABC):
|
|||||||
# Run the team by publishing the task to the team topic and then requesting the result.
|
# Run the team by publishing the task to the team topic and then requesting the result.
|
||||||
team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
|
team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
|
||||||
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
|
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
|
||||||
|
first_chat_message = TextMessage(content=task, source="user")
|
||||||
|
output_messages.append(first_chat_message)
|
||||||
await runtime.publish_message(
|
await runtime.publish_message(
|
||||||
GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")),
|
GroupChatPublishEvent(agent_message=first_chat_message),
|
||||||
topic_id=team_topic_id,
|
topic_id=team_topic_id,
|
||||||
)
|
)
|
||||||
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
|
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
|
||||||
@ -164,4 +168,4 @@ class BaseGroupChat(Team, ABC):
|
|||||||
await runtime.stop_when_idle()
|
await runtime.stop_when_idle()
|
||||||
|
|
||||||
# Return the result.
|
# Return the result.
|
||||||
return TaskResult(messages=group_chat_messages)
|
return TaskResult(messages=output_messages)
|
||||||
|
|||||||
@ -16,12 +16,14 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
parent_topic_type (str): The topic type of the parent orchestrator.
|
parent_topic_type (str): The topic type of the parent orchestrator.
|
||||||
|
output_topic_type (str): The topic type for the output.
|
||||||
agent (ChatAgent): The agent to delegate message handling to.
|
agent (ChatAgent): The agent to delegate message handling to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None:
|
def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent) -> None:
|
||||||
super().__init__(description=agent.description)
|
super().__init__(description=agent.description)
|
||||||
self._parent_topic_type = parent_topic_type
|
self._parent_topic_type = parent_topic_type
|
||||||
|
self._output_topic_type = output_topic_type
|
||||||
self._agent = agent
|
self._agent = agent
|
||||||
self._message_buffer: List[ChatMessage] = []
|
self._message_buffer: List[ChatMessage] = []
|
||||||
|
|
||||||
@ -36,18 +38,27 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||||||
to the delegate agent and publish the response."""
|
to the delegate agent and publish the response."""
|
||||||
# Pass the messages in the buffer to the delegate agent.
|
# Pass the messages in the buffer to the delegate agent.
|
||||||
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
|
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
|
||||||
if not any(isinstance(response, msg_type) for msg_type in self._agent.produced_message_types):
|
if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. "
|
f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. "
|
||||||
f"Expected one of: {self._agent.produced_message_types}"
|
f"Expected one of: {self._agent.produced_message_types}. "
|
||||||
|
f"Check the agent's produced_message_types property."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Publish inner messages to the output topic.
|
||||||
|
if response.inner_messages is not None:
|
||||||
|
for inner_message in response.inner_messages:
|
||||||
|
await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||||
|
|
||||||
# Publish the response.
|
# Publish the response.
|
||||||
self._message_buffer.clear()
|
self._message_buffer.clear()
|
||||||
await self.publish_message(
|
await self.publish_message(
|
||||||
GroupChatPublishEvent(agent_message=response, source=self.id),
|
GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
|
||||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Publish the response to the output topic.
|
||||||
|
await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||||
|
|
||||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||||
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import pytest
|
|||||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||||
from autogen_agentchat.agents import AssistantAgent, Handoff
|
from autogen_agentchat.agents import AssistantAgent, Handoff
|
||||||
from autogen_agentchat.logging import FileLogHandler
|
from autogen_agentchat.logging import FileLogHandler
|
||||||
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
|
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages
|
||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
from autogen_core.components.tools import FunctionTool
|
from autogen_core.components.tools import FunctionTool
|
||||||
from autogen_ext.models import OpenAIChatCompletionClient
|
from autogen_ext.models import OpenAIChatCompletionClient
|
||||||
@ -111,10 +111,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||||
)
|
)
|
||||||
result = await tool_use_agent.run("task")
|
result = await tool_use_agent.run("task")
|
||||||
assert len(result.messages) == 3
|
assert len(result.messages) == 4
|
||||||
assert isinstance(result.messages[0], TextMessage)
|
assert isinstance(result.messages[0], TextMessage)
|
||||||
assert isinstance(result.messages[1], TextMessage)
|
assert isinstance(result.messages[1], ToolCallMessage)
|
||||||
assert isinstance(result.messages[2], StopMessage)
|
assert isinstance(result.messages[2], ToolCallResultMessages)
|
||||||
|
assert isinstance(result.messages[3], TextMessage)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -162,5 +163,5 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
response = await tool_use_agent.on_messages(
|
response = await tool_use_agent.on_messages(
|
||||||
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
|
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
|
||||||
)
|
)
|
||||||
assert isinstance(response, HandoffMessage)
|
assert isinstance(response.chat_message, HandoffMessage)
|
||||||
assert response.target == "agent2"
|
assert response.chat_message.target == "agent2"
|
||||||
|
|||||||
@ -12,12 +12,15 @@ from autogen_agentchat.agents import (
|
|||||||
CodeExecutorAgent,
|
CodeExecutorAgent,
|
||||||
Handoff,
|
Handoff,
|
||||||
)
|
)
|
||||||
|
from autogen_agentchat.base import Response
|
||||||
from autogen_agentchat.logging import FileLogHandler
|
from autogen_agentchat.logging import FileLogHandler
|
||||||
from autogen_agentchat.messages import (
|
from autogen_agentchat.messages import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
HandoffMessage,
|
HandoffMessage,
|
||||||
StopMessage,
|
StopMessage,
|
||||||
TextMessage,
|
TextMessage,
|
||||||
|
ToolCallMessage,
|
||||||
|
ToolCallResultMessages,
|
||||||
)
|
)
|
||||||
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
|
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
|
||||||
from autogen_agentchat.teams import (
|
from autogen_agentchat.teams import (
|
||||||
@ -66,14 +69,14 @@ class _EchoAgent(BaseChatAgent):
|
|||||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||||
return [TextMessage]
|
return [TextMessage]
|
||||||
|
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||||
if len(messages) > 0:
|
if len(messages) > 0:
|
||||||
assert isinstance(messages[0], TextMessage)
|
assert isinstance(messages[0], TextMessage)
|
||||||
self._last_message = messages[0].content
|
self._last_message = messages[0].content
|
||||||
return TextMessage(content=messages[0].content, source=self.name)
|
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
|
||||||
else:
|
else:
|
||||||
assert self._last_message is not None
|
assert self._last_message is not None
|
||||||
return TextMessage(content=self._last_message, source=self.name)
|
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
|
||||||
|
|
||||||
|
|
||||||
class _StopAgent(_EchoAgent):
|
class _StopAgent(_EchoAgent):
|
||||||
@ -86,11 +89,11 @@ class _StopAgent(_EchoAgent):
|
|||||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||||
return [TextMessage, StopMessage]
|
return [TextMessage, StopMessage]
|
||||||
|
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||||
self._count += 1
|
self._count += 1
|
||||||
if self._count < self._stop_at:
|
if self._count < self._stop_at:
|
||||||
return await super().on_messages(messages, cancellation_token)
|
return await super().on_messages(messages, cancellation_token)
|
||||||
return StopMessage(content="TERMINATE", source=self.name)
|
return Response(chat_message=StopMessage(content="TERMINATE", source=self.name))
|
||||||
|
|
||||||
|
|
||||||
def _pass_function(input: str) -> str:
|
def _pass_function(input: str) -> str:
|
||||||
@ -230,11 +233,13 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
|||||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(result.messages) == 4
|
assert len(result.messages) == 6
|
||||||
assert isinstance(result.messages[0], TextMessage) # task
|
assert isinstance(result.messages[0], TextMessage) # task
|
||||||
assert isinstance(result.messages[1], TextMessage) # tool use agent response
|
assert isinstance(result.messages[1], ToolCallMessage) # tool call
|
||||||
assert isinstance(result.messages[2], TextMessage) # echo agent response
|
assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result
|
||||||
assert isinstance(result.messages[3], StopMessage) # tool use agent response
|
assert isinstance(result.messages[3], TextMessage) # tool use agent response
|
||||||
|
assert isinstance(result.messages[4], TextMessage) # echo agent response
|
||||||
|
assert isinstance(result.messages[5], StopMessage) # tool use agent response
|
||||||
|
|
||||||
context = tool_use_agent._model_context # pyright: ignore
|
context = tool_use_agent._model_context # pyright: ignore
|
||||||
assert context[0].content == "Write a program that prints 'Hello, world!'"
|
assert context[0].content == "Write a program that prints 'Hello, world!'"
|
||||||
@ -427,8 +432,12 @@ class _HandOffAgent(BaseChatAgent):
|
|||||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||||
return [HandoffMessage]
|
return [HandoffMessage]
|
||||||
|
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||||
return HandoffMessage(content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name)
|
return Response(
|
||||||
|
chat_message=HandoffMessage(
|
||||||
|
content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -513,9 +522,11 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
|
|||||||
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
|
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
|
||||||
team = Swarm([agnet1, agent2])
|
team = Swarm([agnet1, agent2])
|
||||||
result = await team.run("task", termination_condition=StopMessageTermination())
|
result = await team.run("task", termination_condition=StopMessageTermination())
|
||||||
assert len(result.messages) == 5
|
assert len(result.messages) == 7
|
||||||
assert result.messages[0].content == "task"
|
assert result.messages[0].content == "task"
|
||||||
assert result.messages[1].content == "handoff to agent2"
|
assert isinstance(result.messages[1], ToolCallMessage)
|
||||||
assert result.messages[2].content == "Transferred to agent1."
|
assert isinstance(result.messages[2], ToolCallResultMessages)
|
||||||
assert result.messages[3].content == "Hello"
|
assert result.messages[3].content == "handoff to agent2"
|
||||||
assert result.messages[4].content == "TERMINATE"
|
assert result.messages[4].content == "Transferred to agent1."
|
||||||
|
assert result.messages[5].content == "Hello"
|
||||||
|
assert result.messages[6].content == "TERMINATE"
|
||||||
|
|||||||
@ -251,6 +251,7 @@
|
|||||||
"from typing import List, Sequence\n",
|
"from typing import List, Sequence\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||||
|
"from autogen_agentchat.base import Response\n",
|
||||||
"from autogen_agentchat.messages import (\n",
|
"from autogen_agentchat.messages import (\n",
|
||||||
" ChatMessage,\n",
|
" ChatMessage,\n",
|
||||||
" StopMessage,\n",
|
" StopMessage,\n",
|
||||||
@ -266,11 +267,11 @@
|
|||||||
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
|
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
|
||||||
" return [TextMessage, StopMessage]\n",
|
" return [TextMessage, StopMessage]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
|
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||||
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
|
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
|
||||||
" if \"TERMINATE\" in user_input:\n",
|
" if \"TERMINATE\" in user_input:\n",
|
||||||
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
|
" return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
|
||||||
" return TextMessage(content=user_input, source=self.name)\n",
|
" return Response(chat_message=TextMessage(content=user_input, source=self.name))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"user_proxy_agent = UserProxyAgent(name=\"user_proxy_agent\")\n",
|
"user_proxy_agent = UserProxyAgent(name=\"user_proxy_agent\")\n",
|
||||||
|
|||||||
@ -45,6 +45,7 @@
|
|||||||
" CodingAssistantAgent,\n",
|
" CodingAssistantAgent,\n",
|
||||||
" ToolUseAssistantAgent,\n",
|
" ToolUseAssistantAgent,\n",
|
||||||
")\n",
|
")\n",
|
||||||
|
"from autogen_agentchat.base import Response\n",
|
||||||
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
|
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
|
||||||
"from autogen_agentchat.task import StopMessageTermination\n",
|
"from autogen_agentchat.task import StopMessageTermination\n",
|
||||||
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
||||||
@ -75,11 +76,11 @@
|
|||||||
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
|
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
|
||||||
" return [TextMessage, StopMessage]\n",
|
" return [TextMessage, StopMessage]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
|
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||||
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
|
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
|
||||||
" if \"TERMINATE\" in user_input:\n",
|
" if \"TERMINATE\" in user_input:\n",
|
||||||
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
|
" return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
|
||||||
" return TextMessage(content=user_input, source=self.name)"
|
" return Response(chat_message=TextMessage(content=user_input, source=self.name))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user