Add AssistantAgent, deprecate CodingAssistantAgent and ToolUseAssistantAgent (#3960)

* Add AssistantAgent, deprecate CodingAssistantAgent and ToolUseAssistantAgent

* Rename

* Add note

* Update uv

* uf lock

* Merge branch 'main' into assistant-agent

* Update uv
This commit is contained in:
Eric Zhu 2024-10-25 23:17:06 -07:00 committed by GitHub
parent 69fc742537
commit 3fe0f9e97d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1609 additions and 1279 deletions

View File

@ -1,3 +1,4 @@
from ._assistant_agent import AssistantAgent
from ._base_chat_agent import BaseChatAgent
from ._code_executor_agent import CodeExecutorAgent
from ._coding_assistant_agent import CodingAssistantAgent
@ -5,6 +6,7 @@ from ._tool_use_assistant_agent import ToolUseAssistantAgent
__all__ = [
"BaseChatAgent",
"AssistantAgent",
"CodeExecutorAgent",
"CodingAssistantAgent",
"ToolUseAssistantAgent",

View File

@ -0,0 +1,140 @@
import asyncio
import json
import logging
from typing import Any, Awaitable, Callable, List, Sequence
from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, ConfigDict
from .. import EVENT_LOGGER_NAME
from ..messages import (
ChatMessage,
StopMessage,
TextMessage,
)
from ._base_chat_agent import BaseChatAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class ToolCallEvent(BaseModel):
"""A tool call event."""
tool_calls: List[FunctionCall]
"""The tool call message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class ToolCallResultEvent(BaseModel):
"""A tool call result event."""
tool_call_results: List[FunctionExecutionResult]
"""The tool call result message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class AssistantAgent(BaseChatAgent):
"""An agent that provides assistance with tool use.
It responds with a StopMessage when 'terminate' is detected in the response.
Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client to use for inference.
tools (List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
description (str, optional): The description of the agent.
system_message (str, optional): The system message for the model.
"""
def __init__(
self,
name: str,
model_client: ChatCompletionClient,
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
):
super().__init__(name=name, description=description)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
if tools is not None:
for tool in tools:
if isinstance(tool, Tool):
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
self._model_context: List[LLMMessage] = []
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
# Add messages to the model context.
for msg in messages:
# TODO: add special handling for handoff messages
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content))
# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
self._model_context.append(FunctionExecutionResultMessage(content=results))
# Generate an inference result based on the current model context.
result = await self._model_client.create(
self._model_context, tools=self._tools, cancellation_token=cancellation_token
)
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
assert isinstance(result.content, str)
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
return StopMessage(content=result.content, source=self.name)
return TextMessage(content=result.content, source=self.name)
async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
) -> FunctionExecutionResult:
"""Execute a tool call and return the result."""
try:
if not self._tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id)
except Exception as e:
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id)

View File

@ -1,20 +1,14 @@
from typing import List, Sequence
import warnings
from autogen_core.base import CancellationToken
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
SystemMessage,
UserMessage,
)
from ..messages import ChatMessage, MultiModalMessage, StopMessage, TextMessage
from ._base_chat_agent import BaseChatAgent
from ._assistant_agent import AssistantAgent
class CodingAssistantAgent(BaseChatAgent):
"""An agent that provides coding assistance using an LLM model client.
class CodingAssistantAgent(AssistantAgent):
"""[DEPRECATED] An agent that provides coding assistance using an LLM model client.
It responds with a StopMessage when 'terminate' is detected in the response.
"""
@ -37,29 +31,10 @@ If the result indicates there is an error, fix the error and output the code aga
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
Reply "TERMINATE" in the end when code has been executed and task is complete.""",
):
super().__init__(name=name, description=description)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._model_context: List[LLMMessage] = []
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
# Add messages to the model context and detect stopping.
for msg in messages:
if not isinstance(msg, TextMessage | MultiModalMessage | StopMessage):
raise ValueError(f"Unsupported message type: {type(msg)}")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
return StopMessage(content=result.content, source=self.name)
return TextMessage(content=result.content, source=self.name)
# Deprecation warning.
warnings.warn(
"CodingAssistantAgent is deprecated. Use AssistantAgent instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(name, model_client, description=description, system_message=system_message)

View File

@ -1,53 +1,20 @@
import asyncio
import json
import logging
from typing import Any, Awaitable, Callable, List, Sequence
import warnings
from typing import Any, Awaitable, Callable, List
from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, ConfigDict
from autogen_core.components.tools import Tool
from .. import EVENT_LOGGER_NAME
from ..messages import (
ChatMessage,
StopMessage,
TextMessage,
)
from ._base_chat_agent import BaseChatAgent
from ._assistant_agent import AssistantAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class ToolCallEvent(BaseModel):
"""A tool call event."""
tool_calls: List[FunctionCall]
"""The tool call message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class ToolCallResultEvent(BaseModel):
"""A tool call result event."""
tool_call_results: List[FunctionExecutionResult]
"""The tool call result message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class ToolUseAssistantAgent(BaseChatAgent):
"""An agent that provides assistance with tool use.
class ToolUseAssistantAgent(AssistantAgent):
"""[DEPRECATED] An agent that provides assistance with tool use.
It responds with a StopMessage when 'terminate' is detected in the response.
@ -68,72 +35,12 @@ class ToolUseAssistantAgent(BaseChatAgent):
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
):
super().__init__(name=name, description=description)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
for tool in registered_tools:
if isinstance(tool, Tool):
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
self._model_context: List[LLMMessage] = []
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
# Add messages to the model context.
for msg in messages:
# TODO: add special handling for handoff messages
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content))
# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
self._model_context.append(FunctionExecutionResultMessage(content=results))
# Generate an inference result based on the current model context.
result = await self._model_client.create(
self._model_context, tools=self._tools, cancellation_token=cancellation_token
)
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
assert isinstance(result.content, str)
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
return StopMessage(content=result.content, source=self.name)
return TextMessage(content=result.content, source=self.name)
async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
) -> FunctionExecutionResult:
"""Execute a tool call and return the result."""
try:
if not self._tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id)
except Exception as e:
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id)
# Deprecation warning.
warnings.warn(
"ToolUseAssistantAgent is deprecated. Use AssistantAgent instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(
name, model_client, tools=registered_tools, description=description, system_message=system_message
)

View File

@ -3,7 +3,7 @@ import logging
import sys
from datetime import datetime
from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..agents._assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..messages import ChatMessage, StopMessage, TextMessage
from ..teams._events import (
GroupChatPublishEvent,

View File

@ -4,7 +4,7 @@ from dataclasses import asdict, is_dataclass
from datetime import datetime
from typing import Any
from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..agents._assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..teams._events import (
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,

View File

@ -3,10 +3,10 @@ import json
from typing import Any, AsyncGenerator, List
import pytest
from autogen_agentchat.agents import ToolUseAssistantAgent
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import StopMessage, TextMessage
from autogen_core.components.models import OpenAIChatCompletionClient
from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@ -42,7 +42,7 @@ async def _echo_function(input: str) -> str:
@pytest.mark.asyncio
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
@ -97,10 +97,10 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
tool_use_agent = ToolUseAssistantAgent(
tool_use_agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
registered_tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
result = await tool_use_agent.run("task")
assert len(result.messages) == 3

View File

@ -7,10 +7,9 @@ from typing import Any, AsyncGenerator, List, Sequence
import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import (
AssistantAgent,
BaseChatAgent,
CodeExecutorAgent,
CodingAssistantAgent,
ToolUseAssistantAgent,
)
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import (
@ -131,7 +130,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
code_executor_agent = CodeExecutorAgent(
"code_executor", code_executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)
)
coding_assistant_agent = CodingAssistantAgent(
coding_assistant_agent = AssistantAgent(
"coding_assistant", model_client=OpenAIChatCompletionClient(model=model, api_key="")
)
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
@ -211,10 +210,10 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
tool = FunctionTool(_pass_function, name="pass", description="pass function")
tool_use_agent = ToolUseAssistantAgent(
tool_use_agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
registered_tools=[tool],
tools=[tool],
)
echo_agent = _EchoAgent("echo_agent", description="echo agent")
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])

2553
python/uv.lock generated

File diff suppressed because it is too large Load Diff