mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-26 14:38:50 +00:00
Use agentchat message types rather than core's model client message types (#662)
* Use agentchat message types rather than core's model client message types * Merge remote-tracking branch 'origin/main' into ekzhu-tool-use-assistant
This commit is contained in:
parent
43c85d68e0
commit
18efc2314a
@ -1,10 +1,23 @@
|
||||
from ._base_chat_agent import BaseChatAgent, ChatMessage
|
||||
from .coding._code_executor_agent import CodeExecutorAgent
|
||||
from .coding._coding_assistant_agent import CodingAssistantAgent
|
||||
from ._base_chat_agent import (
|
||||
BaseChatAgent,
|
||||
ChatMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
from ._code_executor_agent import CodeExecutorAgent
|
||||
from ._coding_assistant_agent import CodingAssistantAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseChatAgent",
|
||||
"ChatMessage",
|
||||
"TextMessage",
|
||||
"MultiModalMessage",
|
||||
"ToolCallMessage",
|
||||
"ToolCallResultMessage",
|
||||
"StopMessage",
|
||||
"CodeExecutorAgent",
|
||||
"CodingAssistantAgent",
|
||||
]
|
||||
|
||||
@ -1,20 +1,56 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.models import AssistantMessage, UserMessage
|
||||
from autogen_core.components import FunctionCall, Image
|
||||
from autogen_core.components.models import FunctionExecutionResult
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""A chat message from a user or agent."""
|
||||
class BaseMessage(BaseModel):
|
||||
"""A base message."""
|
||||
|
||||
content: UserMessage | AssistantMessage
|
||||
source: str
|
||||
"""The name of the agent that sent this message."""
|
||||
|
||||
|
||||
class TextMessage(BaseMessage):
|
||||
"""A text message."""
|
||||
|
||||
content: str
|
||||
"""The content of the message."""
|
||||
|
||||
request_pause: bool
|
||||
"""A flag indicating whether the current conversation session should be
|
||||
paused after processing this message."""
|
||||
|
||||
class MultiModalMessage(BaseMessage):
|
||||
"""A multimodal message."""
|
||||
|
||||
content: List[str | Image]
|
||||
"""The content of the message."""
|
||||
|
||||
|
||||
class ToolCallMessage(BaseMessage):
|
||||
"""A message containing a list of function calls."""
|
||||
|
||||
content: List[FunctionCall]
|
||||
"""The list of function calls."""
|
||||
|
||||
|
||||
class ToolCallResultMessage(BaseMessage):
|
||||
"""A message containing the results of function calls."""
|
||||
|
||||
content: List[FunctionExecutionResult]
|
||||
"""The list of function execution results."""
|
||||
|
||||
|
||||
class StopMessage(BaseMessage):
|
||||
"""A message requesting stop of a conversation."""
|
||||
|
||||
content: str
|
||||
"""The content for the stop message."""
|
||||
|
||||
|
||||
ChatMessage = TextMessage | MultiModalMessage | ToolCallMessage | ToolCallResultMessage | StopMessage
|
||||
"""A message used by agents in a team."""
|
||||
|
||||
|
||||
class BaseChatAgent(ABC):
|
||||
|
||||
@ -2,9 +2,8 @@ from typing import List, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
|
||||
from autogen_core.components.models import UserMessage
|
||||
|
||||
from .._base_chat_agent import BaseChatAgent, ChatMessage
|
||||
from ._base_chat_agent import BaseChatAgent, ChatMessage, TextMessage
|
||||
|
||||
|
||||
class CodeExecutorAgent(BaseChatAgent):
|
||||
@ -21,14 +20,11 @@ class CodeExecutorAgent(BaseChatAgent):
|
||||
# Extract code blocks from the messages.
|
||||
code_blocks: List[CodeBlock] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, UserMessage) and isinstance(msg.content.content, str):
|
||||
code_blocks.extend(extract_markdown_code_blocks(msg.content.content))
|
||||
if isinstance(msg, TextMessage):
|
||||
code_blocks.extend(extract_markdown_code_blocks(msg.content))
|
||||
if code_blocks:
|
||||
# Execute the code blocks.
|
||||
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
|
||||
return ChatMessage(content=UserMessage(content=result.output, source=self.name), request_pause=False)
|
||||
return TextMessage(content=result.output, source=self.name)
|
||||
else:
|
||||
return ChatMessage(
|
||||
content=UserMessage(content="No code blocks found in the thread.", source=self.name),
|
||||
request_pause=False,
|
||||
)
|
||||
return TextMessage(content="No code blocks found in the thread.", source=self.name)
|
||||
@ -1,9 +1,15 @@
|
||||
from typing import List, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.models import AssistantMessage, ChatCompletionClient, SystemMessage, UserMessage
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from .._base_chat_agent import BaseChatAgent, ChatMessage
|
||||
from ._base_chat_agent import BaseChatAgent, ChatMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
|
||||
|
||||
class CodingAssistantAgent(BaseChatAgent):
|
||||
@ -27,22 +33,26 @@ Reply "TERMINATE" in the end when everything is done."""
|
||||
super().__init__(name=name, description=self.DESCRIPTION)
|
||||
self._model_client = model_client
|
||||
self._system_messages = [SystemMessage(content=self.SYSTEM_MESSAGE)]
|
||||
self._message_thread: List[UserMessage | AssistantMessage] = []
|
||||
self._model_context: List[LLMMessage] = []
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
# Add messages to the thread.
|
||||
# Add messages to the model context and detect stopping.
|
||||
for msg in messages:
|
||||
self._message_thread.append(msg.content)
|
||||
if not isinstance(msg, TextMessage | MultiModalMessage | StopMessage):
|
||||
raise ValueError(f"Unsupported message type: {type(msg)}")
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
|
||||
# Generate an inference result based on the thread.
|
||||
llm_messages = self._system_messages + self._message_thread
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
# Add the response to the thread.
|
||||
self._message_thread.append(AssistantMessage(content=result.content, source=self.name))
|
||||
# Add the response to the model context.
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
|
||||
# Detect pause request.
|
||||
request_pause = "terminate" in result.content.strip().lower()
|
||||
# Detect stop request.
|
||||
request_stop = "terminate" in result.content.strip().lower()
|
||||
if request_stop:
|
||||
return StopMessage(content=result.content, source=self.name)
|
||||
|
||||
return ChatMessage(content=UserMessage(content=result.content, source=self.name), request_pause=request_pause)
|
||||
return TextMessage(content=result.content, source=self.name)
|
||||
@ -3,8 +3,8 @@ from typing import List
|
||||
from autogen_core.base import MessageContext
|
||||
from autogen_core.components import DefaultTopicId, RoutedAgent, event
|
||||
|
||||
from ...agents import BaseChatAgent, ChatMessage
|
||||
from ._messages import ContentPublishEvent, ContentRequestEvent
|
||||
from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage
|
||||
from ._events import ContentPublishEvent, ContentRequestEvent
|
||||
|
||||
|
||||
class BaseChatAgentContainer(RoutedAgent):
|
||||
@ -21,20 +21,26 @@ class BaseChatAgentContainer(RoutedAgent):
|
||||
super().__init__(description=agent.description)
|
||||
self._parent_topic_type = parent_topic_type
|
||||
self._agent = agent
|
||||
self._message_buffer: List[ChatMessage] = []
|
||||
self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = []
|
||||
|
||||
@event
|
||||
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
|
||||
"""Handle a content publish event by appending the content to the buffer."""
|
||||
self._message_buffer.append(ChatMessage(content=message.content, request_pause=message.request_pause))
|
||||
if not isinstance(message.agent_message, TextMessage | MultiModalMessage | StopMessage):
|
||||
raise ValueError(
|
||||
f"Unexpected message type: {type(message.agent_message)}. "
|
||||
"The message must be a text, multimodal, or stop message."
|
||||
)
|
||||
self._message_buffer.append(message.agent_message)
|
||||
|
||||
@event
|
||||
async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None:
|
||||
"""Handle a content request event by passing the messages in the buffer
|
||||
to the delegate agent and publish the response."""
|
||||
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
|
||||
# TODO: handle tool call messages.
|
||||
assert isinstance(response, TextMessage | MultiModalMessage | StopMessage)
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
ContentPublishEvent(content=response.content, request_pause=response.request_pause),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
ContentPublishEvent(agent_message=response), topic_id=DefaultTopicId(type=self._parent_topic_type)
|
||||
)
|
||||
|
||||
@ -3,9 +3,9 @@ from typing import List
|
||||
|
||||
from autogen_core.base import MessageContext, TopicId
|
||||
from autogen_core.components import RoutedAgent, event
|
||||
from autogen_core.components.models import AssistantMessage, UserMessage
|
||||
|
||||
from ._messages import ContentPublishEvent, ContentRequestEvent
|
||||
from ...agents import MultiModalMessage, StopMessage, TextMessage
|
||||
from ._events import ContentPublishEvent, ContentRequestEvent
|
||||
|
||||
|
||||
class BaseGroupChatManager(RoutedAgent):
|
||||
@ -47,7 +47,7 @@ class BaseGroupChatManager(RoutedAgent):
|
||||
raise ValueError("The group topic type must not be the same as the parent topic type.")
|
||||
self._participant_topic_types = participant_topic_types
|
||||
self._participant_descriptions = participant_descriptions
|
||||
self._message_thread: List[UserMessage | AssistantMessage] = []
|
||||
self._message_thread: List[TextMessage | MultiModalMessage | StopMessage] = []
|
||||
|
||||
@event
|
||||
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
|
||||
@ -61,23 +61,27 @@ class BaseGroupChatManager(RoutedAgent):
|
||||
group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source)
|
||||
|
||||
# TODO: use something else other than print.
|
||||
assert isinstance(message.content, UserMessage) or isinstance(message.content, AssistantMessage)
|
||||
sys.stdout.write(f"{'-'*80}\n{message.content.source}:\n{message.content.content}\n")
|
||||
sys.stdout.write(f"{'-'*80}\n{message.agent_message.source}:\n{message.agent_message.content}\n")
|
||||
|
||||
# Process event from parent.
|
||||
if ctx.topic_id.type == self._parent_topic_type:
|
||||
self._message_thread.append(message.content)
|
||||
self._message_thread.append(message.agent_message)
|
||||
await self.publish_message(message, topic_id=group_chat_topic_id)
|
||||
return
|
||||
|
||||
# Process event from the group chat this agent manages.
|
||||
assert ctx.topic_id.type == self._group_topic_type
|
||||
self._message_thread.append(message.content)
|
||||
self._message_thread.append(message.agent_message)
|
||||
|
||||
if message.request_pause:
|
||||
# If the message is a stop message, publish the last message as a TextMessage to the parent topic.
|
||||
# TODO: custom handling the final message.
|
||||
if isinstance(message.agent_message, StopMessage):
|
||||
parent_topic_id = TopicId(type=self._parent_topic_type, source=ctx.topic_id.source)
|
||||
await self.publish_message(
|
||||
ContentPublishEvent(content=message.content, request_pause=True), topic_id=parent_topic_id
|
||||
ContentPublishEvent(
|
||||
agent_message=TextMessage(content=message.agent_message.content, source=self.metadata["type"])
|
||||
),
|
||||
topic_id=parent_topic_id,
|
||||
)
|
||||
return
|
||||
|
||||
@ -100,7 +104,7 @@ class BaseGroupChatManager(RoutedAgent):
|
||||
participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source)
|
||||
await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id)
|
||||
|
||||
async def select_speaker(self, thread: List[UserMessage | AssistantMessage]) -> str:
|
||||
async def select_speaker(self, thread: List[TextMessage | MultiModalMessage | StopMessage]) -> str:
|
||||
"""Select a speaker from the participants and return the
|
||||
topic type of the selected speaker."""
|
||||
raise NotImplementedError("Method not implemented")
|
||||
|
||||
@ -0,0 +1,21 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...agents import MultiModalMessage, StopMessage, TextMessage
|
||||
|
||||
|
||||
class ContentPublishEvent(BaseModel):
|
||||
"""An event for sharing some data. Agents receive this event should
|
||||
update their internal state (e.g., append to message history) with the
|
||||
content of the event.
|
||||
"""
|
||||
|
||||
agent_message: TextMessage | MultiModalMessage | StopMessage
|
||||
"""The message published by the agent."""
|
||||
|
||||
|
||||
class ContentRequestEvent(BaseModel):
|
||||
"""An event for requesting to publish a content event.
|
||||
Upon receiving this event, the agent should publish a ContentPublishEvent.
|
||||
"""
|
||||
|
||||
...
|
||||
@ -1,25 +0,0 @@
|
||||
from autogen_core.components.models import AssistantMessage, UserMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ContentPublishEvent(BaseModel):
|
||||
"""An event message for sharing some data. Agents receive this message should
|
||||
update their internal state (e.g., append to message history) with the
|
||||
content of the message.
|
||||
"""
|
||||
|
||||
content: UserMessage | AssistantMessage
|
||||
"""The content of the message."""
|
||||
|
||||
request_pause: bool
|
||||
"""A flag indicating whether the current conversation session should be
|
||||
paused after processing this message."""
|
||||
|
||||
|
||||
class ContentRequestEvent(BaseModel):
|
||||
"""An event message for requesting to publish a content message.
|
||||
Upon receiving this message, the agent should publish a ContentPublishEvent
|
||||
message.
|
||||
"""
|
||||
|
||||
...
|
||||
@ -5,12 +5,11 @@ from typing import Callable, List
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId
|
||||
from autogen_core.components import ClosureAgent, TypeSubscription
|
||||
from autogen_core.components.models import UserMessage
|
||||
|
||||
from ...agents import BaseChatAgent
|
||||
from ...agents import BaseChatAgent, TextMessage
|
||||
from .._base_team import BaseTeam, TeamRunResult
|
||||
from ._base_chat_agent_container import BaseChatAgentContainer
|
||||
from ._messages import ContentPublishEvent, ContentRequestEvent
|
||||
from ._events import ContentPublishEvent, ContentRequestEvent
|
||||
from ._round_robin_group_chat_manager import RoundRobinGroupChatManager
|
||||
|
||||
|
||||
@ -106,7 +105,7 @@ class RoundRobinGroupChat(BaseTeam):
|
||||
team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
|
||||
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
|
||||
await runtime.publish_message(
|
||||
ContentPublishEvent(content=UserMessage(content=task, source="user"), request_pause=False),
|
||||
ContentPublishEvent(agent_message=TextMessage(content=task, source="user")),
|
||||
topic_id=team_topic_id,
|
||||
)
|
||||
await runtime.publish_message(ContentRequestEvent(), topic_id=group_chat_manager_topic_id)
|
||||
@ -121,7 +120,7 @@ class RoundRobinGroupChat(BaseTeam):
|
||||
|
||||
assert (
|
||||
last_message is not None
|
||||
and isinstance(last_message.content, UserMessage)
|
||||
and isinstance(last_message.content.content, str)
|
||||
and isinstance(last_message.agent_message, TextMessage)
|
||||
and isinstance(last_message.agent_message.content, str)
|
||||
)
|
||||
return TeamRunResult(last_message.content.content)
|
||||
return TeamRunResult(last_message.agent_message.content)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from autogen_core.components.models import AssistantMessage, UserMessage
|
||||
|
||||
from ...agents import MultiModalMessage, StopMessage, TextMessage
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
||||
|
||||
@ -23,7 +22,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
)
|
||||
self._next_speaker_index = 0
|
||||
|
||||
async def select_speaker(self, thread: List[UserMessage | AssistantMessage]) -> str:
|
||||
async def select_speaker(self, thread: List[TextMessage | MultiModalMessage | StopMessage]) -> str:
|
||||
"""Select a speaker from the participants in a round-robin fashion."""
|
||||
current_speaker_index = self._next_speaker_index
|
||||
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user