mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-10 14:31:12 +00:00
Add AssistantAgent, deprecate CodingAssistantAgent and ToolUseAssistantAgent (#3960)
* Add AssistantAgent, deprecate CodingAssistantAgent and ToolUseAssistantAgent * Rename * Add note * Update uv * uf lock * Merge branch 'main' into assistant-agent * Update uv
This commit is contained in:
parent
69fc742537
commit
3fe0f9e97d
@ -1,3 +1,4 @@
|
||||
from ._assistant_agent import AssistantAgent
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
from ._code_executor_agent import CodeExecutorAgent
|
||||
from ._coding_assistant_agent import CodingAssistantAgent
|
||||
@ -5,6 +6,7 @@ from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseChatAgent",
|
||||
"AssistantAgent",
|
||||
"CodeExecutorAgent",
|
||||
"CodingAssistantAgent",
|
||||
"ToolUseAssistantAgent",
|
||||
|
||||
@ -0,0 +1,140 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, List, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components import FunctionCall
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.components.tools import FunctionTool, Tool
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
from ..messages import (
|
||||
ChatMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class ToolCallEvent(BaseModel):
|
||||
"""A tool call event."""
|
||||
|
||||
tool_calls: List[FunctionCall]
|
||||
"""The tool call message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ToolCallResultEvent(BaseModel):
|
||||
"""A tool call result event."""
|
||||
|
||||
tool_call_results: List[FunctionExecutionResult]
|
||||
"""The tool call result message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class AssistantAgent(BaseChatAgent):
|
||||
"""An agent that provides assistance with tool use.
|
||||
|
||||
It responds with a StopMessage when 'terminate' is detected in the response.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
model_client (ChatCompletionClient): The model client to use for inference.
|
||||
tools (List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
|
||||
description (str, optional): The description of the agent.
|
||||
system_message (str, optional): The system message for the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
|
||||
description: str = "An agent that provides assistance with ability to use tools.",
|
||||
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
self._model_client = model_client
|
||||
self._system_messages = [SystemMessage(content=system_message)]
|
||||
self._tools: List[Tool] = []
|
||||
if tools is not None:
|
||||
for tool in tools:
|
||||
if isinstance(tool, Tool):
|
||||
self._tools.append(tool)
|
||||
elif callable(tool):
|
||||
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
|
||||
description = tool.__doc__
|
||||
else:
|
||||
description = ""
|
||||
self._tools.append(FunctionTool(tool, description=description))
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
self._model_context: List[LLMMessage] = []
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
# Add messages to the model context.
|
||||
for msg in messages:
|
||||
# TODO: add special handling for handoff messages
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)
|
||||
|
||||
# Add the response to the model context.
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
|
||||
# Run tool calls until the model produces a string response.
|
||||
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
||||
event_logger.debug(ToolCallEvent(tool_calls=result.content))
|
||||
# Execute the tool calls.
|
||||
results = await asyncio.gather(
|
||||
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
||||
)
|
||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
|
||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||
# Generate an inference result based on the current model context.
|
||||
result = await self._model_client.create(
|
||||
self._model_context, tools=self._tools, cancellation_token=cancellation_token
|
||||
)
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
|
||||
assert isinstance(result.content, str)
|
||||
# Detect stop request.
|
||||
request_stop = "terminate" in result.content.strip().lower()
|
||||
if request_stop:
|
||||
return StopMessage(content=result.content, source=self.name)
|
||||
|
||||
return TextMessage(content=result.content, source=self.name)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
||||
) -> FunctionExecutionResult:
|
||||
"""Execute a tool call and return the result."""
|
||||
try:
|
||||
if not self._tools:
|
||||
raise ValueError("No tools are available.")
|
||||
tool = next((t for t in self._tools if t.name == tool_call.name), None)
|
||||
if tool is None:
|
||||
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
||||
arguments = json.loads(tool_call.arguments)
|
||||
result = await tool.run_json(arguments, cancellation_token)
|
||||
result_as_str = tool.return_value_as_string(result)
|
||||
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id)
|
||||
except Exception as e:
|
||||
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id)
|
||||
@ -1,20 +1,14 @@
|
||||
from typing import List, Sequence
|
||||
import warnings
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from ..messages import ChatMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
from ._assistant_agent import AssistantAgent
|
||||
|
||||
|
||||
class CodingAssistantAgent(BaseChatAgent):
|
||||
"""An agent that provides coding assistance using an LLM model client.
|
||||
class CodingAssistantAgent(AssistantAgent):
|
||||
"""[DEPRECATED] An agent that provides coding assistance using an LLM model client.
|
||||
|
||||
It responds with a StopMessage when 'terminate' is detected in the response.
|
||||
"""
|
||||
@ -37,29 +31,10 @@ If the result indicates there is an error, fix the error and output the code aga
|
||||
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
|
||||
Reply "TERMINATE" in the end when code has been executed and task is complete.""",
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
self._model_client = model_client
|
||||
self._system_messages = [SystemMessage(content=system_message)]
|
||||
self._model_context: List[LLMMessage] = []
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
# Add messages to the model context and detect stopping.
|
||||
for msg in messages:
|
||||
if not isinstance(msg, TextMessage | MultiModalMessage | StopMessage):
|
||||
raise ValueError(f"Unsupported message type: {type(msg)}")
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
# Add the response to the model context.
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
|
||||
# Detect stop request.
|
||||
request_stop = "terminate" in result.content.strip().lower()
|
||||
if request_stop:
|
||||
return StopMessage(content=result.content, source=self.name)
|
||||
|
||||
return TextMessage(content=result.content, source=self.name)
|
||||
# Deprecation warning.
|
||||
warnings.warn(
|
||||
"CodingAssistantAgent is deprecated. Use AssistantAgent instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(name, model_client, description=description, system_message=system_message)
|
||||
|
||||
@ -1,53 +1,20 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, List, Sequence
|
||||
import warnings
|
||||
from typing import Any, Awaitable, Callable, List
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components import FunctionCall
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.components.tools import FunctionTool, Tool
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from autogen_core.components.tools import Tool
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
from ..messages import (
|
||||
ChatMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
from ._assistant_agent import AssistantAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class ToolCallEvent(BaseModel):
|
||||
"""A tool call event."""
|
||||
|
||||
tool_calls: List[FunctionCall]
|
||||
"""The tool call message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ToolCallResultEvent(BaseModel):
|
||||
"""A tool call result event."""
|
||||
|
||||
tool_call_results: List[FunctionExecutionResult]
|
||||
"""The tool call result message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ToolUseAssistantAgent(BaseChatAgent):
|
||||
"""An agent that provides assistance with tool use.
|
||||
class ToolUseAssistantAgent(AssistantAgent):
|
||||
"""[DEPRECATED] An agent that provides assistance with tool use.
|
||||
|
||||
It responds with a StopMessage when 'terminate' is detected in the response.
|
||||
|
||||
@ -68,72 +35,12 @@ class ToolUseAssistantAgent(BaseChatAgent):
|
||||
description: str = "An agent that provides assistance with ability to use tools.",
|
||||
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
self._model_client = model_client
|
||||
self._system_messages = [SystemMessage(content=system_message)]
|
||||
self._tools: List[Tool] = []
|
||||
for tool in registered_tools:
|
||||
if isinstance(tool, Tool):
|
||||
self._tools.append(tool)
|
||||
elif callable(tool):
|
||||
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
|
||||
description = tool.__doc__
|
||||
else:
|
||||
description = ""
|
||||
self._tools.append(FunctionTool(tool, description=description))
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
self._model_context: List[LLMMessage] = []
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
# Add messages to the model context.
|
||||
for msg in messages:
|
||||
# TODO: add special handling for handoff messages
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)
|
||||
|
||||
# Add the response to the model context.
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
|
||||
# Run tool calls until the model produces a string response.
|
||||
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
||||
event_logger.debug(ToolCallEvent(tool_calls=result.content))
|
||||
# Execute the tool calls.
|
||||
results = await asyncio.gather(
|
||||
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
||||
)
|
||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
|
||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||
# Generate an inference result based on the current model context.
|
||||
result = await self._model_client.create(
|
||||
self._model_context, tools=self._tools, cancellation_token=cancellation_token
|
||||
)
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
|
||||
assert isinstance(result.content, str)
|
||||
# Detect stop request.
|
||||
request_stop = "terminate" in result.content.strip().lower()
|
||||
if request_stop:
|
||||
return StopMessage(content=result.content, source=self.name)
|
||||
|
||||
return TextMessage(content=result.content, source=self.name)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
||||
) -> FunctionExecutionResult:
|
||||
"""Execute a tool call and return the result."""
|
||||
try:
|
||||
if not self._tools:
|
||||
raise ValueError("No tools are available.")
|
||||
tool = next((t for t in self._tools if t.name == tool_call.name), None)
|
||||
if tool is None:
|
||||
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
||||
arguments = json.loads(tool_call.arguments)
|
||||
result = await tool.run_json(arguments, cancellation_token)
|
||||
result_as_str = tool.return_value_as_string(result)
|
||||
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id)
|
||||
except Exception as e:
|
||||
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id)
|
||||
# Deprecation warning.
|
||||
warnings.warn(
|
||||
"ToolUseAssistantAgent is deprecated. Use AssistantAgent instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(
|
||||
name, model_client, tools=registered_tools, description=description, system_message=system_message
|
||||
)
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
|
||||
from ..agents._assistant_agent import ToolCallEvent, ToolCallResultEvent
|
||||
from ..messages import ChatMessage, StopMessage, TextMessage
|
||||
from ..teams._events import (
|
||||
GroupChatPublishEvent,
|
||||
|
||||
@ -4,7 +4,7 @@ from dataclasses import asdict, is_dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
|
||||
from ..agents._assistant_agent import ToolCallEvent, ToolCallResultEvent
|
||||
from ..teams._events import (
|
||||
GroupChatPublishEvent,
|
||||
GroupChatSelectSpeakerEvent,
|
||||
|
||||
@ -3,10 +3,10 @@ import json
|
||||
from typing import Any, AsyncGenerator, List
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat.agents import ToolUseAssistantAgent
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.messages import StopMessage, TextMessage
|
||||
from autogen_core.components.models import OpenAIChatCompletionClient
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
@ -42,7 +42,7 @@ async def _echo_function(input: str) -> str:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
@ -97,10 +97,10 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
tool_use_agent = ToolUseAssistantAgent(
|
||||
tool_use_agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
registered_tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
)
|
||||
result = await tool_use_agent.run("task")
|
||||
assert len(result.messages) == 3
|
||||
@ -7,10 +7,9 @@ from typing import Any, AsyncGenerator, List, Sequence
|
||||
import pytest
|
||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.agents import (
|
||||
AssistantAgent,
|
||||
BaseChatAgent,
|
||||
CodeExecutorAgent,
|
||||
CodingAssistantAgent,
|
||||
ToolUseAssistantAgent,
|
||||
)
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import (
|
||||
@ -131,7 +130,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
code_executor_agent = CodeExecutorAgent(
|
||||
"code_executor", code_executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)
|
||||
)
|
||||
coding_assistant_agent = CodingAssistantAgent(
|
||||
coding_assistant_agent = AssistantAgent(
|
||||
"coding_assistant", model_client=OpenAIChatCompletionClient(model=model, api_key="")
|
||||
)
|
||||
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
|
||||
@ -211,10 +210,10 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
tool = FunctionTool(_pass_function, name="pass", description="pass function")
|
||||
tool_use_agent = ToolUseAssistantAgent(
|
||||
tool_use_agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
registered_tools=[tool],
|
||||
tools=[tool],
|
||||
)
|
||||
echo_agent = _EchoAgent("echo_agent", description="echo agent")
|
||||
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
|
||||
|
||||
2553
python/uv.lock
generated
2553
python/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user