mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-26 14:38:50 +00:00
Code execute cancellation (#299)
* Hook cancelation token into code execution * Add unit test for code cancellation * actually save the merge
This commit is contained in:
parent
ec654253d2
commit
136af65b74
@ -169,7 +169,7 @@ class Executor(TypeRoutedAgent):
|
||||
code = self._extract_execution_request(message.execution_request)
|
||||
if code is not None:
|
||||
execution_requests = [CodeBlock(code=code, language="python")]
|
||||
result = await self._executor.execute_code_blocks(execution_requests)
|
||||
result = await self._executor.execute_code_blocks(execution_requests, cancellation_token)
|
||||
await self.publish_message(
|
||||
CodeExecutionResultMessage(
|
||||
output=result.output,
|
||||
|
||||
@ -155,7 +155,9 @@ class Executor(TypeRoutedAgent):
|
||||
)
|
||||
return
|
||||
# Execute code blocks.
|
||||
result = await self._executor.execute_code_blocks(code_blocks=code_blocks)
|
||||
result = await self._executor.execute_code_blocks(
|
||||
code_blocks=code_blocks, cancellation_token=cancellation_token
|
||||
)
|
||||
# Publish the code execution result.
|
||||
await self.publish_message(
|
||||
CodeExecutionTaskResult(output=result.output, exit_code=result.exit_code, session_id=message.session_id),
|
||||
|
||||
@ -6,6 +6,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Protocol, runtime_checkable
|
||||
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeBlock:
|
||||
@ -27,7 +29,9 @@ class CodeResult:
|
||||
class CodeExecutor(Protocol):
|
||||
"""Executes code blocks and returns the result."""
|
||||
|
||||
async def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CodeResult:
|
||||
async def execute_code_blocks(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CodeResult:
|
||||
"""Execute code blocks and return the result.
|
||||
|
||||
This method should be implemented by the code executor.
|
||||
|
||||
@ -8,7 +8,7 @@ import warnings
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Callable, ClassVar, List, Sequence, Union, Optional
|
||||
from typing import Any, Callable, ClassVar, List, Sequence, Union
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
@ -22,7 +22,6 @@ from .._func_with_reqs import (
|
||||
)
|
||||
from .command_line_code_result import CommandLineCodeResult
|
||||
from .utils import PYTHON_VARIANTS, get_file_name_from_content, lang_to_cmd, silence_pip # type: ignore
|
||||
from ....core import CancellationToken
|
||||
|
||||
__all__ = ("LocalCommandLineCodeExecutor",)
|
||||
|
||||
@ -146,7 +145,7 @@ $functions"""
|
||||
"""(Experimental) The working directory for the code execution."""
|
||||
return self._work_dir
|
||||
|
||||
async def _setup_functions(self, cancellation_token: Optional[CancellationToken]) -> None:
|
||||
async def _setup_functions(self, cancellation_token: CancellationToken) -> None:
|
||||
func_file_content = build_python_functions_file(self._functions)
|
||||
func_file = self._work_dir / f"{self._functions_module}.py"
|
||||
func_file.write_text(func_file_content)
|
||||
@ -170,8 +169,7 @@ $functions"""
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
)
|
||||
if cancellation_token:
|
||||
cancellation_token.link_future(task)
|
||||
cancellation_token.link_future(task)
|
||||
try:
|
||||
proc = await task
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), self._timeout)
|
||||
@ -194,13 +192,13 @@ $functions"""
|
||||
self._setup_functions_complete = True
|
||||
|
||||
async def execute_code_blocks(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: Optional[CancellationToken] = None
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CommandLineCodeResult:
|
||||
"""(Experimental) Execute the code blocks and return the result.
|
||||
|
||||
Args:
|
||||
code_blocks (List[CodeBlock]): The code blocks to execute.
|
||||
cancellation_token (CancellationToken|None): an optional token to cancel the operation
|
||||
cancellation_token (CancellationToken): a token to cancel the operation
|
||||
|
||||
Returns:
|
||||
CommandLineCodeResult: The result of the code execution."""
|
||||
@ -211,7 +209,7 @@ $functions"""
|
||||
return await self._execute_code_dont_check_setup(code_blocks, cancellation_token)
|
||||
|
||||
async def _execute_code_dont_check_setup(
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: Optional[CancellationToken]
|
||||
self, code_blocks: List[CodeBlock], cancellation_token: CancellationToken
|
||||
) -> CommandLineCodeResult:
|
||||
logs_all: str = ""
|
||||
file_names: List[Path] = []
|
||||
@ -262,8 +260,7 @@ $functions"""
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
)
|
||||
if cancellation_token:
|
||||
cancellation_token.link_future(task)
|
||||
cancellation_token.link_future(task)
|
||||
try:
|
||||
proc = await task
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), self._timeout)
|
||||
|
||||
@ -25,6 +25,8 @@ class PythonCodeExecutionTool(BaseTool[CodeExecutionInput, CodeExecutionResult])
|
||||
|
||||
async def run(self, args: CodeExecutionInput, cancellation_token: CancellationToken) -> CodeExecutionResult:
|
||||
code_blocks = [CodeBlock(code=args.code, language="python")]
|
||||
result = await self._executor.execute_code_blocks(code_blocks=code_blocks)
|
||||
result = await self._executor.execute_code_blocks(
|
||||
code_blocks=code_blocks, cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
return CodeExecutionResult(success=result.exit_code == 0, output=result.output)
|
||||
|
||||
@ -80,7 +80,7 @@ class Executor(BaseWorker):
|
||||
code = self._extract_execution_request(message_content_to_str(message.content))
|
||||
if code is not None:
|
||||
execution_requests = [CodeBlock(code=code, language="python")]
|
||||
result = await self._executor.execute_code_blocks(execution_requests)
|
||||
result = await self._executor.execute_code_blocks(execution_requests, cancellation_token)
|
||||
|
||||
if result.output.strip() == "":
|
||||
# Sometimes agents forget to print(). Remind the to print something
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/test/coding/test_commandline_code_executor.py
|
||||
# Credit to original authors
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from agnext.components.code_executor import CodeBlock, LocalCommandLineCodeExecutor
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
UNIX_SHELLS = ["bash", "sh", "shell"]
|
||||
WINDOWS_SHELLS = ["ps1", "pwsh", "powershell"]
|
||||
@ -15,12 +17,12 @@ PYTHON_VARIANTS = ["python", "Python", "py"]
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_code() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cancellation_token = CancellationToken()
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir)
|
||||
|
||||
|
||||
# Test single code block.
|
||||
code_blocks = [CodeBlock(code="import sys; print('hello world!')", language="python")]
|
||||
code_result = await executor.execute_code_blocks(code_blocks)
|
||||
code_result = await executor.execute_code_blocks(code_blocks, cancellation_token)
|
||||
assert code_result.exit_code == 0 and "hello world!" in code_result.output and code_result.code_file is not None
|
||||
|
||||
# Test multiple code blocks.
|
||||
@ -28,7 +30,7 @@ async def test_execute_code() -> None:
|
||||
CodeBlock(code="import sys; print('hello world!')", language="python"),
|
||||
CodeBlock(code="a = 100 + 100; print(a)", language="python"),
|
||||
]
|
||||
code_result = await executor.execute_code_blocks(code_blocks)
|
||||
code_result = await executor.execute_code_blocks(code_blocks, cancellation_token)
|
||||
assert (
|
||||
code_result.exit_code == 0
|
||||
and "hello world!" in code_result.output
|
||||
@ -39,13 +41,13 @@ async def test_execute_code() -> None:
|
||||
# Test bash script.
|
||||
if sys.platform not in ["win32"]:
|
||||
code_blocks = [CodeBlock(code="echo 'hello world!'", language="bash")]
|
||||
code_result = await executor.execute_code_blocks(code_blocks)
|
||||
code_result = await executor.execute_code_blocks(code_blocks, cancellation_token)
|
||||
assert code_result.exit_code == 0 and "hello world!" in code_result.output and code_result.code_file is not None
|
||||
|
||||
# Test running code.
|
||||
file_lines = ["import sys", "print('hello world!')", "a = 100 + 100", "print(a)"]
|
||||
code_blocks = [CodeBlock(code="\n".join(file_lines), language="python")]
|
||||
code_result = await executor.execute_code_blocks(code_blocks)
|
||||
code_result = await executor.execute_code_blocks(code_blocks, cancellation_token)
|
||||
assert (
|
||||
code_result.exit_code == 0
|
||||
and "hello world!" in code_result.output
|
||||
@ -62,11 +64,26 @@ async def test_execute_code() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_commandline_code_executor_timeout() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cancellation_token = CancellationToken()
|
||||
executor = LocalCommandLineCodeExecutor(timeout=1, work_dir=temp_dir)
|
||||
code_blocks = [CodeBlock(code="import time; time.sleep(10); print('hello world!')", language="python")]
|
||||
code_result = await executor.execute_code_blocks(code_blocks)
|
||||
code_result = await executor.execute_code_blocks(code_blocks, cancellation_token)
|
||||
assert code_result.exit_code and "Timeout" in code_result.output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commandline_code_executor_cancellation() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cancellation_token = CancellationToken()
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir)
|
||||
code_blocks = [CodeBlock(code="import time; time.sleep(10); print('hello world!')", language="python")]
|
||||
|
||||
coro = executor.execute_code_blocks(code_blocks, cancellation_token)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
cancellation_token.cancel()
|
||||
code_result = await coro
|
||||
|
||||
assert code_result.exit_code and "Cancelled" in code_result.output
|
||||
|
||||
def test_local_commandline_code_executor_restart() -> None:
|
||||
executor = LocalCommandLineCodeExecutor()
|
||||
@ -77,24 +94,26 @@ def test_local_commandline_code_executor_restart() -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_relative_path() -> None:
|
||||
cancellation_token = CancellationToken()
|
||||
executor = LocalCommandLineCodeExecutor()
|
||||
code = """# filename: /tmp/test.py
|
||||
|
||||
print("hello world")
|
||||
"""
|
||||
result = await executor.execute_code_blocks([CodeBlock(code=code, language="python")])
|
||||
result = await executor.execute_code_blocks([CodeBlock(code=code, language="python")], cancellation_token=cancellation_token)
|
||||
assert result.exit_code == 1 and "Filename is not in the workspace" in result.output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_relative_path() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||
cancellation_token = CancellationToken()
|
||||
temp_dir = Path(temp_dir_str)
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir)
|
||||
code = """# filename: test.py
|
||||
|
||||
print("hello world")
|
||||
"""
|
||||
result = await executor.execute_code_blocks([CodeBlock(code=code, language="python")])
|
||||
result = await executor.execute_code_blocks([CodeBlock(code=code, language="python")], cancellation_token=cancellation_token)
|
||||
assert result.exit_code == 0
|
||||
assert "hello world" in result.output
|
||||
assert result.code_file is not None
|
||||
|
||||
@ -11,7 +11,7 @@ from agnext.components.code_executor import (
|
||||
LocalCommandLineCodeExecutor,
|
||||
with_requirements,
|
||||
)
|
||||
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
def add_two_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
@ -50,6 +50,7 @@ def function_missing_reqs() -> "polars.DataFrame":
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_load_function_with_reqs() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cancellation_token = CancellationToken()
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[load_data]
|
||||
)
|
||||
@ -63,7 +64,8 @@ print(data['name'][0])"""
|
||||
result = await executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
],
|
||||
cancellation_token=cancellation_token
|
||||
)
|
||||
assert result.output == "John\n"
|
||||
assert result.exit_code == 0
|
||||
@ -72,6 +74,7 @@ print(data['name'][0])"""
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_load_function() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cancellation_token = CancellationToken()
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[add_two_numbers]
|
||||
)
|
||||
@ -81,7 +84,8 @@ print(add_two_numbers(1, 2))"""
|
||||
result = await executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
],
|
||||
cancellation_token=cancellation_token
|
||||
)
|
||||
assert result.output == "3\n"
|
||||
assert result.exit_code == 0
|
||||
@ -90,6 +94,7 @@ print(add_two_numbers(1, 2))"""
|
||||
@pytest.mark.asyncio
|
||||
async def test_fails_for_function_incorrect_import() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cancellation_token = CancellationToken()
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[function_incorrect_import]
|
||||
)
|
||||
@ -100,13 +105,15 @@ function_incorrect_import()"""
|
||||
await executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
],
|
||||
cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fails_for_function_incorrect_dep() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cancellation_token = CancellationToken()
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[function_incorrect_dep]
|
||||
)
|
||||
@ -117,7 +124,8 @@ function_incorrect_dep()"""
|
||||
await executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
],
|
||||
cancellation_token=cancellation_token
|
||||
)
|
||||
|
||||
|
||||
@ -159,6 +167,7 @@ def add_two_numbers(a: int, b: int) -> int:
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_load_str_function_with_reqs() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cancellation_token = CancellationToken()
|
||||
func = FunctionWithRequirements.from_str(
|
||||
'''
|
||||
def add_two_numbers(a: int, b: int) -> int:
|
||||
@ -174,7 +183,8 @@ print(add_two_numbers(1, 2))"""
|
||||
result = await executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
],
|
||||
cancellation_token=cancellation_token
|
||||
)
|
||||
assert result.output == "3\n"
|
||||
assert result.exit_code == 0
|
||||
@ -195,6 +205,7 @@ invaliddef add_two_numbers(a: int, b: int) -> int:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cant_run_broken_str_function_with_reqs() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cancellation_token = CancellationToken()
|
||||
func = FunctionWithRequirements.from_str(
|
||||
'''
|
||||
def add_two_numbers(a: int, b: int) -> int:
|
||||
@ -210,7 +221,8 @@ print(add_two_numbers(object(), False))"""
|
||||
result = await executor.execute_code_blocks(
|
||||
code_blocks=[
|
||||
CodeBlock(language="python", code=code),
|
||||
]
|
||||
],
|
||||
cancellation_token=cancellation_token
|
||||
)
|
||||
assert "TypeError: unsupported operand type(s) for +:" in result.output
|
||||
assert result.exit_code == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user