mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 15:09:41 +00:00
Refine types in agentchat (#4802)
* Refine types in agentchat * importg * fix mypy
This commit is contained in:
parent
2c76ff9fcc
commit
150a54c4f5
@ -11,6 +11,7 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from autogen_core import CancellationToken, FunctionCall
|
||||
@ -294,14 +295,14 @@ class AssistantAgent(BaseChatAgent):
|
||||
self._is_running = False
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
"""The types of messages that the assistant agent produces."""
|
||||
message_types: List[type[ChatMessage]] = [TextMessage]
|
||||
if self._handoffs:
|
||||
message_types.append(HandoffMessage)
|
||||
if self._tools:
|
||||
message_types.append(ToolCallSummaryMessage)
|
||||
return message_types
|
||||
return tuple(message_types)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
async for message in self.on_messages_stream(messages, cancellation_token):
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence, Tuple
|
||||
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
from ..base import ChatAgent, Response, TaskResult
|
||||
from ..messages import (
|
||||
AgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
TextMessage,
|
||||
)
|
||||
@ -36,8 +37,9 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
"""The types of messages that the agent produces."""
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
"""The types of messages that the agent produces in the
|
||||
:attr:`Response.chat_message` field. They must be :class:`ChatMessage` types."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
@ -82,7 +84,7 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the agent with the given task and return the result."""
|
||||
@ -96,18 +98,19 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
text_msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(text_msg)
|
||||
output_messages.append(text_msg)
|
||||
elif isinstance(task, list):
|
||||
for msg in task:
|
||||
if isinstance(msg, get_args(ChatMessage)[0]):
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in list: {type(msg)}")
|
||||
elif isinstance(task, get_args(ChatMessage)[0]):
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
input_messages.append(task)
|
||||
output_messages.append(task)
|
||||
else:
|
||||
raise ValueError(f"Invalid task type: {type(task)}")
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
# Task is a sequence of messages.
|
||||
for msg in task:
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in sequence: {type(msg)}")
|
||||
response = await self.on_messages(input_messages, cancellation_token)
|
||||
if response.inner_messages is not None:
|
||||
output_messages += response.inner_messages
|
||||
@ -117,7 +120,7 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
|
||||
"""Run the agent with the given task and return a stream of messages
|
||||
@ -133,20 +136,20 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
input_messages.append(text_msg)
|
||||
output_messages.append(text_msg)
|
||||
yield text_msg
|
||||
elif isinstance(task, list):
|
||||
for msg in task:
|
||||
if isinstance(msg, get_args(ChatMessage)[0]):
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
yield msg
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in list: {type(msg)}")
|
||||
elif isinstance(task, get_args(ChatMessage)[0]):
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
input_messages.append(task)
|
||||
output_messages.append(task)
|
||||
yield task
|
||||
else:
|
||||
raise ValueError(f"Invalid task type: {type(task)}")
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
for msg in task:
|
||||
if isinstance(msg, BaseChatMessage):
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
yield msg
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in sequence: {type(msg)}")
|
||||
async for message in self.on_messages_stream(input_messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
yield message.chat_message
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import List, Sequence
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.code_executor import CodeBlock, CodeExecutor
|
||||
@ -80,9 +80,9 @@ class CodeExecutorAgent(BaseChatAgent):
|
||||
self._code_executor = code_executor
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
"""The types of messages that the code executor agent produces."""
|
||||
return [TextMessage]
|
||||
return (TextMessage,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Extract code blocks from the messages.
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence, Tuple
|
||||
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
|
||||
@ -9,10 +9,8 @@ from autogen_agentchat.state import SocietyOfMindAgentState
|
||||
from ..base import TaskResult, Team
|
||||
from ..messages import (
|
||||
AgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
@ -105,8 +103,8 @@ class SocietyOfMindAgent(BaseChatAgent):
|
||||
self._response_prompt = response_prompt
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [TextMessage]
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
return (TextMessage,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Call the stream method and collect the messages.
|
||||
@ -150,7 +148,7 @@ class SocietyOfMindAgent(BaseChatAgent):
|
||||
[
|
||||
UserMessage(content=message.content, source=message.source)
|
||||
for message in inner_messages
|
||||
if isinstance(message, TextMessage | MultiModalMessage | StopMessage | HandoffMessage)
|
||||
if isinstance(message, BaseChatMessage)
|
||||
]
|
||||
)
|
||||
llm_messages.append(SystemMessage(content=self._response_prompt))
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import Awaitable, Callable, List, Optional, Sequence, Union, cast
|
||||
from typing import Awaitable, Callable, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from aioconsole import ainput # type: ignore
|
||||
from autogen_core import CancellationToken
|
||||
@ -122,9 +122,9 @@ class UserProxyAgent(BaseChatAgent):
|
||||
self._is_async = iscoroutinefunction(self.input_func)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
"""Message types this agent can produce."""
|
||||
return [TextMessage, HandoffMessage]
|
||||
return (TextMessage, HandoffMessage)
|
||||
|
||||
def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]:
|
||||
"""Find the HandoffMessage in the message sequence that addresses this agent."""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Protocol, Sequence, runtime_checkable
|
||||
from typing import Any, AsyncGenerator, Mapping, Protocol, Sequence, Tuple, runtime_checkable
|
||||
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
@ -14,8 +14,9 @@ class Response:
|
||||
chat_message: ChatMessage
|
||||
"""A chat message produced by the agent as the response."""
|
||||
|
||||
inner_messages: List[AgentEvent | ChatMessage] | None = None
|
||||
"""Inner messages produced by the agent."""
|
||||
inner_messages: Sequence[AgentEvent | ChatMessage] | None = None
|
||||
"""Inner messages produced by the agent, they can be :class:`AgentEvent`
|
||||
or :class:`ChatMessage`."""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -36,8 +37,9 @@ class ChatAgent(TaskRunner, Protocol):
|
||||
...
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
"""The types of messages that the agent produces."""
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
"""The types of messages that the agent produces in the
|
||||
:attr:`Response.chat_message` field. They must be :class:`ChatMessage` types."""
|
||||
...
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, List, Protocol, Sequence
|
||||
from typing import AsyncGenerator, Protocol, Sequence
|
||||
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
@ -23,11 +23,13 @@ class TaskRunner(Protocol):
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the task and return the result.
|
||||
|
||||
The task can be a string, a single message, or a sequence of messages.
|
||||
|
||||
The runner is stateful and a subsequent call to this method will continue
|
||||
from where the previous call left off. If the task is not specified,
|
||||
the runner will continue with the current task."""
|
||||
@ -36,12 +38,14 @@ class TaskRunner(Protocol):
|
||||
def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
|
||||
"""Run the task and produces a stream of messages and the final result
|
||||
:class:`TaskResult` as the last item in the stream.
|
||||
|
||||
The task can be a string, a single message, or a sequence of messages.
|
||||
|
||||
The runner is stateful and a subsequent call to this method will continue
|
||||
from where the previous call left off. If the task is not specified,
|
||||
the runner will continue with the current task."""
|
||||
|
||||
@ -2,7 +2,7 @@ import time
|
||||
from typing import List, Sequence
|
||||
|
||||
from ..base import TerminatedException, TerminationCondition
|
||||
from ..messages import AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from ..messages import AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage
|
||||
|
||||
|
||||
class StopMessageTermination(TerminationCondition):
|
||||
@ -77,7 +77,7 @@ class TextMentionTermination(TerminationCondition):
|
||||
if self._terminated:
|
||||
raise TerminatedException("Termination condition has already been reached")
|
||||
for message in messages:
|
||||
if isinstance(message, TextMessage | StopMessage) and self._text in message.content:
|
||||
if isinstance(message.content, str) and self._text in message.content:
|
||||
self._terminated = True
|
||||
return StopMessage(content=f"Text '{self._text}' mentioned", source="TextMentionTermination")
|
||||
elif isinstance(message, MultiModalMessage):
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
"""
|
||||
This module defines various message types used for agent-to-agent communication.
|
||||
Each message type inherits from the BaseMessage class and includes specific fields
|
||||
relevant to the type of message being sent.
|
||||
Each message type inherits either from the BaseChatMessage class or BaseAgentEvent
|
||||
class and includes specific fields relevant to the type of message being sent.
|
||||
"""
|
||||
|
||||
from abc import ABC
|
||||
from typing import List, Literal
|
||||
|
||||
from autogen_core import FunctionCall, Image
|
||||
@ -12,8 +13,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated, deprecated
|
||||
|
||||
|
||||
class BaseMessage(BaseModel):
|
||||
"""A base message."""
|
||||
class BaseMessage(BaseModel, ABC):
|
||||
"""Base class for all message types."""
|
||||
|
||||
source: str
|
||||
"""The name of the agent that sent this message."""
|
||||
@ -24,7 +25,19 @@ class BaseMessage(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class TextMessage(BaseMessage):
|
||||
class BaseChatMessage(BaseMessage, ABC):
|
||||
"""Base class for chat messages."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseAgentEvent(BaseMessage, ABC):
|
||||
"""Base class for agent events."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TextMessage(BaseChatMessage):
|
||||
"""A text message."""
|
||||
|
||||
content: str
|
||||
@ -33,7 +46,7 @@ class TextMessage(BaseMessage):
|
||||
type: Literal["TextMessage"] = "TextMessage"
|
||||
|
||||
|
||||
class MultiModalMessage(BaseMessage):
|
||||
class MultiModalMessage(BaseChatMessage):
|
||||
"""A multimodal message."""
|
||||
|
||||
content: List[str | Image]
|
||||
@ -42,7 +55,7 @@ class MultiModalMessage(BaseMessage):
|
||||
type: Literal["MultiModalMessage"] = "MultiModalMessage"
|
||||
|
||||
|
||||
class StopMessage(BaseMessage):
|
||||
class StopMessage(BaseChatMessage):
|
||||
"""A message requesting stop of a conversation."""
|
||||
|
||||
content: str
|
||||
@ -51,7 +64,7 @@ class StopMessage(BaseMessage):
|
||||
type: Literal["StopMessage"] = "StopMessage"
|
||||
|
||||
|
||||
class HandoffMessage(BaseMessage):
|
||||
class HandoffMessage(BaseChatMessage):
|
||||
"""A message requesting handoff of a conversation to another agent."""
|
||||
|
||||
target: str
|
||||
@ -83,7 +96,7 @@ class ToolCallResultMessage(BaseMessage):
|
||||
type: Literal["ToolCallResultMessage"] = "ToolCallResultMessage"
|
||||
|
||||
|
||||
class ToolCallRequestEvent(BaseMessage):
|
||||
class ToolCallRequestEvent(BaseAgentEvent):
|
||||
"""An event signaling a request to use tools."""
|
||||
|
||||
content: List[FunctionCall]
|
||||
@ -92,7 +105,7 @@ class ToolCallRequestEvent(BaseMessage):
|
||||
type: Literal["ToolCallRequestEvent"] = "ToolCallRequestEvent"
|
||||
|
||||
|
||||
class ToolCallExecutionEvent(BaseMessage):
|
||||
class ToolCallExecutionEvent(BaseAgentEvent):
|
||||
"""An event signaling the execution of tool calls."""
|
||||
|
||||
content: List[FunctionExecutionResult]
|
||||
@ -101,7 +114,7 @@ class ToolCallExecutionEvent(BaseMessage):
|
||||
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"
|
||||
|
||||
|
||||
class ToolCallSummaryMessage(BaseMessage):
|
||||
class ToolCallSummaryMessage(BaseChatMessage):
|
||||
"""A message signaling the summary of tool call results."""
|
||||
|
||||
content: str
|
||||
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Callable, List, Mapping, get_args
|
||||
from typing import Any, AsyncGenerator, Callable, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
@ -19,7 +19,7 @@ from autogen_core._closure_agent import ClosureContext
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import AgentEvent, ChatMessage, TextMessage
|
||||
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, TextMessage
|
||||
from ...state import TeamState
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
|
||||
@ -172,7 +172,7 @@ class BaseGroupChat(Team, ABC):
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the team and return the result. The base implementation uses
|
||||
@ -180,7 +180,7 @@ class BaseGroupChat(Team, ABC):
|
||||
Once the team is stopped, the termination condition is reset.
|
||||
|
||||
Args:
|
||||
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
|
||||
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
|
||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||
Setting the cancellation token potentially put the team in an inconsistent state,
|
||||
and it may not reset the termination condition.
|
||||
@ -271,7 +271,7 @@ class BaseGroupChat(Team, ABC):
|
||||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
|
||||
"""Run the team and produces a stream of messages and the final result
|
||||
@ -279,7 +279,7 @@ class BaseGroupChat(Team, ABC):
|
||||
team is stopped, the termination condition is reset.
|
||||
|
||||
Args:
|
||||
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
|
||||
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
|
||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||
Setting the cancellation token potentially put the team in an inconsistent state,
|
||||
and it may not reset the termination condition.
|
||||
@ -368,14 +368,16 @@ class BaseGroupChat(Team, ABC):
|
||||
pass
|
||||
elif isinstance(task, str):
|
||||
messages = [TextMessage(content=task, source="user")]
|
||||
elif isinstance(task, get_args(ChatMessage)[0]):
|
||||
messages = [task] # type: ignore
|
||||
elif isinstance(task, list):
|
||||
elif isinstance(task, BaseChatMessage):
|
||||
messages = [task]
|
||||
else:
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty")
|
||||
if not all(isinstance(msg, get_args(ChatMessage)[0]) for msg in task):
|
||||
raise ValueError("All messages in task list must be valid ChatMessage types")
|
||||
messages = task
|
||||
raise ValueError("Task list cannot be empty.")
|
||||
messages = []
|
||||
for msg in task:
|
||||
if not isinstance(msg, BaseChatMessage):
|
||||
raise ValueError("All messages in task list must be valid ChatMessage types")
|
||||
messages.append(msg)
|
||||
|
||||
if self._is_running:
|
||||
raise ValueError("The team is already running, it cannot run again until it is stopped.")
|
||||
|
||||
@ -8,14 +8,9 @@ from ... import TRACE_LOGGER_NAME
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import (
|
||||
AgentEvent,
|
||||
BaseAgentEvent,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallExecutionEvent,
|
||||
ToolCallRequestEvent,
|
||||
ToolCallSummaryMessage,
|
||||
)
|
||||
from ...state import SelectorManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
@ -96,12 +91,12 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
# Construct the history of the conversation.
|
||||
history_messages: List[str] = []
|
||||
for msg in thread:
|
||||
if isinstance(msg, ToolCallRequestEvent | ToolCallExecutionEvent):
|
||||
# Ignore tool call messages.
|
||||
if isinstance(msg, BaseAgentEvent):
|
||||
# Ignore agent events.
|
||||
continue
|
||||
# The agent type must be the same as the topic type, which we use as the agent name.
|
||||
message = f"{msg.source}:"
|
||||
if isinstance(msg, TextMessage | StopMessage | HandoffMessage | ToolCallSummaryMessage):
|
||||
if isinstance(msg.content, str):
|
||||
message += f" {msg.content}"
|
||||
elif isinstance(msg, MultiModalMessage):
|
||||
for item in msg.content:
|
||||
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Any, AsyncGenerator, List, Sequence
|
||||
from typing import Any, AsyncGenerator, List, Sequence, Tuple
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
@ -75,8 +75,8 @@ class _EchoAgent(BaseChatAgent):
|
||||
self._total_messages = 0
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [TextMessage]
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def total_messages(self) -> int:
|
||||
@ -104,8 +104,8 @@ class _StopAgent(_EchoAgent):
|
||||
self._stop_at = stop_at
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [TextMessage, StopMessage]
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
return (TextMessage, StopMessage)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
self._count += 1
|
||||
@ -797,8 +797,8 @@ class _HandOffAgent(BaseChatAgent):
|
||||
self._next_agent = next_agent
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [HandoffMessage]
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
return (HandoffMessage,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
return Response(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Sequence
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
@ -33,8 +33,8 @@ class _EchoAgent(BaseChatAgent):
|
||||
self._total_messages = 0
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [TextMessage]
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def total_messages(self) -> int:
|
||||
|
||||
@ -1,313 +1,313 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Custom Agents\n",
|
||||
"\n",
|
||||
"You may have agents with behaviors that do not fall into a preset. \n",
|
||||
"In such cases, you can build custom agents.\n",
|
||||
"\n",
|
||||
"All agents in AgentChat inherit from {py:class}`~autogen_agentchat.agents.BaseChatAgent` \n",
|
||||
"class and implement the following abstract methods and attributes:\n",
|
||||
"\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: The abstract method that defines the behavior of the agent in response to messages. This method is called when the agent is asked to provide a response in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run`. It returns a {py:class}`~autogen_agentchat.base.Response` object.\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: The abstract method that resets the agent to its initial state. This method is called when the agent is asked to reset itself.\n",
|
||||
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.ChatMessage` message types the agent can produce in its response.\n",
|
||||
"\n",
|
||||
"Optionally, you can implement the the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` method to stream messages as they are generated by the agent. If this method is not implemented, the agent\n",
|
||||
"uses the default implementation of {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`\n",
|
||||
"that calls the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` method and\n",
|
||||
"yields all messages in the response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## CountDownAgent\n",
|
||||
"\n",
|
||||
"In this example, we create a simple agent that counts down from a given number to zero,\n",
|
||||
"and produces a stream of messages with the current count."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"3...\n",
|
||||
"2...\n",
|
||||
"1...\n",
|
||||
"Done!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import AsyncGenerator, List, Sequence\n",
|
||||
"\n",
|
||||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class CountDownAgent(BaseChatAgent):\n",
|
||||
" def __init__(self, name: str, count: int = 3):\n",
|
||||
" super().__init__(name, \"A simple agent that counts down.\")\n",
|
||||
" self._count = count\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
|
||||
" return [TextMessage]\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" # Calls the on_messages_stream.\n",
|
||||
" response: Response | None = None\n",
|
||||
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
|
||||
" if isinstance(message, Response):\n",
|
||||
" response = message\n",
|
||||
" assert response is not None\n",
|
||||
" return response\n",
|
||||
"\n",
|
||||
" async def on_messages_stream(\n",
|
||||
" self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
|
||||
" ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
|
||||
" inner_messages: List[AgentEvent | ChatMessage] = []\n",
|
||||
" for i in range(self._count, 0, -1):\n",
|
||||
" msg = TextMessage(content=f\"{i}...\", source=self.name)\n",
|
||||
" inner_messages.append(msg)\n",
|
||||
" yield msg\n",
|
||||
" # The response is returned at the end of the stream.\n",
|
||||
" # It contains the final message and all the inner messages.\n",
|
||||
" yield Response(chat_message=TextMessage(content=\"Done!\", source=self.name), inner_messages=inner_messages)\n",
|
||||
"\n",
|
||||
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def run_countdown_agent() -> None:\n",
|
||||
" # Create a countdown agent.\n",
|
||||
" countdown_agent = CountDownAgent(\"countdown\")\n",
|
||||
"\n",
|
||||
" # Run the agent with a given task and stream the response.\n",
|
||||
" async for message in countdown_agent.on_messages_stream([], CancellationToken()):\n",
|
||||
" if isinstance(message, Response):\n",
|
||||
" print(message.chat_message.content)\n",
|
||||
" else:\n",
|
||||
" print(message.content)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(run_countdown_agent()) when running in a script.\n",
|
||||
"await run_countdown_agent()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ArithmeticAgent\n",
|
||||
"\n",
|
||||
"In this example, we create an agent class that can perform simple arithmetic operations\n",
|
||||
"on a given integer. Then, we will use different instances of this agent class\n",
|
||||
"in a {py:class}`~autogen_agentchat.teams.SelectorGroupChat`\n",
|
||||
"to transform a given integer into another integer by applying a sequence of arithmetic operations.\n",
|
||||
"\n",
|
||||
"The `ArithmeticAgent` class takes an `operator_func` that takes an integer and returns an integer,\n",
|
||||
"after applying an arithmetic operation to the integer.\n",
|
||||
"In its `on_messages` method, it applies the `operator_func` to the integer in the input message,\n",
|
||||
"and returns a response with the result."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Callable, List, Sequence\n",
|
||||
"\n",
|
||||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.conditions import MaxMessageTermination\n",
|
||||
"from autogen_agentchat.messages import ChatMessage\n",
|
||||
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ArithmeticAgent(BaseChatAgent):\n",
|
||||
" def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:\n",
|
||||
" super().__init__(name, description=description)\n",
|
||||
" self._operator_func = operator_func\n",
|
||||
" self._message_history: List[ChatMessage] = []\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
|
||||
" return [TextMessage]\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" # Update the message history.\n",
|
||||
" # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.\n",
|
||||
" self._message_history.extend(messages)\n",
|
||||
" # Parse the number in the last message.\n",
|
||||
" assert isinstance(self._message_history[-1], TextMessage)\n",
|
||||
" number = int(self._message_history[-1].content)\n",
|
||||
" # Apply the operator function to the number.\n",
|
||||
" result = self._operator_func(number)\n",
|
||||
" # Create a new message with the result.\n",
|
||||
" response_message = TextMessage(content=str(result), source=self.name)\n",
|
||||
" # Update the message history.\n",
|
||||
" self._message_history.append(response_message)\n",
|
||||
" # Return the response.\n",
|
||||
" return Response(chat_message=response_message)\n",
|
||||
"\n",
|
||||
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```{note}\n",
|
||||
"The `on_messages` method may be called with an empty list of messages, in which\n",
|
||||
"case it means the agent was called previously and is now being called again,\n",
|
||||
"without any new messages from the caller. So it is important to keep a history\n",
|
||||
"of the previous messages received by the agent, and use that history to generate\n",
|
||||
"the response.\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with 5 instances of `ArithmeticAgent`:\n",
|
||||
"\n",
|
||||
"- one that adds 1 to the input integer,\n",
|
||||
"- one that subtracts 1 from the input integer,\n",
|
||||
"- one that multiplies the input integer by 2,\n",
|
||||
"- one that divides the input integer by 2 and rounds down to the nearest integer, and\n",
|
||||
"- one that returns the input integer unchanged.\n",
|
||||
"\n",
|
||||
"We then create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with these agents,\n",
|
||||
"and set the appropriate selector settings:\n",
|
||||
"\n",
|
||||
"- allow the same agent to be selected consecutively to allow for repeated operations, and\n",
|
||||
"- customize the selector prompt to tailor the model's response to the specific task."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"Apply the operations to turn the given number into 25.\n",
|
||||
"---------- user ----------\n",
|
||||
"10\n",
|
||||
"---------- multiply_agent ----------\n",
|
||||
"20\n",
|
||||
"---------- add_agent ----------\n",
|
||||
"21\n",
|
||||
"---------- multiply_agent ----------\n",
|
||||
"42\n",
|
||||
"---------- divide_agent ----------\n",
|
||||
"21\n",
|
||||
"---------- add_agent ----------\n",
|
||||
"22\n",
|
||||
"---------- add_agent ----------\n",
|
||||
"23\n",
|
||||
"---------- add_agent ----------\n",
|
||||
"24\n",
|
||||
"---------- add_agent ----------\n",
|
||||
"25\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 10\n",
|
||||
"Finish reason: Maximum number of messages 10 reached, current message count: 10\n",
|
||||
"Total prompt tokens: 0\n",
|
||||
"Total completion tokens: 0\n",
|
||||
"Duration: 2.40 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"async def run_number_agents() -> None:\n",
|
||||
" # Create agents for number operations.\n",
|
||||
" add_agent = ArithmeticAgent(\"add_agent\", \"Adds 1 to the number.\", lambda x: x + 1)\n",
|
||||
" multiply_agent = ArithmeticAgent(\"multiply_agent\", \"Multiplies the number by 2.\", lambda x: x * 2)\n",
|
||||
" subtract_agent = ArithmeticAgent(\"subtract_agent\", \"Subtracts 1 from the number.\", lambda x: x - 1)\n",
|
||||
" divide_agent = ArithmeticAgent(\"divide_agent\", \"Divides the number by 2 and rounds down.\", lambda x: x // 2)\n",
|
||||
" identity_agent = ArithmeticAgent(\"identity_agent\", \"Returns the number as is.\", lambda x: x)\n",
|
||||
"\n",
|
||||
" # The termination condition is to stop after 10 messages.\n",
|
||||
" termination_condition = MaxMessageTermination(10)\n",
|
||||
"\n",
|
||||
" # Create a selector group chat.\n",
|
||||
" selector_group_chat = SelectorGroupChat(\n",
|
||||
" [add_agent, multiply_agent, subtract_agent, divide_agent, identity_agent],\n",
|
||||
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o\"),\n",
|
||||
" termination_condition=termination_condition,\n",
|
||||
" allow_repeated_speaker=True, # Allow the same agent to speak multiple times, necessary for this task.\n",
|
||||
" selector_prompt=(\n",
|
||||
" \"Available roles:\\n{roles}\\nTheir job descriptions:\\n{participants}\\n\"\n",
|
||||
" \"Current conversation history:\\n{history}\\n\"\n",
|
||||
" \"Please select the most appropriate role for the next message, and only return the role name.\"\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Run the selector group chat with a given task and stream the response.\n",
|
||||
" task: List[ChatMessage] = [\n",
|
||||
" TextMessage(content=\"Apply the operations to turn the given number into 25.\", source=\"user\"),\n",
|
||||
" TextMessage(content=\"10\", source=\"user\"),\n",
|
||||
" ]\n",
|
||||
" stream = selector_group_chat.run_stream(task=task)\n",
|
||||
" await Console(stream)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(run_number_agents()) when running in a script.\n",
|
||||
"await run_number_agents()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"From the output, we can see that the agents have successfully transformed the input integer\n",
|
||||
"from 10 to 25 by choosing appropriate agents that apply the arithmetic operations in sequence."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Custom Agents\n",
|
||||
"\n",
|
||||
"You may have agents with behaviors that do not fall into a preset. \n",
|
||||
"In such cases, you can build custom agents.\n",
|
||||
"\n",
|
||||
"All agents in AgentChat inherit from {py:class}`~autogen_agentchat.agents.BaseChatAgent` \n",
|
||||
"class and implement the following abstract methods and attributes:\n",
|
||||
"\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: The abstract method that defines the behavior of the agent in response to messages. This method is called when the agent is asked to provide a response in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run`. It returns a {py:class}`~autogen_agentchat.base.Response` object.\n",
|
||||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: The abstract method that resets the agent to its initial state. This method is called when the agent is asked to reset itself.\n",
|
||||
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.ChatMessage` message types the agent can produce in its response.\n",
|
||||
"\n",
|
||||
"Optionally, you can implement the the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` method to stream messages as they are generated by the agent. If this method is not implemented, the agent\n",
|
||||
"uses the default implementation of {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`\n",
|
||||
"that calls the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` method and\n",
|
||||
"yields all messages in the response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## CountDownAgent\n",
|
||||
"\n",
|
||||
"In this example, we create a simple agent that counts down from a given number to zero,\n",
|
||||
"and produces a stream of messages with the current count."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"3...\n",
|
||||
"2...\n",
|
||||
"1...\n",
|
||||
"Done!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import AsyncGenerator, List, Sequence, Tuple\n",
|
||||
"\n",
|
||||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class CountDownAgent(BaseChatAgent):\n",
|
||||
" def __init__(self, name: str, count: int = 3):\n",
|
||||
" super().__init__(name, \"A simple agent that counts down.\")\n",
|
||||
" self._count = count\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:\n",
|
||||
" return (TextMessage,)\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" # Calls the on_messages_stream.\n",
|
||||
" response: Response | None = None\n",
|
||||
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
|
||||
" if isinstance(message, Response):\n",
|
||||
" response = message\n",
|
||||
" assert response is not None\n",
|
||||
" return response\n",
|
||||
"\n",
|
||||
" async def on_messages_stream(\n",
|
||||
" self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken\n",
|
||||
" ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:\n",
|
||||
" inner_messages: List[AgentEvent | ChatMessage] = []\n",
|
||||
" for i in range(self._count, 0, -1):\n",
|
||||
" msg = TextMessage(content=f\"{i}...\", source=self.name)\n",
|
||||
" inner_messages.append(msg)\n",
|
||||
" yield msg\n",
|
||||
" # The response is returned at the end of the stream.\n",
|
||||
" # It contains the final message and all the inner messages.\n",
|
||||
" yield Response(chat_message=TextMessage(content=\"Done!\", source=self.name), inner_messages=inner_messages)\n",
|
||||
"\n",
|
||||
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def run_countdown_agent() -> None:\n",
|
||||
" # Create a countdown agent.\n",
|
||||
" countdown_agent = CountDownAgent(\"countdown\")\n",
|
||||
"\n",
|
||||
" # Run the agent with a given task and stream the response.\n",
|
||||
" async for message in countdown_agent.on_messages_stream([], CancellationToken()):\n",
|
||||
" if isinstance(message, Response):\n",
|
||||
" print(message.chat_message.content)\n",
|
||||
" else:\n",
|
||||
" print(message.content)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(run_countdown_agent()) when running in a script.\n",
|
||||
"await run_countdown_agent()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ArithmeticAgent\n",
|
||||
"\n",
|
||||
"In this example, we create an agent class that can perform simple arithmetic operations\n",
|
||||
"on a given integer. Then, we will use different instances of this agent class\n",
|
||||
"in a {py:class}`~autogen_agentchat.teams.SelectorGroupChat`\n",
|
||||
"to transform a given integer into another integer by applying a sequence of arithmetic operations.\n",
|
||||
"\n",
|
||||
"The `ArithmeticAgent` class takes an `operator_func` that takes an integer and returns an integer,\n",
|
||||
"after applying an arithmetic operation to the integer.\n",
|
||||
"In its `on_messages` method, it applies the `operator_func` to the integer in the input message,\n",
|
||||
"and returns a response with the result."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Callable, Sequence, Tuple\n",
|
||||
"\n",
|
||||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||
"from autogen_agentchat.base import Response\n",
|
||||
"from autogen_agentchat.conditions import MaxMessageTermination\n",
|
||||
"from autogen_agentchat.messages import ChatMessage\n",
|
||||
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ArithmeticAgent(BaseChatAgent):\n",
|
||||
" def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:\n",
|
||||
" super().__init__(name, description=description)\n",
|
||||
" self._operator_func = operator_func\n",
|
||||
" self._message_history: List[ChatMessage] = []\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:\n",
|
||||
" return (TextMessage,)\n",
|
||||
"\n",
|
||||
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||||
" # Update the message history.\n",
|
||||
" # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.\n",
|
||||
" self._message_history.extend(messages)\n",
|
||||
" # Parse the number in the last message.\n",
|
||||
" assert isinstance(self._message_history[-1], TextMessage)\n",
|
||||
" number = int(self._message_history[-1].content)\n",
|
||||
" # Apply the operator function to the number.\n",
|
||||
" result = self._operator_func(number)\n",
|
||||
" # Create a new message with the result.\n",
|
||||
" response_message = TextMessage(content=str(result), source=self.name)\n",
|
||||
" # Update the message history.\n",
|
||||
" self._message_history.append(response_message)\n",
|
||||
" # Return the response.\n",
|
||||
" return Response(chat_message=response_message)\n",
|
||||
"\n",
|
||||
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```{note}\n",
|
||||
"The `on_messages` method may be called with an empty list of messages, in which\n",
|
||||
"case it means the agent was called previously and is now being called again,\n",
|
||||
"without any new messages from the caller. So it is important to keep a history\n",
|
||||
"of the previous messages received by the agent, and use that history to generate\n",
|
||||
"the response.\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with 5 instances of `ArithmeticAgent`:\n",
|
||||
"\n",
|
||||
"- one that adds 1 to the input integer,\n",
|
||||
"- one that subtracts 1 from the input integer,\n",
|
||||
"- one that multiplies the input integer by 2,\n",
|
||||
"- one that divides the input integer by 2 and rounds down to the nearest integer, and\n",
|
||||
"- one that returns the input integer unchanged.\n",
|
||||
"\n",
|
||||
"We then create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with these agents,\n",
|
||||
"and set the appropriate selector settings:\n",
|
||||
"\n",
|
||||
"- allow the same agent to be selected consecutively to allow for repeated operations, and\n",
|
||||
"- customize the selector prompt to tailor the model's response to the specific task."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"Apply the operations to turn the given number into 25.\n",
|
||||
"---------- user ----------\n",
|
||||
"10\n",
|
||||
"---------- multiply_agent ----------\n",
|
||||
"20\n",
|
||||
"---------- add_agent ----------\n",
|
||||
"21\n",
|
||||
"---------- multiply_agent ----------\n",
|
||||
"42\n",
|
||||
"---------- divide_agent ----------\n",
|
||||
"21\n",
|
||||
"---------- add_agent ----------\n",
|
||||
"22\n",
|
||||
"---------- add_agent ----------\n",
|
||||
"23\n",
|
||||
"---------- add_agent ----------\n",
|
||||
"24\n",
|
||||
"---------- add_agent ----------\n",
|
||||
"25\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 10\n",
|
||||
"Finish reason: Maximum number of messages 10 reached, current message count: 10\n",
|
||||
"Total prompt tokens: 0\n",
|
||||
"Total completion tokens: 0\n",
|
||||
"Duration: 2.40 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"async def run_number_agents() -> None:\n",
|
||||
" # Create agents for number operations.\n",
|
||||
" add_agent = ArithmeticAgent(\"add_agent\", \"Adds 1 to the number.\", lambda x: x + 1)\n",
|
||||
" multiply_agent = ArithmeticAgent(\"multiply_agent\", \"Multiplies the number by 2.\", lambda x: x * 2)\n",
|
||||
" subtract_agent = ArithmeticAgent(\"subtract_agent\", \"Subtracts 1 from the number.\", lambda x: x - 1)\n",
|
||||
" divide_agent = ArithmeticAgent(\"divide_agent\", \"Divides the number by 2 and rounds down.\", lambda x: x // 2)\n",
|
||||
" identity_agent = ArithmeticAgent(\"identity_agent\", \"Returns the number as is.\", lambda x: x)\n",
|
||||
"\n",
|
||||
" # The termination condition is to stop after 10 messages.\n",
|
||||
" termination_condition = MaxMessageTermination(10)\n",
|
||||
"\n",
|
||||
" # Create a selector group chat.\n",
|
||||
" selector_group_chat = SelectorGroupChat(\n",
|
||||
" [add_agent, multiply_agent, subtract_agent, divide_agent, identity_agent],\n",
|
||||
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o\"),\n",
|
||||
" termination_condition=termination_condition,\n",
|
||||
" allow_repeated_speaker=True, # Allow the same agent to speak multiple times, necessary for this task.\n",
|
||||
" selector_prompt=(\n",
|
||||
" \"Available roles:\\n{roles}\\nTheir job descriptions:\\n{participants}\\n\"\n",
|
||||
" \"Current conversation history:\\n{history}\\n\"\n",
|
||||
" \"Please select the most appropriate role for the next message, and only return the role name.\"\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Run the selector group chat with a given task and stream the response.\n",
|
||||
" task: List[ChatMessage] = [\n",
|
||||
" TextMessage(content=\"Apply the operations to turn the given number into 25.\", source=\"user\"),\n",
|
||||
" TextMessage(content=\"10\", source=\"user\"),\n",
|
||||
" ]\n",
|
||||
" stream = selector_group_chat.run_stream(task=task)\n",
|
||||
" await Console(stream)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(run_number_agents()) when running in a script.\n",
|
||||
"await run_number_agents()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"From the output, we can see that the agents have successfully transformed the input integer\n",
|
||||
"from 10 to 25 by choosing appropriate agents that apply the arithmetic operations in sequence."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
||||
@ -63,8 +63,8 @@ class FileSurfer(BaseChatAgent):
|
||||
self._browser = MarkdownFileBrowser(viewport_size=1024 * 5)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [TextMessage]
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
return (TextMessage,)
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
for chat_message in messages:
|
||||
|
||||
@ -15,6 +15,7 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -298,9 +299,9 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
self._initial_message_ids = initial_message_ids
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
"""The types of messages that the assistant agent produces."""
|
||||
return [TextMessage]
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def threads(self) -> AsyncThreads:
|
||||
|
||||
@ -15,6 +15,7 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import quote_plus
|
||||
@ -321,8 +322,8 @@ class MultimodalWebSurfer(BaseChatAgent):
|
||||
)
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [MultiModalMessage]
|
||||
def produced_message_types(self) -> Tuple[type[ChatMessage], ...]:
|
||||
return (MultiModalMessage,)
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
if not self.did_lazy_init:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user