Add tool_agent_caller_loop and group chat notebook. (#405)

* Add tool_agent_caller_loop and group chat notebook.

* Fix types

* fix ref

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Eric Zhu 2024-08-27 12:11:48 -07:00 committed by GitHub
parent c8f6f3bb38
commit 12cf331e71
9 changed files with 924 additions and 354 deletions

View File

@ -116,7 +116,7 @@ implementation of the contracts determines how agents handle messages.
The behavior contract is sometimes referred to as the message protocol. The behavior contract is sometimes referred to as the message protocol.
It is the developer's responsibility to implement the behavior contract. It is the developer's responsibility to implement the behavior contract.
Multi-agent patterns are design patterns that emerge from behavior contracts Multi-agent patterns are design patterns that emerge from behavior contracts
(see [Multi-Agent Design Patterns](../getting-started/multi-agent-design-patterns.ipynb)). (see [Multi-Agent Design Patterns](../getting-started/multi-agent-design-patterns.md)).
### An Example Application ### An Example Application

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,18 @@
# Multi-Agent Design Patterns
Agents can work together in a variety of ways to solve problems.
Research works like [AutoGen](https://aka.ms/autogen-paper),
[MetaGPT](https://arxiv.org/abs/2308.00352)
and [ChatDev](https://arxiv.org/abs/2307.07924) have shown
multi-agent systems out-performing single agent systems at complex tasks
like software development.
A multi-agent design pattern is a structure that emerges from message protocols:
it describes how agents interact with each other to solve problems.
For example, the [tool-equiped agent](./tools.ipynb#tool-equipped-agent) in
the previous section employs a design pattern called ReAct,
which involves an agent interacting with tools.
You can implement any multi-agent design pattern using AGNext agents.
In the next two sections, we will discuss two common design patterns:
group chat for task decomposition, and reflection for robustness.

View File

@ -4,30 +4,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Multi-Agent Design Patterns\n", "# Reflection\n",
"\n",
"Agents can work together in a variety of ways to solve problems.\n",
"Research works like [AutoGen](https://aka.ms/autogen-paper),\n",
"[MetaGPT](https://arxiv.org/abs/2308.00352)\n",
"and [ChatDev](https://arxiv.org/abs/2307.07924) have shown\n",
"multi-agent systems out-performing single agent systems at complex tasks\n",
"like software development.\n",
"\n",
"A multi-agent design pattern is a structure that emerges from message protocols:\n",
"it describes how agents interact with each other to solve problems.\n",
"For example, the [tool-equiped agent](./tools.ipynb#tool-equipped-agent) in\n",
"the previous section employs a design pattern called ReAct,\n",
"which involves an agent interacting with tools.\n",
"\n",
"You can implement any multi-agent design pattern using AGNext agents.\n",
"In this section, we use the reflection pattern as an example."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reflection\n",
"\n", "\n",
"Reflection is a design pattern where an LLM generation is followed by a reflection,\n", "Reflection is a design pattern where an LLM generation is followed by a reflection,\n",
"which in itself is another LLM generation conditioned on the output of the first one.\n", "which in itself is another LLM generation conditioned on the output of the first one.\n",
@ -50,7 +27,7 @@
"will generate a code snippet, and the reviewer agent will generate a critique\n", "will generate a code snippet, and the reviewer agent will generate a critique\n",
"of the code snippet.\n", "of the code snippet.\n",
"\n", "\n",
"### Message Protocol\n", "## Message Protocol\n",
"\n", "\n",
"Before we define the agents, we need to first define the message protocol for the agents." "Before we define the agents, we need to first define the message protocol for the agents."
] ]
@ -107,7 +84,7 @@
"\n", "\n",
"![coder-reviewer data flow](coder-reviewer-data-flow.svg)\n", "![coder-reviewer data flow](coder-reviewer-data-flow.svg)\n",
"\n", "\n",
"### Agents\n", "## Agents\n",
"\n", "\n",
"Now, let's define the agents for the reflection design pattern." "Now, let's define the agents for the reflection design pattern."
] ]
@ -376,7 +353,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Logging\n", "## Logging\n",
"\n", "\n",
"Turn on logging to see the messages exchanged between the agents." "Turn on logging to see the messages exchanged between the agents."
] ]
@ -397,7 +374,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Running the Design Pattern\n", "## Running the Design Pattern\n",
"\n", "\n",
"Let's test the design pattern with a coding task." "Let's test the design pattern with a coding task."
] ]

View File

@ -30,9 +30,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello, world!\n",
"\n"
]
}
],
"source": [ "source": [
"from agnext.components.code_executor import LocalCommandLineCodeExecutor\n", "from agnext.components.code_executor import LocalCommandLineCodeExecutor\n",
"from agnext.components.tools import PythonCodeExecutionTool\n", "from agnext.components.tools import PythonCodeExecutionTool\n",
@ -77,14 +86,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"138.75280591295171\n" "194.71306528148511\n"
] ]
} }
], ],
@ -126,27 +135,23 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import asyncio\n",
"from dataclasses import dataclass\n", "from dataclasses import dataclass\n",
"from typing import List\n", "from typing import List\n",
"\n", "\n",
"from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.application import SingleThreadedAgentRuntime\n",
"from agnext.components import FunctionCall, RoutedAgent, message_handler\n", "from agnext.components import RoutedAgent, message_handler\n",
"from agnext.components.models import (\n", "from agnext.components.models import (\n",
" AssistantMessage,\n",
" ChatCompletionClient,\n", " ChatCompletionClient,\n",
" FunctionExecutionResult,\n",
" FunctionExecutionResultMessage,\n",
" LLMMessage,\n", " LLMMessage,\n",
" OpenAIChatCompletionClient,\n", " OpenAIChatCompletionClient,\n",
" SystemMessage,\n", " SystemMessage,\n",
" UserMessage,\n", " UserMessage,\n",
")\n", ")\n",
"from agnext.components.tool_agent import ToolAgent, ToolException\n", "from agnext.components.tool_agent import ToolAgent, tool_agent_caller_loop\n",
"from agnext.components.tools import FunctionTool, Tool, ToolSchema\n", "from agnext.components.tools import FunctionTool, Tool, ToolSchema\n",
"from agnext.core import AgentId, AgentInstantiationContext, MessageContext\n", "from agnext.core import AgentId, AgentInstantiationContext, MessageContext\n",
"\n", "\n",
@ -168,47 +173,27 @@
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n", " async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
" # Create a session of messages.\n", " # Create a session of messages.\n",
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n", " session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n",
" # Get a response from the model.\n", " # Run the caller loop to handle tool calls.\n",
" response = await self._model_client.create(\n", " messages = await tool_agent_caller_loop(\n",
" self._system_messages + session, tools=self._tool_schema, cancellation_token=cancellation_token\n", " self,\n",
" tool_agent_id=self._tool_agent,\n",
" model_client=self._model_client,\n",
" input_messages=session,\n",
" tool_schema=self._tool_schema,\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n", " )\n",
" # Add the response to the session.\n",
" session.append(AssistantMessage(content=response.content, source=\"assistant\"))\n",
"\n",
" # Keep iterating until the model stops generating tool calls.\n",
" while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):\n",
" # Execute functions called by the model by sending messages to itself.\n",
" results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(\n",
" *[self.send_message(call, self._tool_agent) for call in response.content],\n",
" return_exceptions=True,\n",
" )\n",
" # Combine the results into a single response and handle exceptions.\n",
" function_results: List[FunctionExecutionResult] = []\n",
" for result in results:\n",
" if isinstance(result, FunctionExecutionResult):\n",
" function_results.append(result)\n",
" elif isinstance(result, ToolException):\n",
" function_results.append(FunctionExecutionResult(content=f\"Error: {result}\", call_id=result.call_id))\n",
" elif isinstance(result, BaseException):\n",
" raise result # Unexpected exception.\n",
" session.append(FunctionExecutionResultMessage(content=function_results))\n",
" # Query the model again with the new response.\n",
" response = await self._model_client.create(\n",
" self._system_messages + session, tools=self._tool_schema, cancellation_token=cancellation_token\n",
" )\n",
" session.append(AssistantMessage(content=response.content, source=self.metadata[\"type\"]))\n",
"\n",
" # Return the final response.\n", " # Return the final response.\n",
" assert isinstance(response.content, str)\n", " assert isinstance(messages[-1].content, str)\n",
" return Message(content=response.content)" " return Message(content=messages[-1].content)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"The `ToolUseAgent` class is a bit involved, however,\n", "The `ToolUseAgent` class uses a convenience function {py:meth}`agnext.components.tool_agent.tool_agent_caller_loop`, \n",
"the core idea can be described using a simple control flow graph:\n", "to handle the interaction between the model and the tool agent.\n",
"The core idea can be described using a simple control flow graph:\n",
"\n", "\n",
"![ToolUseAgent control flow graph](tool-use-agent-cfg.svg)\n", "![ToolUseAgent control flow graph](tool-use-agent-cfg.svg)\n",
"\n", "\n",
@ -230,9 +215,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"text/plain": [
"AgentType(type='tool_use_agent')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"# Create a runtime.\n", "# Create a runtime.\n",
"runtime = SingleThreadedAgentRuntime()\n", "runtime = SingleThreadedAgentRuntime()\n",
@ -240,18 +236,18 @@
"tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n", "tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n",
"# Register the agents.\n", "# Register the agents.\n",
"await runtime.register(\n", "await runtime.register(\n",
" \"tool-executor-agent\",\n", " \"tool_executor_agent\",\n",
" lambda: ToolAgent(\n", " lambda: ToolAgent(\n",
" description=\"Tool Executor Agent\",\n", " description=\"Tool Executor Agent\",\n",
" tools=tools,\n", " tools=tools,\n",
" ),\n", " ),\n",
")\n", ")\n",
"await runtime.register(\n", "await runtime.register(\n",
" \"tool-use-agent\",\n", " \"tool_use_agent\",\n",
" lambda: ToolUseAgent(\n", " lambda: ToolUseAgent(\n",
" OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n", " OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
" tool_schema=[tool.schema for tool in tools],\n", " tool_schema=[tool.schema for tool in tools],\n",
" tool_agent=AgentId(\"tool-executor-agent\", AgentInstantiationContext.current_agent_id().key),\n", " tool_agent=AgentId(\"tool_executor_agent\", AgentInstantiationContext.current_agent_id().key),\n",
" ),\n", " ),\n",
")" ")"
] ]
@ -267,14 +263,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"The stock price of NVDA on June 1, 2024, is approximately $49.28.\n" "The stock price of NVIDIA (NVDA) on June 1, 2024, was approximately $148.86.\n"
] ]
} }
], ],
@ -282,7 +278,7 @@
"# Start processing messages.\n", "# Start processing messages.\n",
"runtime.start()\n", "runtime.start()\n",
"# Send a direct message to the tool agent.\n", "# Send a direct message to the tool agent.\n",
"tool_use_agent = AgentId(\"tool-use-agent\", \"default\")\n", "tool_use_agent = AgentId(\"tool_use_agent\", \"default\")\n",
"response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n", "response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n",
"print(response.content)\n", "print(response.content)\n",
"# Stop processing messages.\n", "# Stop processing messages.\n",

View File

@ -30,6 +30,8 @@ To learn about the core concepts of AGNext, read the `overview <core-concepts/ov
getting-started/model-clients getting-started/model-clients
getting-started/tools getting-started/tools
getting-started/multi-agent-design-patterns getting-started/multi-agent-design-patterns
getting-started/group-chat
getting-started/reflection
.. toctree:: .. toctree::
:caption: Guides :caption: Guides

View File

@ -1,3 +1,4 @@
from ._caller_loop import tool_agent_caller_loop
from ._tool_agent import ( from ._tool_agent import (
InvalidToolArgumentsException, InvalidToolArgumentsException,
ToolAgent, ToolAgent,
@ -12,4 +13,5 @@ __all__ = [
"ToolNotFoundException", "ToolNotFoundException",
"InvalidToolArgumentsException", "InvalidToolArgumentsException",
"ToolExecutionException", "ToolExecutionException",
"tool_agent_caller_loop",
] ]

View File

@ -0,0 +1,77 @@
import asyncio
from typing import List
from ...components import FunctionCall
from ...core import AgentId, AgentRuntime, BaseAgent, CancellationToken
from ..models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
)
from ..tools import Tool, ToolSchema
from ._tool_agent import ToolException
async def tool_agent_caller_loop(
caller: BaseAgent | AgentRuntime,
tool_agent_id: AgentId,
model_client: ChatCompletionClient,
input_messages: List[LLMMessage],
tool_schema: List[ToolSchema] | List[Tool],
cancellation_token: CancellationToken | None = None,
caller_source: str = "assistant",
) -> List[LLMMessage]:
"""Start a caller loop for a tool agent. This function sends messages to the tool agent
and the model client in an alternating fashion until the model client stops generating tool calls.
Args:
tool_agent_id (AgentId): The Agent ID of the tool agent.
input_messages (List[LLMMessage]): The list of input messages.
model_client (ChatCompletionClient): The model client to use for the model API.
tool_schema (List[Tool | ToolSchema]): The list of tools that the model can use.
Returns:
List[LLMMessage]: The list of output messages created in the caller loop.
"""
generated_messages: List[LLMMessage] = []
# Get a response from the model.
response = await model_client.create(input_messages, tools=tool_schema, cancellation_token=cancellation_token)
# Add the response to the generated messages.
generated_messages.append(AssistantMessage(content=response.content, source=caller_source))
# Keep iterating until the model stops generating tool calls.
while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):
# Execute functions called by the model by sending messages to tool agent.
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
caller.send_message(
message=call,
recipient=tool_agent_id,
cancellation_token=cancellation_token,
)
for call in response.content
],
return_exceptions=True,
)
# Combine the results into a single response and handle exceptions.
function_results: List[FunctionExecutionResult] = []
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, ToolException):
function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id))
elif isinstance(result, BaseException):
raise result # Unexpected exception.
generated_messages.append(FunctionExecutionResultMessage(content=function_results))
# Query the model again with the new response.
response = await model_client.create(
input_messages + generated_messages, tools=tool_schema, cancellation_token=cancellation_token
)
generated_messages.append(AssistantMessage(content=response.content, source=caller_source))
# Return the generated messages.
return generated_messages

View File

@ -1,19 +1,34 @@
import asyncio import asyncio
import json import json
from typing import Any, AsyncGenerator, List
import pytest import pytest
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
from openai.types.completion_usage import CompletionUsage
from agnext.application import SingleThreadedAgentRuntime from agnext.application import SingleThreadedAgentRuntime
from agnext.components import FunctionCall from agnext.components import FunctionCall
from agnext.components.models import FunctionExecutionResult
from agnext.components.tool_agent import ( from agnext.components.tool_agent import (
InvalidToolArgumentsException, InvalidToolArgumentsException,
ToolAgent, ToolAgent,
ToolExecutionException, ToolExecutionException,
ToolNotFoundException, ToolNotFoundException,
tool_agent_caller_loop,
)
from agnext.components.tools import FunctionTool, Tool
from agnext.core import CancellationToken, AgentId
from agnext.components.models import (
AssistantMessage,
FunctionExecutionResult,
FunctionExecutionResultMessage,
OpenAIChatCompletionClient,
UserMessage,
) )
from agnext.components.tools import FunctionTool from agnext.components.tools import FunctionTool
from agnext.core import CancellationToken
from agnext.core import AgentId
def _pass_function(input: str) -> str: def _pass_function(input: str) -> str:
@ -29,6 +44,60 @@ async def _async_sleep_function(input: str) -> str:
return "pass" return "pass"
class _MockChatCompletion:
def __init__(self, model: str = "gpt-4o") -> None:
self._saved_chat_completions: List[ChatCompletion] = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="pass",
arguments=json.dumps({"input": "pass"}),
),
)
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant")
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
]
self._curr_index = 0
async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self._curr_index]
self._curr_index += 1
return completion
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_agent() -> None: async def test_tool_agent() -> None:
runtime = SingleThreadedAgentRuntime() runtime = SingleThreadedAgentRuntime()
@ -74,3 +143,33 @@ async def test_tool_agent() -> None:
await result_future await result_future
await runtime.stop() await runtime.stop()
@pytest.mark.asyncio
async def test_caller_loop(monkeypatch: pytest.MonkeyPatch) -> None:
mock = _MockChatCompletion(model="gpt-4o-2024-05-13")
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
client = OpenAIChatCompletionClient(model="gpt-4o-2024-05-13", api_key="api_key")
tools : List[Tool] = [FunctionTool(_pass_function, name="pass", description="Pass function")]
runtime = SingleThreadedAgentRuntime()
await runtime.register(
"tool_agent",
lambda: ToolAgent(
description="Tool agent",
tools=tools,
),
)
agent = AgentId("tool_agent", "default")
runtime.start()
messages = await tool_agent_caller_loop(
runtime,
agent,
client,
[UserMessage(content="Hello", source="user")],
tool_schema=tools
)
assert len(messages) == 3
assert isinstance(messages[0], AssistantMessage)
assert isinstance(messages[1], FunctionExecutionResultMessage)
assert isinstance(messages[2], AssistantMessage)
await runtime.stop()