Added an optional sources parameter to CodeExecutorAgent (#5259)

This PR adds a `sources` optional parameter to CodeExecutorAgent
(similar to the termination conditions), that allows finer-grained
control on which agents can provide code for execution.

It also moves the `_extract_markdown_code_blocks` subroutine to a member
method, so that it can be overridden by subclasses. I've found this to
be very important to support benchmarks like HumanEval, where we need to
add a test harness around the implementation.
This commit is contained in:
afourney 2025-01-29 23:28:57 -08:00 committed by GitHub
parent 403844ef2b
commit fff201f813
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,22 +9,18 @@ from ..messages import ChatMessage, TextMessage
from ._base_chat_agent import BaseChatAgent
def _extract_markdown_code_blocks(markdown_text: str) -> List[CodeBlock]:
pattern = re.compile(r"```(?:\s*([\w\+\-]+))?\n([\s\S]*?)```")
matches = pattern.findall(markdown_text)
code_blocks: List[CodeBlock] = []
for match in matches:
language = match[0].strip() if match[0] else ""
code_content = match[1]
code_blocks.append(CodeBlock(code=code_content, language=language))
return code_blocks
class CodeExecutorAgent(BaseChatAgent):
"""An agent that extracts and executes code snippets found in received messages and returns the output.
It is typically used within a team with another agent that generates code snippets to be executed.
Args:
name: The name of the agent.
code_executor: The CodeExecutor responsible for executing code received in messages (:py:class:`~autogen_ext.code_executors.docker.DockerCommandLineCodeExecutor` recommended. See example below)
description (optional): The description of the agent.
sources (optional): Check only messages from the specified agents for the code to execute.
.. note::
It is recommended that the `CodeExecutorAgent` agent uses a Docker container to execute code. This ensures that model-generated code is executed in an isolated environment. To use Docker, your environment must have Docker installed and running.
@ -75,9 +71,11 @@ class CodeExecutorAgent(BaseChatAgent):
code_executor: CodeExecutor,
*,
description: str = "A computer terminal that performs no other action than running Python scripts (provided to it quoted in ```python code blocks), or sh shell scripts (provided to it quoted in ```sh code blocks).",
sources: Sequence[str] | None = None,
) -> None:
super().__init__(name=name, description=description)
self._code_executor = code_executor
self._sources = sources
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
@ -89,7 +87,8 @@ class CodeExecutorAgent(BaseChatAgent):
code_blocks: List[CodeBlock] = []
for msg in messages:
if isinstance(msg, TextMessage):
code_blocks.extend(_extract_markdown_code_blocks(msg.content))
if self._sources is None or msg.source in self._sources:
code_blocks.extend(self._extract_markdown_code_blocks(msg.content))
if code_blocks:
# Execute the code blocks.
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
@ -114,3 +113,13 @@ class CodeExecutorAgent(BaseChatAgent):
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""It it's a no-op as the code executor agent has no mutable state."""
pass
def _extract_markdown_code_blocks(self, markdown_text: str) -> List[CodeBlock]:
pattern = re.compile(r"```(?:\s*([\w\+\-]+))?\n([\s\S]*?)```")
matches = pattern.findall(markdown_text)
code_blocks: List[CodeBlock] = []
for match in matches:
language = match[0].strip() if match[0] else ""
code_content = match[1]
code_blocks.append(CodeBlock(code=code_content, language=language))
return code_blocks