Update custom agent doc with arithmetic agent example (#4720)

* Update custom agent doc with arithmetic agent example

* fix mypy

* fix model capabilities usage
This commit is contained in:
Eric Zhu 2024-12-16 09:12:26 -08:00 committed by GitHub
parent 7eaffa83a7
commit 43eed01bbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 185 additions and 33 deletions

View File

@ -34,9 +34,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3...\n",
"2...\n",
"1...\n",
"Done!\n"
]
}
],
"source": [
"from typing import AsyncGenerator, List, Sequence\n",
"\n",
@ -100,12 +111,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## UserProxyAgent \n",
"## ArithmeticAgent\n",
"\n",
"A common use case for building a custom agent is to create an agent that acts as a proxy for the user.\n",
"In this example, we create an agent class that can perform simple arithmetic operations\n",
"on a given integer. Then, we will use different instances of this agent class\n",
"in a {py:class}`~autogen_agentchat.teams.SelectorGroupChat`\n",
"to transform a given integer into another integer by applying a sequence of arithmetic operations.\n",
"\n",
"In the example below we show how to implement a `UserProxyAgent` - an agent that asks the user to enter\n",
"some text through console and then returns that message as a response."
"The `ArithmeticAgent` class takes an `operator_func` that takes an integer and returns an integer,\n",
"after applying an arithmetic operation to the integer.\n",
"In its `on_messages` method, it applies the `operator_func` to the integer in the input message,\n",
"and returns a response with the result."
]
},
{
@ -114,39 +130,162 @@
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"from typing import List, Sequence\n",
"from typing import Callable, List, Sequence\n",
"\n",
"from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.conditions import MaxMessageTermination\n",
"from autogen_agentchat.messages import ChatMessage\n",
"from autogen_agentchat.teams import SelectorGroupChat\n",
"from autogen_agentchat.ui import Console\n",
"from autogen_core import CancellationToken\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
"\n",
"\n",
"class UserProxyAgent(BaseChatAgent):\n",
" def __init__(self, name: str) -> None:\n",
" super().__init__(name, \"A human user.\")\n",
"class ArithmeticAgent(BaseChatAgent):\n",
" def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:\n",
" super().__init__(name, description=description)\n",
" self._operator_func = operator_func\n",
" self._message_history: List[ChatMessage] = []\n",
"\n",
" @property\n",
" def produced_message_types(self) -> List[type[ChatMessage]]:\n",
" return [TextMessage]\n",
"\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
" return Response(chat_message=TextMessage(content=user_input, source=self.name))\n",
" # Update the message history.\n",
" # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.\n",
" self._message_history.extend(messages)\n",
" # Parse the number in the last message.\n",
" assert isinstance(self._message_history[-1], TextMessage)\n",
" number = int(self._message_history[-1].content)\n",
" # Apply the operator function to the number.\n",
" result = self._operator_func(number)\n",
" # Create a new message with the result.\n",
" response_message = TextMessage(content=str(result), source=self.name)\n",
" # Update the message history.\n",
" self._message_history.append(response_message)\n",
" # Return the response.\n",
" return Response(chat_message=response_message)\n",
"\n",
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
" pass\n",
" pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{note}\n",
"The `on_messages` method may be called with an empty list of messages, in which\n",
"case it means the agent was called previously and is now being called again,\n",
"without any new messages from the caller. So it is important to keep a history\n",
"of the previous messages received by the agent, and use that history to generate\n",
"the response.\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with 5 instances of `ArithmeticAgent`:\n",
"\n",
"- one that adds 1 to the input integer,\n",
"- one that subtracts 1 from the input integer,\n",
"- one that multiplies the input integer by 2,\n",
"- one that divides the input integer by 2 and rounds down to the nearest integer, and\n",
"- one that returns the input integer unchanged.\n",
"\n",
"We then create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with these agents,\n",
"and set the appropriate selector settings:\n",
"\n",
"- allow the same agent to be selected consecutively to allow for repeated operations, and\n",
"- customize the selector prompt to tailor the model's response to the specific task."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- user ----------\n",
"Apply the operations to turn the given number into 25.\n",
"---------- user ----------\n",
"10\n",
"---------- multiply_agent ----------\n",
"20\n",
"---------- add_agent ----------\n",
"21\n",
"---------- multiply_agent ----------\n",
"42\n",
"---------- divide_agent ----------\n",
"21\n",
"---------- add_agent ----------\n",
"22\n",
"---------- add_agent ----------\n",
"23\n",
"---------- add_agent ----------\n",
"24\n",
"---------- add_agent ----------\n",
"25\n",
"---------- Summary ----------\n",
"Number of messages: 10\n",
"Finish reason: Maximum number of messages 10 reached, current message count: 10\n",
"Total prompt tokens: 0\n",
"Total completion tokens: 0\n",
"Duration: 2.40 seconds\n"
]
}
],
"source": [
"async def run_number_agents() -> None:\n",
" # Create agents for number operations.\n",
" add_agent = ArithmeticAgent(\"add_agent\", \"Adds 1 to the number.\", lambda x: x + 1)\n",
" multiply_agent = ArithmeticAgent(\"multiply_agent\", \"Multiplies the number by 2.\", lambda x: x * 2)\n",
" subtract_agent = ArithmeticAgent(\"subtract_agent\", \"Subtracts 1 from the number.\", lambda x: x - 1)\n",
" divide_agent = ArithmeticAgent(\"divide_agent\", \"Divides the number by 2 and rounds down.\", lambda x: x // 2)\n",
" identity_agent = ArithmeticAgent(\"identity_agent\", \"Returns the number as is.\", lambda x: x)\n",
"\n",
" # The termination condition is to stop after 10 messages.\n",
" termination_condition = MaxMessageTermination(10)\n",
"\n",
" # Create a selector group chat.\n",
" selector_group_chat = SelectorGroupChat(\n",
" [add_agent, multiply_agent, subtract_agent, divide_agent, identity_agent],\n",
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o\"),\n",
" termination_condition=termination_condition,\n",
" allow_repeated_speaker=True, # Allow the same agent to speak multiple times, necessary for this task.\n",
" selector_prompt=(\n",
" \"Available roles:\\n{roles}\\nTheir job descriptions:\\n{participants}\\n\"\n",
" \"Current conversation history:\\n{history}\\n\"\n",
" \"Please select the most appropriate role for the next message, and only return the role name.\"\n",
" ),\n",
" )\n",
"\n",
" # Run the selector group chat with a given task and stream the response.\n",
" task: List[ChatMessage] = [\n",
" TextMessage(content=\"Apply the operations to turn the given number into 25.\", source=\"user\"),\n",
" TextMessage(content=\"10\", source=\"user\"),\n",
" ]\n",
" stream = selector_group_chat.run_stream(task=task)\n",
" await Console(stream)\n",
"\n",
"\n",
"async def run_user_proxy_agent() -> None:\n",
" user_proxy_agent = UserProxyAgent(name=\"user_proxy_agent\")\n",
" response = await user_proxy_agent.on_messages([], CancellationToken())\n",
" print(response.chat_message.content)\n",
"\n",
"\n",
"# Use asyncio.run(run_user_proxy_agent()) when running in a script.\n",
"await run_user_proxy_agent()"
"# Use asyncio.run(run_number_agents()) when running in a script.\n",
"await run_number_agents()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the output, we can see that the agents have successfully transformed the input integer\n",
"from 10 to 25 by choosing appropriate agents that apply the arithmetic operations in sequence."
]
}
],
@ -157,8 +296,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.11.5"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,

View File

@ -15,7 +15,7 @@
"source": [
"## OpenAI\n",
"\n",
"To access OpenAI models, install the `openai` extension, which allows you to use the {py:class}`~autogen_ext.models.OpenAIChatCompletionClient`."
"To access OpenAI models, install the `openai` extension, which allows you to use the {py:class}`~autogen_ext.models.openai.OpenAIChatCompletionClient`."
]
},
{
@ -85,7 +85,7 @@
"source": [
"```{note}\n",
"You can use this client with models hosted on OpenAI-compatible endpoints, however, we have not tested this functionality.\n",
"See {py:class}`~autogen_ext.models.OpenAIChatCompletionClient` for more information.\n",
"See {py:class}`~autogen_ext.models.openai.OpenAIChatCompletionClient` for more information.\n",
"```"
]
},
@ -95,7 +95,7 @@
"source": [
"## Azure OpenAI\n",
"\n",
"Similarly, install the `azure` and `openai` extensions to use the {py:class}`~autogen_ext.models.AzureOpenAIChatCompletionClient`."
"Similarly, install the `azure` and `openai` extensions to use the {py:class}`~autogen_ext.models.openai.AzureOpenAIChatCompletionClient`."
]
},
{
@ -179,7 +179,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.7"
}
},
"nbformat": 4,

View File

@ -15,10 +15,10 @@
"```\n",
"curl -fsSL https://ollama.com/install.sh | sh\n",
"\n",
"ollama pull llama3:instruct\n",
"ollama pull llama3.2:1b\n",
"\n",
"pip install 'litellm[proxy]'\n",
"litellm --model ollama/llama3:instruct\n",
"litellm --model ollama/llama3.2:1b\n",
"``` \n",
"\n",
"This will run the proxy server and it will be available at 'http://0.0.0.0:4000/'."
@ -48,7 +48,7 @@
" default_subscription,\n",
" message_handler,\n",
")\n",
"from autogen_core.components.model_context import BufferedChatCompletionContext\n",
"from autogen_core.model_context import BufferedChatCompletionContext\n",
"from autogen_core.models import (\n",
" AssistantMessage,\n",
" ChatCompletionClient,\n",
@ -74,9 +74,14 @@
"def get_model_client() -> OpenAIChatCompletionClient: # type: ignore\n",
" \"Mimic OpenAI API using Local LLM Server.\"\n",
" return OpenAIChatCompletionClient(\n",
" model=\"gpt-4o\", # Need to use one of the OpenAI models as a placeholder for now.\n",
" model=\"llama3.2:1b\",\n",
" api_key=\"NotRequiredSinceWeAreLocal\",\n",
" base_url=\"http://0.0.0.0:4000\",\n",
" model_capabilities={\n",
" \"json_output\": False,\n",
" \"vision\": False,\n",
" \"function_calling\": True,\n",
" },\n",
" )"
]
},
@ -225,7 +230,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "pyautogen",
"display_name": ".venv",
"language": "python",
"name": "python3"
},

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 11 KiB

After

Width:  |  Height:  |  Size: 11 KiB