autogen/notebook/agentchat_capability_long_context_handling.ipynb

606 lines
24 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Handling A Long Context via `TransformChatHistory`\n",
"\n",
"This notebook illustrates how you can use the `TransformChatHistory` capability to give any `Conversable` agent an ability to handle a long context. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"## Uncomment to install pyautogen if you don't have it already\n",
"#! pip install pyautogen"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import autogen\n",
"from autogen.agentchat.contrib.capabilities import context_handling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To add this ability to any agent, define the capability and then use `add_to_agent`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"plot and save a graph of x^2 from -10 to 10\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Here's a Python code snippet to plot and save a graph of x^2 from -10 to 10:\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Generate x values from -10 to 10\n",
"x = np.linspace(-10, 10, 100)\n",
"\n",
"# Evaluate y values using x^2\n",
"y = x**2\n",
"\n",
"# Plot the graph\n",
"plt.plot(x, y)\n",
"\n",
"# Set labels and title\n",
"plt.xlabel('x')\n",
"plt.ylabel('y')\n",
"plt.title('Graph of y = x^2')\n",
"\n",
"# Save the graph as an image file (e.g., PNG)\n",
"plt.savefig('x_squared_graph.png')\n",
"\n",
"# Show the graph\n",
"plt.show()\n",
"```\n",
"\n",
"Please make sure to have the `matplotlib` library installed in your Python environment. After executing the code, the graph will be saved as \"x_squared_graph.png\" in the current directory. You can change the filename if desired.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Figure(640x480)\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Great! The code executed successfully and generated a graph of x^2 from -10 to 10. You can save the graph by adding the following code snippet:\n",
"\n",
"```python\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"```\n",
"\n",
"This will save the graph as a PNG file named \"graph.png\" in your current working directory.\n",
"\n",
"Now, you can check the saved graph in your current directory. Let me know if you need any further assistance.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"\", line 2, in <module>\n",
" plt.savefig('graph.png')\n",
"NameError: name 'plt' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"I apologize for the mistake. It seems that the `plt` module was not imported correctly. Please make sure that you have the `matplotlib` library installed. You can install it using the following command:\n",
"\n",
"```sh\n",
"pip install matplotlib\n",
"```\n",
"\n",
"Once you have the library installed, please try running the code again. If you are still facing any issues, please let me know.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is sh)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: matplotlib in /home/vscode/.local/lib/python3.10/site-packages (3.8.2)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (3.1.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (4.47.2)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (2.8.2)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/site-packages (from matplotlib) (23.2)\n",
"Requirement already satisfied: pillow>=8 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (10.2.0)\n",
"Requirement already satisfied: cycler>=0.10 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (0.12.1)\n",
"Requirement already satisfied: kiwisolver>=1.3.1 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (1.4.5)\n",
"Requirement already satisfied: numpy<2,>=1.21 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (1.26.3)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /home/vscode/.local/lib/python3.10/site-packages (from matplotlib) (1.2.0)\n",
"Requirement already satisfied: six>=1.5 in /home/vscode/.local/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Thank you for installing the `matplotlib` library. Let's try running the code again:\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Generate x values from -10 to 10\n",
"x = np.linspace(-10, 10, 100)\n",
"\n",
"# Compute the y values (x^2)\n",
"y = x**2\n",
"\n",
"# Plot the graph\n",
"plt.plot(x, y)\n",
"\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"\n",
"# Show the graph\n",
"plt.show()\n",
"```\n",
"\n",
"Please give it a try and let me know if it works for you.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Figure(640x480)\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"It seems that the code executed successfully and generated a figure of x^2 from -10 to 10. To save the graph as an image file, you can add the following code snippet:\n",
"\n",
"```python\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"```\n",
"\n",
"This will save the graph as a PNG file named \"graph.png\" in the current directory. You can change the file name or modify the file format as needed.\n",
"Remember to replace `plt.savefig('graph.png')` with your desired file name if necessary.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"\", line 2, in <module>\n",
" plt.savefig('graph.png')\n",
"NameError: name 'plt' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 1 messages.\u001b[0m\n",
"\u001b[33mTruncated 15 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"My apologies for the mistake again. It seems that the `plt` module was not imported correctly. Please make sure that you have imported the `matplotlib.pyplot` module.\n",
"\n",
"Here is the correct code snippet:\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Generate x values from -10 to 10\n",
"x = np.linspace(-10, 10, 100)\n",
"y = x ** 2\n",
"\n",
"# Create a plot of x^2\n",
"plt.plot(x, y)\n",
"\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"```\n",
"\n",
"Please try running this code. It should generate the graph and save it as `graph.png` in the current working directory.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 3 messages.\u001b[0m\n",
"\u001b[33mTruncated 84 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Great! The code executed successfully and generated a graph of x^2 from -10 to 10. You can save the graph by adding the following code snippet:\n",
"\n",
"```python\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"```\n",
"\n",
"This will save the graph as a PNG image file named \"graph.png\" in the same directory as your Python script. You can change the filename and format (e.g., 'graph.jpg', 'graph.pdf') as needed.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"\", line 2, in <module>\n",
" plt.savefig('graph.png')\n",
"NameError: name 'plt' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 5 messages.\u001b[0m\n",
"\u001b[33mTruncated 184 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"I apologize for the confusion. It looks like the code didn't execute as expected. \n",
"\n",
"To save the graph in a file, try the following code:\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Generate x values from -10 to 10\n",
"x = np.linspace(-10, 10, 100)\n",
"\n",
"# Calculate the y values (square of x)\n",
"y = x ** 2\n",
"\n",
"# Plot the graph\n",
"plt.plot(x, y)\n",
"\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"\n",
"# Show the graph\n",
"plt.show()\n",
"```\n",
"\n",
"Make sure you have the `matplotlib` library installed. Let me know if you encounter any further issues.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Figure(640x480)\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 7 messages.\u001b[0m\n",
"\u001b[33mTruncated 284 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Great! The code executed successfully and generated a graph of x^2 from -10 to 10. You can save the graph by adding the following code snippet:\n",
"\n",
"```python\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 9 messages.\u001b[0m\n",
"\u001b[33mTruncated 353 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Sorry for the mistake again. The code snippet provided was incomplete. To correctly save the graph as an image file, you need to add a line of code to close the plot after saving. Here's the updated code:\n",
"\n",
"```python\n",
"# Save the graph as a file\n",
"plt.savefig('graph.png')\n",
"plt.close()\n",
"```\n",
"\n",
"Please try running the complete code again and it should save the graph as \"graph.png\" in the current directory. Let me know if you encounter any further issues.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"\", line 2, in <module>\n",
" plt.savefig('graph.png')\n",
"NameError: name 'plt' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 11 messages.\u001b[0m\n",
"\u001b[33mTruncated 453 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"I apologize for the mistake once again. It seems that the matplotlib module is not installed on your system. To install it, you can use pip, the package manager for Python.\n",
"\n",
"Here are the steps to install matplotlib:\n",
"\n",
"1. Open a command prompt or terminal.\n",
"2. Type the following command and press Enter:\n",
"\n",
"```sh\n",
"pip install matplotlib\n",
"```\n",
"\n",
"3. Wait for the installation to complete.\n",
"\n",
"Once you have installed matplotlib, you can try running the code again to save the graph as an image file.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"assistant = autogen.AssistantAgent(\n",
" \"assistant\",\n",
" llm_config={\n",
" \"config_list\": [{\"model\": \"gpt-3.5-turbo\", \"api_key\": \"YOUR_API_KEY\"}],\n",
" },\n",
")\n",
"\n",
"# Instantiate the capability to manage chat history\n",
"manage_chat_history = context_handling.TransformChatHistory(max_tokens_per_message=50, max_messages=10, max_tokens=1000)\n",
"# Add the capability to the assistant\n",
"manage_chat_history.add_to_agent(assistant)\n",
"\n",
"user_proxy = autogen.UserProxyAgent(\n",
" \"user_proxy\",\n",
" human_input_mode=\"NEVER\",\n",
" is_termination_msg=lambda x: \"TERMINATE\" in x.get(\"content\", \"\"),\n",
" code_execution_config={\n",
" \"work_dir\": \"coding\",\n",
" \"use_docker\": False,\n",
" },\n",
" max_consecutive_auto_reply=10,\n",
")\n",
"\n",
"user_proxy.initiate_chat(assistant, message=\"plot and save a graph of x^2 from -10 to 10\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Why is this important?\n",
"This capability is especially useful if you expect the agent histories to become exceptionally large and exceed the context length offered by your LLM.\n",
"For example, in the example below, we will define two agents -- one without this ability and one with this ability.\n",
"\n",
"The agent with this ability will be able to handle longer chat history without crashing."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"plot and save a graph of x^2 from -10 to 10\n",
"\n",
"--------------------------------------------------------------------------------\n",
"Encountered an error with the base assistant\n",
"Error code: 400 - {'error': {'message': \"This model's maximum context length is 4097 tokens. However, your messages resulted in 1009487 tokens. Please reduce the length of the messages.\", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}}\n",
"\n",
"\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"plot and save a graph of x^2 from -10 to 10\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 1991 messages.\u001b[0m\n",
"\u001b[33mTruncated 49800 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Here is the Python code to plot and save a graph of x^2 from -10 to 10:\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Generate values for x from -10 to 10\n",
"x = np.linspace(-10, 10, 100)\n",
"\n",
"# Calculate y values for x^2\n",
"y = x ** 2\n",
"\n",
"# Plot the graph\n",
"plt.plot(x, y)\n",
"plt.xlabel('x')\n",
"plt.ylabel('y = x^2')\n",
"plt.title('Graph of y = x^2')\n",
"\n",
"# Save the graph as a PNG file\n",
"plt.savefig('graph.png')\n",
"\n",
"# Close the plot\n",
"plt.close()\n",
"\n",
"print('Graph saved as graph.png')\n",
"```\n",
"\n",
"Please make sure you have the `matplotlib` library installed before running this code. You can install it by running `pip install matplotlib` in your terminal.\n",
"\n",
"After executing the code, a graph of y = x^2 will be saved as `graph.png` in your current working directory.\n",
"\n",
"Let me know if you need any further assistance!\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Graph saved as graph.png\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 1993 messages.\u001b[0m\n",
"\u001b[33mTruncated 49850 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Great! The code executed successfully and the graph has been saved as \"graph.png\". You can now view the graph to see the plot of x^2 from -10 to 10. If you have any more questions or need further assistance, feel free to ask.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mTruncated 1995 messages.\u001b[0m\n",
"\u001b[33mTruncated 49900 tokens.\u001b[0m\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"assistant_base = autogen.AssistantAgent(\n",
" \"assistant\",\n",
" llm_config={\n",
" \"config_list\": [{\"model\": \"gpt-3.5-turbo\", \"api_key\": \"YOUR_API_KEY\"}],\n",
" },\n",
")\n",
"\n",
"assistant_with_context_handling = autogen.AssistantAgent(\n",
" \"assistant\",\n",
" llm_config={\n",
" \"config_list\": [{\"model\": \"gpt-3.5-turbo\", \"api_key\": \"YOUR_API_KEY\"}],\n",
" },\n",
")\n",
"# suppose this capability is not available\n",
"manage_chat_history = context_handling.TransformChatHistory(max_tokens_per_message=50, max_messages=10, max_tokens=1000)\n",
"manage_chat_history.add_to_agent(assistant_with_context_handling)\n",
"\n",
"user_proxy = autogen.UserProxyAgent(\n",
" \"user_proxy\",\n",
" human_input_mode=\"NEVER\",\n",
" is_termination_msg=lambda x: \"TERMINATE\" in x.get(\"content\", \"\"),\n",
" code_execution_config={\n",
" \"work_dir\": \"coding\",\n",
" \"use_docker\": False,\n",
" },\n",
" max_consecutive_auto_reply=2,\n",
")\n",
"\n",
"# suppose the chat history is large\n",
"# Create a very long chat history that is bound to cause a crash\n",
"# for gpt 3.5\n",
"long_history = []\n",
"for i in range(1000):\n",
" # define a fake, very long message\n",
" assitant_msg = {\"role\": \"assistant\", \"content\": \"test \" * 1000}\n",
" user_msg = {\"role\": \"user\", \"content\": \"\"}\n",
"\n",
" assistant_base.send(assitant_msg, user_proxy, request_reply=False, silent=True)\n",
" assistant_with_context_handling.send(assitant_msg, user_proxy, request_reply=False, silent=True)\n",
" user_proxy.send(user_msg, assistant_base, request_reply=False, silent=True)\n",
" user_proxy.send(user_msg, assistant_with_context_handling, request_reply=False, silent=True)\n",
"\n",
"try:\n",
" user_proxy.initiate_chat(assistant_base, message=\"plot and save a graph of x^2 from -10 to 10\", clear_history=False)\n",
"except Exception as e:\n",
" print(\"Encountered an error with the base assistant\")\n",
" print(e)\n",
" print(\"\\n\\n\")\n",
"\n",
"try:\n",
" user_proxy.initiate_chat(\n",
" assistant_with_context_handling, message=\"plot and save a graph of x^2 from -10 to 10\", clear_history=False\n",
" )\n",
"except Exception as e:\n",
" print(e)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}