Code executors (#1405)

* code executor

* test

* revert to main conversable agent

* prepare for pr

* kernel

* run open ai tests only when it's out of draft status

* update workflow file

* revert workflow changes

* ipython executor

* check kernel installed; fix tests

* fix tests

* fix tests

* update system prompt

* Update notebook, more tests

* notebook

* raise instead of return None

* allow user provided code executor.

* fixing types

* wip

* refactoring

* polishing

* fixed failing tests

* resolved merge conflict

* fixing failing test

* wip

* local command line executor and embedded ipython executor

* revert notebook

* fix format

* fix merged error

* fix lmm test

* fix lmm test

* move warning

* name and description should be part of the agent protocol, reset is not as it is only used for ConversableAgent; removing accidentally commited file

* version for dependency

* Update autogen/agentchat/conversable_agent.py

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>

* ordering of protocol

* description

* fix tests

* make ipython executor dependency optional

* update document optional dependencies

* Remove exclude from Agent protocol

* Make ConversableAgent consistent with Agent

* fix tests

* add doc string

* add doc string

* fix notebook

* fix interface

* merge and update agents

* disable config usage in reply function

* description field setter

* customize system message update

* update doc

---------

Co-authored-by: Davor Runje <davor@airt.ai>
Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
Co-authored-by: Aaron <aaronlaptop12@hotmail.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Eric Zhu 2024-02-09 20:52:16 -08:00 committed by GitHub
parent 5d81ed43f3
commit 609ba7c649
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1421 additions and 131 deletions

View File

@ -42,6 +42,14 @@ jobs:
pip install -e .
python -c "import autogen"
pip install pytest mock
pip install jupyter-client ipykernel
python -m ipykernel install --user --name python3
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Test with pytest skipping openai tests
if: matrix.python-version != '3.10' && matrix.os == 'ubuntu-latest'
run: |

View File

@ -1,70 +1,136 @@
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Protocol, Union, runtime_checkable
class Agent:
"""(In preview) An abstract class for AI agent.
@runtime_checkable
class Agent(Protocol):
"""(In preview) A protocol for Agent.
An agent can communicate with other agents and perform actions.
Different agents can differ in what actions they perform in the `receive` method.
"""
def __init__(
self,
name: str,
):
"""
Args:
name (str): name of the agent.
"""
# a dictionary of conversations, default value is list
self._name = name
@property
def name(self) -> str:
"""The name of the agent."""
...
@property
def name(self):
"""Get the name of the agent."""
return self._name
def description(self) -> str:
"""The description of the agent. Used for the agent's introduction in
a group chat setting."""
...
def send(self, message: Union[Dict, str], recipient: "Agent", request_reply: Optional[bool] = None):
"""(Abstract method) Send a message to another agent."""
def send(
self,
message: Union[Dict[str, Any], str],
recipient: "Agent",
request_reply: Optional[bool] = None,
) -> None:
"""Send a message to another agent.
async def a_send(self, message: Union[Dict, str], recipient: "Agent", request_reply: Optional[bool] = None):
"""(Abstract async method) Send a message to another agent."""
Args:
message (dict or str): the message to send. If a dict, it should be
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
recipient (Agent): the recipient of the message.
request_reply (bool): whether to request a reply from the recipient.
"""
...
def receive(self, message: Union[Dict, str], sender: "Agent", request_reply: Optional[bool] = None):
"""(Abstract method) Receive a message from another agent."""
async def a_send(
self,
message: Union[Dict[str, Any], str],
recipient: "Agent",
request_reply: Optional[bool] = None,
) -> None:
"""(Async) Send a message to another agent.
async def a_receive(self, message: Union[Dict, str], sender: "Agent", request_reply: Optional[bool] = None):
"""(Abstract async method) Receive a message from another agent."""
Args:
message (dict or str): the message to send. If a dict, it should be
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
recipient (Agent): the recipient of the message.
request_reply (bool): whether to request a reply from the recipient.
"""
...
def reset(self):
"""(Abstract method) Reset the agent."""
def receive(
self,
message: Union[Dict[str, Any], str],
sender: "Agent",
request_reply: Optional[bool] = None,
) -> None:
"""Receive a message from another agent.
Args:
message (dict or str): the message received. If a dict, it should be
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
sender (Agent): the sender of the message.
request_reply (bool): whether the sender requests a reply.
"""
async def a_receive(
self,
message: Union[Dict[str, Any], str],
sender: "Agent",
request_reply: Optional[bool] = None,
) -> None:
"""(Async) Receive a message from another agent.
Args:
message (dict or str): the message received. If a dict, it should be
a JSON-serializable and follows the OpenAI's ChatCompletion schema.
sender (Agent): the sender of the message.
request_reply (bool): whether the sender requests a reply.
"""
...
def generate_reply(
self,
messages: Optional[List[Dict]] = None,
messages: Optional[List[Dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs,
) -> Union[str, Dict, None]:
"""(Abstract method) Generate a reply based on the received messages.
**kwargs: Any,
) -> Union[str, Dict[str, Any], None]:
"""Generate a reply based on the received messages.
Args:
messages (list[dict]): a list of messages received.
messages (list[dict]): a list of messages received from other agents.
The messages are dictionaries that are JSON-serializable and
follows the OpenAI's ChatCompletion schema.
sender: sender of an Agent instance.
Returns:
str or dict or None: the generated reply. If None, no reply is generated.
"""
async def a_generate_reply(
self,
messages: Optional[List[Dict]] = None,
messages: Optional[List[Dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs,
) -> Union[str, Dict, None]:
"""(Abstract async method) Generate a reply based on the received messages.
**kwargs: Any,
) -> Union[str, Dict[str, Any], None]:
"""(Async) Generate a reply based on the received messages.
Args:
messages (list[dict]): a list of messages received.
messages (list[dict]): a list of messages received from other agents.
The messages are dictionaries that are JSON-serializable and
follows the OpenAI's ChatCompletion schema.
sender: sender of an Agent instance.
Returns:
str or dict or None: the generated reply. If None, no reply is generated.
"""
@runtime_checkable
class LLMAgent(Agent, Protocol):
"""(In preview) A protocol for an LLM agent."""
@property
def system_message(self) -> str:
"""The system message of this agent."""
def update_system_message(self, system_message: str) -> None:
"""Update this agent's system message.
Args:
system_message (str): system message for inference.
"""

View File

@ -6,14 +6,16 @@ import json
import logging
import re
from collections import defaultdict
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
import warnings
from openai import BadRequestError
from ..coding.base import CodeExecutor
from ..coding.factory import CodeExecutorFactory
from ..oai.client import OpenAIWrapper, ModelClient
from ..cache.cache import Cache
from ..code_utils import (
DEFAULT_MODEL,
UNKNOWN,
content_str,
check_can_use_docker_or_throw,
@ -27,7 +29,7 @@ from .chat import ChatResult
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
from .agent import Agent
from .agent import Agent, LLMAgent
from .._pydantic import model_dump
try:
@ -45,7 +47,7 @@ logger = logging.getLogger(__name__)
F = TypeVar("F", bound=Callable[..., Any])
class ConversableAgent(Agent):
class ConversableAgent(LLMAgent):
"""(In preview) A class for generic conversable agents which can be configured as assistant or user proxy.
After receiving each message, the agent will send a reply to the sender unless the msg is a termination msg.
@ -122,11 +124,11 @@ class ConversableAgent(Agent):
description (str): a short description of the agent. This description is used by other agents
(e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message)
"""
super().__init__(name)
self._name = name
# a dictionary of conversations, default value is list
self._oai_messages = defaultdict(list)
self._oai_system_message = [{"content": system_message, "role": "system"}]
self.description = description if description is not None else system_message
self._description = description if description is not None else system_message
self._is_termination_msg = (
is_termination_msg
if is_termination_msg is not None
@ -145,23 +147,6 @@ class ConversableAgent(Agent):
# Initialize standalone client cache object.
self.client_cache = None
if code_execution_config is None:
warnings.warn(
"Using None to signal a default code_execution_config is deprecated. "
"Use {} to use default or False to disable code execution.",
stacklevel=2,
)
self._code_execution_config: Union[Dict, Literal[False]] = (
{} if code_execution_config is None else code_execution_config
)
if isinstance(self._code_execution_config, dict):
use_docker = self._code_execution_config.get("use_docker", None)
use_docker = decide_use_docker(use_docker)
check_can_use_docker_or_throw(use_docker)
self._code_execution_config["use_docker"] = use_docker
self.human_input_mode = human_input_mode
self._max_consecutive_auto_reply = (
max_consecutive_auto_reply if max_consecutive_auto_reply is not None else self.MAX_CONSECUTIVE_AUTO_REPLY
@ -180,7 +165,39 @@ class ConversableAgent(Agent):
self.reply_at_receive = defaultdict(bool)
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True)
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
# Setting up code execution.
# Do not register code execution reply if code execution is disabled.
if code_execution_config is not False:
# If code_execution_config is None, set it to an empty dict.
if code_execution_config is None:
warnings.warn(
"Using None to signal a default code_execution_config is deprecated. "
"Use {} to use default or False to disable code execution.",
stacklevel=2,
)
code_execution_config = {}
if not isinstance(code_execution_config, dict):
raise ValueError("code_execution_config must be a dict or False.")
# We have got a valid code_execution_config.
self._code_execution_config = code_execution_config
if self._code_execution_config.get("executor") is not None:
# Use the new code executor.
self._code_executor = CodeExecutorFactory.create(self._code_execution_config)
self.register_reply([Agent, None], ConversableAgent._generate_code_execution_reply_using_executor)
else:
# Legacy code execution using code_utils.
use_docker = self._code_execution_config.get("use_docker", None)
use_docker = decide_use_docker(use_docker)
check_can_use_docker_or_throw(use_docker)
self._code_execution_config["use_docker"] = use_docker
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
else:
# Code execution is disabled.
self._code_execution_config = False
self.register_reply([Agent, None], ConversableAgent.generate_tool_calls_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply, ignore_async_in_sync_chat=True)
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
@ -196,6 +213,31 @@ class ConversableAgent(Agent):
# New hookable methods should be added to this list as required to support new agent capabilities.
self.hook_lists = {self.process_last_message: [], self.process_all_messages: []}
@property
def name(self) -> str:
"""Get the name of the agent."""
return self._name
@property
def description(self) -> str:
"""Get the description of the agent."""
return self._description
@description.setter
def description(self, description: str):
"""Set the description of the agent."""
self._description = description
@property
def code_executor(self) -> CodeExecutor:
"""The code executor used by this agent. Raise if code execution is disabled."""
if not hasattr(self, "_code_executor"):
raise ValueError(
"No code executor as code execution is disabled. "
"To enable code execution, set code_execution_config."
)
return self._code_executor
def register_reply(
self,
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
@ -268,15 +310,15 @@ class ConversableAgent(Agent):
self._ignore_async_func_in_sync_chat_list.append(reply_func)
@property
def system_message(self) -> Union[str, List]:
def system_message(self) -> str:
"""Return the system message."""
return self._oai_system_message[0]["content"]
def update_system_message(self, system_message: Union[str, List]):
def update_system_message(self, system_message: str) -> None:
"""Update the system message.
Args:
system_message (str or List): system message for the ChatCompletion inference.
system_message (str): system message for the ChatCompletion inference.
"""
self._oai_system_message[0]["content"] = system_message
@ -1062,6 +1104,54 @@ class ConversableAgent(Agent):
None, functools.partial(self.generate_oai_reply, messages=messages, sender=sender, config=config)
)
def _generate_code_execution_reply_using_executor(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Union[Dict, Literal[False]]] = None,
):
"""Generate a reply using code executor."""
if config is not None:
raise ValueError("config is not supported for _generate_code_execution_reply_using_executor.")
if self._code_execution_config is False:
return False, None
if messages is None:
messages = self._oai_messages[sender]
last_n_messages = self._code_execution_config.get("last_n_messages", "auto")
if not (isinstance(last_n_messages, (int, float)) and last_n_messages >= 0) and last_n_messages != "auto":
raise ValueError("last_n_messages must be either a non-negative integer, or the string 'auto'.")
num_messages_to_scan = last_n_messages
if last_n_messages == "auto":
# Find when the agent last spoke
num_messages_to_scan = 0
for message in reversed(messages):
if "role" not in message:
break
elif message["role"] != "user":
break
else:
num_messages_to_scan += 1
num_messages_to_scan = min(len(messages), num_messages_to_scan)
messages_to_scan = messages[-num_messages_to_scan:]
# iterate through the last n messages in reverse
# if code blocks are found, execute the code blocks and return the output
# if no code blocks are found, continue
for message in reversed(messages_to_scan):
if not message["content"]:
continue
code_blocks = self._code_executor.code_extractor.extract_code_blocks(message["content"])
if len(code_blocks) == 0:
continue
# found code blocks, execute code.
code_result = self._code_executor.execute_code_blocks(code_blocks)
exitcode2str = "execution succeeded" if code_result.exit_code == 0 else "execution failed"
return True, f"exitcode: {code_result.exit_code} ({exitcode2str})\nCode output: {code_result.output}"
return False, None
def generate_code_execution_reply(
self,
messages: Optional[List[Dict]] = None,
@ -1490,9 +1580,9 @@ class ConversableAgent(Agent):
def generate_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
exclude: Optional[List[Callable]] = None,
messages: Optional[List[Dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs: Any,
) -> Union[str, Dict, None]:
"""Reply based on the conversation history and the sender.
@ -1513,9 +1603,10 @@ class ConversableAgent(Agent):
Args:
messages: a list of messages in the conversation history.
default_reply (str or dict): default reply.
sender: sender of an Agent instance.
exclude: a list of functions to exclude.
Additional keyword arguments:
exclude (List[Callable]): a list of reply functions to be excluded.
Returns:
str or dict or None: reply. None if no reply is generated.
@ -1538,7 +1629,7 @@ class ConversableAgent(Agent):
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
if "exclude" in kwargs and reply_func in kwargs["exclude"]:
continue
if inspect.iscoroutinefunction(reply_func):
continue
@ -1550,10 +1641,10 @@ class ConversableAgent(Agent):
async def a_generate_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
exclude: Optional[List[Callable]] = None,
) -> Union[str, Dict, None]:
messages: Optional[List[Dict[str, Any]]] = None,
sender: Optional["Agent"] = None,
**kwargs: Any,
) -> Union[str, Dict[str, Any], None]:
"""(async) Reply based on the conversation history and the sender.
Either messages or sender must be provided.
@ -1573,9 +1664,10 @@ class ConversableAgent(Agent):
Args:
messages: a list of messages in the conversation history.
default_reply (str or dict): default reply.
sender: sender of an Agent instance.
exclude: a list of functions to exclude.
Additional keyword arguments:
exclude (List[Callable]): a list of reply functions to be excluded.
Returns:
str or dict or None: reply. None if no reply is generated.
@ -1598,7 +1690,7 @@ class ConversableAgent(Agent):
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
if "exclude" in kwargs and reply_func in kwargs["exclude"]:
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
if inspect.iscoroutinefunction(reply_func):

View File

@ -8,7 +8,7 @@ import sys
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from hashlib import md5
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from autogen import oai
@ -37,7 +37,7 @@ PATH_SEPARATOR = WIN32 and "\\" or "/"
logger = logging.getLogger(__name__)
def content_str(content: Union[str, List, None]) -> str:
def content_str(content: Union[str, List[Dict[str, Any]], None]) -> str:
"""Converts `content` into a string format.
This function processes content that may be a string, a list of mixed text and image URLs, or None,
@ -78,7 +78,7 @@ def content_str(content: Union[str, List, None]) -> str:
return rst
def infer_lang(code):
def infer_lang(code: str) -> str:
"""infer the language for the code.
TODO: make it robust.
"""
@ -223,7 +223,7 @@ def _cmd(lang):
raise NotImplementedError(f"{lang} not recognized in code execution")
def is_docker_running():
def is_docker_running() -> bool:
"""Check if docker is running.
Returns:
@ -237,7 +237,7 @@ def is_docker_running():
return False
def in_docker_container():
def in_docker_container() -> bool:
"""Check if the code is running in a docker container.
Returns:
@ -315,7 +315,7 @@ def execute_code(
work_dir: Optional[str] = None,
use_docker: Union[List[str], str, bool] = SENTINEL,
lang: Optional[str] = "python",
) -> Tuple[int, str, str]:
) -> Tuple[int, str, Optional[str]]:
"""Execute code in a docker container.
This function is not tested on MacOS.

View File

@ -0,0 +1,5 @@
from .base import CodeBlock, CodeExecutor, CodeExtractor, CodeResult
from .factory import CodeExecutorFactory
from .markdown_code_extractor import MarkdownCodeExtractor
__all__ = ("CodeBlock", "CodeResult", "CodeExtractor", "CodeExecutor", "CodeExecutorFactory", "MarkdownCodeExtractor")

94
autogen/coding/base.py Normal file
View File

@ -0,0 +1,94 @@
from typing import Any, Dict, List, Protocol, Union, runtime_checkable
from pydantic import BaseModel, Field
from ..agentchat.agent import LLMAgent
__all__ = ("CodeBlock", "CodeResult", "CodeExtractor", "CodeExecutor")
class CodeBlock(BaseModel):
"""A class that represents a code block."""
code: str = Field(description="The code to execute.")
language: str = Field(description="The language of the code.")
class CodeResult(BaseModel):
"""A class that represents the result of a code execution."""
exit_code: int = Field(description="The exit code of the code execution.")
output: str = Field(description="The output of the code execution.")
class CodeExtractor(Protocol):
"""A code extractor class that extracts code blocks from a message."""
def extract_code_blocks(self, message: Union[str, List[Dict[str, Any]], None]) -> List[CodeBlock]:
"""Extract code blocks from a message.
Args:
message (str): The message to extract code blocks from.
Returns:
List[CodeBlock]: The extracted code blocks.
"""
... # pragma: no cover
@runtime_checkable
class CodeExecutor(Protocol):
"""A code executor class that executes code blocks and returns the result."""
class UserCapability(Protocol):
"""An AgentCapability class that gives agent ability use this code executor."""
def add_to_agent(self, agent: LLMAgent) -> None:
... # pragma: no cover
@property
def user_capability(self) -> "CodeExecutor.UserCapability":
"""Capability to use this code executor.
The exported capability can be added to an agent to allow it to use this
code executor:
```python
code_executor = CodeExecutor()
agent = ConversableAgent("agent", ...)
code_executor.user_capability.add_to_agent(agent)
```
A typical implementation is to update the system message of the agent with
instructions for how to use this code executor.
"""
... # pragma: no cover
@property
def code_extractor(self) -> CodeExtractor:
"""The code extractor used by this code executor."""
... # pragma: no cover
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CodeResult:
"""Execute code blocks and return the result.
This method should be implemented by the code executor.
Args:
code_blocks (List[CodeBlock]): The code blocks to execute.
Returns:
CodeResult: The result of the code execution.
"""
... # pragma: no cover
def restart(self) -> None:
"""Restart the code executor.
This method should be implemented by the code executor.
This method is called when the agent is reset.
"""
... # pragma: no cover

View File

@ -0,0 +1,244 @@
import base64
import json
import os
import re
import uuid
from queue import Empty
from typing import Any, ClassVar, List
from jupyter_client import KernelManager # type: ignore[attr-defined]
from jupyter_client.kernelspec import KernelSpecManager
from pydantic import BaseModel, Field, field_validator
from ..agentchat.agent import LLMAgent
from .base import CodeBlock, CodeExtractor, CodeResult
from .markdown_code_extractor import MarkdownCodeExtractor
__all__ = ("EmbeddedIPythonCodeExecutor",)
class IPythonCodeResult(CodeResult):
"""A code result class for IPython code executor."""
output_files: List[str] = Field(
default_factory=list,
description="The list of files that the executed code blocks generated.",
)
class EmbeddedIPythonCodeExecutor(BaseModel):
"""A code executor class that executes code statefully using an embedded
IPython kernel managed by this class.
**This will execute LLM generated code on the local machine.**
Each execution is stateful and can access variables created from previous
executions in the same session. The kernel must be installed before using
this class. The kernel can be installed using the following command:
`python -m ipykernel install --user --name {kernel_name}`
where `kernel_name` is the name of the kernel to install.
Args:
timeout (int): The timeout for code execution, by default 60.
kernel_name (str): The kernel name to use. Make sure it is installed.
By default, it is "python3".
output_dir (str): The directory to save output files, by default ".".
system_message_update (str): The system message update to add to the
agent that produces code. By default it is
`EmbeddedIPythonCodeExecutor.DEFAULT_SYSTEM_MESSAGE_UPDATE`.
"""
DEFAULT_SYSTEM_MESSAGE_UPDATE: ClassVar[
str
] = """
# IPython Coding Capability
You have been given coding capability to solve tasks using Python code in a stateful IPython kernel.
You are responsible for writing the code, and the user is responsible for executing the code.
When you write Python code, put the code in a markdown code block with the language set to Python.
For example:
```python
x = 3
```
You can use the variable `x` in subsequent code blocks.
```python
print(x)
```
Write code incrementally and leverage the statefulness of the kernel to avoid repeating code.
Import libraries in a separate code block.
Define a function or a class in a separate code block.
Run code that produces output in a separate code block.
Run code that involves expensive operations like download, upload, and call external APIs in a separate code block.
When your code produces an output, the output will be returned to you.
Because you have limited conversation memory, if your code creates an image,
the output will be a path to the image instead of the image itself.
"""
timeout: int = Field(default=60, ge=1, description="The timeout for code execution.")
kernel_name: str = Field(default="python3", description="The kernel name to use. Make sure it is installed.")
output_dir: str = Field(default=".", description="The directory to save output files.")
system_message_update: str = Field(
default=DEFAULT_SYSTEM_MESSAGE_UPDATE,
description="The system message update to the agent that produces code to be executed by this executor.",
)
class UserCapability:
"""An AgentCapability class that gives agent ability use a stateful
IPython code executor. This capability can be added to an agent using
the `add_to_agent` method which append a system message update to the
agent's system message."""
def __init__(self, system_message_update: str):
self.system_message_update = system_message_update
def add_to_agent(self, agent: LLMAgent) -> None:
"""Add this capability to an agent by appending a system message
update to the agent's system message.
**Currently we do not check for conflicts with existing content in
the agent's system message.**
Args:
agent (LLMAgent): The agent to add the capability to.
"""
agent.update_system_message(agent.system_message + self.system_message_update)
@field_validator("output_dir")
@classmethod
def _output_dir_must_exist(cls, value: str) -> str:
if not os.path.exists(value):
raise ValueError(f"Output directory {value} does not exist.")
return value
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
# Check if the kernel is installed.
if self.kernel_name not in KernelSpecManager().find_kernel_specs():
raise ValueError(
f"Kernel {self.kernel_name} is not installed. "
"Please first install it with "
f"`python -m ipykernel install --user --name {self.kernel_name}`."
)
self._kernel_manager = KernelManager(kernel_name=self.kernel_name)
self._kernel_manager.start_kernel()
self._kernel_client = self._kernel_manager.client()
self._kernel_client.start_channels()
self._timeout = self.timeout
@property
def user_capability(self) -> "EmbeddedIPythonCodeExecutor.UserCapability":
"""Export a user capability for this executor that can be added to
an agent using the `add_to_agent` method."""
return EmbeddedIPythonCodeExecutor.UserCapability(self.system_message_update)
@property
def code_extractor(self) -> CodeExtractor:
"""Export a code extractor that can be used by an agent."""
return MarkdownCodeExtractor()
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> IPythonCodeResult:
"""Execute a list of code blocks and return the result.
This method executes a list of code blocks as cells in an IPython kernel
managed by this class.
See: https://jupyter-client.readthedocs.io/en/stable/messaging.html
for the message protocol.
Args:
code_blocks (List[CodeBlock]): A list of code blocks to execute.
Returns:
IPythonCodeResult: The result of the code execution.
"""
self._kernel_client.wait_for_ready()
outputs = []
output_files = []
for code_block in code_blocks:
code = self._process_code(code_block.code)
self._kernel_client.execute(code, store_history=True)
while True:
try:
msg = self._kernel_client.get_iopub_msg(timeout=self._timeout)
msg_type = msg["msg_type"]
content = msg["content"]
if msg_type in ["execute_result", "display_data"]:
for data_type, data in content["data"].items():
if data_type == "text/plain":
# Output is a text.
outputs.append(data)
elif data_type.startswith("image/"):
# Output is an image.
path = self._save_image(data)
outputs.append(f"Image data saved to {path}")
output_files.append(path)
elif data_type == "text/html":
# Output is an html.
path = self._save_html(data)
outputs.append(f"HTML data saved to {path}")
output_files.append(path)
else:
# Output raw data.
outputs.append(json.dumps(data))
elif msg_type == "stream":
# Output is a text.
outputs.append(content["text"])
elif msg_type == "error":
# Output is an error.
return IPythonCodeResult(
exit_code=1,
output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}",
)
if msg_type == "status" and content["execution_state"] == "idle":
break
# handle time outs.
except Empty:
return IPythonCodeResult(
exit_code=1,
output=f"ERROR: Timeout waiting for output from code block: {code_block.code}",
)
# We return the full output.
return IPythonCodeResult(
exit_code=0, output="\n".join([str(output) for output in outputs]), output_files=output_files
)
def restart(self) -> None:
"""Restart a new session."""
self._kernel_client.stop_channels()
self._kernel_manager.shutdown_kernel()
self._kernel_manager = KernelManager(kernel_name=self.kernel_name)
self._kernel_manager.start_kernel()
self._kernel_client = self._kernel_manager.client()
self._kernel_client.start_channels()
def _save_image(self, image_data_base64: str) -> str:
"""Save image data to a file."""
image_data = base64.b64decode(image_data_base64)
# Randomly generate a filename.
filename = f"{uuid.uuid4().hex}.png"
path = os.path.join(self.output_dir, filename)
with open(path, "wb") as f:
f.write(image_data)
return os.path.abspath(path)
def _save_html(self, html_data: str) -> str:
"""Save html data to a file."""
# Randomly generate a filename.
filename = f"{uuid.uuid4().hex}.html"
path = os.path.join(self.output_dir, filename)
with open(path, "w") as f:
f.write(html_data)
return os.path.abspath(path)
def _process_code(self, code: str) -> str:
"""Process code before execution."""
# Find lines that start with `! pip install` and make sure "-qqq" flag is added.
lines = code.split("\n")
for i, line in enumerate(lines):
# use regex to find lines that start with `! pip install` or `!pip install`.
match = re.search(r"^! ?pip install", line)
if match is not None:
if "-qqq" not in line:
lines[i] = line.replace(match.group(0), match.group(0) + " -qqq")
return "\n".join(lines)

41
autogen/coding/factory.py Normal file
View File

@ -0,0 +1,41 @@
from typing import Any, Dict
from .base import CodeExecutor
__all__ = ("CodeExecutorFactory",)
class CodeExecutorFactory:
"""A factory class for creating code executors."""
@staticmethod
def create(code_execution_config: Dict[str, Any]) -> CodeExecutor:
"""Get a code executor based on the code execution config.
Args:
code_execution_config (Dict): The code execution config,
which is a dictionary that must contain the key "executor".
The value of the key "executor" can be either a string
or an instance of CodeExecutor, in which case the code
executor is returned directly.
Returns:
CodeExecutor: The code executor.
Raises:
ValueError: If the code executor is unknown or not specified.
"""
executor = code_execution_config.get("executor")
if isinstance(executor, CodeExecutor):
# If the executor is already an instance of CodeExecutor, return it.
return executor
if executor == "ipython-embedded":
from .embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor
return EmbeddedIPythonCodeExecutor(**code_execution_config.get("ipython-embedded", {}))
elif executor == "commandline-local":
from .local_commandline_code_executor import LocalCommandlineCodeExecutor
return LocalCommandlineCodeExecutor(**code_execution_config.get("commandline-local", {}))
else:
raise ValueError(f"Unknown code executor {executor}")

View File

@ -0,0 +1,162 @@
import os
import uuid
import warnings
from typing import Any, ClassVar, List, Optional
from pydantic import BaseModel, Field, field_validator
from ..agentchat.agent import LLMAgent
from ..code_utils import execute_code
from .base import CodeBlock, CodeExtractor, CodeResult
from .markdown_code_extractor import MarkdownCodeExtractor
try:
from termcolor import colored
except ImportError:
def colored(x: Any, *args: Any, **kwargs: Any) -> str: # type: ignore[misc]
return x # type: ignore[no-any-return]
__all__ = (
"LocalCommandlineCodeExecutor",
"CommandlineCodeResult",
)
class CommandlineCodeResult(CodeResult):
"""A code result class for command line code executor."""
code_file: Optional[str] = Field(
default=None,
description="The file that the executed code block was saved to.",
)
class LocalCommandlineCodeExecutor(BaseModel):
"""A code executor class that executes code through a local command line
environment.
**This will execute LLM generated code on the local machine.**
Each code block is saved as a file and executed in a separate process in
the working directory, and a unique file is generated and saved in the
working directory for each code block.
The code blocks are executed in the order they are received.
Currently the only supported languages is Python and shell scripts.
For Python code, use the language "python" for the code block.
For shell scripts, use the language "bash", "shell", or "sh" for the code
block.
Args:
timeout (int): The timeout for code execution. Default is 60.
work_dir (str): The working directory for the code execution. If None,
a default working directory will be used. The default working
directory is the current directory ".".
system_message_update (str): The system message update for agent that
produces code to run on this executor.
Default is `LocalCommandlineCodeExecutor.DEFAULT_SYSTEM_MESSAGE_UPDATE`.
"""
DEFAULT_SYSTEM_MESSAGE_UPDATE: ClassVar[
str
] = """
You have been given coding capability to solve tasks using Python code.
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
"""
timeout: int = Field(default=60, ge=1, description="The timeout for code execution.")
work_dir: str = Field(default=".", description="The working directory for the code execution.")
system_message_update: str = Field(
default=DEFAULT_SYSTEM_MESSAGE_UPDATE,
description="The system message update for agent that produces code to run on this executor.",
)
class UserCapability:
"""An AgentCapability class that gives agent ability use a command line
code executor via a system message update. This capability can be added
to an agent using the `add_to_agent` method."""
def __init__(self, system_message_update: str) -> None:
self.system_message_update = system_message_update
def add_to_agent(self, agent: LLMAgent) -> None:
"""Add this capability to an agent by updating the agent's system
message."""
agent.update_system_message(agent.system_message + self.system_message_update)
@field_validator("work_dir")
@classmethod
def _check_work_dir(cls, v: str) -> str:
if os.path.exists(v):
return v
raise ValueError(f"Working directory {v} does not exist.")
@property
def user_capability(self) -> "LocalCommandlineCodeExecutor.UserCapability":
"""Export a user capability for this executor that can be added to
an agent that produces code to be executed by this executor."""
return LocalCommandlineCodeExecutor.UserCapability(self.system_message_update)
@property
def code_extractor(self) -> CodeExtractor:
"""Export a code extractor that can be used by an agent."""
return MarkdownCodeExtractor()
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandlineCodeResult:
"""Execute the code blocks and return the result.
Args:
code_blocks (List[CodeBlock]): The code blocks to execute.
Returns:
CommandlineCodeResult: The result of the code execution."""
logs_all = ""
for i, code_block in enumerate(code_blocks):
lang, code = code_block.language, code_block.code
print(
colored(
f"\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...",
"red",
),
flush=True,
)
filename_uuid = uuid.uuid4().hex
filename = None
if lang in ["bash", "shell", "sh"]:
filename = f"{filename_uuid}.{lang}"
exitcode, logs, _ = execute_code(
code=code,
lang=lang,
timeout=self.timeout,
work_dir=self.work_dir,
filename=filename,
use_docker=False,
)
elif lang in ["python", "Python"]:
filename = f"{filename_uuid}.py"
exitcode, logs, _ = execute_code(
code=code,
lang="python",
timeout=self.timeout,
work_dir=self.work_dir,
filename=filename,
use_docker=False,
)
else:
# In case the language is not supported, we return an error message.
exitcode, logs, _ = (1, f"unknown language {lang}", None)
logs_all += "\n" + logs
if exitcode != 0:
break
code_filename = os.path.join(self.work_dir, filename) if filename is not None else None
return CommandlineCodeResult(exit_code=exitcode, output=logs_all, code_file=code_filename)
def restart(self) -> None:
"""Restart the code executor."""
warnings.warn("Restarting local command line code executor is not supported. No action is taken.")

View File

@ -0,0 +1,35 @@
import re
from typing import Any, Dict, List, Optional, Union
from ..code_utils import CODE_BLOCK_PATTERN, UNKNOWN, content_str, infer_lang
from .base import CodeBlock
__all__ = ("MarkdownCodeExtractor",)
class MarkdownCodeExtractor:
"""A class that extracts code blocks from a message using Markdown syntax."""
def extract_code_blocks(self, message: Union[str, List[Dict[str, Any]], None]) -> List[CodeBlock]:
"""Extract code blocks from a message. If no code blocks are found,
return an empty list.
Args:
message (str): The message to extract code blocks from.
Returns:
List[CodeBlock]: The extracted code blocks or an empty list.
"""
text = content_str(message)
match = re.findall(CODE_BLOCK_PATTERN, text, flags=re.DOTALL)
if not match:
return []
code_blocks = []
for lang, code in match:
if lang == "":
lang = infer_lang(code)
if lang == UNKNOWN:
lang = ""
code_blocks.append(CodeBlock(code=code, language=lang))
return code_blocks

View File

@ -54,8 +54,9 @@
"import networkx as nx # noqa E402\n",
"\n",
"import autogen # noqa E402\n",
"from autogen.agentchat.conversable_agent import ConversableAgent # noqa E402\n",
"from autogen.agentchat.assistant_agent import AssistantAgent # noqa E402\n",
"from autogen.agentchat.groupchat import GroupChat, Agent # noqa E402\n",
"from autogen.agentchat.groupchat import GroupChat # noqa E402\n",
"from autogen.graph_utils import visualize_speaker_transitions_dict # noqa E402"
]
},
@ -119,7 +120,7 @@
}
],
"source": [
"agents = [Agent(name=f\"Agent{i}\") for i in range(5)]\n",
"agents = [ConversableAgent(name=f\"Agent{i}\", llm_config=False) for i in range(5)]\n",
"allowed_speaker_transitions_dict = {agent: [other_agent for other_agent in agents] for agent in agents}\n",
"\n",
"visualize_speaker_transitions_dict(allowed_speaker_transitions_dict, agents)"
@ -152,7 +153,7 @@
}
],
"source": [
"agents = [Agent(name=f\"Agent{i}\") for i in range(5)]\n",
"agents = [ConversableAgent(name=f\"Agent{i}\", llm_config=False) for i in range(5)]\n",
"allowed_speaker_transitions_dict = {\n",
" agents[0]: [agents[1], agents[2], agents[3], agents[4]],\n",
" agents[1]: [agents[0]],\n",
@ -196,14 +197,14 @@
"team_size = 5\n",
"\n",
"\n",
"def get_agent_of_name(agents, name) -> Agent:\n",
"def get_agent_of_name(agents, name) -> ConversableAgent:\n",
" for agent in agents:\n",
" if agent.name == name:\n",
" return agent\n",
"\n",
"\n",
"# Create a list of 15 agents 3 teams x 5 agents\n",
"agents = [Agent(name=f\"{team}{i}\") for team in teams for i in range(team_size)]\n",
"agents = [ConversableAgent(name=f\"{team}{i}\", llm_config=False) for team in teams for i in range(team_size)]\n",
"\n",
"# Loop through each team and add members and their connections\n",
"for team in teams:\n",
@ -239,7 +240,7 @@
}
],
"source": [
"agents = [Agent(name=f\"Agent{i}\") for i in range(2)]\n",
"agents = [ConversableAgent(name=f\"Agent{i}\", llm_config=False) for i in range(2)]\n",
"allowed_speaker_transitions_dict = {\n",
" agents[0]: [agents[0], agents[1]],\n",
" agents[1]: [agents[0], agents[1]],\n",

View File

@ -41,7 +41,10 @@ exclude = [
"_build",
"build",
"dist",
"docs"
"docs",
# This file needs to be either upgraded or removed and therefore should be
# ignore from type checking for now
"math_utils\\.py$",
]
ignore-init-module-imports = true
unfixable = ["F401"]
@ -75,9 +78,3 @@ warn_unused_ignores = true
disallow_incomplete_defs = true
disallow_untyped_decorators = true
disallow_any_unimported = true
exclude = [
# This file needs to be either upgraded or removed and therefore should be
# ignore from type checking for now
"math_utils\\.py$",
]

View File

@ -55,6 +55,7 @@ setuptools.setup(
"graph": ["networkx", "matplotlib"],
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
"redis": ["redis"],
"ipython": ["jupyter-client>=8.6.0", "ipykernel>=6.29.0"],
},
classifiers=[
"Programming Language :: Python :: 3",

View File

@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
import autogen
from autogen.agentchat.agent import Agent
from autogen.agentchat.conversable_agent import ConversableAgent
try:
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
@ -72,7 +72,7 @@ class TestMultimodalConversableAgent(unittest.TestCase):
self.assertDictEqual(self.agent._message_to_dict(message_dict), message_dict)
def test_print_received_message(self):
sender = Agent(name="SenderAgent")
sender = ConversableAgent(name="SenderAgent", llm_config=False, code_execution_config=False)
message_str = "Hello"
self.agent._print_received_message = MagicMock() # Mocking print method to avoid actual print
self.agent._print_received_message(message_str, sender)

View File

@ -489,7 +489,7 @@ def test_selection_helpers():
def test_init_default_parameters():
agents = [Agent(name=f"Agent{i}") for i in range(3)]
agents = [autogen.ConversableAgent(name=f"Agent{i}", llm_config=False) for i in range(3)]
group_chat = GroupChat(agents=agents, messages=[], max_round=3)
for agent in agents:
assert set([a.name for a in group_chat.allowed_speaker_transitions_dict[agent]]) == set(
@ -498,7 +498,7 @@ def test_init_default_parameters():
def test_graph_parameters():
agents = [Agent(name=f"Agent{i}") for i in range(3)]
agents = [autogen.ConversableAgent(name=f"Agent{i}", llm_config=False) for i in range(3)]
with pytest.raises(ValueError):
GroupChat(
agents=agents,

View File

@ -194,9 +194,17 @@ def test_update_tool():
def test_multi_tool_call():
class FakeAgent(autogen.Agent):
def __init__(self, name):
super().__init__(name)
self._name = name
self.received = []
@property
def name(self):
return self._name
@property
def description(self):
return self._name
def receive(
self,
message,
@ -281,9 +289,17 @@ def test_multi_tool_call():
async def test_async_multi_tool_call():
class FakeAgent(autogen.Agent):
def __init__(self, name):
super().__init__(name)
self._name = name
self.received = []
@property
def name(self):
return self._name
@property
def description(self):
return self._name
async def a_receive(
self,
message,

View File

@ -0,0 +1,179 @@
import sys
import tempfile
import pytest
from autogen.agentchat.conversable_agent import ConversableAgent
from autogen.coding.base import CodeBlock, CodeExecutor
from autogen.coding.factory import CodeExecutorFactory
from autogen.coding.local_commandline_code_executor import LocalCommandlineCodeExecutor
from autogen.oai.openai_utils import config_list_from_json
from conftest import skip_openai
def test_create() -> None:
config = {"executor": "commandline-local"}
executor = CodeExecutorFactory.create(config)
assert isinstance(executor, LocalCommandlineCodeExecutor)
config = {"executor": LocalCommandlineCodeExecutor()}
executor = CodeExecutorFactory.create(config)
assert executor is config["executor"]
def test_local_commandline_executor_init() -> None:
executor = LocalCommandlineCodeExecutor(timeout=10, work_dir=".")
assert executor.timeout == 10 and executor.work_dir == "."
# Try invalid working directory.
with pytest.raises(ValueError, match="Working directory .* does not exist."):
executor = LocalCommandlineCodeExecutor(timeout=111, work_dir="/invalid/directory")
def test_local_commandline_executor_execute_code() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = LocalCommandlineCodeExecutor(work_dir=temp_dir)
_test_execute_code(executor=executor)
def _test_execute_code(executor: CodeExecutor) -> None:
# Test single code block.
code_blocks = [CodeBlock(code="import sys; print('hello world!')", language="python")]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and "hello world!" in code_result.output and code_result.code_file is not None
# Test multiple code blocks.
code_blocks = [
CodeBlock(code="import sys; print('hello world!')", language="python"),
CodeBlock(code="a = 100 + 100; print(a)", language="python"),
]
code_result = executor.execute_code_blocks(code_blocks)
assert (
code_result.exit_code == 0
and "hello world!" in code_result.output
and "200" in code_result.output
and code_result.code_file is not None
)
# Test bash script.
if sys.platform not in ["win32"]:
code_blocks = [CodeBlock(code="echo 'hello world!'", language="bash")]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and "hello world!" in code_result.output and code_result.code_file is not None
# Test running code.
file_lines = ["import sys", "print('hello world!')", "a = 100 + 100", "print(a)"]
code_blocks = [CodeBlock(code="\n".join(file_lines), language="python")]
code_result = executor.execute_code_blocks(code_blocks)
assert (
code_result.exit_code == 0
and "hello world!" in code_result.output
and "200" in code_result.output
and code_result.code_file is not None
)
# Check saved code file.
with open(code_result.code_file) as f:
code_lines = f.readlines()
for file_line, code_line in zip(file_lines, code_lines):
assert file_line.strip() == code_line.strip()
@pytest.mark.skipif(sys.platform in ["win32"], reason="do not run on windows")
def test_local_commandline_code_executor_timeout() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = LocalCommandlineCodeExecutor(timeout=1, work_dir=temp_dir)
_test_timeout(executor)
def _test_timeout(executor: CodeExecutor) -> None:
code_blocks = [CodeBlock(code="import time; time.sleep(10); print('hello world!')", language="python")]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code and "Timeout" in code_result.output
def test_local_commandline_code_executor_restart() -> None:
executor = LocalCommandlineCodeExecutor()
_test_restart(executor)
def _test_restart(executor: CodeExecutor) -> None:
# Check warning.
with pytest.warns(UserWarning, match=r".*No action is taken."):
executor.restart()
@pytest.mark.skipif(skip_openai, reason="requested to skip openai tests")
def test_local_commandline_executor_conversable_agent_capability() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = LocalCommandlineCodeExecutor(work_dir=temp_dir)
_test_conversable_agent_capability(executor=executor)
def _test_conversable_agent_capability(executor: CodeExecutor) -> None:
KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
config_list = config_list_from_json(
OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={
"model": {
"gpt-3.5-turbo",
"gpt-35-turbo",
},
},
)
llm_config = {"config_list": config_list}
agent = ConversableAgent(
"coding_agent",
llm_config=llm_config,
code_execution_config=False,
)
executor.user_capability.add_to_agent(agent)
# Test updated system prompt.
assert executor.DEFAULT_SYSTEM_MESSAGE_UPDATE in agent.system_message
# Test code generation.
reply = agent.generate_reply(
[{"role": "user", "content": "write a python script to print 'hello world' to the console"}],
sender=ConversableAgent(name="user", llm_config=False, code_execution_config=False),
)
# Test code extraction.
code_blocks = executor.code_extractor.extract_code_blocks(reply) # type: ignore[arg-type]
assert len(code_blocks) == 1 and code_blocks[0].language == "python"
# Test code execution.
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and "hello world" in code_result.output.lower().replace(",", "")
def test_local_commandline_executor_conversable_agent_code_execution() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = LocalCommandlineCodeExecutor(work_dir=temp_dir)
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
_test_conversable_agent_code_execution(executor)
def _test_conversable_agent_code_execution(executor: CodeExecutor) -> None:
agent = ConversableAgent(
"user_proxy",
code_execution_config={"executor": executor},
llm_config=False,
)
assert agent.code_executor is executor
message = """
Example:
```python
print("hello extract code")
```
"""
reply = agent.generate_reply(
[{"role": "user", "content": message}],
sender=ConversableAgent("user", llm_config=False, code_execution_config=False),
)
assert "hello extract code" in reply # type: ignore[operator]

View File

@ -0,0 +1,219 @@
import os
import tempfile
from typing import Dict, Union
import uuid
import pytest
from autogen.agentchat.conversable_agent import ConversableAgent
from autogen.coding.base import CodeBlock, CodeExecutor
from autogen.coding.factory import CodeExecutorFactory
from autogen.oai.openai_utils import config_list_from_json
from conftest import skip_openai # noqa: E402
try:
from autogen.coding.embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor
skip = False
skip_reason = ""
except ImportError:
skip = True
skip_reason = "Dependencies for EmbeddedIPythonCodeExecutor not installed."
@pytest.mark.skipif(skip, reason=skip_reason)
def test_create() -> None:
config: Dict[str, Union[str, CodeExecutor]] = {"executor": "ipython-embedded"}
executor = CodeExecutorFactory.create(config)
assert isinstance(executor, EmbeddedIPythonCodeExecutor)
config = {"executor": EmbeddedIPythonCodeExecutor()}
executor = CodeExecutorFactory.create(config)
assert executor is config["executor"]
@pytest.mark.skipif(skip, reason=skip_reason)
def test_init() -> None:
executor = EmbeddedIPythonCodeExecutor(timeout=10, kernel_name="python3", output_dir=".")
assert executor.timeout == 10 and executor.kernel_name == "python3" and executor.output_dir == "."
# Try invalid output directory.
with pytest.raises(ValueError, match="Output directory .* does not exist."):
executor = EmbeddedIPythonCodeExecutor(timeout=111, kernel_name="python3", output_dir="/invalid/directory")
# Try invalid kernel name.
with pytest.raises(ValueError, match="Kernel .* is not installed."):
executor = EmbeddedIPythonCodeExecutor(timeout=111, kernel_name="invalid_kernel_name", output_dir=".")
@pytest.mark.skipif(skip, reason=skip_reason)
def test_execute_code_single_code_block() -> None:
executor = EmbeddedIPythonCodeExecutor()
code_blocks = [CodeBlock(code="import sys\nprint('hello world!')", language="python")]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and "hello world!" in code_result.output
@pytest.mark.skipif(skip, reason=skip_reason)
def test_execute_code_multiple_code_blocks() -> None:
executor = EmbeddedIPythonCodeExecutor()
code_blocks = [
CodeBlock(code="import sys\na = 123 + 123\n", language="python"),
CodeBlock(code="print(a)", language="python"),
]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and "246" in code_result.output
msg = """
def test_function(a, b):
return a + b
"""
code_blocks = [
CodeBlock(code=msg, language="python"),
CodeBlock(code="test_function(431, 423)", language="python"),
]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and "854" in code_result.output
@pytest.mark.skipif(skip, reason=skip_reason)
def test_execute_code_bash_script() -> None:
executor = EmbeddedIPythonCodeExecutor()
# Test bash script.
code_blocks = [CodeBlock(code='!echo "hello world!"', language="bash")]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and "hello world!" in code_result.output
@pytest.mark.skipif(skip, reason=skip_reason)
def test_timeout() -> None:
executor = EmbeddedIPythonCodeExecutor(timeout=1)
code_blocks = [CodeBlock(code="import time; time.sleep(10); print('hello world!')", language="python")]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code and "Timeout" in code_result.output
@pytest.mark.skipif(skip, reason=skip_reason)
def test_silent_pip_install() -> None:
executor = EmbeddedIPythonCodeExecutor()
code_blocks = [CodeBlock(code="!pip install matplotlib numpy", language="python")]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and code_result.output.strip() == ""
none_existing_package = uuid.uuid4().hex
code_blocks = [CodeBlock(code=f"!pip install matplotlib_{none_existing_package}", language="python")]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and "ERROR: " in code_result.output
@pytest.mark.skipif(skip, reason=skip_reason)
def test_restart() -> None:
executor = EmbeddedIPythonCodeExecutor()
code_blocks = [CodeBlock(code="x = 123", language="python")]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and code_result.output.strip() == ""
executor.restart()
code_blocks = [CodeBlock(code="print(x)", language="python")]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code and "NameError" in code_result.output
@pytest.mark.skipif(skip, reason=skip_reason)
def test_save_image() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = EmbeddedIPythonCodeExecutor(output_dir=temp_dir)
# Install matplotlib.
code_blocks = [CodeBlock(code="!pip install matplotlib", language="python")]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and code_result.output.strip() == ""
# Test saving image.
code_blocks = [
CodeBlock(code="import matplotlib.pyplot as plt\nplt.plot([1, 2, 3, 4])\nplt.show()", language="python")
]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0
assert os.path.exists(code_result.output_files[0])
assert f"Image data saved to {code_result.output_files[0]}" in code_result.output
@pytest.mark.skipif(skip, reason=skip_reason)
def test_save_html() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = EmbeddedIPythonCodeExecutor(output_dir=temp_dir)
# Test saving html.
code_blocks = [
CodeBlock(code="from IPython.display import HTML\nHTML('<h1>Hello, world!</h1>')", language="python")
]
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0
assert os.path.exists(code_result.output_files[0])
assert f"HTML data saved to {code_result.output_files[0]}" in code_result.output
@pytest.mark.skipif(skip, reason=skip_reason)
@pytest.mark.skipif(skip_openai, reason="openai not installed OR requested to skip")
def test_conversable_agent_capability() -> None:
KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
config_list = config_list_from_json(
OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={
"model": {
"gpt-3.5-turbo",
"gpt-35-turbo",
},
},
)
llm_config = {"config_list": config_list}
agent = ConversableAgent(
"coding_agent",
llm_config=llm_config,
code_execution_config=False,
)
executor = EmbeddedIPythonCodeExecutor()
executor.user_capability.add_to_agent(agent)
# Test updated system prompt.
assert executor.DEFAULT_SYSTEM_MESSAGE_UPDATE in agent.system_message
# Test code generation.
reply = agent.generate_reply(
[{"role": "user", "content": "print 'hello world' to the console in a single python code block"}],
sender=ConversableAgent("user", llm_config=False, code_execution_config=False),
)
# Test code extraction.
code_blocks = executor.code_extractor.extract_code_blocks(reply) # type: ignore[arg-type]
assert len(code_blocks) == 1 and code_blocks[0].language == "python"
# Test code execution.
code_result = executor.execute_code_blocks(code_blocks)
assert code_result.exit_code == 0 and "hello world" in code_result.output.lower()
@pytest.mark.skipif(skip, reason=skip_reason)
def test_conversable_agent_code_execution() -> None:
agent = ConversableAgent(
"user_proxy",
llm_config=False,
code_execution_config={"executor": "ipython-embedded"},
)
msg = """
Run this code:
```python
def test_function(a, b):
return a * b
```
And then this:
```python
print(test_function(123, 4))
```
"""
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
reply = agent.generate_reply(
[{"role": "user", "content": msg}],
sender=ConversableAgent("user", llm_config=False, code_execution_config=False),
)
assert "492" in reply # type: ignore[operator]

View File

@ -0,0 +1,12 @@
import pytest
from autogen.coding.factory import CodeExecutorFactory
def test_create_unknown() -> None:
config = {"executor": "unknown"}
with pytest.raises(ValueError, match="Unknown code executor unknown"):
CodeExecutorFactory.create(config)
config = {}
with pytest.raises(ValueError, match="Unknown code executor None"):
CodeExecutorFactory.create(config)

View File

@ -0,0 +1,115 @@
from autogen.coding import MarkdownCodeExtractor
_message_1 = """
Example:
```
print("hello extract code")
```
"""
_message_2 = """Example:
```python
def scrape(url):
import requests
from bs4 import BeautifulSoup
response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser")
title = soup.find("title").text
text = soup.find("div", {"id": "bodyContent"}).text
return title, text
```
Test:
```python
url = "https://en.wikipedia.org/wiki/Web_scraping"
title, text = scrape(url)
print(f"Title: {title}")
print(f"Text: {text}")
```
"""
_message_3 = """
Example:
```python
def scrape(url):
import requests
from bs4 import BeautifulSoup
response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser")
title = soup.find("title").text
text = soup.find("div", {"id": "bodyContent"}).text
return title, text
```
"""
_message_4 = """
Example:
``` python
def scrape(url):
import requests
from bs4 import BeautifulSoup
response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser")
title = soup.find("title").text
text = soup.find("div", {"id": "bodyContent"}).text
return title, text
```
""".replace(
"\n", "\r\n"
)
_message_5 = """
Test bash script:
```bash
echo 'hello world!'
```
"""
_message_6 = """
Test some C# code, expecting ""
```
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace ConsoleApplication1
{
class Program
{
static void Main(string[] args)
{
Console.WriteLine("Hello World");
}
}
}
```
"""
_message_7 = """
Test some message that has no code block.
"""
def test_extract_code() -> None:
extractor = MarkdownCodeExtractor()
code_blocks = extractor.extract_code_blocks(_message_1)
assert len(code_blocks) == 1 and code_blocks[0].language == "python"
code_blocks = extractor.extract_code_blocks(_message_2)
assert len(code_blocks) == 2 and code_blocks[0].language == "python" and code_blocks[1].language == "python"
code_blocks = extractor.extract_code_blocks(_message_3)
assert len(code_blocks) == 1 and code_blocks[0].language == "python"
code_blocks = extractor.extract_code_blocks(_message_4)
assert len(code_blocks) == 1 and code_blocks[0].language == "python"
code_blocks = extractor.extract_code_blocks(_message_5)
assert len(code_blocks) == 1 and code_blocks[0].language == "bash"
code_blocks = extractor.extract_code_blocks(_message_6)
assert len(code_blocks) == 1 and code_blocks[0].language == ""
code_blocks = extractor.extract_code_blocks(_message_7)
assert len(code_blocks) == 0

View File

@ -1,13 +1,23 @@
from typing import Any
import pytest
import logging
from autogen.agentchat import Agent
import autogen.graph_utils as gru
class FakeAgent(Agent):
def __init__(self, name) -> None:
self._name = name
@property
def name(self) -> str:
return self._name
class TestHelpers:
def test_has_self_loops(self):
# Setup test data
agents = [Agent(name=f"Agent{i}") for i in range(3)]
agents = [FakeAgent(name=f"Agent{i}") for i in range(3)]
allowed_speaker_transitions = {
agents[0]: [agents[1], agents[2]],
agents[1]: [agents[2]],
@ -26,19 +36,19 @@ class TestHelpers:
class TestGraphUtilCheckGraphValidity:
def test_valid_structure(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
agents = [FakeAgent("agent1"), FakeAgent("agent2"), FakeAgent("agent3")]
valid_speaker_transitions_dict = {agent: [other_agent for other_agent in agents] for agent in agents}
gru.check_graph_validity(allowed_speaker_transitions_dict=valid_speaker_transitions_dict, agents=agents)
def test_graph_with_invalid_structure(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
unseen_agent = Agent("unseen_agent")
agents = [FakeAgent("agent1"), FakeAgent("agent2"), FakeAgent("agent3")]
unseen_agent = FakeAgent("unseen_agent")
invalid_speaker_transitions_dict = {unseen_agent: ["stranger"]}
with pytest.raises(ValueError):
gru.check_graph_validity(invalid_speaker_transitions_dict, agents)
def test_graph_with_invalid_string(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
agents = [FakeAgent("agent1"), FakeAgent("agent2"), FakeAgent("agent3")]
invalid_speaker_transitions_dict = {
agent: ["agent1"] for agent in agents
} # 'agent1' is a string, not an Agent. Therefore raises an error.
@ -46,13 +56,13 @@ class TestGraphUtilCheckGraphValidity:
gru.check_graph_validity(invalid_speaker_transitions_dict, agents)
def test_graph_with_invalid_key(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
agents = [FakeAgent("agent1"), FakeAgent("agent2"), FakeAgent("agent3")]
with pytest.raises(ValueError):
gru.check_graph_validity({1: 1}, agents)
# Test for Warning 1: Isolated agent nodes
def test_isolated_agent_nodes_warning(self, caplog):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
agents = [FakeAgent("agent1"), FakeAgent("agent2"), FakeAgent("agent3")]
# Create a speaker_transitions_dict where at least one agent is isolated
speaker_transitions_dict_with_isolation = {agents[0]: [agents[0], agents[1]], agents[1]: [agents[0]]}
# Add an isolated agent
@ -66,14 +76,14 @@ class TestGraphUtilCheckGraphValidity:
# Test for Warning 2: Warning if the set of agents in allowed_speaker_transitions do not match agents
def test_warning_for_mismatch_in_agents(self, caplog):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
agents = [FakeAgent("agent1"), FakeAgent("agent2"), FakeAgent("agent3")]
# Test with missing agents in allowed_speaker_transitions_dict
unknown_agent_dict = {
agents[0]: [agents[0], agents[1], agents[2]],
agents[1]: [agents[0], agents[1], agents[2]],
agents[2]: [agents[0], agents[1], agents[2], Agent("unknown_agent")],
agents[2]: [agents[0], agents[1], agents[2], FakeAgent("unknown_agent")],
}
with caplog.at_level(logging.WARNING):
@ -83,7 +93,7 @@ class TestGraphUtilCheckGraphValidity:
# Test for Warning 3: Warning if there is duplicated agents in allowed_speaker_transitions_dict
def test_warning_for_duplicate_agents(self, caplog):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
agents = [FakeAgent("agent1"), FakeAgent("agent2"), FakeAgent("agent3")]
# Construct an `allowed_speaker_transitions_dict` with duplicated agents
duplicate_agents_dict = {
@ -100,7 +110,7 @@ class TestGraphUtilCheckGraphValidity:
class TestGraphUtilInvertDisallowedToAllowed:
def test_basic_functionality(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
agents = [FakeAgent("agent1"), FakeAgent("agent2"), FakeAgent("agent3")]
disallowed_graph = {agents[0]: [agents[1]], agents[1]: [agents[0], agents[2]], agents[2]: []}
expected_allowed_graph = {
agents[0]: [agents[0], agents[2]],
@ -113,7 +123,7 @@ class TestGraphUtilInvertDisallowedToAllowed:
assert inverted == expected_allowed_graph
def test_empty_disallowed_graph(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
agents = [FakeAgent("agent1"), FakeAgent("agent2"), FakeAgent("agent3")]
disallowed_graph = {}
expected_allowed_graph = {
agents[0]: [agents[0], agents[1], agents[2]],
@ -126,7 +136,7 @@ class TestGraphUtilInvertDisallowedToAllowed:
assert inverted == expected_allowed_graph
def test_fully_disallowed_graph(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
agents = [FakeAgent("agent1"), FakeAgent("agent2"), FakeAgent("agent3")]
disallowed_graph = {
agents[0]: [agents[0], agents[1], agents[2]],
@ -140,9 +150,9 @@ class TestGraphUtilInvertDisallowedToAllowed:
assert inverted == expected_allowed_graph
def test_disallowed_graph_with_nonexistent_agent(self):
agents = [Agent("agent1"), Agent("agent2"), Agent("agent3")]
agents = [FakeAgent("agent1"), FakeAgent("agent2"), FakeAgent("agent3")]
disallowed_graph = {agents[0]: [Agent("nonexistent_agent")]}
disallowed_graph = {agents[0]: [FakeAgent("nonexistent_agent")]}
# In this case, the function should ignore the nonexistent agent and proceed with the inversion
expected_allowed_graph = {
agents[0]: [agents[0], agents[1], agents[2]],

View File

@ -11,28 +11,21 @@ pip install "pyautogen[redis]"
See [LLM Caching](Use-Cases/agent_chat.md#llm-caching) for details.
## Docker
## IPython Code Executor
Even if you install AutoGen locally, we highly recommend using Docker for [code execution](FAQ.md#enable-python-3-docker-image).
To use docker for code execution, you also need to install the python package `docker`:
To use the IPython code executor, you need to install the `jupyter-client`
and `ipykernel` packages:
```bash
pip install docker
pip install "pyautogen[ipython]"
```
You might want to override the default docker image used for code execution. To do that set `use_docker` key of `code_execution_config` property to the name of the image. E.g.:
To use the IPython code executor:
```python
user_proxy = autogen.UserProxyAgent(
name="agent",
human_input_mode="TERMINATE",
max_consecutive_auto_reply=10,
code_execution_config={"work_dir":"_output", "use_docker":"python:3"},
llm_config=llm_config,
system_message=""""Reply TERMINATE if the task has been solved at full satisfaction.
Otherwise, reply CONTINUE, or the reason why the task is not solved yet."""
)
from autogen import UserProxyAgent
proxy = UserProxyAgent(name="proxy", code_execution_config={"executor": "ipython-embedded"})
```
## blendsearch