mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-11 02:51:06 +00:00
742 lines
42 KiB
Plaintext
742 lines
42 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Custom Agents\n",
|
||
"\n",
|
||
"You may have agents with behaviors that do not fall into a preset. \n",
|
||
"In such cases, you can build custom agents.\n",
|
||
"\n",
|
||
"All agents in AgentChat inherit from {py:class}`~autogen_agentchat.agents.BaseChatAgent` \n",
|
||
"class and implement the following abstract methods and attributes:\n",
|
||
"\n",
|
||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages`: The abstract method that defines the behavior of the agent in response to messages. This method is called when the agent is asked to provide a response in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run`. It returns a {py:class}`~autogen_agentchat.base.Response` object.\n",
|
||
"- {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_reset`: The abstract method that resets the agent to its initial state. This method is called when the agent is asked to reset itself.\n",
|
||
"- {py:attr}`~autogen_agentchat.agents.BaseChatAgent.produced_message_types`: The list of possible {py:class}`~autogen_agentchat.messages.BaseChatMessage` message types the agent can produce in its response.\n",
|
||
"\n",
|
||
"Optionally, you can implement the the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` method to stream messages as they are generated by the agent.\n",
|
||
"This method is called by {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream` to stream messages.\n",
|
||
"If this method is not implemented, the agent\n",
|
||
"uses the default implementation of {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream`\n",
|
||
"that calls the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages` method and\n",
|
||
"yields all messages in the response."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## CountDownAgent\n",
|
||
"\n",
|
||
"In this example, we create a simple agent that counts down from a given number to zero,\n",
|
||
"and produces a stream of messages with the current count."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"3...\n",
|
||
"2...\n",
|
||
"1...\n",
|
||
"Done!\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import AsyncGenerator, List, Sequence\n",
|
||
"\n",
|
||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||
"from autogen_agentchat.base import Response\n",
|
||
"from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage\n",
|
||
"from autogen_core import CancellationToken\n",
|
||
"\n",
|
||
"\n",
|
||
"class CountDownAgent(BaseChatAgent):\n",
|
||
" def __init__(self, name: str, count: int = 3):\n",
|
||
" super().__init__(name, \"A simple agent that counts down.\")\n",
|
||
" self._count = count\n",
|
||
"\n",
|
||
" @property\n",
|
||
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
|
||
" return (TextMessage,)\n",
|
||
"\n",
|
||
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||
" # Calls the on_messages_stream.\n",
|
||
" response: Response | None = None\n",
|
||
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
|
||
" if isinstance(message, Response):\n",
|
||
" response = message\n",
|
||
" assert response is not None\n",
|
||
" return response\n",
|
||
"\n",
|
||
" async def on_messages_stream(\n",
|
||
" self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken\n",
|
||
" ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:\n",
|
||
" inner_messages: List[BaseAgentEvent | BaseChatMessage] = []\n",
|
||
" for i in range(self._count, 0, -1):\n",
|
||
" msg = TextMessage(content=f\"{i}...\", source=self.name)\n",
|
||
" inner_messages.append(msg)\n",
|
||
" yield msg\n",
|
||
" # The response is returned at the end of the stream.\n",
|
||
" # It contains the final message and all the inner messages.\n",
|
||
" yield Response(chat_message=TextMessage(content=\"Done!\", source=self.name), inner_messages=inner_messages)\n",
|
||
"\n",
|
||
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
|
||
" pass\n",
|
||
"\n",
|
||
"\n",
|
||
"async def run_countdown_agent() -> None:\n",
|
||
" # Create a countdown agent.\n",
|
||
" countdown_agent = CountDownAgent(\"countdown\")\n",
|
||
"\n",
|
||
" # Run the agent with a given task and stream the response.\n",
|
||
" async for message in countdown_agent.on_messages_stream([], CancellationToken()):\n",
|
||
" if isinstance(message, Response):\n",
|
||
" print(message.chat_message)\n",
|
||
" else:\n",
|
||
" print(message)\n",
|
||
"\n",
|
||
"\n",
|
||
"# Use asyncio.run(run_countdown_agent()) when running in a script.\n",
|
||
"await run_countdown_agent()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## ArithmeticAgent\n",
|
||
"\n",
|
||
"In this example, we create an agent class that can perform simple arithmetic operations\n",
|
||
"on a given integer. Then, we will use different instances of this agent class\n",
|
||
"in a {py:class}`~autogen_agentchat.teams.SelectorGroupChat`\n",
|
||
"to transform a given integer into another integer by applying a sequence of arithmetic operations.\n",
|
||
"\n",
|
||
"The `ArithmeticAgent` class takes an `operator_func` that takes an integer and returns an integer,\n",
|
||
"after applying an arithmetic operation to the integer.\n",
|
||
"In its `on_messages` method, it applies the `operator_func` to the integer in the input message,\n",
|
||
"and returns a response with the result."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from typing import Callable, Sequence\n",
|
||
"\n",
|
||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||
"from autogen_agentchat.base import Response\n",
|
||
"from autogen_agentchat.conditions import MaxMessageTermination\n",
|
||
"from autogen_agentchat.messages import BaseChatMessage\n",
|
||
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
||
"from autogen_agentchat.ui import Console\n",
|
||
"from autogen_core import CancellationToken\n",
|
||
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
|
||
"\n",
|
||
"\n",
|
||
"class ArithmeticAgent(BaseChatAgent):\n",
|
||
" def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:\n",
|
||
" super().__init__(name, description=description)\n",
|
||
" self._operator_func = operator_func\n",
|
||
" self._message_history: List[BaseChatMessage] = []\n",
|
||
"\n",
|
||
" @property\n",
|
||
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
|
||
" return (TextMessage,)\n",
|
||
"\n",
|
||
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||
" # Update the message history.\n",
|
||
" # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.\n",
|
||
" self._message_history.extend(messages)\n",
|
||
" # Parse the number in the last message.\n",
|
||
" assert isinstance(self._message_history[-1], TextMessage)\n",
|
||
" number = int(self._message_history[-1].content)\n",
|
||
" # Apply the operator function to the number.\n",
|
||
" result = self._operator_func(number)\n",
|
||
" # Create a new message with the result.\n",
|
||
" response_message = TextMessage(content=str(result), source=self.name)\n",
|
||
" # Update the message history.\n",
|
||
" self._message_history.append(response_message)\n",
|
||
" # Return the response.\n",
|
||
" return Response(chat_message=response_message)\n",
|
||
"\n",
|
||
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
|
||
" pass"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"```{note}\n",
|
||
"The `on_messages` method may be called with an empty list of messages, in which\n",
|
||
"case it means the agent was called previously and is now being called again,\n",
|
||
"without any new messages from the caller. So it is important to keep a history\n",
|
||
"of the previous messages received by the agent, and use that history to generate\n",
|
||
"the response.\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Now we can create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with 5 instances of `ArithmeticAgent`:\n",
|
||
"\n",
|
||
"- one that adds 1 to the input integer,\n",
|
||
"- one that subtracts 1 from the input integer,\n",
|
||
"- one that multiplies the input integer by 2,\n",
|
||
"- one that divides the input integer by 2 and rounds down to the nearest integer, and\n",
|
||
"- one that returns the input integer unchanged.\n",
|
||
"\n",
|
||
"We then create a {py:class}`~autogen_agentchat.teams.SelectorGroupChat` with these agents,\n",
|
||
"and set the appropriate selector settings:\n",
|
||
"\n",
|
||
"- allow the same agent to be selected consecutively to allow for repeated operations, and\n",
|
||
"- customize the selector prompt to tailor the model's response to the specific task."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"---------- user ----------\n",
|
||
"Apply the operations to turn the given number into 25.\n",
|
||
"---------- user ----------\n",
|
||
"10\n",
|
||
"---------- multiply_agent ----------\n",
|
||
"20\n",
|
||
"---------- add_agent ----------\n",
|
||
"21\n",
|
||
"---------- multiply_agent ----------\n",
|
||
"42\n",
|
||
"---------- divide_agent ----------\n",
|
||
"21\n",
|
||
"---------- add_agent ----------\n",
|
||
"22\n",
|
||
"---------- add_agent ----------\n",
|
||
"23\n",
|
||
"---------- add_agent ----------\n",
|
||
"24\n",
|
||
"---------- add_agent ----------\n",
|
||
"25\n",
|
||
"---------- Summary ----------\n",
|
||
"Number of messages: 10\n",
|
||
"Finish reason: Maximum number of messages 10 reached, current message count: 10\n",
|
||
"Total prompt tokens: 0\n",
|
||
"Total completion tokens: 0\n",
|
||
"Duration: 2.40 seconds\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"async def run_number_agents() -> None:\n",
|
||
" # Create agents for number operations.\n",
|
||
" add_agent = ArithmeticAgent(\"add_agent\", \"Adds 1 to the number.\", lambda x: x + 1)\n",
|
||
" multiply_agent = ArithmeticAgent(\"multiply_agent\", \"Multiplies the number by 2.\", lambda x: x * 2)\n",
|
||
" subtract_agent = ArithmeticAgent(\"subtract_agent\", \"Subtracts 1 from the number.\", lambda x: x - 1)\n",
|
||
" divide_agent = ArithmeticAgent(\"divide_agent\", \"Divides the number by 2 and rounds down.\", lambda x: x // 2)\n",
|
||
" identity_agent = ArithmeticAgent(\"identity_agent\", \"Returns the number as is.\", lambda x: x)\n",
|
||
"\n",
|
||
" # The termination condition is to stop after 10 messages.\n",
|
||
" termination_condition = MaxMessageTermination(10)\n",
|
||
"\n",
|
||
" # Create a selector group chat.\n",
|
||
" selector_group_chat = SelectorGroupChat(\n",
|
||
" [add_agent, multiply_agent, subtract_agent, divide_agent, identity_agent],\n",
|
||
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o\"),\n",
|
||
" termination_condition=termination_condition,\n",
|
||
" allow_repeated_speaker=True, # Allow the same agent to speak multiple times, necessary for this task.\n",
|
||
" selector_prompt=(\n",
|
||
" \"Available roles:\\n{roles}\\nTheir job descriptions:\\n{participants}\\n\"\n",
|
||
" \"Current conversation history:\\n{history}\\n\"\n",
|
||
" \"Please select the most appropriate role for the next message, and only return the role name.\"\n",
|
||
" ),\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Run the selector group chat with a given task and stream the response.\n",
|
||
" task: List[BaseChatMessage] = [\n",
|
||
" TextMessage(content=\"Apply the operations to turn the given number into 25.\", source=\"user\"),\n",
|
||
" TextMessage(content=\"10\", source=\"user\"),\n",
|
||
" ]\n",
|
||
" stream = selector_group_chat.run_stream(task=task)\n",
|
||
" await Console(stream)\n",
|
||
"\n",
|
||
"\n",
|
||
"# Use asyncio.run(run_number_agents()) when running in a script.\n",
|
||
"await run_number_agents()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"From the output, we can see that the agents have successfully transformed the input integer\n",
|
||
"from 10 to 25 by choosing appropriate agents that apply the arithmetic operations in sequence."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Using Custom Model Clients in Custom Agents\n",
|
||
"\n",
|
||
"One of the key features of the {py:class}`~autogen_agentchat.agents.AssistantAgent` preset in AgentChat is that it takes a `model_client` argument and can use it in responding to messages. However, in some cases, you may want your agent to use a custom model client that is not currently supported (see [supported model clients](https://microsoft.github.io/autogen/dev/user-guide/core-user-guide/components/model-clients.html)) or custom model behaviours. \n",
|
||
"\n",
|
||
"You can accomplish this with a custom agent that implements *your custom model client*.\n",
|
||
"\n",
|
||
"In the example below, we will walk through an example of a custom agent that uses the [Google Gemini SDK](https://github.com/googleapis/python-genai) directly to respond to messages.\n",
|
||
"\n",
|
||
"> **Note:** You will need to install the [Google Gemini SDK](https://github.com/googleapis/python-genai) to run this example. You can install it using the following command: \n",
|
||
"\n",
|
||
"```bash\n",
|
||
"pip install google-genai\n",
|
||
"``` "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# !pip install google-genai\n",
|
||
"import os\n",
|
||
"from typing import AsyncGenerator, Sequence\n",
|
||
"\n",
|
||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||
"from autogen_agentchat.base import Response\n",
|
||
"from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage\n",
|
||
"from autogen_core import CancellationToken\n",
|
||
"from autogen_core.model_context import UnboundedChatCompletionContext\n",
|
||
"from autogen_core.models import AssistantMessage, RequestUsage, UserMessage\n",
|
||
"from google import genai\n",
|
||
"from google.genai import types\n",
|
||
"\n",
|
||
"\n",
|
||
"class GeminiAssistantAgent(BaseChatAgent):\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" name: str,\n",
|
||
" description: str = \"An agent that provides assistance with ability to use tools.\",\n",
|
||
" model: str = \"gemini-1.5-flash-002\",\n",
|
||
" api_key: str = os.environ[\"GEMINI_API_KEY\"],\n",
|
||
" system_message: str\n",
|
||
" | None = \"You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed.\",\n",
|
||
" ):\n",
|
||
" super().__init__(name=name, description=description)\n",
|
||
" self._model_context = UnboundedChatCompletionContext()\n",
|
||
" self._model_client = genai.Client(api_key=api_key)\n",
|
||
" self._system_message = system_message\n",
|
||
" self._model = model\n",
|
||
"\n",
|
||
" @property\n",
|
||
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
|
||
" return (TextMessage,)\n",
|
||
"\n",
|
||
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||
" final_response = None\n",
|
||
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
|
||
" if isinstance(message, Response):\n",
|
||
" final_response = message\n",
|
||
"\n",
|
||
" if final_response is None:\n",
|
||
" raise AssertionError(\"The stream should have returned the final result.\")\n",
|
||
"\n",
|
||
" return final_response\n",
|
||
"\n",
|
||
" async def on_messages_stream(\n",
|
||
" self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken\n",
|
||
" ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:\n",
|
||
" # Add messages to the model context\n",
|
||
" for msg in messages:\n",
|
||
" await self._model_context.add_message(msg.to_model_message())\n",
|
||
"\n",
|
||
" # Get conversation history\n",
|
||
" history = [\n",
|
||
" (msg.source if hasattr(msg, \"source\") else \"system\")\n",
|
||
" + \": \"\n",
|
||
" + (msg.content if isinstance(msg.content, str) else \"\")\n",
|
||
" + \"\\n\"\n",
|
||
" for msg in await self._model_context.get_messages()\n",
|
||
" ]\n",
|
||
" # Generate response using Gemini\n",
|
||
" response = self._model_client.models.generate_content(\n",
|
||
" model=self._model,\n",
|
||
" contents=f\"History: {history}\\nGiven the history, please provide a response\",\n",
|
||
" config=types.GenerateContentConfig(\n",
|
||
" system_instruction=self._system_message,\n",
|
||
" temperature=0.3,\n",
|
||
" ),\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Create usage metadata\n",
|
||
" usage = RequestUsage(\n",
|
||
" prompt_tokens=response.usage_metadata.prompt_token_count,\n",
|
||
" completion_tokens=response.usage_metadata.candidates_token_count,\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Add response to model context\n",
|
||
" await self._model_context.add_message(AssistantMessage(content=response.text, source=self.name))\n",
|
||
"\n",
|
||
" # Yield the final response\n",
|
||
" yield Response(\n",
|
||
" chat_message=TextMessage(content=response.text, source=self.name, models_usage=usage),\n",
|
||
" inner_messages=[],\n",
|
||
" )\n",
|
||
"\n",
|
||
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
|
||
" \"\"\"Reset the assistant by clearing the model context.\"\"\"\n",
|
||
" await self._model_context.clear()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"---------- user ----------\n",
|
||
"What is the capital of New York?\n",
|
||
"---------- gemini_assistant ----------\n",
|
||
"Albany\n",
|
||
"TERMINATE\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the capital of New York?', type='TextMessage'), TextMessage(source='gemini_assistant', models_usage=RequestUsage(prompt_tokens=46, completion_tokens=5), content='Albany\\nTERMINATE\\n', type='TextMessage')], stop_reason=None)"
|
||
]
|
||
},
|
||
"execution_count": 38,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"gemini_assistant = GeminiAssistantAgent(\"gemini_assistant\")\n",
|
||
"await Console(gemini_assistant.run_stream(task=\"What is the capital of New York?\"))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"In the example above, we have chosen to provide `model`, `api_key` and `system_message` as arguments - you can choose to provide any other arguments that are required by the model client you are using or fits with your application design. \n",
|
||
"\n",
|
||
"Now, let us explore how to use this custom agent as part of a team in AgentChat."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 39,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"---------- user ----------\n",
|
||
"Write a Haiku poem with 4 lines about the fall season.\n",
|
||
"---------- primary ----------\n",
|
||
"Crimson leaves cascade, \n",
|
||
"Whispering winds sing of change, \n",
|
||
"Chill wraps the fading, \n",
|
||
"Nature's quilt, rich and warm.\n",
|
||
"---------- gemini_critic ----------\n",
|
||
"The poem is good, but it has four lines instead of three. A haiku must have three lines with a 5-7-5 syllable structure. The content is evocative of autumn, but the form is incorrect. Please revise to adhere to the haiku's syllable structure.\n",
|
||
"\n",
|
||
"---------- primary ----------\n",
|
||
"Thank you for your feedback! Here’s a revised haiku that follows the 5-7-5 syllable structure:\n",
|
||
"\n",
|
||
"Crimson leaves drift down, \n",
|
||
"Chill winds whisper through the gold, \n",
|
||
"Autumn’s breath is near.\n",
|
||
"---------- gemini_critic ----------\n",
|
||
"The revised haiku is much improved. It correctly follows the 5-7-5 syllable structure and maintains the evocative imagery of autumn. APPROVE\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write a Haiku poem with 4 lines about the fall season.', type='TextMessage'), TextMessage(source='primary', models_usage=RequestUsage(prompt_tokens=33, completion_tokens=31), content=\"Crimson leaves cascade, \\nWhispering winds sing of change, \\nChill wraps the fading, \\nNature's quilt, rich and warm.\", type='TextMessage'), TextMessage(source='gemini_critic', models_usage=RequestUsage(prompt_tokens=86, completion_tokens=60), content=\"The poem is good, but it has four lines instead of three. A haiku must have three lines with a 5-7-5 syllable structure. The content is evocative of autumn, but the form is incorrect. Please revise to adhere to the haiku's syllable structure.\\n\", type='TextMessage'), TextMessage(source='primary', models_usage=RequestUsage(prompt_tokens=141, completion_tokens=49), content='Thank you for your feedback! Here’s a revised haiku that follows the 5-7-5 syllable structure:\\n\\nCrimson leaves drift down, \\nChill winds whisper through the gold, \\nAutumn’s breath is near.', type='TextMessage'), TextMessage(source='gemini_critic', models_usage=RequestUsage(prompt_tokens=211, completion_tokens=32), content='The revised haiku is much improved. It correctly follows the 5-7-5 syllable structure and maintains the evocative imagery of autumn. APPROVE\\n', type='TextMessage')], stop_reason=\"Text 'APPROVE' mentioned\")"
|
||
]
|
||
},
|
||
"execution_count": 39,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from autogen_agentchat.agents import AssistantAgent\n",
|
||
"from autogen_agentchat.conditions import TextMentionTermination\n",
|
||
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||
"from autogen_agentchat.ui import Console\n",
|
||
"\n",
|
||
"model_client = OpenAIChatCompletionClient(model=\"gpt-4o-mini\")\n",
|
||
"\n",
|
||
"# Create the primary agent.\n",
|
||
"primary_agent = AssistantAgent(\n",
|
||
" \"primary\",\n",
|
||
" model_client=model_client,\n",
|
||
" system_message=\"You are a helpful AI assistant.\",\n",
|
||
")\n",
|
||
"\n",
|
||
"# Create a critic agent based on our new GeminiAssistantAgent.\n",
|
||
"gemini_critic_agent = GeminiAssistantAgent(\n",
|
||
" \"gemini_critic\",\n",
|
||
" system_message=\"Provide constructive feedback. Respond with 'APPROVE' to when your feedbacks are addressed.\",\n",
|
||
")\n",
|
||
"\n",
|
||
"\n",
|
||
"# Define a termination condition that stops the task if the critic approves or after 10 messages.\n",
|
||
"termination = TextMentionTermination(\"APPROVE\") | MaxMessageTermination(10)\n",
|
||
"\n",
|
||
"# Create a team with the primary and critic agents.\n",
|
||
"team = RoundRobinGroupChat([primary_agent, gemini_critic_agent], termination_condition=termination)\n",
|
||
"\n",
|
||
"await Console(team.run_stream(task=\"Write a Haiku poem with 4 lines about the fall season.\"))\n",
|
||
"await model_client.close()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"In section above, we show several very important concepts:\n",
|
||
"- We have developed a custom agent that uses the Google Gemini SDK to respond to messages. \n",
|
||
"- We show that this custom agent can be used as part of the broader AgentChat ecosystem - in this case as a participant in a {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` as long as it inherits from {py:class}`~autogen_agentchat.agents.BaseChatAgent`.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Making the Custom Agent Declarative \n",
|
||
"\n",
|
||
"Autogen provides a [Component](https://microsoft.github.io/autogen/dev/user-guide/core-user-guide/framework/component-config.html) interface for making the configuration of components serializable to a declarative format. This is useful for saving and loading configurations, and for sharing configurations with others. \n",
|
||
"\n",
|
||
"We accomplish this by inheriting from the `Component` class and implementing the `_from_config` and `_to_config` methods.\n",
|
||
"The declarative class can be serialized to a JSON format using the `dump_component` method, and deserialized from a JSON format using the `load_component` method."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import os\n",
|
||
"from typing import AsyncGenerator, Sequence\n",
|
||
"\n",
|
||
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||
"from autogen_agentchat.base import Response\n",
|
||
"from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage\n",
|
||
"from autogen_core import CancellationToken, Component\n",
|
||
"from pydantic import BaseModel\n",
|
||
"from typing_extensions import Self\n",
|
||
"\n",
|
||
"\n",
|
||
"class GeminiAssistantAgentConfig(BaseModel):\n",
|
||
" name: str\n",
|
||
" description: str = \"An agent that provides assistance with ability to use tools.\"\n",
|
||
" model: str = \"gemini-1.5-flash-002\"\n",
|
||
" system_message: str | None = None\n",
|
||
"\n",
|
||
"\n",
|
||
"class GeminiAssistantAgent(BaseChatAgent, Component[GeminiAssistantAgentConfig]): # type: ignore[no-redef]\n",
|
||
" component_config_schema = GeminiAssistantAgentConfig\n",
|
||
" # component_provider_override = \"mypackage.agents.GeminiAssistantAgent\"\n",
|
||
"\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" name: str,\n",
|
||
" description: str = \"An agent that provides assistance with ability to use tools.\",\n",
|
||
" model: str = \"gemini-1.5-flash-002\",\n",
|
||
" api_key: str = os.environ[\"GEMINI_API_KEY\"],\n",
|
||
" system_message: str\n",
|
||
" | None = \"You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed.\",\n",
|
||
" ):\n",
|
||
" super().__init__(name=name, description=description)\n",
|
||
" self._model_context = UnboundedChatCompletionContext()\n",
|
||
" self._model_client = genai.Client(api_key=api_key)\n",
|
||
" self._system_message = system_message\n",
|
||
" self._model = model\n",
|
||
"\n",
|
||
" @property\n",
|
||
" def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:\n",
|
||
" return (TextMessage,)\n",
|
||
"\n",
|
||
" async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:\n",
|
||
" final_response = None\n",
|
||
" async for message in self.on_messages_stream(messages, cancellation_token):\n",
|
||
" if isinstance(message, Response):\n",
|
||
" final_response = message\n",
|
||
"\n",
|
||
" if final_response is None:\n",
|
||
" raise AssertionError(\"The stream should have returned the final result.\")\n",
|
||
"\n",
|
||
" return final_response\n",
|
||
"\n",
|
||
" async def on_messages_stream(\n",
|
||
" self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken\n",
|
||
" ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:\n",
|
||
" # Add messages to the model context\n",
|
||
" for msg in messages:\n",
|
||
" await self._model_context.add_message(msg.to_model_message())\n",
|
||
"\n",
|
||
" # Get conversation history\n",
|
||
" history = [\n",
|
||
" (msg.source if hasattr(msg, \"source\") else \"system\")\n",
|
||
" + \": \"\n",
|
||
" + (msg.content if isinstance(msg.content, str) else \"\")\n",
|
||
" + \"\\n\"\n",
|
||
" for msg in await self._model_context.get_messages()\n",
|
||
" ]\n",
|
||
"\n",
|
||
" # Generate response using Gemini\n",
|
||
" response = self._model_client.models.generate_content(\n",
|
||
" model=self._model,\n",
|
||
" contents=f\"History: {history}\\nGiven the history, please provide a response\",\n",
|
||
" config=types.GenerateContentConfig(\n",
|
||
" system_instruction=self._system_message,\n",
|
||
" temperature=0.3,\n",
|
||
" ),\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Create usage metadata\n",
|
||
" usage = RequestUsage(\n",
|
||
" prompt_tokens=response.usage_metadata.prompt_token_count,\n",
|
||
" completion_tokens=response.usage_metadata.candidates_token_count,\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Add response to model context\n",
|
||
" await self._model_context.add_message(AssistantMessage(content=response.text, source=self.name))\n",
|
||
"\n",
|
||
" # Yield the final response\n",
|
||
" yield Response(\n",
|
||
" chat_message=TextMessage(content=response.text, source=self.name, models_usage=usage),\n",
|
||
" inner_messages=[],\n",
|
||
" )\n",
|
||
"\n",
|
||
" async def on_reset(self, cancellation_token: CancellationToken) -> None:\n",
|
||
" \"\"\"Reset the assistant by clearing the model context.\"\"\"\n",
|
||
" await self._model_context.clear()\n",
|
||
"\n",
|
||
" @classmethod\n",
|
||
" def _from_config(cls, config: GeminiAssistantAgentConfig) -> Self:\n",
|
||
" return cls(\n",
|
||
" name=config.name, description=config.description, model=config.model, system_message=config.system_message\n",
|
||
" )\n",
|
||
"\n",
|
||
" def _to_config(self) -> GeminiAssistantAgentConfig:\n",
|
||
" return GeminiAssistantAgentConfig(\n",
|
||
" name=self.name,\n",
|
||
" description=self.description,\n",
|
||
" model=self._model,\n",
|
||
" system_message=self._system_message,\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Now that we have the required methods implemented, we can now load and dump the custom agent to and from a JSON format, and then load the agent from the JSON format.\n",
|
||
" \n",
|
||
" > Note: You should set the `component_provider_override` class variable to the full path of the module containing the custom agent class e.g., (`mypackage.agents.GeminiAssistantAgent`). This is used by `load_component` method to determine how to instantiate the class. \n",
|
||
" "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"{\n",
|
||
" \"provider\": \"__main__.GeminiAssistantAgent\",\n",
|
||
" \"component_type\": \"agent\",\n",
|
||
" \"version\": 1,\n",
|
||
" \"component_version\": 1,\n",
|
||
" \"description\": null,\n",
|
||
" \"label\": \"GeminiAssistantAgent\",\n",
|
||
" \"config\": {\n",
|
||
" \"name\": \"gemini_assistant\",\n",
|
||
" \"description\": \"An agent that provides assistance with ability to use tools.\",\n",
|
||
" \"model\": \"gemini-1.5-flash-002\",\n",
|
||
" \"system_message\": \"You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed.\"\n",
|
||
" }\n",
|
||
"}\n",
|
||
"<__main__.GeminiAssistantAgent object at 0x11a5c5a90>\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"gemini_assistant = GeminiAssistantAgent(\"gemini_assistant\")\n",
|
||
"config = gemini_assistant.dump_component()\n",
|
||
"print(config.model_dump_json(indent=2))\n",
|
||
"loaded_agent = GeminiAssistantAgent.load_component(config)\n",
|
||
"print(loaded_agent)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Next Steps \n",
|
||
"\n",
|
||
"So far, we have seen how to create custom agents, add custom model clients to agents, and make custom agents declarative. There are a few ways in which this basic sample can be extended:\n",
|
||
"\n",
|
||
"- Extend the Gemini model client to handle function calling similar to the {py:class}`~autogen_agentchat.agents.AssistantAgent` class. https://ai.google.dev/gemini-api/docs/function-calling \n",
|
||
"- Implement a package with a custom agent and experiment with using its declarative format in a tool like [AutoGen Studio](https://microsoft.github.io/autogen/stable/user-guide/autogenstudio-user-guide/index.html)."
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.7"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|