2024-07-02 10:58:49 -07:00
import asyncio
import json
import re
import uuid
from dataclasses import dataclass
from typing import Any , Dict , List , Tuple , Union
from agnext . application import SingleThreadedAgentRuntime
from agnext . components import FunctionCall , TypeRoutedAgent , message_handler
2024-07-02 16:18:48 -07:00
from agnext . components . code_executor import (
CodeBlock ,
CodeExecutor ,
LocalCommandLineCodeExecutor ,
)
2024-07-02 10:58:49 -07:00
from agnext . components . models import (
AssistantMessage ,
AzureOpenAIChatCompletionClient ,
ChatCompletionClient ,
FunctionExecutionResult ,
FunctionExecutionResultMessage ,
LLMMessage ,
ModelCapabilities ,
OpenAIChatCompletionClient ,
SystemMessage ,
UserMessage ,
)
from agnext . components . tools import CodeExecutionResult , PythonCodeExecutionTool
from agnext . core import AgentId , CancellationToken
2024-07-02 16:18:48 -07:00
# from azure.identity import DefaultAzureCredential, get_bearer_token_provider
2024-07-02 10:58:49 -07:00
@dataclass
class TaskMessage :
content : str
2024-07-02 16:18:48 -07:00
2024-07-02 10:58:49 -07:00
@dataclass
class CodeExecutionRequestMessage :
session_id : str
execution_request : str
2024-07-02 16:18:48 -07:00
2024-07-02 10:58:49 -07:00
@dataclass
class CodeExecutionResultMessage :
session_id : str
output : str
exit_code : int
2024-07-02 16:18:48 -07:00
2024-07-02 10:58:49 -07:00
class Coder ( TypeRoutedAgent ) :
""" An agent that uses tools to write, execute, and debug Python code. """
DEFAULT_DESCRIPTION = " A Python coder assistant. "
DEFAULT_SYSTEM_MESSAGES = [
2024-07-02 16:18:48 -07:00
SystemMessage (
""" You are a helpful AI assistant. Solve tasks using your Python coding skills. The code you output must be formatted in Markdown code blocks demarcated by triple backticks (```). As an example:
2024-07-02 10:58:49 -07:00
` ` ` python
def main ( ) :
print ( " Hello world. " )
if __name__ == " __main__ " :
main ( )
` ` `
The user cannot provide any feedback or perform any other action beyond executing the code you suggest . In particular , the user can ' t modify your code, and can ' t copy and paste anything , and can ' t fill in missing values. Thus, do not suggest incomplete code which requires users to perform any of these actions.
Check the execution result returned by the user . If the result indicates there is an error , fix the error and output the code again . Suggest the full code instead of partial code or code changes - - code blocks must stand alone and be ready to execute without modification . If the error can ' t be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, and think of a different approach to try.
If the code has executed successfully , and the problem is stolved , reply " TERMINATE " .
2024-07-02 16:18:48 -07:00
"""
)
2024-07-02 10:58:49 -07:00
]
def __init__ (
self ,
model_client : ChatCompletionClient ,
description : str = DEFAULT_DESCRIPTION ,
system_messages : List [ SystemMessage ] = DEFAULT_SYSTEM_MESSAGES ,
max_turns : int | None = None ,
) - > None :
super ( ) . __init__ ( description )
self . _model_client = model_client
self . _system_messages = system_messages
self . _session_memory : Dict [ str , List [ LLMMessage ] ] = { }
self . _max_turns = max_turns
@message_handler
async def handle_user_message (
self , message : TaskMessage , cancellation_token : CancellationToken
) - > None :
""" Handle a user message, execute the model and tools, and returns the response. """
# Create a new session.
session_id = str ( uuid . uuid4 ( ) )
2024-07-02 16:18:48 -07:00
self . _session_memory . setdefault ( session_id , [ ] ) . append (
UserMessage ( content = message . content , source = " user " )
)
2024-07-02 10:58:49 -07:00
# Make an inference to the model.
2024-07-02 16:18:48 -07:00
response = await self . _model_client . create (
self . _system_messages + self . _session_memory [ session_id ]
)
2024-07-02 10:58:49 -07:00
assert isinstance ( response . content , str )
2024-07-02 16:18:48 -07:00
self . _session_memory [ session_id ] . append (
AssistantMessage ( content = response . content , source = self . metadata [ " name " ] )
)
2024-07-02 10:58:49 -07:00
2024-07-02 16:18:48 -07:00
await self . publish_message (
CodeExecutionRequestMessage (
execution_request = response . content , session_id = session_id
) ,
cancellation_token = cancellation_token ,
)
2024-07-02 10:58:49 -07:00
@message_handler
2024-07-02 16:18:48 -07:00
async def handle_code_execution_result (
self , message : CodeExecutionResultMessage , cancellation_token : CancellationToken
) - > None :
2024-07-02 10:58:49 -07:00
execution_result = f " The script ran, then exited with Unix exit code: { message . exit_code } \n Its output was: \n { message . output } "
# Store the code execution output.
2024-07-02 16:18:48 -07:00
self . _session_memory [ message . session_id ] . append (
UserMessage ( content = execution_result , source = " user " )
)
2024-07-02 10:58:49 -07:00
# Count the number of rounds so far
if self . _max_turns is not None :
2024-07-02 16:18:48 -07:00
n_turns = sum (
1
for message in self . _session_memory [ message . session_id ]
if isinstance ( message , AssistantMessage )
)
2024-07-02 10:58:49 -07:00
if n_turns > = self . _max_turns :
return
# Make an inference to the model.
2024-07-02 16:18:48 -07:00
response = await self . _model_client . create (
self . _system_messages + self . _session_memory [ message . session_id ]
)
2024-07-02 10:58:49 -07:00
assert isinstance ( response . content , str )
2024-07-02 16:18:48 -07:00
self . _session_memory [ message . session_id ] . append (
AssistantMessage ( content = response . content , source = self . metadata [ " name " ] )
)
2024-07-02 10:58:49 -07:00
if " TERMINATE " in response . content :
2024-07-02 16:18:48 -07:00
return
2024-07-02 10:58:49 -07:00
else :
2024-07-02 16:18:48 -07:00
await self . publish_message (
CodeExecutionRequestMessage (
execution_request = response . content , session_id = message . session_id
) ,
cancellation_token = cancellation_token ,
)
2024-07-02 10:58:49 -07:00
class Executor ( TypeRoutedAgent ) :
def __init__ ( self , description : str , executor : CodeExecutor ) - > None :
super ( ) . __init__ ( description )
self . _executor = executor
2024-07-02 16:18:48 -07:00
2024-07-02 10:58:49 -07:00
@message_handler
2024-07-02 16:18:48 -07:00
async def handle_code_execution (
self ,
message : CodeExecutionRequestMessage ,
cancellation_token : CancellationToken ,
) - > None :
2024-07-02 10:58:49 -07:00
# Extract code block from the message.
code = self . _extract_execution_request ( message . execution_request )
if code is not None :
execution_requests = [ CodeBlock ( code = code , language = " python " ) ]
2024-07-02 16:18:48 -07:00
future = asyncio . get_event_loop ( ) . run_in_executor (
None , self . _executor . execute_code_blocks , execution_requests
)
2024-07-02 10:58:49 -07:00
cancellation_token . link_future ( future )
result = await future
2024-07-02 16:18:48 -07:00
await self . publish_message (
CodeExecutionResultMessage (
output = result . output ,
exit_code = result . exit_code ,
session_id = message . session_id ,
)
)
2024-07-02 10:58:49 -07:00
else :
2024-07-02 16:18:48 -07:00
await self . publish_message (
CodeExecutionResultMessage (
output = " No code block detected. Please provide a markdown-encoded code block to execute. " ,
exit_code = 1 ,
session_id = message . session_id ,
)
)
2024-07-02 10:58:49 -07:00
def _extract_execution_request ( self , markdown_text : str ) - > Union [ str , None ] :
pattern = r " ```( \ w+) \ n(.*?) \ n``` "
# Search for the pattern in the markdown text
match = re . search ( pattern , markdown_text , re . DOTALL )
# Extract the language and code block if a match is found
if match :
return match . group ( 2 )
return None
2024-07-02 16:18:48 -07:00
2024-07-02 10:58:49 -07:00
async def main ( ) - > None :
# Create the runtime.
runtime = SingleThreadedAgentRuntime ( )
# Create the AzureOpenAI client, with AAD auth
2024-07-02 16:18:48 -07:00
# token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
2024-07-02 10:58:49 -07:00
client = AzureOpenAIChatCompletionClient (
api_version = " 2024-02-15-preview " ,
azure_endpoint = " https://aif-complex-tasks-west-us-3.openai.azure.com/ " ,
model = " gpt-4o-2024-05-13 " ,
2024-07-02 16:18:48 -07:00
model_capabilities = ModelCapabilities (
function_calling = True , json_output = True , vision = True
) ,
# azure_ad_token_provider=token_provider
2024-07-02 10:58:49 -07:00
)
# Register agents.
coder = runtime . register_and_get (
" Coder " ,
lambda : Coder ( model_client = client ) ,
)
runtime . register (
" Executor " ,
2024-07-02 16:18:48 -07:00
lambda : Executor (
" A agent for executing code " , executor = LocalCommandLineCodeExecutor ( )
) ,
2024-07-02 10:58:49 -07:00
)
prompt = " "
with open ( " prompt.txt " , " rt " ) as fh :
prompt = fh . read ( )
2024-07-02 16:18:48 -07:00
entry_point = " __ENTRY_POINT__ "
2024-07-02 10:58:49 -07:00
2024-07-02 16:18:48 -07:00
task = TaskMessage (
f """
2024-07-02 10:58:49 -07:00
The following python code imports the ` run_tests ` function from unit_tests . py , and runs
it on the function ` { entry_point } ` . This will run a set of automated unit tests to verify the
correct implementation of ` { entry_point } ` . However , ` { entry_point } ` is only partially
implemented in the code below . Complete the implementation of ` { entry_point } ` and then execute
a new stand - alone code block that contains everything needed to run the tests , including : importing
` unit_tests ` , calling ` run_tests ( { entry_point } ) ` , as well as { entry_point } ' s complete definition,
such that this code block can be run directly in Python .
` ` ` python
from unit_tests import run_tests
{ prompt }
# Run the unit tests
run_tests ( { entry_point } )
` ` `
2024-07-02 16:18:48 -07:00
""" .strip()
)
2024-07-02 10:58:49 -07:00
2024-07-02 16:18:48 -07:00
# Run the runtime until the task is completed.
run_context = runtime . start ( )
2024-07-02 10:58:49 -07:00
# Send a task to the tool user.
await runtime . send_message ( task , coder )
2024-07-02 16:18:48 -07:00
await run_context . stop_when_idle ( )
2024-07-02 10:58:49 -07:00
if __name__ == " __main__ " :
import logging
logging . basicConfig ( level = logging . WARNING )
logging . getLogger ( " agnext " ) . setLevel ( logging . DEBUG )
asyncio . run ( main ( ) )