mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-14 12:31:28 +00:00
145 lines
5.1 KiB
Python
145 lines
5.1 KiB
Python
![]() |
import asyncio
|
||
|
import json
|
||
|
from dataclasses import dataclass
|
||
|
from typing import List
|
||
|
|
||
|
from agnext.application import SingleThreadedAgentRuntime
|
||
|
from agnext.components import FunctionCall, TypeRoutedAgent, message_handler
|
||
|
from agnext.components.code_executor import LocalCommandLineCodeExecutor
|
||
|
from agnext.components.models import (
|
||
|
AssistantMessage,
|
||
|
ChatCompletionClient,
|
||
|
FunctionExecutionResult,
|
||
|
FunctionExecutionResultMessage,
|
||
|
LLMMessage,
|
||
|
OpenAI,
|
||
|
SystemMessage,
|
||
|
UserMessage,
|
||
|
)
|
||
|
from agnext.components.tools import PythonCodeExecutionTool, Tool
|
||
|
from agnext.core import CancellationToken
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class ToolExecutionTask:
|
||
|
function_call: FunctionCall
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class ToolExecutionTaskResult:
|
||
|
result: FunctionExecutionResult
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class UserRequest:
|
||
|
content: str
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class AIResponse:
|
||
|
content: str
|
||
|
|
||
|
|
||
|
class ToolEnabledAgent(TypeRoutedAgent):
|
||
|
"""An agent that uses tools to perform tasks. It executes the tools
|
||
|
by itself by sending the tool execution task to itself."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
description: str,
|
||
|
system_messages: List[SystemMessage],
|
||
|
model_client: ChatCompletionClient,
|
||
|
tools: List[Tool],
|
||
|
) -> None:
|
||
|
super().__init__(description)
|
||
|
self._model_client = model_client
|
||
|
self._system_messages = system_messages
|
||
|
self._tools = tools
|
||
|
|
||
|
@message_handler
|
||
|
async def handle_user_message(self, message: UserRequest, cancellation_token: CancellationToken) -> AIResponse:
|
||
|
"""Handle a user message, execute the model and tools, and returns the response."""
|
||
|
session: List[LLMMessage] = []
|
||
|
session.append(UserMessage(content=message.content, source="User"))
|
||
|
response = await self._model_client.create(self._system_messages + session, tools=self._tools)
|
||
|
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
|
||
|
|
||
|
# Keep executing the tools until the response is not a list of function calls.
|
||
|
while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):
|
||
|
results = await asyncio.gather(
|
||
|
*[self.send_message(ToolExecutionTask(function_call=call), self.id) for call in response.content]
|
||
|
)
|
||
|
# Combine the results into a single response.
|
||
|
result = FunctionExecutionResultMessage(content=[result.result for result in results])
|
||
|
session.append(result)
|
||
|
# Execute the model again with the new response.
|
||
|
response = await self._model_client.create(self._system_messages + session, tools=self._tools)
|
||
|
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
|
||
|
|
||
|
assert isinstance(response.content, str)
|
||
|
return AIResponse(content=response.content)
|
||
|
|
||
|
@message_handler
|
||
|
async def handle_tool_call(
|
||
|
self, message: ToolExecutionTask, cancellation_token: CancellationToken
|
||
|
) -> ToolExecutionTaskResult:
|
||
|
"""Handle a tool execution task. This method executes the tool and publishes the result."""
|
||
|
# Find the tool
|
||
|
tool = next((tool for tool in self._tools if tool.name == message.function_call.name), None)
|
||
|
if tool is None:
|
||
|
result_as_str = f"Error: Tool not found: {message.function_call.name}"
|
||
|
else:
|
||
|
try:
|
||
|
arguments = json.loads(message.function_call.arguments)
|
||
|
result = await tool.run_json(args=arguments, cancellation_token=cancellation_token)
|
||
|
result_as_str = tool.return_value_as_string(result)
|
||
|
except json.JSONDecodeError:
|
||
|
result_as_str = f"Error: Invalid arguments: {message.function_call.arguments}"
|
||
|
except Exception as e:
|
||
|
result_as_str = f"Error: {e}"
|
||
|
return ToolExecutionTaskResult(
|
||
|
result=FunctionExecutionResult(content=result_as_str, call_id=message.function_call.id),
|
||
|
)
|
||
|
|
||
|
|
||
|
async def main() -> None:
|
||
|
# Create the runtime.
|
||
|
runtime = SingleThreadedAgentRuntime()
|
||
|
# Define the tools.
|
||
|
tools: List[Tool] = [
|
||
|
# A tool that executes Python code.
|
||
|
PythonCodeExecutionTool(
|
||
|
LocalCommandLineCodeExecutor(),
|
||
|
)
|
||
|
]
|
||
|
# Register agents.
|
||
|
tool_agent = runtime.register_and_get(
|
||
|
"tool_enabled_agent",
|
||
|
lambda: ToolEnabledAgent(
|
||
|
description="Tool Use Agent",
|
||
|
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
|
||
|
model_client=OpenAI(model="gpt-3.5-turbo"),
|
||
|
tools=tools,
|
||
|
),
|
||
|
)
|
||
|
|
||
|
# Send a task to the tool user.
|
||
|
result = runtime.send_message(UserRequest("Run the following Python code: print('Hello, World!')"), tool_agent)
|
||
|
|
||
|
# Run the runtime until the task is completed.
|
||
|
while not result.done():
|
||
|
await runtime.process_next()
|
||
|
|
||
|
# Print the result.
|
||
|
ai_response = result.result()
|
||
|
assert isinstance(ai_response, AIResponse)
|
||
|
print(ai_response.content)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import logging
|
||
|
|
||
|
logging.basicConfig(level=logging.WARNING)
|
||
|
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||
|
asyncio.run(main())
|