Refine types in agentchat (#4802)

* Refine types in agentchat

* importg

* fix mypy
This commit is contained in:
Eric Zhu 2024-12-23 16:10:46 -08:00 committed by GitHub
parent 2c76ff9fcc
commit 150a54c4f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 429 additions and 409 deletions

View File

@ -11,6 +11,7 @@ from typing import (
List,
Mapping,
Sequence,
Tuple,
)
from autogen_core import CancellationToken, FunctionCall
@ -294,14 +295,14 @@ class AssistantAgent(BaseChatAgent):
self._is_running = False
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
"""The types of messages that the assistant agent produces."""
message_types: List[type[ChatMessage]] = [TextMessage]
if self._handoffs:
message_types.append(HandoffMessage)
if self._tools:
message_types.append(ToolCallSummaryMessage)
return message_types
return tuple(message_types)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):

View File

@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args
from typing import Any, AsyncGenerator, List, Mapping, Sequence, Tuple
from autogen_core import CancellationToken
from ..base import ChatAgent, Response, TaskResult
from ..messages import (
AgentEvent,
BaseChatMessage,
ChatMessage,
TextMessage,
)
@ -36,8 +37,9 @@ class BaseChatAgent(ChatAgent, ABC):
@property
@abstractmethod
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the agent produces."""
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
"""The types of messages that the agent produces in the
:attr:`Response.chat_message` field. They must be :class:`ChatMessage` types."""
...
@abstractmethod
@ -82,7 +84,7 @@ class BaseChatAgent(ChatAgent, ABC):
async def run(
self,
*,
task: str | ChatMessage | List[ChatMessage] | None = None,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
@ -96,18 +98,19 @@ class BaseChatAgent(ChatAgent, ABC):
text_msg = TextMessage(content=task, source="user")
input_messages.append(text_msg)
output_messages.append(text_msg)
elif isinstance(task, list):
for msg in task:
if isinstance(msg, get_args(ChatMessage)[0]):
input_messages.append(msg)
output_messages.append(msg)
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, get_args(ChatMessage)[0]):
elif isinstance(task, BaseChatMessage):
input_messages.append(task)
output_messages.append(task)
else:
raise ValueError(f"Invalid task type: {type(task)}")
if not task:
raise ValueError("Task list cannot be empty.")
# Task is a sequence of messages.
for msg in task:
if isinstance(msg, BaseChatMessage):
input_messages.append(msg)
output_messages.append(msg)
else:
raise ValueError(f"Invalid message type in sequence: {type(msg)}")
response = await self.on_messages(input_messages, cancellation_token)
if response.inner_messages is not None:
output_messages += response.inner_messages
@ -117,7 +120,7 @@ class BaseChatAgent(ChatAgent, ABC):
async def run_stream(
self,
*,
task: str | ChatMessage | List[ChatMessage] | None = None,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
@ -133,20 +136,20 @@ class BaseChatAgent(ChatAgent, ABC):
input_messages.append(text_msg)
output_messages.append(text_msg)
yield text_msg
elif isinstance(task, list):
for msg in task:
if isinstance(msg, get_args(ChatMessage)[0]):
input_messages.append(msg)
output_messages.append(msg)
yield msg
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, get_args(ChatMessage)[0]):
elif isinstance(task, BaseChatMessage):
input_messages.append(task)
output_messages.append(task)
yield task
else:
raise ValueError(f"Invalid task type: {type(task)}")
if not task:
raise ValueError("Task list cannot be empty.")
for msg in task:
if isinstance(msg, BaseChatMessage):
input_messages.append(msg)
output_messages.append(msg)
yield msg
else:
raise ValueError(f"Invalid message type in sequence: {type(msg)}")
async for message in self.on_messages_stream(input_messages, cancellation_token):
if isinstance(message, Response):
yield message.chat_message

View File

@ -1,5 +1,5 @@
import re
from typing import List, Sequence
from typing import List, Sequence, Tuple
from autogen_core import CancellationToken
from autogen_core.code_executor import CodeBlock, CodeExecutor
@ -80,9 +80,9 @@ class CodeExecutorAgent(BaseChatAgent):
self._code_executor = code_executor
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
"""The types of messages that the code executor agent produces."""
return [TextMessage]
return (TextMessage,)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
# Extract code blocks from the messages.

View File

@ -1,4 +1,4 @@
from typing import Any, AsyncGenerator, List, Mapping, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence, Tuple
from autogen_core import CancellationToken
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
@ -9,10 +9,8 @@ from autogen_agentchat.state import SocietyOfMindAgentState
from ..base import TaskResult, Team
from ..messages import (
AgentEvent,
BaseChatMessage,
ChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
TextMessage,
)
from ._base_chat_agent import BaseChatAgent
@ -105,8 +103,8 @@ class SocietyOfMindAgent(BaseChatAgent):
self._response_prompt = response_prompt
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage]
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
return (TextMessage,)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
# Call the stream method and collect the messages.
@ -150,7 +148,7 @@ class SocietyOfMindAgent(BaseChatAgent):
[
UserMessage(content=message.content, source=message.source)
for message in inner_messages
if isinstance(message, TextMessage | MultiModalMessage | StopMessage | HandoffMessage)
if isinstance(message, BaseChatMessage)
]
)
llm_messages.append(SystemMessage(content=self._response_prompt))

View File

@ -1,6 +1,6 @@
import asyncio
from inspect import iscoroutinefunction
from typing import Awaitable, Callable, List, Optional, Sequence, Union, cast
from typing import Awaitable, Callable, Optional, Sequence, Tuple, Union, cast
from aioconsole import ainput # type: ignore
from autogen_core import CancellationToken
@ -122,9 +122,9 @@ class UserProxyAgent(BaseChatAgent):
self._is_async = iscoroutinefunction(self.input_func)
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
"""Message types this agent can produce."""
return [TextMessage, HandoffMessage]
return (TextMessage, HandoffMessage)
def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]:
"""Find the HandoffMessage in the message sequence that addresses this agent."""

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, AsyncGenerator, List, Mapping, Protocol, Sequence, runtime_checkable
from typing import Any, AsyncGenerator, Mapping, Protocol, Sequence, Tuple, runtime_checkable
from autogen_core import CancellationToken
@ -14,8 +14,9 @@ class Response:
chat_message: ChatMessage
"""A chat message produced by the agent as the response."""
inner_messages: List[AgentEvent | ChatMessage] | None = None
"""Inner messages produced by the agent."""
inner_messages: Sequence[AgentEvent | ChatMessage] | None = None
"""Inner messages produced by the agent, they can be :class:`AgentEvent`
or :class:`ChatMessage`."""
@runtime_checkable
@ -36,8 +37,9 @@ class ChatAgent(TaskRunner, Protocol):
...
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the agent produces."""
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
"""The types of messages that the agent produces in the
:attr:`Response.chat_message` field. They must be :class:`ChatMessage` types."""
...
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import AsyncGenerator, List, Protocol, Sequence
from typing import AsyncGenerator, Protocol, Sequence
from autogen_core import CancellationToken
@ -23,11 +23,13 @@ class TaskRunner(Protocol):
async def run(
self,
*,
task: str | ChatMessage | List[ChatMessage] | None = None,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the task and return the result.
The task can be a string, a single message, or a sequence of messages.
The runner is stateful and a subsequent call to this method will continue
from where the previous call left off. If the task is not specified,
the runner will continue with the current task."""
@ -36,12 +38,14 @@ class TaskRunner(Protocol):
def run_stream(
self,
*,
task: str | ChatMessage | List[ChatMessage] | None = None,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
:class:`TaskResult` as the last item in the stream.
The task can be a string, a single message, or a sequence of messages.
The runner is stateful and a subsequent call to this method will continue
from where the previous call left off. If the task is not specified,
the runner will continue with the current task."""

View File

@ -2,7 +2,7 @@ import time
from typing import List, Sequence
from ..base import TerminatedException, TerminationCondition
from ..messages import AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ..messages import AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage
class StopMessageTermination(TerminationCondition):
@ -77,7 +77,7 @@ class TextMentionTermination(TerminationCondition):
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
if isinstance(message, TextMessage | StopMessage) and self._text in message.content:
if isinstance(message.content, str) and self._text in message.content:
self._terminated = True
return StopMessage(content=f"Text '{self._text}' mentioned", source="TextMentionTermination")
elif isinstance(message, MultiModalMessage):

View File

@ -1,9 +1,10 @@
"""
This module defines various message types used for agent-to-agent communication.
Each message type inherits from the BaseMessage class and includes specific fields
relevant to the type of message being sent.
Each message type inherits either from the BaseChatMessage class or BaseAgentEvent
class and includes specific fields relevant to the type of message being sent.
"""
from abc import ABC
from typing import List, Literal
from autogen_core import FunctionCall, Image
@ -12,8 +13,8 @@ from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated, deprecated
class BaseMessage(BaseModel):
"""A base message."""
class BaseMessage(BaseModel, ABC):
"""Base class for all message types."""
source: str
"""The name of the agent that sent this message."""
@ -24,7 +25,19 @@ class BaseMessage(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class TextMessage(BaseMessage):
class BaseChatMessage(BaseMessage, ABC):
"""Base class for chat messages."""
pass
class BaseAgentEvent(BaseMessage, ABC):
"""Base class for agent events."""
pass
class TextMessage(BaseChatMessage):
"""A text message."""
content: str
@ -33,7 +46,7 @@ class TextMessage(BaseMessage):
type: Literal["TextMessage"] = "TextMessage"
class MultiModalMessage(BaseMessage):
class MultiModalMessage(BaseChatMessage):
"""A multimodal message."""
content: List[str | Image]
@ -42,7 +55,7 @@ class MultiModalMessage(BaseMessage):
type: Literal["MultiModalMessage"] = "MultiModalMessage"
class StopMessage(BaseMessage):
class StopMessage(BaseChatMessage):
"""A message requesting stop of a conversation."""
content: str
@ -51,7 +64,7 @@ class StopMessage(BaseMessage):
type: Literal["StopMessage"] = "StopMessage"
class HandoffMessage(BaseMessage):
class HandoffMessage(BaseChatMessage):
"""A message requesting handoff of a conversation to another agent."""
target: str
@ -83,7 +96,7 @@ class ToolCallResultMessage(BaseMessage):
type: Literal["ToolCallResultMessage"] = "ToolCallResultMessage"
class ToolCallRequestEvent(BaseMessage):
class ToolCallRequestEvent(BaseAgentEvent):
"""An event signaling a request to use tools."""
content: List[FunctionCall]
@ -92,7 +105,7 @@ class ToolCallRequestEvent(BaseMessage):
type: Literal["ToolCallRequestEvent"] = "ToolCallRequestEvent"
class ToolCallExecutionEvent(BaseMessage):
class ToolCallExecutionEvent(BaseAgentEvent):
"""An event signaling the execution of tool calls."""
content: List[FunctionExecutionResult]
@ -101,7 +114,7 @@ class ToolCallExecutionEvent(BaseMessage):
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"
class ToolCallSummaryMessage(BaseMessage):
class ToolCallSummaryMessage(BaseChatMessage):
"""A message signaling the summary of tool call results."""
content: str

View File

@ -2,7 +2,7 @@ import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, List, Mapping, get_args
from typing import Any, AsyncGenerator, Callable, List, Mapping, Sequence
from autogen_core import (
AgentId,
@ -19,7 +19,7 @@ from autogen_core._closure_agent import ClosureContext
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentEvent, ChatMessage, TextMessage
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, TextMessage
from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
@ -172,7 +172,7 @@ class BaseGroupChat(Team, ABC):
async def run(
self,
*,
task: str | ChatMessage | List[ChatMessage] | None = None,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the team and return the result. The base implementation uses
@ -180,7 +180,7 @@ class BaseGroupChat(Team, ABC):
Once the team is stopped, the termination condition is reset.
Args:
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
@ -271,7 +271,7 @@ class BaseGroupChat(Team, ABC):
async def run_stream(
self,
*,
task: str | ChatMessage | List[ChatMessage] | None = None,
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
@ -279,7 +279,7 @@ class BaseGroupChat(Team, ABC):
team is stopped, the termination condition is reset.
Args:
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
@ -368,14 +368,16 @@ class BaseGroupChat(Team, ABC):
pass
elif isinstance(task, str):
messages = [TextMessage(content=task, source="user")]
elif isinstance(task, get_args(ChatMessage)[0]):
messages = [task] # type: ignore
elif isinstance(task, list):
elif isinstance(task, BaseChatMessage):
messages = [task]
else:
if not task:
raise ValueError("Task list cannot be empty")
if not all(isinstance(msg, get_args(ChatMessage)[0]) for msg in task):
raise ValueError("All messages in task list must be valid ChatMessage types")
messages = task
raise ValueError("Task list cannot be empty.")
messages = []
for msg in task:
if not isinstance(msg, BaseChatMessage):
raise ValueError("All messages in task list must be valid ChatMessage types")
messages.append(msg)
if self._is_running:
raise ValueError("The team is already running, it cannot run again until it is stopped.")

View File

@ -8,14 +8,9 @@ from ... import TRACE_LOGGER_NAME
from ...base import ChatAgent, TerminationCondition
from ...messages import (
AgentEvent,
BaseAgentEvent,
ChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
@ -96,12 +91,12 @@ class SelectorGroupChatManager(BaseGroupChatManager):
# Construct the history of the conversation.
history_messages: List[str] = []
for msg in thread:
if isinstance(msg, ToolCallRequestEvent | ToolCallExecutionEvent):
# Ignore tool call messages.
if isinstance(msg, BaseAgentEvent):
# Ignore agent events.
continue
# The agent type must be the same as the topic type, which we use as the agent name.
message = f"{msg.source}:"
if isinstance(msg, TextMessage | StopMessage | HandoffMessage | ToolCallSummaryMessage):
if isinstance(msg.content, str):
message += f" {msg.content}"
elif isinstance(msg, MultiModalMessage):
for item in msg.content:

View File

@ -2,7 +2,7 @@ import asyncio
import json
import logging
import tempfile
from typing import Any, AsyncGenerator, List, Sequence
from typing import Any, AsyncGenerator, List, Sequence, Tuple
import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
@ -75,8 +75,8 @@ class _EchoAgent(BaseChatAgent):
self._total_messages = 0
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage]
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
return (TextMessage,)
@property
def total_messages(self) -> int:
@ -104,8 +104,8 @@ class _StopAgent(_EchoAgent):
self._stop_at = stop_at
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage, StopMessage]
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
return (TextMessage, StopMessage)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
self._count += 1
@ -797,8 +797,8 @@ class _HandOffAgent(BaseChatAgent):
self._next_agent = next_agent
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [HandoffMessage]
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
return (HandoffMessage,)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
return Response(

View File

@ -1,7 +1,7 @@
import asyncio
import json
import logging
from typing import List, Sequence
from typing import Sequence, Tuple
import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
@ -33,8 +33,8 @@ class _EchoAgent(BaseChatAgent):
self._total_messages = 0
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage]
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
return (TextMessage,)
@property
def total_messages(self) -> int:

View File

@ -1,313 +1,313 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Custom Agents\n",
"\n",
"You may have agents with behaviors that do not fall into a preset. \n",
"In such cases, you can build custom agents.\n",
"\n",
"All agents in AgentChat inherit from {py:class}`~autogen_agentchat.agents.BaseChatAgent` \n",
"class and implement the following abstract methods and attributes:\n",
"\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: The abstract method that defines the behavior of the agent in response to messages. This method is called when the agent is asked to provide a response in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run`. It returns a {py:class}`~autogen_agentchat.base.Response` object.\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: The abstract method that resets the agent to its initial state. This method is called when the agent is asked to reset itself.\n",
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.ChatMessage` message types the agent can produce in its response.\n",
"\n",
"Optionally, you can implement the the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` method to stream messages as they are generated by the agent. If this method is not implemented, the agent\n",
"uses the default implementation of {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`\n",
"that calls the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` method and\n",
"yields all messages in the response."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CountDownAgent\n",
"\n",
"In this example, we create a simple agent that counts down from a given number to zero,\n",
"and produces a stream of messages with the current count."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3...\n",
"2...\n",
"1...\n",
"Done!\n"
]
}
],
"source": [
"from typing import AsyncGenerator, List, Sequence\n",
"\n",
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage\n",
"from autogen_core import CancellationToken\n",
"\n",
"\n",
"class CountDownAgent(BaseChatAgent):\n",
" def __init__(self, name: str, count: int = 3):\n",
" super().__init__(name, \"A simple agent that counts down.\")\n",
" self._count = count\n",
"\n",
" @property\n",
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
" return [TextMessage]\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" # Calls the on_messages_stream.\n",
" response: Response | None = None\n",
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
" if isinstance(message, Response):\n",
" response = message\n",
" assert response is not None\n",
" return response\n",
"\n",
" async def on_messages_stream(\n",
" self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
" ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
" inner_messages: List[AgentEvent | ChatMessage] = []\n",
" for i in range(self._count, 0, -1):\n",
" msg = TextMessage(content=f\"{i}...\", source=self.name)\n",
" inner_messages.append(msg)\n",
" yield msg\n",
" # The response is returned at the end of the stream.\n",
" # It contains the final message and all the inner messages.\n",
" yield Response(chat_message=TextMessage(content=\"Done!\", source=self.name), inner_messages=inner_messages)\n",
"\n",
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
" pass\n",
"\n",
"\n",
"async def run_countdown_agent() -> None:\n",
" # Create a countdown agent.\n",
" countdown_agent = CountDownAgent(\"countdown\")\n",
"\n",
" # Run the agent with a given task and stream the response.\n",
" async for message in countdown_agent.on_messages_stream([], CancellationToken()):\n",
" if isinstance(message, Response):\n",
" print(message.chat_message.content)\n",
" else:\n",
" print(message.content)\n",
"\n",
"\n",
"# Use asyncio.run(run_countdown_agent()) when running in a script.\n",
"await run_countdown_agent()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ArithmeticAgent\n",
"\n",
"In this example, we create an agent class that can perform simple arithmetic operations\n",
"on a given integer. Then, we will use different instances of this agent class\n",
"in a {py:class}`~autogen_agentchat.teams.SelectorGroupChat`\n",
"to transform a given integer into another integer by applying a sequence of arithmetic operations.\n",
"\n",
"The `ArithmeticAgent` class takes an `operator_func` that takes an integer and returns an integer,\n",
"after applying an arithmetic operation to the integer.\n",
"In its `on_messages` method, it applies the `operator_func` to the integer in the input message,\n",
"and returns a response with the result."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable, List, Sequence\n",
"\n",
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.conditions import MaxMessageTermination\n",
"from autogen_agentchat.messages import ChatMessage\n",
"from autogen_agentchat.teams import SelectorGroupChat\n",
"from autogen_agentchat.ui import Console\n",
"from autogen_core import CancellationToken\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
"\n",
"\n",
"class ArithmeticAgent(BaseChatAgent):\n",
" def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:\n",
" super().__init__(name, description=description)\n",
" self._operator_func = operator_func\n",
" self._message_history: List[ChatMessage] = []\n",
"\n",
" @property\n",
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
" return [TextMessage]\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" # Update the message history.\n",
" # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.\n",
" self._message_history.extend(messages)\n",
" # Parse the number in the last message.\n",
" assert isinstance(self._message_history[-1], TextMessage)\n",
" number = int(self._message_history[-1].content)\n",
" # Apply the operator function to the number.\n",
" result = self._operator_func(number)\n",
" # Create a new message with the result.\n",
" response_message = TextMessage(content=str(result), source=self.name)\n",
" # Update the message history.\n",
" self._message_history.append(response_message)\n",
" # Return the response.\n",
" return Response(chat_message=response_message)\n",
"\n",
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
" pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{note}\n",
"The `on_messages` method may be called with an empty list of messages, in which\n",
"case it means the agent was called previously and is now being called again,\n",
"without any new messages from the caller. So it is important to keep a history\n",
"of the previous messages received by the agent, and use that history to generate\n",
"the response.\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with 5 instances of `ArithmeticAgent`:\n",
"\n",
"- one that adds 1 to the input integer,\n",
"- one that subtracts 1 from the input integer,\n",
"- one that multiplies the input integer by 2,\n",
"- one that divides the input integer by 2 and rounds down to the nearest integer, and\n",
"- one that returns the input integer unchanged.\n",
"\n",
"We then create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with these agents,\n",
"and set the appropriate selector settings:\n",
"\n",
"- allow the same agent to be selected consecutively to allow for repeated operations, and\n",
"- customize the selector prompt to tailor the model's response to the specific task."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- user ----------\n",
"Apply the operations to turn the given number into 25.\n",
"---------- user ----------\n",
"10\n",
"---------- multiply_agent ----------\n",
"20\n",
"---------- add_agent ----------\n",
"21\n",
"---------- multiply_agent ----------\n",
"42\n",
"---------- divide_agent ----------\n",
"21\n",
"---------- add_agent ----------\n",
"22\n",
"---------- add_agent ----------\n",
"23\n",
"---------- add_agent ----------\n",
"24\n",
"---------- add_agent ----------\n",
"25\n",
"---------- Summary ----------\n",
"Number of messages: 10\n",
"Finish reason: Maximum number of messages 10 reached, current message count: 10\n",
"Total prompt tokens: 0\n",
"Total completion tokens: 0\n",
"Duration: 2.40 seconds\n"
]
}
],
"source": [
"async def run_number_agents() -> None:\n",
" # Create agents for number operations.\n",
" add_agent = ArithmeticAgent(\"add_agent\", \"Adds 1 to the number.\", lambda x: x + 1)\n",
" multiply_agent = ArithmeticAgent(\"multiply_agent\", \"Multiplies the number by 2.\", lambda x: x * 2)\n",
" subtract_agent = ArithmeticAgent(\"subtract_agent\", \"Subtracts 1 from the number.\", lambda x: x - 1)\n",
" divide_agent = ArithmeticAgent(\"divide_agent\", \"Divides the number by 2 and rounds down.\", lambda x: x // 2)\n",
" identity_agent = ArithmeticAgent(\"identity_agent\", \"Returns the number as is.\", lambda x: x)\n",
"\n",
" # The termination condition is to stop after 10 messages.\n",
" termination_condition = MaxMessageTermination(10)\n",
"\n",
" # Create a selector group chat.\n",
" selector_group_chat = SelectorGroupChat(\n",
" [add_agent, multiply_agent, subtract_agent, divide_agent, identity_agent],\n",
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o\"),\n",
" termination_condition=termination_condition,\n",
" allow_repeated_speaker=True, # Allow the same agent to speak multiple times, necessary for this task.\n",
" selector_prompt=(\n",
" \"Available roles:\\n{roles}\\nTheir job descriptions:\\n{participants}\\n\"\n",
" \"Current conversation history:\\n{history}\\n\"\n",
" \"Please select the most appropriate role for the next message, and only return the role name.\"\n",
" ),\n",
" )\n",
"\n",
" # Run the selector group chat with a given task and stream the response.\n",
" task: List[ChatMessage] = [\n",
" TextMessage(content=\"Apply the operations to turn the given number into 25.\", source=\"user\"),\n",
" TextMessage(content=\"10\", source=\"user\"),\n",
" ]\n",
" stream = selector_group_chat.run_stream(task=task)\n",
" await Console(stream)\n",
"\n",
"\n",
"# Use asyncio.run(run_number_agents()) when running in a script.\n",
"await run_number_agents()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the output, we can see that the agents have successfully transformed the input integer\n",
"from 10 to 25 by choosing appropriate agents that apply the arithmetic operations in sequence."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Custom Agents\n",
"\n",
"You may have agents with behaviors that do not fall into a preset. \n",
"In such cases, you can build custom agents.\n",
"\n",
"All agents in AgentChat inherit from {py:class}`~autogen_agentchat.agents.BaseChatAgent` \n",
"class and implement the following abstract methods and attributes:\n",
"\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: The abstract method that defines the behavior of the agent in response to messages. This method is called when the agent is asked to provide a response in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run`. It returns a {py:class}`~autogen_agentchat.base.Response` object.\n",
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: The abstract method that resets the agent to its initial state. This method is called when the agent is asked to reset itself.\n",
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.ChatMessage` message types the agent can produce in its response.\n",
"\n",
"Optionally, you can implement the the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` method to stream messages as they are generated by the agent. If this method is not implemented, the agent\n",
"uses the default implementation of {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`\n",
"that calls the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` method and\n",
"yields all messages in the response."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CountDownAgent\n",
"\n",
"In this example, we create a simple agent that counts down from a given number to zero,\n",
"and produces a stream of messages with the current count."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3...\n",
"2...\n",
"1...\n",
"Done!\n"
]
}
],
"source": [
"from typing import AsyncGenerator, List, Sequence, Tuple\n",
"\n",
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage\n",
"from autogen_core import CancellationToken\n",
"\n",
"\n",
"class CountDownAgent(BaseChatAgent):\n",
" def __init__(self, name: str, count: int = 3):\n",
" super().__init__(name, \"A simple agent that counts down.\")\n",
" self._count = count\n",
"\n",
" @property\n",
" def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:\n",
" return (TextMessage,)\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" # Calls the on_messages_stream.\n",
" response: Response | None = None\n",
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
" if isinstance(message, Response):\n",
" response = message\n",
" assert response is not None\n",
" return response\n",
"\n",
" async def on_messages_stream(\n",
" self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
" ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
" inner_messages: List[AgentEvent | ChatMessage] = []\n",
" for i in range(self._count, 0, -1):\n",
" msg = TextMessage(content=f\"{i}...\", source=self.name)\n",
" inner_messages.append(msg)\n",
" yield msg\n",
" # The response is returned at the end of the stream.\n",
" # It contains the final message and all the inner messages.\n",
" yield Response(chat_message=TextMessage(content=\"Done!\", source=self.name), inner_messages=inner_messages)\n",
"\n",
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
" pass\n",
"\n",
"\n",
"async def run_countdown_agent() -> None:\n",
" # Create a countdown agent.\n",
" countdown_agent = CountDownAgent(\"countdown\")\n",
"\n",
" # Run the agent with a given task and stream the response.\n",
" async for message in countdown_agent.on_messages_stream([], CancellationToken()):\n",
" if isinstance(message, Response):\n",
" print(message.chat_message.content)\n",
" else:\n",
" print(message.content)\n",
"\n",
"\n",
"# Use asyncio.run(run_countdown_agent()) when running in a script.\n",
"await run_countdown_agent()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ArithmeticAgent\n",
"\n",
"In this example, we create an agent class that can perform simple arithmetic operations\n",
"on a given integer. Then, we will use different instances of this agent class\n",
"in a {py:class}`~autogen_agentchat.teams.SelectorGroupChat`\n",
"to transform a given integer into another integer by applying a sequence of arithmetic operations.\n",
"\n",
"The `ArithmeticAgent` class takes an `operator_func` that takes an integer and returns an integer,\n",
"after applying an arithmetic operation to the integer.\n",
"In its `on_messages` method, it applies the `operator_func` to the integer in the input message,\n",
"and returns a response with the result."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable, Sequence, Tuple\n",
"\n",
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.conditions import MaxMessageTermination\n",
"from autogen_agentchat.messages import ChatMessage\n",
"from autogen_agentchat.teams import SelectorGroupChat\n",
"from autogen_agentchat.ui import Console\n",
"from autogen_core import CancellationToken\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
"\n",
"\n",
"class ArithmeticAgent(BaseChatAgent):\n",
" def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:\n",
" super().__init__(name, description=description)\n",
" self._operator_func = operator_func\n",
" self._message_history: List[ChatMessage] = []\n",
"\n",
" @property\n",
" def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:\n",
" return (TextMessage,)\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" # Update the message history.\n",
" # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.\n",
" self._message_history.extend(messages)\n",
" # Parse the number in the last message.\n",
" assert isinstance(self._message_history[-1], TextMessage)\n",
" number = int(self._message_history[-1].content)\n",
" # Apply the operator function to the number.\n",
" result = self._operator_func(number)\n",
" # Create a new message with the result.\n",
" response_message = TextMessage(content=str(result), source=self.name)\n",
" # Update the message history.\n",
" self._message_history.append(response_message)\n",
" # Return the response.\n",
" return Response(chat_message=response_message)\n",
"\n",
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
" pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{note}\n",
"The `on_messages` method may be called with an empty list of messages, in which\n",
"case it means the agent was called previously and is now being called again,\n",
"without any new messages from the caller. So it is important to keep a history\n",
"of the previous messages received by the agent, and use that history to generate\n",
"the response.\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with 5 instances of `ArithmeticAgent`:\n",
"\n",
"- one that adds 1 to the input integer,\n",
"- one that subtracts 1 from the input integer,\n",
"- one that multiplies the input integer by 2,\n",
"- one that divides the input integer by 2 and rounds down to the nearest integer, and\n",
"- one that returns the input integer unchanged.\n",
"\n",
"We then create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with these agents,\n",
"and set the appropriate selector settings:\n",
"\n",
"- allow the same agent to be selected consecutively to allow for repeated operations, and\n",
"- customize the selector prompt to tailor the model's response to the specific task."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- user ----------\n",
"Apply the operations to turn the given number into 25.\n",
"---------- user ----------\n",
"10\n",
"---------- multiply_agent ----------\n",
"20\n",
"---------- add_agent ----------\n",
"21\n",
"---------- multiply_agent ----------\n",
"42\n",
"---------- divide_agent ----------\n",
"21\n",
"---------- add_agent ----------\n",
"22\n",
"---------- add_agent ----------\n",
"23\n",
"---------- add_agent ----------\n",
"24\n",
"---------- add_agent ----------\n",
"25\n",
"---------- Summary ----------\n",
"Number of messages: 10\n",
"Finish reason: Maximum number of messages 10 reached, current message count: 10\n",
"Total prompt tokens: 0\n",
"Total completion tokens: 0\n",
"Duration: 2.40 seconds\n"
]
}
],
"source": [
"async def run_number_agents() -> None:\n",
" # Create agents for number operations.\n",
" add_agent = ArithmeticAgent(\"add_agent\", \"Adds 1 to the number.\", lambda x: x + 1)\n",
" multiply_agent = ArithmeticAgent(\"multiply_agent\", \"Multiplies the number by 2.\", lambda x: x * 2)\n",
" subtract_agent = ArithmeticAgent(\"subtract_agent\", \"Subtracts 1 from the number.\", lambda x: x - 1)\n",
" divide_agent = ArithmeticAgent(\"divide_agent\", \"Divides the number by 2 and rounds down.\", lambda x: x // 2)\n",
" identity_agent = ArithmeticAgent(\"identity_agent\", \"Returns the number as is.\", lambda x: x)\n",
"\n",
" # The termination condition is to stop after 10 messages.\n",
" termination_condition = MaxMessageTermination(10)\n",
"\n",
" # Create a selector group chat.\n",
" selector_group_chat = SelectorGroupChat(\n",
" [add_agent, multiply_agent, subtract_agent, divide_agent, identity_agent],\n",
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o\"),\n",
" termination_condition=termination_condition,\n",
" allow_repeated_speaker=True, # Allow the same agent to speak multiple times, necessary for this task.\n",
" selector_prompt=(\n",
" \"Available roles:\\n{roles}\\nTheir job descriptions:\\n{participants}\\n\"\n",
" \"Current conversation history:\\n{history}\\n\"\n",
" \"Please select the most appropriate role for the next message, and only return the role name.\"\n",
" ),\n",
" )\n",
"\n",
" # Run the selector group chat with a given task and stream the response.\n",
" task: List[ChatMessage] = [\n",
" TextMessage(content=\"Apply the operations to turn the given number into 25.\", source=\"user\"),\n",
" TextMessage(content=\"10\", source=\"user\"),\n",
" ]\n",
" stream = selector_group_chat.run_stream(task=task)\n",
" await Console(stream)\n",
"\n",
"\n",
"# Use asyncio.run(run_number_agents()) when running in a script.\n",
"await run_number_agents()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the output, we can see that the agents have successfully transformed the input integer\n",
"from 10 to 25 by choosing appropriate agents that apply the arithmetic operations in sequence."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -63,8 +63,8 @@ class FileSurfer(BaseChatAgent):
self._browser = MarkdownFileBrowser(viewport_size=1024 * 5)
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage]
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
return (TextMessage,)
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
for chat_message in messages:

View File

@ -15,6 +15,7 @@ from typing import (
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
@ -298,9 +299,9 @@ class OpenAIAssistantAgent(BaseChatAgent):
self._initial_message_ids = initial_message_ids
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
"""The types of messages that the assistant agent produces."""
return [TextMessage]
return (TextMessage,)
@property
def threads(self) -> AsyncThreads:

View File

@ -15,6 +15,7 @@ from typing import (
List,
Optional,
Sequence,
Tuple,
cast,
)
from urllib.parse import quote_plus
@ -321,8 +322,8 @@ class MultimodalWebSurfer(BaseChatAgent):
)
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [MultiModalMessage]
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
return (MultiModalMessage,)
async def on_reset(self, cancellation_token: CancellationToken) -> None:
if not self.did_lazy_init: