2024-06-26 06:32:33 -07:00
|
|
|
"""
|
|
|
|
This example shows how to use custom function tools with a tool-enabled
|
|
|
|
agent.
|
|
|
|
"""
|
|
|
|
|
2024-06-24 15:05:47 -07:00
|
|
|
import asyncio
|
|
|
|
import os
|
|
|
|
import random
|
|
|
|
import sys
|
|
|
|
|
|
|
|
from agnext.application import SingleThreadedAgentRuntime
|
|
|
|
from agnext.components.models import (
|
|
|
|
SystemMessage,
|
|
|
|
)
|
|
|
|
from agnext.components.tools import FunctionTool
|
|
|
|
from typing_extensions import Annotated
|
|
|
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__))))
|
2024-06-28 23:15:46 -07:00
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
2024-06-24 15:05:47 -07:00
|
|
|
|
|
|
|
from coding_one_agent_direct import AIResponse, ToolEnabledAgent, UserRequest
|
2024-06-28 23:15:46 -07:00
|
|
|
from common.utils import get_chat_completion_client_from_envs
|
2024-06-24 15:05:47 -07:00
|
|
|
|
|
|
|
|
|
|
|
async def get_stock_price(ticker: str, date: Annotated[str, "The date in YYYY/MM/DD format."]) -> float:
|
|
|
|
"""Get the stock price of a company."""
|
|
|
|
# This is a placeholder function that returns a random number.
|
|
|
|
return random.uniform(10, 100)
|
|
|
|
|
|
|
|
|
|
|
|
async def main() -> None:
|
|
|
|
# Create the runtime.
|
|
|
|
runtime = SingleThreadedAgentRuntime()
|
|
|
|
# Register agents.
|
2024-07-23 11:49:38 -07:00
|
|
|
tool_agent = await runtime.register_and_get(
|
2024-06-24 15:05:47 -07:00
|
|
|
"tool_enabled_agent",
|
|
|
|
lambda: ToolEnabledAgent(
|
|
|
|
description="Tool Use Agent",
|
|
|
|
system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")],
|
2024-06-28 23:15:46 -07:00
|
|
|
model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"),
|
2024-06-24 15:05:47 -07:00
|
|
|
tools=[
|
|
|
|
# Define a tool that gets the stock price.
|
|
|
|
FunctionTool(
|
|
|
|
get_stock_price,
|
|
|
|
description="Get the stock price of a company given the ticker and date.",
|
|
|
|
name="get_stock_price",
|
|
|
|
)
|
|
|
|
],
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2024-07-01 11:53:45 -04:00
|
|
|
run_context = runtime.start()
|
|
|
|
|
2024-06-24 15:05:47 -07:00
|
|
|
# Send a task to the tool user.
|
2024-06-27 13:40:12 -04:00
|
|
|
result = await runtime.send_message(UserRequest("What is the stock price of NVDA on 2024/06/01"), tool_agent)
|
2024-06-24 15:05:47 -07:00
|
|
|
|
|
|
|
# Run the runtime until the task is completed.
|
2024-07-01 11:53:45 -04:00
|
|
|
await run_context.stop()
|
2024-06-24 15:05:47 -07:00
|
|
|
|
|
|
|
# Print the result.
|
|
|
|
ai_response = result.result()
|
|
|
|
assert isinstance(ai_response, AIResponse)
|
|
|
|
print(ai_response.content)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import logging
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.WARNING)
|
|
|
|
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
|
|
|
asyncio.run(main())
|