2024-05-30 09:01:35 -07:00
|
|
|
import asyncio
|
|
|
|
import json
|
2024-06-05 15:48:14 -04:00
|
|
|
from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple
|
2024-05-28 23:18:28 -07:00
|
|
|
|
|
|
|
from agnext.chat.agents.base import BaseChatAgent
|
2024-05-30 09:01:35 -07:00
|
|
|
from agnext.chat.types import (
|
|
|
|
FunctionCallMessage,
|
|
|
|
Message,
|
|
|
|
Reset,
|
|
|
|
RespondNow,
|
|
|
|
ResponseFormat,
|
|
|
|
TextMessage,
|
|
|
|
)
|
2024-05-28 23:18:28 -07:00
|
|
|
from agnext.chat.utils import convert_messages_to_llm_messages
|
2024-06-05 15:48:14 -04:00
|
|
|
from agnext.components import (
|
2024-06-04 10:00:05 -04:00
|
|
|
FunctionCall,
|
2024-06-05 15:48:14 -04:00
|
|
|
TypeRoutedAgent,
|
|
|
|
message_handler,
|
2024-06-04 10:00:05 -04:00
|
|
|
)
|
2024-06-05 15:51:40 -04:00
|
|
|
from agnext.components.models import (
|
|
|
|
ChatCompletionClient,
|
|
|
|
FunctionExecutionResult,
|
|
|
|
FunctionExecutionResultMessage,
|
|
|
|
SystemMessage,
|
|
|
|
)
|
2024-06-05 15:48:14 -04:00
|
|
|
from agnext.components.tools import Tool
|
2024-05-28 23:18:28 -07:00
|
|
|
from agnext.core import AgentRuntime, CancellationToken
|
|
|
|
|
|
|
|
|
|
|
|
class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
description: str,
|
|
|
|
runtime: AgentRuntime,
|
|
|
|
system_messages: List[SystemMessage],
|
2024-06-05 15:51:40 -04:00
|
|
|
model_client: ChatCompletionClient,
|
2024-06-05 15:48:14 -04:00
|
|
|
tools: Sequence[Tool] = [],
|
2024-05-28 23:18:28 -07:00
|
|
|
) -> None:
|
|
|
|
super().__init__(name, description, runtime)
|
|
|
|
self._system_messages = system_messages
|
|
|
|
self._client = model_client
|
|
|
|
self._chat_messages: List[Message] = []
|
2024-06-05 15:48:14 -04:00
|
|
|
self._tools = tools
|
2024-05-28 23:18:28 -07:00
|
|
|
|
2024-06-05 08:51:49 -04:00
|
|
|
@message_handler()
|
2024-05-28 23:18:28 -07:00
|
|
|
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
|
|
|
# Add a user message.
|
|
|
|
self._chat_messages.append(message)
|
|
|
|
|
2024-06-05 08:51:49 -04:00
|
|
|
@message_handler()
|
2024-05-28 23:18:28 -07:00
|
|
|
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
|
|
|
# Reset the chat messages.
|
|
|
|
self._chat_messages = []
|
|
|
|
|
2024-06-05 08:51:49 -04:00
|
|
|
@message_handler()
|
2024-05-30 09:01:35 -07:00
|
|
|
async def on_respond_now(
|
|
|
|
self, message: RespondNow, cancellation_token: CancellationToken
|
|
|
|
) -> TextMessage | FunctionCallMessage:
|
|
|
|
# Get a response from the model.
|
2024-05-28 23:18:28 -07:00
|
|
|
response = await self._client.create(
|
|
|
|
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
|
2024-06-05 15:48:14 -04:00
|
|
|
tools=self._tools,
|
2024-05-30 09:01:35 -07:00
|
|
|
json_output=message.response_format == ResponseFormat.json_object,
|
2024-05-28 23:18:28 -07:00
|
|
|
)
|
2024-05-30 09:01:35 -07:00
|
|
|
|
|
|
|
# If the agent has function executor, and the response is a list of
|
|
|
|
# tool calls, iterate with itself until we get a response that is not a
|
|
|
|
# list of tool calls.
|
|
|
|
while (
|
2024-06-05 15:48:14 -04:00
|
|
|
len(self._tools) > 0
|
2024-05-30 09:01:35 -07:00
|
|
|
and isinstance(response.content, list)
|
|
|
|
and all(isinstance(x, FunctionCall) for x in response.content)
|
|
|
|
):
|
|
|
|
# Send a function call message to itself.
|
|
|
|
response = await self._send_message(
|
|
|
|
message=FunctionCallMessage(content=response.content, source=self.name),
|
|
|
|
recipient=self,
|
|
|
|
cancellation_token=cancellation_token,
|
|
|
|
)
|
|
|
|
# Make an assistant message from the response.
|
|
|
|
response = await self._client.create(
|
|
|
|
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
|
2024-06-05 15:48:14 -04:00
|
|
|
tools=self._tools,
|
2024-05-30 09:01:35 -07:00
|
|
|
json_output=message.response_format == ResponseFormat.json_object,
|
|
|
|
)
|
|
|
|
|
|
|
|
final_response: Message
|
2024-05-28 23:18:28 -07:00
|
|
|
if isinstance(response.content, str):
|
2024-05-30 09:01:35 -07:00
|
|
|
# If the response is a string, return a text message.
|
|
|
|
final_response = TextMessage(content=response.content, source=self.name)
|
|
|
|
elif isinstance(response.content, list) and all(isinstance(x, FunctionCall) for x in response.content):
|
|
|
|
# If the response is a list of function calls, return a function call message.
|
|
|
|
final_response = FunctionCallMessage(content=response.content, source=self.name)
|
2024-05-28 23:18:28 -07:00
|
|
|
else:
|
|
|
|
raise ValueError(f"Unexpected response: {response.content}")
|
|
|
|
|
2024-05-30 09:01:35 -07:00
|
|
|
# Add the response to the chat messages.
|
|
|
|
self._chat_messages.append(final_response)
|
|
|
|
|
|
|
|
# Return the response.
|
|
|
|
return final_response
|
|
|
|
|
2024-06-05 08:51:49 -04:00
|
|
|
@message_handler()
|
2024-05-30 09:01:35 -07:00
|
|
|
async def on_tool_call_message(
|
|
|
|
self, message: FunctionCallMessage, cancellation_token: CancellationToken
|
|
|
|
) -> FunctionExecutionResultMessage:
|
2024-06-05 15:48:14 -04:00
|
|
|
if len(self._tools) == 0:
|
|
|
|
raise ValueError("No tools available")
|
2024-05-30 09:01:35 -07:00
|
|
|
|
|
|
|
# Add a tool call message.
|
|
|
|
self._chat_messages.append(message)
|
|
|
|
|
|
|
|
# Execute the tool calls.
|
|
|
|
results: List[FunctionExecutionResult] = []
|
|
|
|
execution_futures: List[Coroutine[Any, Any, Tuple[str, str]]] = []
|
|
|
|
for function_call in message.content:
|
|
|
|
# Parse the arguments.
|
|
|
|
try:
|
|
|
|
arguments = json.loads(function_call.arguments)
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
results.append(
|
|
|
|
FunctionExecutionResult(
|
|
|
|
content=f"Error: Could not parse arguments for function {function_call.name}.",
|
|
|
|
call_id=function_call.id,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
continue
|
|
|
|
# Execute the function.
|
2024-06-05 15:48:14 -04:00
|
|
|
future = self.execute_function(
|
2024-06-07 13:33:51 -07:00
|
|
|
function_call.name,
|
|
|
|
arguments,
|
|
|
|
function_call.id,
|
|
|
|
cancellation_token=cancellation_token,
|
2024-06-05 15:48:14 -04:00
|
|
|
)
|
2024-05-30 09:01:35 -07:00
|
|
|
# Append the async result.
|
|
|
|
execution_futures.append(future)
|
|
|
|
if execution_futures:
|
|
|
|
# Wait for all async results.
|
|
|
|
execution_results = await asyncio.gather(*execution_futures)
|
|
|
|
# Add the results.
|
|
|
|
for execution_result, call_id in execution_results:
|
|
|
|
results.append(FunctionExecutionResult(content=execution_result, call_id=call_id))
|
|
|
|
|
|
|
|
# Create a tool call result message.
|
2024-06-04 11:13:13 -04:00
|
|
|
tool_call_result_msg = FunctionExecutionResultMessage(content=results)
|
2024-05-30 09:01:35 -07:00
|
|
|
|
|
|
|
# Add tool call result message.
|
|
|
|
self._chat_messages.append(tool_call_result_msg)
|
|
|
|
|
|
|
|
# Return the results.
|
|
|
|
return tool_call_result_msg
|
|
|
|
|
2024-06-05 15:48:14 -04:00
|
|
|
async def execute_function(
|
2024-06-07 13:33:51 -07:00
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
args: Dict[str, Any],
|
|
|
|
call_id: str,
|
|
|
|
cancellation_token: CancellationToken,
|
2024-06-05 15:48:14 -04:00
|
|
|
) -> Tuple[str, str]:
|
|
|
|
# Find tool
|
|
|
|
tool = next((t for t in self._tools if t.name == name), None)
|
|
|
|
if tool is None:
|
2024-06-07 13:33:51 -07:00
|
|
|
return (f"Error: tool {name} not found.", call_id)
|
2024-05-30 09:01:35 -07:00
|
|
|
try:
|
2024-06-05 15:48:14 -04:00
|
|
|
result = await tool.run_json(args, cancellation_token)
|
2024-06-07 13:33:51 -07:00
|
|
|
result_as_str = tool.return_value_as_string(result)
|
2024-05-30 09:01:35 -07:00
|
|
|
except Exception as e:
|
2024-06-07 13:33:51 -07:00
|
|
|
result_as_str = f"Error: {str(e)}"
|
|
|
|
return (result_as_str, call_id)
|
2024-05-30 09:01:35 -07:00
|
|
|
|
2024-05-28 23:18:28 -07:00
|
|
|
def save_state(self) -> Mapping[str, Any]:
|
|
|
|
return {
|
|
|
|
"description": self.description,
|
|
|
|
"chat_messages": self._chat_messages,
|
|
|
|
"system_messages": self._system_messages,
|
|
|
|
}
|
|
|
|
|
|
|
|
def load_state(self, state: Mapping[str, Any]) -> None:
|
|
|
|
self._chat_messages = state["chat_messages"]
|
|
|
|
self._system_messages = state["system_messages"]
|
|
|
|
self._description = state["description"]
|