Use agentchat message types rather than core's model client message types (#662)

* Use agentchat message types rather than core's model client message types

* Merge remote-tracking branch 'origin/main' into ekzhu-tool-use-assistant
This commit is contained in:
Eric Zhu 2024-09-28 08:40:13 -07:00 committed by GitHub
parent 43c85d68e0
commit 18efc2314a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 142 additions and 83 deletions

View File

@ -1,10 +1,23 @@
from ._base_chat_agent import BaseChatAgent, ChatMessage
from .coding._code_executor_agent import CodeExecutorAgent
from .coding._coding_assistant_agent import CodingAssistantAgent
from ._base_chat_agent import (
BaseChatAgent,
ChatMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
)
from ._code_executor_agent import CodeExecutorAgent
from ._coding_assistant_agent import CodingAssistantAgent
__all__ = [
"BaseChatAgent",
"ChatMessage",
"TextMessage",
"MultiModalMessage",
"ToolCallMessage",
"ToolCallResultMessage",
"StopMessage",
"CodeExecutorAgent",
"CodingAssistantAgent",
]

View File

@ -1,20 +1,56 @@
from abc import ABC, abstractmethod
from typing import Sequence
from typing import List, Sequence
from autogen_core.base import CancellationToken
from autogen_core.components.models import AssistantMessage, UserMessage
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from pydantic import BaseModel
class ChatMessage(BaseModel):
"""A chat message from a user or agent."""
class BaseMessage(BaseModel):
"""A base message."""
content: UserMessage | AssistantMessage
source: str
"""The name of the agent that sent this message."""
class TextMessage(BaseMessage):
"""A text message."""
content: str
"""The content of the message."""
request_pause: bool
"""A flag indicating whether the current conversation session should be
paused after processing this message."""
class MultiModalMessage(BaseMessage):
"""A multimodal message."""
content: List[str | Image]
"""The content of the message."""
class ToolCallMessage(BaseMessage):
"""A message containing a list of function calls."""
content: List[FunctionCall]
"""The list of function calls."""
class ToolCallResultMessage(BaseMessage):
"""A message containing the results of function calls."""
content: List[FunctionExecutionResult]
"""The list of function execution results."""
class StopMessage(BaseMessage):
"""A message requesting stop of a conversation."""
content: str
"""The content for the stop message."""
ChatMessage = TextMessage | MultiModalMessage | ToolCallMessage | ToolCallResultMessage | StopMessage
"""A message used by agents in a team."""
class BaseChatAgent(ABC):

View File

@ -2,9 +2,8 @@ from typing import List, Sequence
from autogen_core.base import CancellationToken
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
from autogen_core.components.models import UserMessage
from .._base_chat_agent import BaseChatAgent, ChatMessage
from ._base_chat_agent import BaseChatAgent, ChatMessage, TextMessage
class CodeExecutorAgent(BaseChatAgent):
@ -21,14 +20,11 @@ class CodeExecutorAgent(BaseChatAgent):
# Extract code blocks from the messages.
code_blocks: List[CodeBlock] = []
for msg in messages:
if isinstance(msg.content, UserMessage) and isinstance(msg.content.content, str):
code_blocks.extend(extract_markdown_code_blocks(msg.content.content))
if isinstance(msg, TextMessage):
code_blocks.extend(extract_markdown_code_blocks(msg.content))
if code_blocks:
# Execute the code blocks.
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
return ChatMessage(content=UserMessage(content=result.output, source=self.name), request_pause=False)
return TextMessage(content=result.output, source=self.name)
else:
return ChatMessage(
content=UserMessage(content="No code blocks found in the thread.", source=self.name),
request_pause=False,
)
return TextMessage(content="No code blocks found in the thread.", source=self.name)

View File

@ -1,9 +1,15 @@
from typing import List, Sequence
from autogen_core.base import CancellationToken
from autogen_core.components.models import AssistantMessage, ChatCompletionClient, SystemMessage, UserMessage
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
SystemMessage,
UserMessage,
)
from .._base_chat_agent import BaseChatAgent, ChatMessage
from ._base_chat_agent import BaseChatAgent, ChatMessage, MultiModalMessage, StopMessage, TextMessage
class CodingAssistantAgent(BaseChatAgent):
@ -27,22 +33,26 @@ Reply "TERMINATE" in the end when everything is done."""
super().__init__(name=name, description=self.DESCRIPTION)
self._model_client = model_client
self._system_messages = [SystemMessage(content=self.SYSTEM_MESSAGE)]
self._message_thread: List[UserMessage | AssistantMessage] = []
self._model_context: List[LLMMessage] = []
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
# Add messages to the thread.
# Add messages to the model context and detect stopping.
for msg in messages:
self._message_thread.append(msg.content)
if not isinstance(msg, TextMessage | MultiModalMessage | StopMessage):
raise ValueError(f"Unsupported message type: {type(msg)}")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# Generate an inference result based on the thread.
llm_messages = self._system_messages + self._message_thread
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
# Add the response to the thread.
self._message_thread.append(AssistantMessage(content=result.content, source=self.name))
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
# Detect pause request.
request_pause = "terminate" in result.content.strip().lower()
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
return StopMessage(content=result.content, source=self.name)
return ChatMessage(content=UserMessage(content=result.content, source=self.name), request_pause=request_pause)
return TextMessage(content=result.content, source=self.name)

View File

@ -3,8 +3,8 @@ from typing import List
from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, RoutedAgent, event
from ...agents import BaseChatAgent, ChatMessage
from ._messages import ContentPublishEvent, ContentRequestEvent
from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage
from ._events import ContentPublishEvent, ContentRequestEvent
class BaseChatAgentContainer(RoutedAgent):
@ -21,20 +21,26 @@ class BaseChatAgentContainer(RoutedAgent):
super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type
self._agent = agent
self._message_buffer: List[ChatMessage] = []
self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = []
@event
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
"""Handle a content publish event by appending the content to the buffer."""
self._message_buffer.append(ChatMessage(content=message.content, request_pause=message.request_pause))
if not isinstance(message.agent_message, TextMessage | MultiModalMessage | StopMessage):
raise ValueError(
f"Unexpected message type: {type(message.agent_message)}. "
"The message must be a text, multimodal, or stop message."
)
self._message_buffer.append(message.agent_message)
@event
async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None:
"""Handle a content request event by passing the messages in the buffer
to the delegate agent and publish the response."""
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
# TODO: handle tool call messages.
assert isinstance(response, TextMessage | MultiModalMessage | StopMessage)
self._message_buffer.clear()
await self.publish_message(
ContentPublishEvent(content=response.content, request_pause=response.request_pause),
topic_id=DefaultTopicId(type=self._parent_topic_type),
ContentPublishEvent(agent_message=response), topic_id=DefaultTopicId(type=self._parent_topic_type)
)

View File

@ -3,9 +3,9 @@ from typing import List
from autogen_core.base import MessageContext, TopicId
from autogen_core.components import RoutedAgent, event
from autogen_core.components.models import AssistantMessage, UserMessage
from ._messages import ContentPublishEvent, ContentRequestEvent
from ...agents import MultiModalMessage, StopMessage, TextMessage
from ._events import ContentPublishEvent, ContentRequestEvent
class BaseGroupChatManager(RoutedAgent):
@ -47,7 +47,7 @@ class BaseGroupChatManager(RoutedAgent):
raise ValueError("The group topic type must not be the same as the parent topic type.")
self._participant_topic_types = participant_topic_types
self._participant_descriptions = participant_descriptions
self._message_thread: List[UserMessage | AssistantMessage] = []
self._message_thread: List[TextMessage | MultiModalMessage | StopMessage] = []
@event
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
@ -61,23 +61,27 @@ class BaseGroupChatManager(RoutedAgent):
group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source)
# TODO: use something else other than print.
assert isinstance(message.content, UserMessage) or isinstance(message.content, AssistantMessage)
sys.stdout.write(f"{'-'*80}\n{message.content.source}:\n{message.content.content}\n")
sys.stdout.write(f"{'-'*80}\n{message.agent_message.source}:\n{message.agent_message.content}\n")
# Process event from parent.
if ctx.topic_id.type == self._parent_topic_type:
self._message_thread.append(message.content)
self._message_thread.append(message.agent_message)
await self.publish_message(message, topic_id=group_chat_topic_id)
return
# Process event from the group chat this agent manages.
assert ctx.topic_id.type == self._group_topic_type
self._message_thread.append(message.content)
self._message_thread.append(message.agent_message)
if message.request_pause:
# If the message is a stop message, publish the last message as a TextMessage to the parent topic.
# TODO: custom handling the final message.
if isinstance(message.agent_message, StopMessage):
parent_topic_id = TopicId(type=self._parent_topic_type, source=ctx.topic_id.source)
await self.publish_message(
ContentPublishEvent(content=message.content, request_pause=True), topic_id=parent_topic_id
ContentPublishEvent(
agent_message=TextMessage(content=message.agent_message.content, source=self.metadata["type"])
),
topic_id=parent_topic_id,
)
return
@ -100,7 +104,7 @@ class BaseGroupChatManager(RoutedAgent):
participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source)
await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id)
async def select_speaker(self, thread: List[UserMessage | AssistantMessage]) -> str:
async def select_speaker(self, thread: List[TextMessage | MultiModalMessage | StopMessage]) -> str:
"""Select a speaker from the participants and return the
topic type of the selected speaker."""
raise NotImplementedError("Method not implemented")

View File

@ -0,0 +1,21 @@
from pydantic import BaseModel
from ...agents import MultiModalMessage, StopMessage, TextMessage
class ContentPublishEvent(BaseModel):
"""An event for sharing some data. Agents receive this event should
update their internal state (e.g., append to message history) with the
content of the event.
"""
agent_message: TextMessage | MultiModalMessage | StopMessage
"""The message published by the agent."""
class ContentRequestEvent(BaseModel):
"""An event for requesting to publish a content event.
Upon receiving this event, the agent should publish a ContentPublishEvent.
"""
...

View File

@ -1,25 +0,0 @@
from autogen_core.components.models import AssistantMessage, UserMessage
from pydantic import BaseModel
class ContentPublishEvent(BaseModel):
"""An event message for sharing some data. Agents receive this message should
update their internal state (e.g., append to message history) with the
content of the message.
"""
content: UserMessage | AssistantMessage
"""The content of the message."""
request_pause: bool
"""A flag indicating whether the current conversation session should be
paused after processing this message."""
class ContentRequestEvent(BaseModel):
"""An event message for requesting to publish a content message.
Upon receiving this message, the agent should publish a ContentPublishEvent
message.
"""
...

View File

@ -5,12 +5,11 @@ from typing import Callable, List
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId
from autogen_core.components import ClosureAgent, TypeSubscription
from autogen_core.components.models import UserMessage
from ...agents import BaseChatAgent
from ...agents import BaseChatAgent, TextMessage
from .._base_team import BaseTeam, TeamRunResult
from ._base_chat_agent_container import BaseChatAgentContainer
from ._messages import ContentPublishEvent, ContentRequestEvent
from ._events import ContentPublishEvent, ContentRequestEvent
from ._round_robin_group_chat_manager import RoundRobinGroupChatManager
@ -106,7 +105,7 @@ class RoundRobinGroupChat(BaseTeam):
team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
await runtime.publish_message(
ContentPublishEvent(content=UserMessage(content=task, source="user"), request_pause=False),
ContentPublishEvent(agent_message=TextMessage(content=task, source="user")),
topic_id=team_topic_id,
)
await runtime.publish_message(ContentRequestEvent(), topic_id=group_chat_manager_topic_id)
@ -121,7 +120,7 @@ class RoundRobinGroupChat(BaseTeam):
assert (
last_message is not None
and isinstance(last_message.content, UserMessage)
and isinstance(last_message.content.content, str)
and isinstance(last_message.agent_message, TextMessage)
and isinstance(last_message.agent_message.content, str)
)
return TeamRunResult(last_message.content.content)
return TeamRunResult(last_message.agent_message.content)

View File

@ -1,7 +1,6 @@
from typing import List
from autogen_core.components.models import AssistantMessage, UserMessage
from ...agents import MultiModalMessage, StopMessage, TextMessage
from ._base_group_chat_manager import BaseGroupChatManager
@ -23,7 +22,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
)
self._next_speaker_index = 0
async def select_speaker(self, thread: List[UserMessage | AssistantMessage]) -> str:
async def select_speaker(self, thread: List[TextMessage | MultiModalMessage | StopMessage]) -> str:
"""Select a speaker from the participants in a round-robin fashion."""
current_speaker_index = self._next_speaker_index
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types)