mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-11-04 03:39:52 +00:00 
			
		
		
		
	Formalize ChatAgent response as a dataclass with inner messages (#3990)
				
					
				
			This commit is contained in:
		
							parent
							
								
									e63fd17ed5
								
							
						
					
					
						commit
						3d51ab76ae
					
				@ -18,12 +18,16 @@ from autogen_core.components.tools import FunctionTool, Tool
 | 
			
		||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
 | 
			
		||||
 | 
			
		||||
from .. import EVENT_LOGGER_NAME
 | 
			
		||||
from ..base import Response
 | 
			
		||||
from ..messages import (
 | 
			
		||||
    ChatMessage,
 | 
			
		||||
    HandoffMessage,
 | 
			
		||||
    InnerMessage,
 | 
			
		||||
    ResetMessage,
 | 
			
		||||
    StopMessage,
 | 
			
		||||
    TextMessage,
 | 
			
		||||
    ToolCallMessage,
 | 
			
		||||
    ToolCallResultMessages,
 | 
			
		||||
)
 | 
			
		||||
from ._base_chat_agent import BaseChatAgent
 | 
			
		||||
 | 
			
		||||
@ -214,7 +218,7 @@ class AssistantAgent(BaseChatAgent):
 | 
			
		||||
            return [TextMessage, HandoffMessage, StopMessage]
 | 
			
		||||
        return [TextMessage, StopMessage]
 | 
			
		||||
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
 | 
			
		||||
        # Add messages to the model context.
 | 
			
		||||
        for msg in messages:
 | 
			
		||||
            if isinstance(msg, ResetMessage):
 | 
			
		||||
@ -222,6 +226,9 @@ class AssistantAgent(BaseChatAgent):
 | 
			
		||||
            else:
 | 
			
		||||
                self._model_context.append(UserMessage(content=msg.content, source=msg.source))
 | 
			
		||||
 | 
			
		||||
        # Inner messages.
 | 
			
		||||
        inner_messages: List[InnerMessage] = []
 | 
			
		||||
 | 
			
		||||
        # Generate an inference result based on the current model context.
 | 
			
		||||
        llm_messages = self._system_messages + self._model_context
 | 
			
		||||
        result = await self._model_client.create(
 | 
			
		||||
@ -234,12 +241,16 @@ class AssistantAgent(BaseChatAgent):
 | 
			
		||||
        # Run tool calls until the model produces a string response.
 | 
			
		||||
        while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
 | 
			
		||||
            event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
 | 
			
		||||
            # Add the tool call message to the output.
 | 
			
		||||
            inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
 | 
			
		||||
 | 
			
		||||
            # Execute the tool calls.
 | 
			
		||||
            results = await asyncio.gather(
 | 
			
		||||
                *[self._execute_tool_call(call, cancellation_token) for call in result.content]
 | 
			
		||||
            )
 | 
			
		||||
            event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
 | 
			
		||||
            self._model_context.append(FunctionExecutionResultMessage(content=results))
 | 
			
		||||
            inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
 | 
			
		||||
 | 
			
		||||
            # Detect handoff requests.
 | 
			
		||||
            handoffs: List[Handoff] = []
 | 
			
		||||
@ -249,8 +260,13 @@ class AssistantAgent(BaseChatAgent):
 | 
			
		||||
            if len(handoffs) > 0:
 | 
			
		||||
                if len(handoffs) > 1:
 | 
			
		||||
                    raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
 | 
			
		||||
                # Respond with a handoff message.
 | 
			
		||||
                return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name)
 | 
			
		||||
                # Return the output messages to signal the handoff.
 | 
			
		||||
                return Response(
 | 
			
		||||
                    chat_message=HandoffMessage(
 | 
			
		||||
                        content=handoffs[0].message, target=handoffs[0].target, source=self.name
 | 
			
		||||
                    ),
 | 
			
		||||
                    inner_messages=inner_messages,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # Generate an inference result based on the current model context.
 | 
			
		||||
            result = await self._model_client.create(
 | 
			
		||||
@ -262,9 +278,13 @@ class AssistantAgent(BaseChatAgent):
 | 
			
		||||
        # Detect stop request.
 | 
			
		||||
        request_stop = "terminate" in result.content.strip().lower()
 | 
			
		||||
        if request_stop:
 | 
			
		||||
            return StopMessage(content=result.content, source=self.name)
 | 
			
		||||
            return Response(
 | 
			
		||||
                chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return TextMessage(content=result.content, source=self.name)
 | 
			
		||||
        return Response(
 | 
			
		||||
            chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def _execute_tool_call(
 | 
			
		||||
        self, tool_call: FunctionCall, cancellation_token: CancellationToken
 | 
			
		||||
 | 
			
		||||
@ -3,9 +3,8 @@ from typing import List, Sequence
 | 
			
		||||
 | 
			
		||||
from autogen_core.base import CancellationToken
 | 
			
		||||
 | 
			
		||||
from ..base import ChatAgent, TaskResult, TerminationCondition
 | 
			
		||||
from ..messages import ChatMessage
 | 
			
		||||
from ..teams import RoundRobinGroupChat
 | 
			
		||||
from ..base import ChatAgent, Response, TaskResult, TerminationCondition
 | 
			
		||||
from ..messages import ChatMessage, InnerMessage, TextMessage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseChatAgent(ChatAgent, ABC):
 | 
			
		||||
@ -37,8 +36,8 @@ class BaseChatAgent(ChatAgent, ABC):
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
 | 
			
		||||
        """Handle incoming messages and return a response message."""
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
 | 
			
		||||
        """Handles incoming messages and returns a response."""
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    async def run(
 | 
			
		||||
@ -49,10 +48,12 @@ class BaseChatAgent(ChatAgent, ABC):
 | 
			
		||||
        termination_condition: TerminationCondition | None = None,
 | 
			
		||||
    ) -> TaskResult:
 | 
			
		||||
        """Run the agent with the given task and return the result."""
 | 
			
		||||
        group_chat = RoundRobinGroupChat(participants=[self])
 | 
			
		||||
        result = await group_chat.run(
 | 
			
		||||
            task=task,
 | 
			
		||||
            cancellation_token=cancellation_token,
 | 
			
		||||
            termination_condition=termination_condition,
 | 
			
		||||
        )
 | 
			
		||||
        return result
 | 
			
		||||
        if cancellation_token is None:
 | 
			
		||||
            cancellation_token = CancellationToken()
 | 
			
		||||
        first_message = TextMessage(content=task, source="user")
 | 
			
		||||
        response = await self.on_messages([first_message], cancellation_token)
 | 
			
		||||
        messages: List[InnerMessage | ChatMessage] = [first_message]
 | 
			
		||||
        if response.inner_messages is not None:
 | 
			
		||||
            messages += response.inner_messages
 | 
			
		||||
        messages.append(response.chat_message)
 | 
			
		||||
        return TaskResult(messages=messages)
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@ from typing import List, Sequence
 | 
			
		||||
from autogen_core.base import CancellationToken
 | 
			
		||||
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
 | 
			
		||||
 | 
			
		||||
from ..base import Response
 | 
			
		||||
from ..messages import ChatMessage, TextMessage
 | 
			
		||||
from ._base_chat_agent import BaseChatAgent
 | 
			
		||||
 | 
			
		||||
@ -25,7 +26,7 @@ class CodeExecutorAgent(BaseChatAgent):
 | 
			
		||||
        """The types of messages that the code executor agent produces."""
 | 
			
		||||
        return [TextMessage]
 | 
			
		||||
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
 | 
			
		||||
        # Extract code blocks from the messages.
 | 
			
		||||
        code_blocks: List[CodeBlock] = []
 | 
			
		||||
        for msg in messages:
 | 
			
		||||
@ -34,6 +35,6 @@ class CodeExecutorAgent(BaseChatAgent):
 | 
			
		||||
        if code_blocks:
 | 
			
		||||
            # Execute the code blocks.
 | 
			
		||||
            result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
 | 
			
		||||
            return TextMessage(content=result.output, source=self.name)
 | 
			
		||||
            return Response(chat_message=TextMessage(content=result.output, source=self.name))
 | 
			
		||||
        else:
 | 
			
		||||
            return TextMessage(content="No code blocks found in the thread.", source=self.name)
 | 
			
		||||
            return Response(chat_message=TextMessage(content="No code blocks found in the thread.", source=self.name))
 | 
			
		||||
 | 
			
		||||
@ -1,10 +1,11 @@
 | 
			
		||||
from ._chat_agent import ChatAgent
 | 
			
		||||
from ._chat_agent import ChatAgent, Response
 | 
			
		||||
from ._task import TaskResult, TaskRunner
 | 
			
		||||
from ._team import Team
 | 
			
		||||
from ._termination import TerminatedException, TerminationCondition
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    "ChatAgent",
 | 
			
		||||
    "Response",
 | 
			
		||||
    "Team",
 | 
			
		||||
    "TerminatedException",
 | 
			
		||||
    "TerminationCondition",
 | 
			
		||||
 | 
			
		||||
@ -1,12 +1,24 @@
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import List, Protocol, Sequence, runtime_checkable
 | 
			
		||||
 | 
			
		||||
from autogen_core.base import CancellationToken
 | 
			
		||||
 | 
			
		||||
from ..messages import ChatMessage
 | 
			
		||||
from ..messages import ChatMessage, InnerMessage
 | 
			
		||||
from ._task import TaskResult, TaskRunner
 | 
			
		||||
from ._termination import TerminationCondition
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass(kw_only=True)
 | 
			
		||||
class Response:
 | 
			
		||||
    """A response from calling :meth:`ChatAgent.on_messages`."""
 | 
			
		||||
 | 
			
		||||
    chat_message: ChatMessage
 | 
			
		||||
    """A chat message produced by the agent as the response."""
 | 
			
		||||
 | 
			
		||||
    inner_messages: List[InnerMessage] | None = None
 | 
			
		||||
    """Inner messages produced by the agent."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@runtime_checkable
 | 
			
		||||
class ChatAgent(TaskRunner, Protocol):
 | 
			
		||||
    """Protocol for a chat agent."""
 | 
			
		||||
@ -29,8 +41,8 @@ class ChatAgent(TaskRunner, Protocol):
 | 
			
		||||
        """The types of messages that the agent produces."""
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
 | 
			
		||||
        """Handle incoming messages and return a response message."""
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
 | 
			
		||||
        """Handles incoming messages and returns a response."""
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    async def run(
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,7 @@ from typing import Protocol, Sequence
 | 
			
		||||
 | 
			
		||||
from autogen_core.base import CancellationToken
 | 
			
		||||
 | 
			
		||||
from ..messages import ChatMessage
 | 
			
		||||
from ..messages import ChatMessage, InnerMessage
 | 
			
		||||
from ._termination import TerminationCondition
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ from ._termination import TerminationCondition
 | 
			
		||||
class TaskResult:
 | 
			
		||||
    """Result of running a task."""
 | 
			
		||||
 | 
			
		||||
    messages: Sequence[ChatMessage]
 | 
			
		||||
    messages: Sequence[InnerMessage | ChatMessage]
 | 
			
		||||
    """Messages produced by the task."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
from autogen_core.components import Image
 | 
			
		||||
from autogen_core.components import FunctionCall, Image
 | 
			
		||||
from autogen_core.components.models import FunctionExecutionResult
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,8 +50,26 @@ class ResetMessage(BaseMessage):
 | 
			
		||||
    """The content for the reset message."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ToolCallMessage(BaseMessage):
 | 
			
		||||
    """A message signaling the use of tools."""
 | 
			
		||||
 | 
			
		||||
    content: List[FunctionCall]
 | 
			
		||||
    """The tool calls."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ToolCallResultMessages(BaseMessage):
 | 
			
		||||
    """A message signaling the results of tool calls."""
 | 
			
		||||
 | 
			
		||||
    content: List[FunctionExecutionResult]
 | 
			
		||||
    """The tool call results."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
InnerMessage = ToolCallMessage | ToolCallResultMessages
 | 
			
		||||
"""Messages for intra-agent monologues."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ResetMessage
 | 
			
		||||
"""A message used by agents in a team."""
 | 
			
		||||
"""Messages for agent-to-agent communication."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
@ -60,5 +79,7 @@ __all__ = [
 | 
			
		||||
    "StopMessage",
 | 
			
		||||
    "HandoffMessage",
 | 
			
		||||
    "ResetMessage",
 | 
			
		||||
    "ToolCallMessage",
 | 
			
		||||
    "ToolCallResultMessages",
 | 
			
		||||
    "ChatMessage",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@ from autogen_core.base import (
 | 
			
		||||
from autogen_core.components import ClosureAgent, TypeSubscription
 | 
			
		||||
 | 
			
		||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
 | 
			
		||||
from ...messages import ChatMessage, TextMessage
 | 
			
		||||
from ...messages import ChatMessage, InnerMessage, TextMessage
 | 
			
		||||
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
 | 
			
		||||
from ._base_group_chat_manager import BaseGroupChatManager
 | 
			
		||||
from ._chat_agent_container import ChatAgentContainer
 | 
			
		||||
@ -56,12 +56,13 @@ class BaseGroupChat(Team, ABC):
 | 
			
		||||
    def _create_participant_factory(
 | 
			
		||||
        self,
 | 
			
		||||
        parent_topic_type: str,
 | 
			
		||||
        output_topic_type: str,
 | 
			
		||||
        agent: ChatAgent,
 | 
			
		||||
    ) -> Callable[[], ChatAgentContainer]:
 | 
			
		||||
        def _factory() -> ChatAgentContainer:
 | 
			
		||||
            id = AgentInstantiationContext.current_agent_id()
 | 
			
		||||
            assert id == AgentId(type=agent.name, key=self._team_id)
 | 
			
		||||
            container = ChatAgentContainer(parent_topic_type, agent)
 | 
			
		||||
            container = ChatAgentContainer(parent_topic_type, output_topic_type, agent)
 | 
			
		||||
            assert container.id == id
 | 
			
		||||
            return container
 | 
			
		||||
 | 
			
		||||
@ -85,6 +86,7 @@ class BaseGroupChat(Team, ABC):
 | 
			
		||||
        group_chat_manager_topic_type = group_chat_manager_agent_type.type
 | 
			
		||||
        group_topic_type = "round_robin_group_topic"
 | 
			
		||||
        team_topic_type = "team_topic"
 | 
			
		||||
        output_topic_type = "output_topic"
 | 
			
		||||
 | 
			
		||||
        # Register participants.
 | 
			
		||||
        participant_topic_types: List[str] = []
 | 
			
		||||
@ -97,7 +99,7 @@ class BaseGroupChat(Team, ABC):
 | 
			
		||||
            await ChatAgentContainer.register(
 | 
			
		||||
                runtime,
 | 
			
		||||
                type=agent_type,
 | 
			
		||||
                factory=self._create_participant_factory(group_topic_type, participant),
 | 
			
		||||
                factory=self._create_participant_factory(group_topic_type, output_topic_type, participant),
 | 
			
		||||
            )
 | 
			
		||||
            # Add subscriptions for the participant.
 | 
			
		||||
            await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type))
 | 
			
		||||
@ -129,22 +131,22 @@ class BaseGroupChat(Team, ABC):
 | 
			
		||||
            TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        group_chat_messages: List[ChatMessage] = []
 | 
			
		||||
        output_messages: List[InnerMessage | ChatMessage] = []
 | 
			
		||||
 | 
			
		||||
        async def collect_group_chat_messages(
 | 
			
		||||
        async def collect_output_messages(
 | 
			
		||||
            _runtime: AgentRuntime,
 | 
			
		||||
            id: AgentId,
 | 
			
		||||
            message: GroupChatPublishEvent,
 | 
			
		||||
            message: InnerMessage | ChatMessage,
 | 
			
		||||
            ctx: MessageContext,
 | 
			
		||||
        ) -> None:
 | 
			
		||||
            group_chat_messages.append(message.agent_message)
 | 
			
		||||
            output_messages.append(message)
 | 
			
		||||
 | 
			
		||||
        await ClosureAgent.register(
 | 
			
		||||
            runtime,
 | 
			
		||||
            type="collect_group_chat_messages",
 | 
			
		||||
            closure=collect_group_chat_messages,
 | 
			
		||||
            type="collect_output_messages",
 | 
			
		||||
            closure=collect_output_messages,
 | 
			
		||||
            subscriptions=lambda: [
 | 
			
		||||
                TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages"),
 | 
			
		||||
                TypeSubscription(topic_type=output_topic_type, agent_type="collect_output_messages"),
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -154,8 +156,10 @@ class BaseGroupChat(Team, ABC):
 | 
			
		||||
        # Run the team by publishing the task to the team topic and then requesting the result.
 | 
			
		||||
        team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
 | 
			
		||||
        group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
 | 
			
		||||
        first_chat_message = TextMessage(content=task, source="user")
 | 
			
		||||
        output_messages.append(first_chat_message)
 | 
			
		||||
        await runtime.publish_message(
 | 
			
		||||
            GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")),
 | 
			
		||||
            GroupChatPublishEvent(agent_message=first_chat_message),
 | 
			
		||||
            topic_id=team_topic_id,
 | 
			
		||||
        )
 | 
			
		||||
        await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
 | 
			
		||||
@ -164,4 +168,4 @@ class BaseGroupChat(Team, ABC):
 | 
			
		||||
        await runtime.stop_when_idle()
 | 
			
		||||
 | 
			
		||||
        # Return the result.
 | 
			
		||||
        return TaskResult(messages=group_chat_messages)
 | 
			
		||||
        return TaskResult(messages=output_messages)
 | 
			
		||||
 | 
			
		||||
@ -16,12 +16,14 @@ class ChatAgentContainer(SequentialRoutedAgent):
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        parent_topic_type (str): The topic type of the parent orchestrator.
 | 
			
		||||
        output_topic_type (str): The topic type for the output.
 | 
			
		||||
        agent (ChatAgent): The agent to delegate message handling to.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None:
 | 
			
		||||
    def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent) -> None:
 | 
			
		||||
        super().__init__(description=agent.description)
 | 
			
		||||
        self._parent_topic_type = parent_topic_type
 | 
			
		||||
        self._output_topic_type = output_topic_type
 | 
			
		||||
        self._agent = agent
 | 
			
		||||
        self._message_buffer: List[ChatMessage] = []
 | 
			
		||||
 | 
			
		||||
@ -36,18 +38,27 @@ class ChatAgentContainer(SequentialRoutedAgent):
 | 
			
		||||
        to the delegate agent and publish the response."""
 | 
			
		||||
        # Pass the messages in the buffer to the delegate agent.
 | 
			
		||||
        response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
 | 
			
		||||
        if not any(isinstance(response, msg_type) for msg_type in self._agent.produced_message_types):
 | 
			
		||||
        if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. "
 | 
			
		||||
                f"Expected one of: {self._agent.produced_message_types}"
 | 
			
		||||
                f"Expected one of: {self._agent.produced_message_types}. "
 | 
			
		||||
                f"Check the agent's produced_message_types property."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # Publish inner messages to the output topic.
 | 
			
		||||
        if response.inner_messages is not None:
 | 
			
		||||
            for inner_message in response.inner_messages:
 | 
			
		||||
                await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type))
 | 
			
		||||
 | 
			
		||||
        # Publish the response.
 | 
			
		||||
        self._message_buffer.clear()
 | 
			
		||||
        await self.publish_message(
 | 
			
		||||
            GroupChatPublishEvent(agent_message=response, source=self.id),
 | 
			
		||||
            GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
 | 
			
		||||
            topic_id=DefaultTopicId(type=self._parent_topic_type),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Publish the response to the output topic.
 | 
			
		||||
        await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type))
 | 
			
		||||
 | 
			
		||||
    async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
 | 
			
		||||
        raise ValueError(f"Unhandled message in agent container: {type(message)}")
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,7 @@ import pytest
 | 
			
		||||
from autogen_agentchat import EVENT_LOGGER_NAME
 | 
			
		||||
from autogen_agentchat.agents import AssistantAgent, Handoff
 | 
			
		||||
from autogen_agentchat.logging import FileLogHandler
 | 
			
		||||
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
 | 
			
		||||
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages
 | 
			
		||||
from autogen_core.base import CancellationToken
 | 
			
		||||
from autogen_core.components.tools import FunctionTool
 | 
			
		||||
from autogen_ext.models import OpenAIChatCompletionClient
 | 
			
		||||
@ -111,10 +111,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
 | 
			
		||||
        tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
 | 
			
		||||
    )
 | 
			
		||||
    result = await tool_use_agent.run("task")
 | 
			
		||||
    assert len(result.messages) == 3
 | 
			
		||||
    assert len(result.messages) == 4
 | 
			
		||||
    assert isinstance(result.messages[0], TextMessage)
 | 
			
		||||
    assert isinstance(result.messages[1], TextMessage)
 | 
			
		||||
    assert isinstance(result.messages[2], StopMessage)
 | 
			
		||||
    assert isinstance(result.messages[1], ToolCallMessage)
 | 
			
		||||
    assert isinstance(result.messages[2], ToolCallResultMessages)
 | 
			
		||||
    assert isinstance(result.messages[3], TextMessage)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.asyncio
 | 
			
		||||
@ -162,5 +163,5 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
 | 
			
		||||
    response = await tool_use_agent.on_messages(
 | 
			
		||||
        [TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
 | 
			
		||||
    )
 | 
			
		||||
    assert isinstance(response, HandoffMessage)
 | 
			
		||||
    assert response.target == "agent2"
 | 
			
		||||
    assert isinstance(response.chat_message, HandoffMessage)
 | 
			
		||||
    assert response.chat_message.target == "agent2"
 | 
			
		||||
 | 
			
		||||
@ -12,12 +12,15 @@ from autogen_agentchat.agents import (
 | 
			
		||||
    CodeExecutorAgent,
 | 
			
		||||
    Handoff,
 | 
			
		||||
)
 | 
			
		||||
from autogen_agentchat.base import Response
 | 
			
		||||
from autogen_agentchat.logging import FileLogHandler
 | 
			
		||||
from autogen_agentchat.messages import (
 | 
			
		||||
    ChatMessage,
 | 
			
		||||
    HandoffMessage,
 | 
			
		||||
    StopMessage,
 | 
			
		||||
    TextMessage,
 | 
			
		||||
    ToolCallMessage,
 | 
			
		||||
    ToolCallResultMessages,
 | 
			
		||||
)
 | 
			
		||||
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
 | 
			
		||||
from autogen_agentchat.teams import (
 | 
			
		||||
@ -66,14 +69,14 @@ class _EchoAgent(BaseChatAgent):
 | 
			
		||||
    def produced_message_types(self) -> List[type[ChatMessage]]:
 | 
			
		||||
        return [TextMessage]
 | 
			
		||||
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
 | 
			
		||||
        if len(messages) > 0:
 | 
			
		||||
            assert isinstance(messages[0], TextMessage)
 | 
			
		||||
            self._last_message = messages[0].content
 | 
			
		||||
            return TextMessage(content=messages[0].content, source=self.name)
 | 
			
		||||
            return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
 | 
			
		||||
        else:
 | 
			
		||||
            assert self._last_message is not None
 | 
			
		||||
            return TextMessage(content=self._last_message, source=self.name)
 | 
			
		||||
            return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _StopAgent(_EchoAgent):
 | 
			
		||||
@ -86,11 +89,11 @@ class _StopAgent(_EchoAgent):
 | 
			
		||||
    def produced_message_types(self) -> List[type[ChatMessage]]:
 | 
			
		||||
        return [TextMessage, StopMessage]
 | 
			
		||||
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
 | 
			
		||||
        self._count += 1
 | 
			
		||||
        if self._count < self._stop_at:
 | 
			
		||||
            return await super().on_messages(messages, cancellation_token)
 | 
			
		||||
        return StopMessage(content="TERMINATE", source=self.name)
 | 
			
		||||
        return Response(chat_message=StopMessage(content="TERMINATE", source=self.name))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _pass_function(input: str) -> str:
 | 
			
		||||
@ -230,11 +233,13 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
 | 
			
		||||
        "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert len(result.messages) == 4
 | 
			
		||||
    assert len(result.messages) == 6
 | 
			
		||||
    assert isinstance(result.messages[0], TextMessage)  # task
 | 
			
		||||
    assert isinstance(result.messages[1], TextMessage)  # tool use agent response
 | 
			
		||||
    assert isinstance(result.messages[2], TextMessage)  # echo agent response
 | 
			
		||||
    assert isinstance(result.messages[3], StopMessage)  # tool use agent response
 | 
			
		||||
    assert isinstance(result.messages[1], ToolCallMessage)  # tool call
 | 
			
		||||
    assert isinstance(result.messages[2], ToolCallResultMessages)  # tool call result
 | 
			
		||||
    assert isinstance(result.messages[3], TextMessage)  # tool use agent response
 | 
			
		||||
    assert isinstance(result.messages[4], TextMessage)  # echo agent response
 | 
			
		||||
    assert isinstance(result.messages[5], StopMessage)  # tool use agent response
 | 
			
		||||
 | 
			
		||||
    context = tool_use_agent._model_context  # pyright: ignore
 | 
			
		||||
    assert context[0].content == "Write a program that prints 'Hello, world!'"
 | 
			
		||||
@ -427,8 +432,12 @@ class _HandOffAgent(BaseChatAgent):
 | 
			
		||||
    def produced_message_types(self) -> List[type[ChatMessage]]:
 | 
			
		||||
        return [HandoffMessage]
 | 
			
		||||
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
 | 
			
		||||
        return HandoffMessage(content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name)
 | 
			
		||||
    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
 | 
			
		||||
        return Response(
 | 
			
		||||
            chat_message=HandoffMessage(
 | 
			
		||||
                content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.asyncio
 | 
			
		||||
@ -513,9 +522,11 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
 | 
			
		||||
    agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
 | 
			
		||||
    team = Swarm([agnet1, agent2])
 | 
			
		||||
    result = await team.run("task", termination_condition=StopMessageTermination())
 | 
			
		||||
    assert len(result.messages) == 5
 | 
			
		||||
    assert len(result.messages) == 7
 | 
			
		||||
    assert result.messages[0].content == "task"
 | 
			
		||||
    assert result.messages[1].content == "handoff to agent2"
 | 
			
		||||
    assert result.messages[2].content == "Transferred to agent1."
 | 
			
		||||
    assert result.messages[3].content == "Hello"
 | 
			
		||||
    assert result.messages[4].content == "TERMINATE"
 | 
			
		||||
    assert isinstance(result.messages[1], ToolCallMessage)
 | 
			
		||||
    assert isinstance(result.messages[2], ToolCallResultMessages)
 | 
			
		||||
    assert result.messages[3].content == "handoff to agent2"
 | 
			
		||||
    assert result.messages[4].content == "Transferred to agent1."
 | 
			
		||||
    assert result.messages[5].content == "Hello"
 | 
			
		||||
    assert result.messages[6].content == "TERMINATE"
 | 
			
		||||
 | 
			
		||||
@ -251,6 +251,7 @@
 | 
			
		||||
    "from typing import List, Sequence\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "from autogen_agentchat.agents import BaseChatAgent\n",
 | 
			
		||||
    "from autogen_agentchat.base import Response\n",
 | 
			
		||||
    "from autogen_agentchat.messages import (\n",
 | 
			
		||||
    "    ChatMessage,\n",
 | 
			
		||||
    "    StopMessage,\n",
 | 
			
		||||
@ -266,11 +267,11 @@
 | 
			
		||||
    "    def produced_message_types(self) -> List[type[ChatMessage]]:\n",
 | 
			
		||||
    "        return [TextMessage, StopMessage]\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
 | 
			
		||||
    "    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
 | 
			
		||||
    "        user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
 | 
			
		||||
    "        if \"TERMINATE\" in user_input:\n",
 | 
			
		||||
    "            return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
 | 
			
		||||
    "        return TextMessage(content=user_input, source=self.name)\n",
 | 
			
		||||
    "            return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
 | 
			
		||||
    "        return Response(chat_message=TextMessage(content=user_input, source=self.name))\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "user_proxy_agent = UserProxyAgent(name=\"user_proxy_agent\")\n",
 | 
			
		||||
 | 
			
		||||
@ -45,6 +45,7 @@
 | 
			
		||||
    "    CodingAssistantAgent,\n",
 | 
			
		||||
    "    ToolUseAssistantAgent,\n",
 | 
			
		||||
    ")\n",
 | 
			
		||||
    "from autogen_agentchat.base import Response\n",
 | 
			
		||||
    "from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
 | 
			
		||||
    "from autogen_agentchat.task import StopMessageTermination\n",
 | 
			
		||||
    "from autogen_agentchat.teams import SelectorGroupChat\n",
 | 
			
		||||
@ -75,11 +76,11 @@
 | 
			
		||||
    "    def produced_message_types(self) -> List[type[ChatMessage]]:\n",
 | 
			
		||||
    "        return [TextMessage, StopMessage]\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n",
 | 
			
		||||
    "    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
 | 
			
		||||
    "        user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
 | 
			
		||||
    "        if \"TERMINATE\" in user_input:\n",
 | 
			
		||||
    "            return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n",
 | 
			
		||||
    "        return TextMessage(content=user_input, source=self.name)"
 | 
			
		||||
    "            return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
 | 
			
		||||
    "        return Response(chat_message=TextMessage(content=user_input, source=self.name))"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user