Implement User Defined Functions for Local CLI Executor (#2102)

* Implement user defined functions feature for local cli exec, add docs

* add tests, update docs

* fixes

* fix test

* add pandas test dep

* install test

* provide template as func

* formatting

* undo change

* address comments

* add test deps

* formatting

* test only in 1 env

* formatting

* remove test for local only

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits 2024-03-27 19:45:17 -04:00 committed by GitHub
parent d3db7db67f
commit 5ef2dfc104
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 837 additions and 7 deletions

View File

@ -43,10 +43,10 @@ jobs:
python -c "import autogen"
pip install pytest mock
- name: Install optional dependencies for code executors
# code executors auto skip without deps, so only run for python 3.11
# code executors and udfs auto skip without deps, so only run for python 3.11
if: matrix.python-version == '3.11'
run: |
pip install -e ".[jupyter-executor]"
pip install -e ".[jupyter-executor,test]"
python -m ipykernel install --user --name python3
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash

View File

@ -0,0 +1,128 @@
from __future__ import annotations
import inspect
import functools
from typing import Any, Callable, List, TypeVar, Generic, Union
from typing_extensions import ParamSpec
from textwrap import indent, dedent
from dataclasses import dataclass, field
T = TypeVar("T")
P = ParamSpec("P")
def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T]]) -> str:
code = inspect.getsource(func)
# Strip the decorator
if code.startswith("@"):
code = code[code.index("\n") + 1 :]
return code
@dataclass
class Alias:
name: str
alias: str
@dataclass
class ImportFromModule:
module: str
imports: List[Union[str, Alias]]
Import = Union[str, ImportFromModule, Alias]
def _import_to_str(im: Import) -> str:
if isinstance(im, str):
return f"import {im}"
elif isinstance(im, Alias):
return f"import {im.name} as {im.alias}"
else:
def to_str(i: Union[str, Alias]) -> str:
if isinstance(i, str):
return i
else:
return f"{i.name} as {i.alias}"
imports = ", ".join(map(to_str, im.imports))
return f"from {im.module} import {imports}"
@dataclass
class FunctionWithRequirements(Generic[T, P]):
func: Callable[P, T]
python_packages: List[str] = field(default_factory=list)
global_imports: List[Import] = field(default_factory=list)
@classmethod
def from_callable(
cls, func: Callable[P, T], python_packages: List[str] = [], global_imports: List[Import] = []
) -> FunctionWithRequirements[T, P]:
return cls(python_packages=python_packages, global_imports=global_imports, func=func)
# Type this based on F
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self.func(*args, **kwargs)
def with_requirements(
python_packages: List[str] = [], global_imports: List[Import] = []
) -> Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]:
"""Decorate a function with package and import requirements
Args:
python_packages (List[str], optional): Packages required to function. Can include version info.. Defaults to [].
global_imports (List[Import], optional): Required imports. Defaults to [].
Returns:
Callable[[Callable[P, T]], FunctionWithRequirements[T, P]]: The decorated function
"""
def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]:
func_with_reqs = FunctionWithRequirements(
python_packages=python_packages, global_imports=global_imports, func=func
)
functools.update_wrapper(func_with_reqs, func)
return func_with_reqs
return wrapper
def _build_python_functions_file(funcs: List[Union[FunctionWithRequirements[Any, P], Callable[..., Any]]]) -> str:
# First collect all global imports
global_imports = set()
for func in funcs:
if isinstance(func, FunctionWithRequirements):
global_imports.update(func.global_imports)
content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
for func in funcs:
content += _to_code(func) + "\n\n"
return content
def to_stub(func: Callable[..., Any]) -> str:
"""Generate a stub for a function as a string
Args:
func (Callable[..., Any]): The function to generate a stub for
Returns:
str: The stub for the function
"""
content = f"def {func.__name__}{inspect.signature(func)}:\n"
docstring = func.__doc__
if docstring:
docstring = dedent(docstring)
docstring = '"""' + docstring + '"""'
docstring = indent(docstring, " ")
content += docstring + "\n"
content += " ..."
return content

View File

@ -1,14 +1,14 @@
from hashlib import md5
import os
from pathlib import Path
import re
from string import Template
import sys
import uuid
import warnings
from typing import ClassVar, List, Union
from typing import Any, Callable, ClassVar, List, TypeVar, Union, cast
from typing_extensions import ParamSpec
from autogen.coding.func_with_reqs import FunctionWithRequirements, _build_python_functions_file, to_stub
from ..agentchat.agent import LLMAgent
from ..code_utils import TIMEOUT_MSG, WIN32, _cmd, execute_code
from ..code_utils import TIMEOUT_MSG, WIN32, _cmd
from .base import CodeBlock, CodeExecutor, CodeExtractor, CommandLineCodeResult
from .markdown_code_extractor import MarkdownCodeExtractor
@ -16,16 +16,30 @@ from .utils import _get_file_name_from_content, silence_pip
import subprocess
import logging
__all__ = ("LocalCommandLineCodeExecutor",)
A = ParamSpec("A")
class LocalCommandLineCodeExecutor(CodeExecutor):
SUPPORTED_LANGUAGES: ClassVar[List[str]] = ["bash", "shell", "sh", "pwsh", "powershell", "ps1", "python"]
FUNCTIONS_MODULE: ClassVar[str] = "functions"
FUNCTIONS_FILENAME: ClassVar[str] = "functions.py"
FUNCTION_PROMPT_TEMPLATE: ClassVar[
str
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
For example, if there was a function called `foo` you could import it by writing `from $module_name import foo`
$functions"""
def __init__(
self,
timeout: int = 60,
work_dir: Union[Path, str] = Path("."),
functions: List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]] = [],
):
"""(Experimental) A code executor class that executes code through a local command line
environment.
@ -48,6 +62,7 @@ class LocalCommandLineCodeExecutor(CodeExecutor):
work_dir (str): The working directory for the code execution. If None,
a default working directory will be used. The default working
directory is the current directory ".".
functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
"""
if timeout < 1:
@ -62,6 +77,38 @@ class LocalCommandLineCodeExecutor(CodeExecutor):
self._timeout = timeout
self._work_dir: Path = work_dir
self._functions = functions
# Setup could take some time so we intentionally wait for the first code block to do it.
if len(functions) > 0:
self._setup_functions_complete = False
else:
self._setup_functions_complete = True
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
"""(Experimental) Format the functions for a prompt.
The template includes two variables:
- `$module_name`: The module name.
- `$functions`: The functions formatted as stubs with two newlines between each function.
Args:
prompt_template (str): The prompt template. Default is the class default.
Returns:
str: The formatted prompt.
"""
template = Template(prompt_template)
return template.substitute(
module_name=self.FUNCTIONS_MODULE,
functions="\n\n".join([to_stub(func) for func in self._functions]),
)
@property
def functions(self) -> List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]:
"""(Experimental) The functions that are available to the code executor."""
return self._functions
@property
def timeout(self) -> int:
"""(Experimental) The timeout for code execution."""
@ -99,6 +146,39 @@ class LocalCommandLineCodeExecutor(CodeExecutor):
if re.search(pattern, code):
raise ValueError(f"Potentially dangerous command detected: {message}")
def _setup_functions(self) -> None:
func_file_content = _build_python_functions_file(self._functions)
func_file = self._work_dir / self.FUNCTIONS_FILENAME
func_file.write_text(func_file_content)
# Collect requirements
lists_of_packages = [x.python_packages for x in self._functions if isinstance(x, FunctionWithRequirements)]
flattened_packages = [item for sublist in lists_of_packages for item in sublist]
required_packages = list(set(flattened_packages))
if len(required_packages) > 0:
logging.info("Ensuring packages are installed in executor.")
cmd = [sys.executable, "-m", "pip", "install"]
cmd.extend(required_packages)
try:
result = subprocess.run(
cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout)
)
except subprocess.TimeoutExpired as e:
raise ValueError("Pip install timed out") from e
if result.returncode != 0:
raise ValueError(f"Pip install failed. {result.stdout}, {result.stderr}")
# Attempt to load the function file to check for syntax errors, imports etc.
exec_result = self._execute_code_dont_check_setup([CodeBlock(code=func_file_content, language="python")])
if exec_result.exit_code != 0:
raise ValueError(f"Functions failed to load: {exec_result.output}")
self._setup_functions_complete = True
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
"""(Experimental) Execute the code blocks and return the result.
@ -107,6 +187,13 @@ class LocalCommandLineCodeExecutor(CodeExecutor):
Returns:
CommandLineCodeResult: The result of the code execution."""
if not self._setup_functions_complete:
self._setup_functions()
return self._execute_code_dont_check_setup(code_blocks)
def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
logs_all = ""
file_names = []
for code_block in code_blocks:

View File

@ -55,6 +55,7 @@ setuptools.setup(
"pre-commit",
"pytest-asyncio",
"pytest>=6.1.1,<8",
"pandas",
],
"blendsearch": ["flaml[blendsearch]"],
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],

View File

@ -0,0 +1,151 @@
import tempfile
import pytest
from autogen.coding.base import CodeBlock
from autogen.coding.local_commandline_code_executor import LocalCommandLineCodeExecutor
try:
import pandas
except ImportError:
skip = True
else:
skip = False
from autogen.coding.func_with_reqs import with_requirements
classes_to_test = [LocalCommandLineCodeExecutor]
def add_two_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
return a + b
@with_requirements(python_packages=["pandas"], global_imports=["pandas"])
def load_data() -> "pandas.DataFrame":
"""Load some sample data.
Returns:
pandas.DataFrame: A DataFrame with the following columns: name(str), location(str), age(int)
"""
data = {
"name": ["John", "Anna", "Peter", "Linda"],
"location": ["New York", "Paris", "Berlin", "London"],
"age": [24, 13, 53, 33],
}
return pandas.DataFrame(data)
@with_requirements(global_imports=["NOT_A_REAL_PACKAGE"])
def function_incorrect_import() -> "pandas.DataFrame":
return pandas.DataFrame()
@with_requirements(python_packages=["NOT_A_REAL_PACKAGE"])
def function_incorrect_dep() -> "pandas.DataFrame":
return pandas.DataFrame()
def function_missing_reqs() -> "pandas.DataFrame":
return pandas.DataFrame()
@pytest.mark.parametrize("cls", classes_to_test)
@pytest.mark.skipif(skip, reason="pandas not installed")
def test_can_load_function_with_reqs(cls) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = cls(work_dir=temp_dir, functions=[load_data])
code = f"""from {cls.FUNCTIONS_MODULE} import load_data
import pandas
# Get first row's name
print(load_data().iloc[0]['name'])"""
result = executor.execute_code_blocks(
code_blocks=[
CodeBlock(language="python", code=code),
]
)
assert result.output == "John\n"
assert result.exit_code == 0
@pytest.mark.parametrize("cls", classes_to_test)
@pytest.mark.skipif(skip, reason="pandas not installed")
def test_can_load_function(cls) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = cls(work_dir=temp_dir, functions=[add_two_numbers])
code = f"""from {cls.FUNCTIONS_MODULE} import add_two_numbers
print(add_two_numbers(1, 2))"""
result = executor.execute_code_blocks(
code_blocks=[
CodeBlock(language="python", code=code),
]
)
assert result.output == "3\n"
assert result.exit_code == 0
# TODO - only run this test for containerized executors, as the environment is not guaranteed to have pandas installed
# It is common for the local environment to have pandas installed, so this test will not work as expected
# @pytest.mark.parametrize("cls", classes_to_test)
# @pytest.mark.skipif(skip, reason="pandas not installed")
# def test_fails_for_missing_reqs(cls) -> None:
# with tempfile.TemporaryDirectory() as temp_dir:
# executor = cls(work_dir=temp_dir, functions=[function_missing_reqs])
# code = f"""from {cls.FUNCTIONS_MODULE} import function_missing_reqs
# function_missing_reqs()"""
# with pytest.raises(ValueError):
# executor.execute_code_blocks(
# code_blocks=[
# CodeBlock(language="python", code=code),
# ]
# )
@pytest.mark.parametrize("cls", classes_to_test)
@pytest.mark.skipif(skip, reason="pandas not installed")
def test_fails_for_function_incorrect_import(cls) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = cls(work_dir=temp_dir, functions=[function_incorrect_import])
code = f"""from {cls.FUNCTIONS_MODULE} import function_incorrect_import
function_incorrect_import()"""
with pytest.raises(ValueError):
executor.execute_code_blocks(
code_blocks=[
CodeBlock(language="python", code=code),
]
)
@pytest.mark.parametrize("cls", classes_to_test)
@pytest.mark.skipif(skip, reason="pandas not installed")
def test_fails_for_function_incorrect_dep(cls) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = cls(work_dir=temp_dir, functions=[function_incorrect_dep])
code = f"""from {cls.FUNCTIONS_MODULE} import function_incorrect_dep
function_incorrect_dep()"""
with pytest.raises(ValueError):
executor.execute_code_blocks(
code_blocks=[
CodeBlock(language="python", code=code),
]
)
@pytest.mark.parametrize("cls", classes_to_test)
@pytest.mark.skipif(skip, reason="pandas not installed")
def test_formatted_prompt(cls) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = cls(work_dir=temp_dir, functions=[add_two_numbers])
result = executor.format_functions_for_prompt()
assert (
'''def add_two_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
'''
in result
)

View File

@ -0,0 +1,463 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# User Defined Functions\n",
"\n",
"```mdx\n",
":::note\n",
"This is experimental and not *yet* supported by all executors. At this stage only [`LocalCommandLineCodeExecutor`](/docs/reference/coding/local_commandline_code_executor#localcommandlinecodeexecutor) is supported.\n",
"\n",
"\n",
"Currently, the method of registering tools and using this feature are different. We would like to unify them. See Github issue [here](https://github.com/microsoft/autogen/issues/2101)\n",
":::\n",
"```\n",
"\n",
"User defined functions allow you to define Python functions in your AutoGen program and then provide these to be used by your executor. This allows you to provide your agents with tools without using traditional tool calling APIs. Currently only Python is supported for this feature.\n",
"\n",
"There are several steps involved:\n",
"\n",
"1. Define the function\n",
"2. Provide the function to the executor\n",
"3. Explain to the code writing agent how to use the function\n",
"\n",
"\n",
"## Define the function\n",
"\n",
"```mdx\n",
":::warning\n",
"Keep in mind that the entire source code of these functions will be available to the executor. This means that you should not include any sensitive information in the function as an LLM agent may be able to access it.\n",
":::\n",
"```\n",
"\n",
"If the function does not require any external imports or dependencies then you can simply use the function. For example:\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"def add_two_numbers(a: int, b: int) -> int:\n",
" \"\"\"Add two numbers together.\"\"\"\n",
" return a + b"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This would be a valid standalone function.\n",
"\n",
"```mdx\n",
":::tip\n",
"Using type hints and docstrings are not required but are highly recommended. They will help the code writing agent understand the function and how to use it.\n",
":::\n",
"```\n",
"\n",
"If the function requires external imports or dependencies then you can use the `@with_requirements` decorator to specify the requirements. For example:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas\n",
"from autogen.coding.func_with_reqs import with_requirements\n",
"\n",
"\n",
"@with_requirements(python_packages=[\"pandas\"], global_imports=[\"pandas\"])\n",
"def load_data() -> pandas.DataFrame:\n",
" \"\"\"Load some sample data.\n",
"\n",
" Returns:\n",
" pandas.DataFrame: A DataFrame with the following columns: name(str), location(str), age(int)\n",
" \"\"\"\n",
" data = {\n",
" \"name\": [\"John\", \"Anna\", \"Peter\", \"Linda\"],\n",
" \"location\": [\"New York\", \"Paris\", \"Berlin\", \"London\"],\n",
" \"age\": [24, 13, 53, 33],\n",
" }\n",
" return pandas.DataFrame(data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you wanted to rename `pandas` to `pd` or import `DataFrame` directly you could do the following:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from pandas import DataFrame as df\n",
"\n",
"from autogen.coding.func_with_reqs import Alias, ImportFromModule, with_requirements\n",
"\n",
"\n",
"@with_requirements(python_packages=[\"pandas\"], global_imports=[Alias(\"pandas\", \"pd\")])\n",
"def some_func1() -> pd.DataFrame: ...\n",
"\n",
"\n",
"@with_requirements(python_packages=[\"pandas\"], global_imports=[ImportFromModule(\"pandas\", \"DataFrame\")])\n",
"def some_func2() -> DataFrame: ...\n",
"\n",
"\n",
"@with_requirements(python_packages=[\"pandas\"], global_imports=[ImportFromModule(\"pandas\", Alias(\"DataFrame\", \"df\"))])\n",
"def some_func3() -> df: ..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Provide the function to the executor\n",
"\n",
"Functions can be loaded into the executor in its constructor. For example:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from autogen.coding import CodeBlock\n",
"from autogen.coding import LocalCommandLineCodeExecutor\n",
"\n",
"work_dir = Path(\"coding\")\n",
"work_dir.mkdir(exist_ok=True)\n",
"\n",
"executor = LocalCommandLineCodeExecutor(work_dir=work_dir, functions=[add_two_numbers, load_data])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before we get an agent involved, we can sanity check that when the agent writes code that looks like this the executor will be able to handle it."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"exit_code=0 output='3\\n' code_file='/Users/jackgerrits/w/autogen/website/docs/topics/code-execution/coding/tmp_code_1958fe3aea3e8e3c6e907fe951b5f6ab.py'\n"
]
}
],
"source": [
"code = f\"\"\"\n",
"from {LocalCommandLineCodeExecutor.FUNCTIONS_MODULE} import add_two_numbers\n",
"\n",
"print(add_two_numbers(1, 2))\n",
"\"\"\"\n",
"\n",
"print(\n",
" executor.execute_code_blocks(\n",
" code_blocks=[\n",
" CodeBlock(language=\"python\", code=code),\n",
" ]\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we can try the function that required a dependency and import too."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" name location age\n",
"0 John New York 24\n",
"1 Anna Paris 13\n",
"2 Peter Berlin 53\n",
"3 Linda London 33\n",
"\n"
]
}
],
"source": [
"code = f\"\"\"\n",
"from {LocalCommandLineCodeExecutor.FUNCTIONS_MODULE} import load_data\n",
"\n",
"print(load_data())\n",
"\"\"\"\n",
"\n",
"result = executor.execute_code_blocks(\n",
" code_blocks=[\n",
" CodeBlock(language=\"python\", code=code),\n",
" ]\n",
")\n",
"\n",
"print(result.output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Limitations\n",
"\n",
"- Only Python is supported currently\n",
"- The function must not depend on any globals or external state as it is loaded as source code \n",
"\n",
"## Explain to the code writing agent how to use the function\n",
"\n",
"Now that the function is available to be called by the executor, you can explain to the code writing agent how to use the function. This step is very important as by default it will not know about it.\n",
"\n",
"There is a utility function that you can use to generate a default prompt that describes the available functions and how to use them. This function can have its template overridden to provide a custom message, or you can use a different prompt all together.\n",
"\n",
"For example, you could extend the system message from the page about local execution with a new section that describes the functions available."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"You have been given coding capability to solve tasks using Python code.\n",
"In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.\n",
" 1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.\n",
" 2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.\n",
"Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.\n",
"When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.\n",
"If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.\n",
"You have access to the following user defined functions. They can be accessed from the module called `functions` by their function names.\n",
"\n",
"For example, if there was a function called `foo` you could import it by writing `from functions import foo`\n",
"\n",
"def add_two_numbers(a: int, b: int) -> int:\n",
" \"\"\"Add two numbers together.\"\"\"\n",
" ...\n",
"\n",
"def load_data() -> pandas.core.frame.DataFrame:\n",
" \"\"\"Load some sample data.\n",
"\n",
" Returns:\n",
" pandas.DataFrame: A DataFrame with the following columns: name(str), location(str), age(int)\n",
" \"\"\"\n",
" ...\n",
"\n"
]
}
],
"source": [
"nlnl = \"\\n\\n\"\n",
"code_writer_system_message = \"\"\"\n",
"You have been given coding capability to solve tasks using Python code.\n",
"In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.\n",
" 1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.\n",
" 2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.\n",
"Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.\n",
"When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.\n",
"If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.\n",
"\"\"\"\n",
"\n",
"# Add on the new functions\n",
"code_writer_system_message += executor.format_functions_for_prompt()\n",
"\n",
"print(code_writer_system_message)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then you can use this system message for your code writing agent."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from autogen import ConversableAgent\n",
"\n",
"import os\n",
"\n",
"code_writer_agent = ConversableAgent(\n",
" \"code_writer\",\n",
" system_message=code_writer_system_message,\n",
" llm_config={\"config_list\": [{\"model\": \"gpt-4\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}]},\n",
" code_execution_config=False, # Turn off code execution for this agent.\n",
" max_consecutive_auto_reply=2,\n",
" human_input_mode=\"NEVER\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can setup the code execution agent using the local command line executor we defined earlier."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"code_executor_agent = ConversableAgent(\n",
" name=\"code_executor_agent\",\n",
" llm_config=False,\n",
" code_execution_config={\n",
" \"executor\": executor,\n",
" },\n",
" human_input_mode=\"NEVER\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we can start the conversation and get the agent to process the dataframe we provided."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mcode_executor_agent\u001b[0m (to code_writer):\n",
"\n",
"Please use the load_data function to load the data and please calculate the average age of all people.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mcode_writer\u001b[0m (to code_executor_agent):\n",
"\n",
"Below is the python code to load the data using the `load_data()` function and calculate the average age of all people. \n",
"\n",
"```python\n",
"# python code\n",
"from functions import load_data\n",
"import numpy as np\n",
"\n",
"# Load the data\n",
"df = load_data()\n",
"\n",
"# Calculate the average age\n",
"avg_age = np.mean(df['age'])\n",
"\n",
"print(\"The average age is\", avg_age)\n",
"```\n",
"\n",
"This code starts by importing the `load_data()` function. It then uses this function to load the data into a variable `df`. Afterwards, it calculates the average (mean) of the 'age' column in the DataFrame, before printing the result.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> EXECUTING CODE BLOCK (inferred language is python)...\u001b[0m\n",
"\u001b[33mcode_executor_agent\u001b[0m (to code_writer):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: The average age is 30.75\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mcode_writer\u001b[0m (to code_executor_agent):\n",
"\n",
"Great! The code worked fine. So, the average age of all people in the dataset is 30.75 years.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mcode_executor_agent\u001b[0m (to code_writer):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"chat_result = code_executor_agent.initiate_chat(\n",
" code_writer_agent,\n",
" message=\"Please use the load_data function to load the data and please calculate the average age of all people.\",\n",
" summary_method=\"reflection_with_llm\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see the summary of the calculation:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The average age of all people in the dataset is 30.75 years.\n"
]
}
],
"source": [
"print(chat_result.summary)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "autogen",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}