2024-06-26 06:32:33 -07:00
|
|
|
"""
|
|
|
|
This example implements a tool-enabled agent that uses tools to perform tasks.
|
|
|
|
1. The agent receives a user message, and makes an inference using a model.
|
|
|
|
If the response is a list of function calls, the agent executes the tools by
|
|
|
|
sending tool execution task to itself.
|
|
|
|
2. The agent executes the tools and sends the results back to itself, and
|
|
|
|
makes an inference using the model again.
|
|
|
|
3. The agent keeps executing the tools until the inference response is not a
|
|
|
|
list of function calls.
|
|
|
|
4. The agent returns the final response to the user.
|
|
|
|
"""
|
|
|
|
|
2024-06-24 15:05:47 -07:00
|
|
|
import asyncio
|
2024-06-28 23:15:46 -07:00
|
|
|
import os
|
|
|
|
import sys
|
2024-06-24 15:05:47 -07:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from agnext.application import SingleThreadedAgentRuntime
|
2024-07-25 11:53:59 -07:00
|
|
|
from agnext.components import FunctionCall, message_handler
|
2024-06-24 15:05:47 -07:00
|
|
|
from agnext.components.code_executor import LocalCommandLineCodeExecutor
|
|
|
|
from agnext.components.models import (
|
|
|
|
AssistantMessage,
|
|
|
|
ChatCompletionClient,
|
|
|
|
FunctionExecutionResult,
|
|
|
|
FunctionExecutionResultMessage,
|
|
|
|
LLMMessage,
|
|
|
|
SystemMessage,
|
|
|
|
UserMessage,
|
|
|
|
)
|
2024-07-25 11:53:59 -07:00
|
|
|
from agnext.components.tool_agent import ToolAgent, ToolException
|
2024-06-24 15:05:47 -07:00
|
|
|
from agnext.components.tools import PythonCodeExecutionTool, Tool
|
|
|
|
from agnext.core import CancellationToken
|
|
|
|
|
2024-06-28 23:15:46 -07:00
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
|
|
|
|
|
|
from common.utils import get_chat_completion_client_from_envs
|
|
|
|
|
2024-06-24 15:05:47 -07:00
|
|
|
|
|
|
|
@dataclass
|
2024-07-23 17:08:56 -07:00
|
|
|
class Message:
|
2024-06-24 15:05:47 -07:00
|
|
|
content: str
|
|
|
|
|
|
|
|
|
2024-07-25 11:53:59 -07:00
|
|
|
class ToolEnabledAgent(ToolAgent):
|
2024-06-24 15:05:47 -07:00
|
|
|
"""An agent that uses tools to perform tasks. It executes the tools
|
|
|
|
by itself by sending the tool execution task to itself."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
description: str,
|
|
|
|
system_messages: List[SystemMessage],
|
|
|
|
model_client: ChatCompletionClient,
|
|
|
|
tools: List[Tool],
|
|
|
|
) -> None:
|
2024-07-25 11:53:59 -07:00
|
|
|
super().__init__(description, tools)
|
2024-06-24 15:05:47 -07:00
|
|
|
self._model_client = model_client
|
|
|
|
self._system_messages = system_messages
|
|
|
|
|
|
|
|
@message_handler
|
2024-07-23 17:08:56 -07:00
|
|
|
async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:
|
2024-06-24 15:05:47 -07:00
|
|
|
"""Handle a user message, execute the model and tools, and returns the response."""
|
|
|
|
session: List[LLMMessage] = []
|
|
|
|
session.append(UserMessage(content=message.content, source="User"))
|
2024-07-25 11:53:59 -07:00
|
|
|
response = await self._model_client.create(self._system_messages + session, tools=self.tools)
|
2024-06-24 15:05:47 -07:00
|
|
|
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
|
|
|
|
|
|
|
|
# Keep executing the tools until the response is not a list of function calls.
|
|
|
|
while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):
|
2024-07-23 17:58:15 -07:00
|
|
|
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
|
|
|
|
*[self.send_message(call, self.id) for call in response.content],
|
|
|
|
return_exceptions=True,
|
2024-06-24 15:05:47 -07:00
|
|
|
)
|
2024-07-23 17:58:15 -07:00
|
|
|
# Combine the results into a single response and handle exceptions.
|
|
|
|
function_results: List[FunctionExecutionResult] = []
|
|
|
|
for result in results:
|
|
|
|
if isinstance(result, FunctionExecutionResult):
|
|
|
|
function_results.append(result)
|
2024-07-25 11:53:59 -07:00
|
|
|
elif isinstance(result, ToolException):
|
2024-07-23 17:58:15 -07:00
|
|
|
function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id))
|
|
|
|
elif isinstance(result, BaseException):
|
|
|
|
raise result
|
|
|
|
session.append(FunctionExecutionResultMessage(content=function_results))
|
2024-06-24 15:05:47 -07:00
|
|
|
# Execute the model again with the new response.
|
2024-07-25 11:53:59 -07:00
|
|
|
response = await self._model_client.create(self._system_messages + session, tools=self.tools)
|
2024-06-24 15:05:47 -07:00
|
|
|
session.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
|
|
|
|
|
|
|
|
assert isinstance(response.content, str)
|
2024-07-23 17:08:56 -07:00
|
|
|
return Message(content=response.content)
|
2024-06-24 15:05:47 -07:00
|
|
|
|
|
|
|
|
|
|
|
async def main() -> None:
|
|
|
|
# Create the runtime.
|
|
|
|
runtime = SingleThreadedAgentRuntime()
|
|
|
|
# Define the tools.
|
|
|
|
tools: List[Tool] = [
|
|
|
|
# A tool that executes Python code.
|
|
|
|
PythonCodeExecutionTool(
|
|
|
|
LocalCommandLineCodeExecutor(),
|
|
|
|
)
|
|
|
|
]
|
|
|
|
# Register agents.
|
2024-07-23 11:49:38 -07:00
|
|
|
tool_agent = await runtime.register_and_get(
|
2024-06-24 15:05:47 -07:00
|
|
|
"tool_enabled_agent",
|
|
|
|
lambda: ToolEnabledAgent(
|
|
|
|
description="Tool Use Agent",
|
|
|
|
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
|
2024-07-23 18:05:16 -07:00
|
|
|
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
2024-06-24 15:05:47 -07:00
|
|
|
tools=tools,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2024-07-01 11:53:45 -04:00
|
|
|
run_context = runtime.start()
|
|
|
|
|
2024-06-24 15:05:47 -07:00
|
|
|
# Send a task to the tool user.
|
2024-07-23 17:08:56 -07:00
|
|
|
response = await runtime.send_message(Message("Run the following Python code: print('Hello, World!')"), tool_agent)
|
|
|
|
print(response.content)
|
2024-06-24 15:05:47 -07:00
|
|
|
|
|
|
|
# Run the runtime until the task is completed.
|
2024-07-01 11:53:45 -04:00
|
|
|
await run_context.stop()
|
2024-06-24 15:05:47 -07:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import logging
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.WARNING)
|
|
|
|
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
|
|
|
asyncio.run(main())
|