Refactor agentchat +implement base chat agent run method (#3913)

This commit is contained in:
Eric Zhu 2024-10-24 05:36:33 -07:00 committed by GitHub
parent 8f6dc4e1dd
commit 1812cc068d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 176 additions and 85 deletions

View File

@ -1,8 +1,11 @@
from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent
from ._code_executor_agent import CodeExecutorAgent
from ._coding_assistant_agent import CodingAssistantAgent
from ._tool_use_assistant_agent import ToolUseAssistantAgent
__all__ = [
"BaseChatAgent",
"BaseToolUseChatAgent",
"CodeExecutorAgent",
"CodingAssistantAgent",
"ToolUseAssistantAgent",

View File

@ -4,12 +4,13 @@ from typing import List, Sequence
from autogen_core.base import CancellationToken
from autogen_core.components.tools import Tool
from ..base import ChatAgent, TaskResult, TerminationCondition, ToolUseChatAgent
from ..messages import ChatMessage
from ._base_task import TaskResult, TaskRunner
from ..teams import RoundRobinGroupChat
class BaseChatAgent(TaskRunner, ABC):
"""Base class for a chat agent that can participant in a team."""
class BaseChatAgent(ChatAgent, ABC):
"""Base class for a chat agent."""
def __init__(self, name: str, description: str) -> None:
self._name = name
@ -36,13 +37,23 @@ class BaseChatAgent(TaskRunner, ABC):
...
async def run(
self, task: str, *, source: str = "user", cancellation_token: CancellationToken | None = None
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
# TODO: Implement this method.
raise NotImplementedError
"""Run the agent with the given task and return the result."""
group_chat = RoundRobinGroupChat(participants=[self])
result = await group_chat.run(
task=task,
cancellation_token=cancellation_token,
termination_condition=termination_condition,
)
return result
class BaseToolUseChatAgent(BaseChatAgent):
class BaseToolUseChatAgent(BaseChatAgent, ToolUseChatAgent):
"""Base class for a chat agent that can use tools.
Subclass this base class to create an agent class that uses tools by returning

View File

@ -3,8 +3,8 @@ from typing import List, Sequence
from autogen_core.base import CancellationToken
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
from ..base import BaseChatAgent
from ..messages import ChatMessage, TextMessage
from ._base_chat_agent import BaseChatAgent
class CodeExecutorAgent(BaseChatAgent):

View File

@ -9,8 +9,8 @@ from autogen_core.components.models import (
UserMessage,
)
from ..base import BaseChatAgent
from ..messages import ChatMessage, MultiModalMessage, StopMessage, TextMessage
from ._base_chat_agent import BaseChatAgent
class CodingAssistantAgent(BaseChatAgent):

View File

@ -12,7 +12,6 @@ from autogen_core.components.models import (
)
from autogen_core.components.tools import FunctionTool, Tool
from ..base import BaseToolUseChatAgent
from ..messages import (
ChatMessage,
MultiModalMessage,
@ -21,6 +20,7 @@ from ..messages import (
ToolCallMessage,
ToolCallResultMessage,
)
from ._base_chat_agent import BaseToolUseChatAgent
class ToolUseAssistantAgent(BaseToolUseChatAgent):

View File

@ -1,11 +1,11 @@
from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent
from ._base_task import TaskResult, TaskRunner
from ._base_team import Team
from ._base_termination import TerminatedException, TerminationCondition
from ._chat_agent import ChatAgent, ToolUseChatAgent
from ._task import TaskResult, TaskRunner
from ._team import Team
from ._termination import TerminatedException, TerminationCondition
__all__ = [
"BaseChatAgent",
"BaseToolUseChatAgent",
"ChatAgent",
"ToolUseChatAgent",
"Team",
"TerminatedException",
"TerminationCondition",

View File

@ -1,10 +0,0 @@
from typing import Protocol
from ._base_task import TaskResult, TaskRunner
from ._base_termination import TerminationCondition
class Team(TaskRunner, Protocol):
async def run(self, task: str, *, termination_condition: TerminationCondition | None = None) -> TaskResult:
"""Run the team on a given task until the termination condition is met."""
...

View File

@ -0,0 +1,50 @@
from typing import List, Protocol, Sequence, runtime_checkable
from autogen_core.base import CancellationToken
from autogen_core.components.tools import Tool
from ..messages import ChatMessage
from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition
@runtime_checkable
class ChatAgent(TaskRunner, Protocol):
"""Protocol for a chat agent."""
@property
def name(self) -> str:
"""The name of the agent. This is used by team to uniquely identify
the agent. It should be unique within the team."""
...
@property
def description(self) -> str:
"""The description of the agent. This is used by team to
make decisions about which agents to use. The description should
describe the agent's capabilities and how to interact with it."""
...
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
"""Handle incoming messages and return a response message."""
...
async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
...
@runtime_checkable
class ToolUseChatAgent(ChatAgent, Protocol):
"""Protocol for a chat agent that can use tools."""
@property
def registered_tools(self) -> List[Tool]:
"""The list of tools that the agent can use."""
...

View File

@ -1,7 +1,10 @@
from dataclasses import dataclass
from typing import Protocol, Sequence
from autogen_core.base import CancellationToken
from ..messages import ChatMessage
from ._termination import TerminationCondition
@dataclass
@ -15,6 +18,12 @@ class TaskResult:
class TaskRunner(Protocol):
"""A task runner."""
async def run(self, task: str) -> TaskResult:
async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the task."""
...

View File

@ -0,0 +1,18 @@
from typing import Protocol
from autogen_core.base import CancellationToken
from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition
class Team(TaskRunner, Protocol):
async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the team on a given task until the termination condition is met."""
...

View File

@ -0,0 +1,7 @@
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination
__all__ = [
"MaxMessageTermination",
"TextMentionTermination",
"StopMessageTermination",
]

View File

@ -1,11 +1,7 @@
from ._group_chat._round_robin_group_chat import RoundRobinGroupChat
from ._group_chat._selector_group_chat import SelectorGroupChat
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination
__all__ = [
"MaxMessageTermination",
"TextMentionTermination",
"StopMessageTermination",
"RoundRobinGroupChat",
"SelectorGroupChat",
]

View File

@ -8,7 +8,7 @@ from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components.tool_agent import ToolException
from ... import EVENT_LOGGER_NAME
from ...base import BaseChatAgent
from ...base import ChatAgent
from ...messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent
from ._sequential_routed_agent import SequentialRoutedAgent
@ -27,7 +27,7 @@ class BaseChatAgentContainer(SequentialRoutedAgent):
tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None.
"""
def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None = None) -> None:
def __init__(self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None = None) -> None:
super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type
self._agent = agent

View File

@ -3,12 +3,20 @@ from abc import ABC, abstractmethod
from typing import Callable, List
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId
from autogen_core.base import (
AgentId,
AgentInstantiationContext,
AgentRuntime,
AgentType,
CancellationToken,
MessageContext,
TopicId,
)
from autogen_core.components import ClosureAgent, TypeSubscription
from autogen_core.components.tool_agent import ToolAgent
from autogen_core.components.tools import Tool
from ...base import BaseChatAgent, BaseToolUseChatAgent, TaskResult, Team, TerminationCondition
from ...base import ChatAgent, TaskResult, Team, TerminationCondition, ToolUseChatAgent
from ...messages import ChatMessage, TextMessage
from .._events import ContentPublishEvent, ContentRequestEvent
from ._base_chat_agent_container import BaseChatAgentContainer
@ -22,13 +30,13 @@ class BaseGroupChat(Team, ABC):
create a subclass of :class:`BaseGroupChat` that uses the group chat manager.
"""
def __init__(self, participants: List[BaseChatAgent], group_chat_manager_class: type[BaseGroupChatManager]):
def __init__(self, participants: List[ChatAgent], group_chat_manager_class: type[BaseGroupChatManager]):
if len(participants) == 0:
raise ValueError("At least one participant is required.")
if len(participants) != len(set(participant.name for participant in participants)):
raise ValueError("The participant names must be unique.")
for participant in participants:
if isinstance(participant, BaseToolUseChatAgent) and not participant.registered_tools:
if isinstance(participant, ToolUseChatAgent) and not participant.registered_tools:
raise ValueError(
f"Participant '{participant.name}' is a tool use agent so it must have registered tools."
)
@ -47,7 +55,7 @@ class BaseGroupChat(Team, ABC):
) -> Callable[[], BaseGroupChatManager]: ...
def _create_participant_factory(
self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None
self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None
) -> Callable[[], BaseChatAgentContainer]:
def _factory() -> BaseChatAgentContainer:
id = AgentInstantiationContext.current_agent_id()
@ -68,7 +76,13 @@ class BaseGroupChat(Team, ABC):
return _factory
async def run(self, task: str, *, termination_condition: TerminationCondition | None = None) -> TaskResult:
async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the team and return the result."""
# Create intervention handler for termination.
@ -85,7 +99,7 @@ class BaseGroupChat(Team, ABC):
participant_topic_types: List[str] = []
participant_descriptions: List[str] = []
for participant in self._participants:
if isinstance(participant, BaseToolUseChatAgent):
if isinstance(participant, ToolUseChatAgent):
assert participant.registered_tools is not None and len(participant.registered_tools) > 0
# Register the tool agent.
tool_agent_type = await ToolAgent.register(

View File

@ -1,6 +1,6 @@
from typing import Callable, List
from ...base import BaseChatAgent, TerminationCondition
from ...base import ChatAgent, TerminationCondition
from .._events import ContentPublishEvent
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
@ -73,7 +73,7 @@ class RoundRobinGroupChat(BaseGroupChat):
"""
def __init__(self, participants: List[BaseChatAgent]):
def __init__(self, participants: List[ChatAgent]):
super().__init__(participants, group_chat_manager_class=RoundRobinGroupChatManager)
def _create_group_chat_manager_factory(

View File

@ -5,7 +5,7 @@ from typing import Callable, Dict, List
from autogen_core.components.models import ChatCompletionClient, SystemMessage
from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from ...base import BaseChatAgent, TerminationCondition
from ...base import ChatAgent, TerminationCondition
from ...messages import MultiModalMessage, StopMessage, TextMessage
from .._events import ContentPublishEvent, SelectSpeakerEvent
from ._base_group_chat import BaseGroupChat
@ -178,7 +178,7 @@ class SelectorGroupChat(BaseGroupChat):
def __init__(
self,
participants: List[BaseChatAgent],
participants: List[ChatAgent],
model_client: ChatCompletionClient,
*,
selector_prompt: str = """You are in a role play game. The following roles are available:

View File

@ -7,17 +7,17 @@ from typing import Any, AsyncGenerator, List, Sequence
import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import (
BaseChatAgent,
CodeExecutorAgent,
CodingAssistantAgent,
ToolUseAssistantAgent,
)
from autogen_agentchat.base import BaseChatAgent
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage
from autogen_agentchat.task import StopMessageTermination
from autogen_agentchat.teams import (
RoundRobinGroupChat,
SelectorGroupChat,
StopMessageTermination,
)
from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall

View File

@ -1,6 +1,6 @@
import pytest
from autogen_agentchat.messages import StopMessage, TextMessage
from autogen_agentchat.teams import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
@pytest.mark.asyncio

View File

@ -4,13 +4,7 @@ from typing import Any, AsyncGenerator, List
import pytest
from autogen_agentchat.agents import ToolUseAssistantAgent
from autogen_agentchat.messages import (
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
)
from autogen_core.base import CancellationToken
from autogen_core.components.models import FunctionExecutionResult, OpenAIChatCompletionClient
from autogen_core.components.models import OpenAIChatCompletionClient
from autogen_core.components.tools import FunctionTool
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
@ -63,8 +57,8 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
id="1",
type="function",
function=Function(
name="pass",
arguments=json.dumps({"input": "pass"}),
name="_pass_function",
arguments=json.dumps({"input": "task"}),
),
)
],
@ -107,14 +101,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
registered_tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
response = await tool_use_agent.on_messages(
messages=[TextMessage(content="Test", source="user")], cancellation_token=CancellationToken()
)
assert isinstance(response, ToolCallMessage)
tool_call_results = [FunctionExecutionResult(content="", call_id=call.id) for call in response.content]
response = await tool_use_agent.on_messages(
messages=[ToolCallResultMessage(content=tool_call_results, source="test")],
cancellation_token=CancellationToken(),
)
assert isinstance(response, TextMessage)
result = await tool_use_agent.run("task")
assert len(result.messages) == 3
# assert isinstance(result.messages[1], ToolCallMessage)
# assert isinstance(result.messages[2], TextMessage)

View File

@ -23,7 +23,8 @@
"outputs": [],
"source": [
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
"from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_core.components.tools import FunctionTool\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
]

View File

@ -23,7 +23,8 @@
"outputs": [],
"source": [
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
"from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_core.components.tools import FunctionTool\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
]

View File

@ -18,7 +18,8 @@
"outputs": [],
"source": [
"from autogen_agentchat.agents import CodingAssistantAgent\n",
"from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
]
},

View File

@ -58,7 +58,8 @@
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
"from autogen_agentchat.agents import ToolUseAssistantAgent\n",
"from autogen_agentchat.logging import ConsoleLogHandler\n",
"from autogen_agentchat.teams import MaxMessageTermination, RoundRobinGroupChat\n",
"from autogen_agentchat.task import MaxMessageTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_core.components.tools import FunctionTool\n",
"from autogen_ext.models import OpenAIChatCompletionClient\n",
"\n",
@ -126,7 +127,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.6"
}
},
"nbformat": 4,

View File

@ -45,10 +45,9 @@
"import logging\n",
"\n",
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
"from autogen_agentchat.agents import ToolUseAssistantAgent\n",
"from autogen_agentchat.logging import ConsoleLogHandler\n",
"from autogen_agentchat.messages import TextMessage\n",
"from autogen_agentchat.teams import MaxMessageTermination, RoundRobinGroupChat, SelectorGroupChat\n",
"from autogen_core.base import CancellationToken\n",
"from autogen_core.components.models import OpenAIChatCompletionClient\n",
"from autogen_core.components.tools import FunctionTool\n",
@ -251,7 +250,7 @@
"import asyncio\n",
"from typing import Sequence\n",
"\n",
"from autogen_agentchat.base import BaseChatAgent\n",
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.messages import (\n",
" ChatMessage,\n",
" StopMessage,\n",
@ -313,7 +312,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.6"
}
},
"nbformat": 4,

View File

@ -98,7 +98,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "agnext",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@ -112,7 +112,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.6"
}
},
"nbformat": 4,

View File

@ -41,12 +41,13 @@
"from typing import Sequence\n",
"\n",
"from autogen_agentchat.agents import (\n",
" BaseChatAgent,\n",
" CodingAssistantAgent,\n",
" ToolUseAssistantAgent,\n",
")\n",
"from autogen_agentchat.base import BaseChatAgent\n",
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
"from autogen_agentchat.teams import SelectorGroupChat, StopMessageTermination\n",
"from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.teams import SelectorGroupChat\n",
"from autogen_core.base import CancellationToken\n",
"from autogen_core.components.tools import FunctionTool\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
@ -268,7 +269,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.6"
}
},
"nbformat": 4,

View File

@ -28,7 +28,8 @@
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
"from autogen_agentchat.logging import ConsoleLogHandler\n",
"from autogen_agentchat.teams import MaxMessageTermination, RoundRobinGroupChat, SelectorGroupChat\n",
"from autogen_agentchat.task import MaxMessageTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat\n",
"from autogen_core.components.models import OpenAIChatCompletionClient\n",
"from autogen_core.components.tools import FunctionTool\n",
"\n",
@ -208,7 +209,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.6"
}
},
"nbformat": 4,

View File

@ -37,7 +37,8 @@
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
"from autogen_agentchat.agents import CodingAssistantAgent\n",
"from autogen_agentchat.logging import ConsoleLogHandler\n",
"from autogen_agentchat.teams import MaxMessageTermination, RoundRobinGroupChat, StopMessageTermination\n",
"from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_core.components.models import OpenAIChatCompletionClient\n",
"\n",
"logger = logging.getLogger(EVENT_LOGGER_NAME)\n",
@ -198,7 +199,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.6"
}
},
"nbformat": 4,