2024-05-30 09:01:35 -07:00
|
|
|
import asyncio
|
|
|
|
import json
|
2024-06-05 15:48:14 -04:00
|
|
|
from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple
|
2024-05-28 23:18:28 -07:00
|
|
|
|
2024-06-10 21:43:06 -07:00
|
|
|
from tqdm.asyncio import tqdm
|
|
|
|
|
2024-06-08 16:29:27 -07:00
|
|
|
from ...components import (
|
2024-06-04 10:00:05 -04:00
|
|
|
FunctionCall,
|
2024-06-05 15:48:14 -04:00
|
|
|
TypeRoutedAgent,
|
|
|
|
message_handler,
|
2024-06-04 10:00:05 -04:00
|
|
|
)
|
2024-06-08 16:29:27 -07:00
|
|
|
from ...components.models import (
|
2024-06-05 15:51:40 -04:00
|
|
|
ChatCompletionClient,
|
|
|
|
FunctionExecutionResult,
|
|
|
|
FunctionExecutionResultMessage,
|
|
|
|
SystemMessage,
|
|
|
|
)
|
2024-06-08 16:29:27 -07:00
|
|
|
from ...components.tools import Tool
|
|
|
|
from ...core import AgentRuntime, CancellationToken
|
|
|
|
from ..memory import ChatMemory
|
|
|
|
from ..types import (
|
|
|
|
FunctionCallMessage,
|
|
|
|
Message,
|
2024-06-10 19:51:51 -07:00
|
|
|
PublishNow,
|
2024-06-08 16:29:27 -07:00
|
|
|
Reset,
|
|
|
|
RespondNow,
|
|
|
|
ResponseFormat,
|
|
|
|
TextMessage,
|
|
|
|
)
|
|
|
|
from ..utils import convert_messages_to_llm_messages
|
2024-05-28 23:18:28 -07:00
|
|
|
|
|
|
|
|
2024-06-09 12:11:36 -07:00
|
|
|
class ChatCompletionAgent(TypeRoutedAgent):
|
2024-06-11 00:46:52 -07:00
|
|
|
"""An agent implementation that uses the ChatCompletion API to gnenerate
|
|
|
|
responses and execute tools.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name (str): The name of the agent.
|
|
|
|
description (str): The description of the agent.
|
|
|
|
runtime (AgentRuntime): The runtime to register the agent.
|
|
|
|
system_messages (List[SystemMessage]): The system messages to use for
|
|
|
|
the ChatCompletion API.
|
|
|
|
memory (ChatMemory): The memory to store and retrieve messages.
|
|
|
|
model_client (ChatCompletionClient): The client to use for the
|
|
|
|
ChatCompletion API.
|
|
|
|
tools (Sequence[Tool], optional): The tools used by the agent. Defaults
|
|
|
|
to []. If no tools are provided, the agent cannot handle tool calls.
|
|
|
|
If tools are provided, and the response from the model is a list of
|
|
|
|
tool calls, the agent will call itselfs with the tool calls until it
|
|
|
|
gets a response that is not a list of tool calls, and then use that
|
|
|
|
response as the final response.
|
|
|
|
"""
|
|
|
|
|
2024-05-28 23:18:28 -07:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
description: str,
|
|
|
|
runtime: AgentRuntime,
|
|
|
|
system_messages: List[SystemMessage],
|
2024-06-08 16:29:27 -07:00
|
|
|
memory: ChatMemory,
|
2024-06-05 15:51:40 -04:00
|
|
|
model_client: ChatCompletionClient,
|
2024-06-05 15:48:14 -04:00
|
|
|
tools: Sequence[Tool] = [],
|
2024-05-28 23:18:28 -07:00
|
|
|
) -> None:
|
|
|
|
super().__init__(name, description, runtime)
|
2024-06-09 12:11:36 -07:00
|
|
|
self._description = description
|
2024-05-28 23:18:28 -07:00
|
|
|
self._system_messages = system_messages
|
|
|
|
self._client = model_client
|
2024-06-08 16:29:27 -07:00
|
|
|
self._memory = memory
|
2024-06-05 15:48:14 -04:00
|
|
|
self._tools = tools
|
2024-05-28 23:18:28 -07:00
|
|
|
|
2024-06-05 08:51:49 -04:00
|
|
|
@message_handler()
|
2024-05-28 23:18:28 -07:00
|
|
|
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
2024-06-11 00:46:52 -07:00
|
|
|
"""Handle a text message. This method adds the message to the memory and
|
|
|
|
does not generate any message."""
|
2024-05-28 23:18:28 -07:00
|
|
|
# Add a user message.
|
2024-06-11 00:46:52 -07:00
|
|
|
await self._memory.add_message(message)
|
2024-05-28 23:18:28 -07:00
|
|
|
|
2024-06-05 08:51:49 -04:00
|
|
|
@message_handler()
|
2024-05-28 23:18:28 -07:00
|
|
|
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
2024-06-11 00:46:52 -07:00
|
|
|
"""Handle a reset message. This method clears the memory."""
|
2024-05-28 23:18:28 -07:00
|
|
|
# Reset the chat messages.
|
2024-06-11 00:46:52 -07:00
|
|
|
await self._memory.clear()
|
2024-05-28 23:18:28 -07:00
|
|
|
|
2024-06-05 08:51:49 -04:00
|
|
|
@message_handler()
|
2024-05-30 09:01:35 -07:00
|
|
|
async def on_respond_now(
|
|
|
|
self, message: RespondNow, cancellation_token: CancellationToken
|
|
|
|
) -> TextMessage | FunctionCallMessage:
|
2024-06-11 00:46:52 -07:00
|
|
|
"""Handle a respond now message. This method generates a response and
|
|
|
|
returns it to the sender."""
|
2024-06-10 21:43:06 -07:00
|
|
|
# Generate a response.
|
|
|
|
with tqdm(desc=f"{self.name} is thinking...", bar_format="{desc}: {elapsed_s}") as pbar:
|
|
|
|
response = await self._generate_response(message.response_format, cancellation_token)
|
|
|
|
pbar.close()
|
|
|
|
|
2024-05-30 09:01:35 -07:00
|
|
|
# Return the response.
|
2024-06-10 21:43:06 -07:00
|
|
|
return response
|
2024-06-10 19:51:51 -07:00
|
|
|
|
|
|
|
@message_handler()
|
|
|
|
async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:
|
2024-06-11 00:46:52 -07:00
|
|
|
"""Handle a publish now message. This method generates a response and
|
|
|
|
publishes it."""
|
2024-06-10 19:51:51 -07:00
|
|
|
# Generate a response.
|
2024-06-10 21:43:06 -07:00
|
|
|
# TODO: refactor this to use message_handler decorator.
|
|
|
|
with tqdm(desc=f"{self.name} is thinking...", bar_format="{desc}: {elapsed_s}", leave=False) as pbar:
|
|
|
|
response = await self._generate_response(message.response_format, cancellation_token)
|
|
|
|
pbar.close()
|
|
|
|
|
2024-06-10 19:51:51 -07:00
|
|
|
# Publish the response.
|
|
|
|
await self._publish_message(response)
|
2024-05-30 09:01:35 -07:00
|
|
|
|
2024-06-05 08:51:49 -04:00
|
|
|
@message_handler()
|
2024-05-30 09:01:35 -07:00
|
|
|
async def on_tool_call_message(
|
|
|
|
self, message: FunctionCallMessage, cancellation_token: CancellationToken
|
|
|
|
) -> FunctionExecutionResultMessage:
|
2024-06-11 00:46:52 -07:00
|
|
|
"""Handle a tool call message. This method executes the tools and
|
|
|
|
returns the results."""
|
2024-06-05 15:48:14 -04:00
|
|
|
if len(self._tools) == 0:
|
|
|
|
raise ValueError("No tools available")
|
2024-05-30 09:01:35 -07:00
|
|
|
|
|
|
|
# Add a tool call message.
|
2024-06-11 00:46:52 -07:00
|
|
|
await self._memory.add_message(message)
|
2024-05-30 09:01:35 -07:00
|
|
|
|
|
|
|
# Execute the tool calls.
|
|
|
|
results: List[FunctionExecutionResult] = []
|
|
|
|
execution_futures: List[Coroutine[Any, Any, Tuple[str, str]]] = []
|
|
|
|
for function_call in message.content:
|
|
|
|
# Parse the arguments.
|
|
|
|
try:
|
|
|
|
arguments = json.loads(function_call.arguments)
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
results.append(
|
|
|
|
FunctionExecutionResult(
|
|
|
|
content=f"Error: Could not parse arguments for function {function_call.name}.",
|
|
|
|
call_id=function_call.id,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
continue
|
|
|
|
# Execute the function.
|
2024-06-10 19:51:51 -07:00
|
|
|
future = self._execute_function(
|
2024-06-07 13:33:51 -07:00
|
|
|
function_call.name,
|
|
|
|
arguments,
|
|
|
|
function_call.id,
|
|
|
|
cancellation_token=cancellation_token,
|
2024-06-05 15:48:14 -04:00
|
|
|
)
|
2024-05-30 09:01:35 -07:00
|
|
|
# Append the async result.
|
|
|
|
execution_futures.append(future)
|
|
|
|
if execution_futures:
|
|
|
|
# Wait for all async results.
|
|
|
|
execution_results = await asyncio.gather(*execution_futures)
|
|
|
|
# Add the results.
|
|
|
|
for execution_result, call_id in execution_results:
|
|
|
|
results.append(FunctionExecutionResult(content=execution_result, call_id=call_id))
|
|
|
|
|
|
|
|
# Create a tool call result message.
|
2024-06-04 11:13:13 -04:00
|
|
|
tool_call_result_msg = FunctionExecutionResultMessage(content=results)
|
2024-05-30 09:01:35 -07:00
|
|
|
|
|
|
|
# Add tool call result message.
|
2024-06-11 00:46:52 -07:00
|
|
|
await self._memory.add_message(tool_call_result_msg)
|
2024-05-30 09:01:35 -07:00
|
|
|
|
|
|
|
# Return the results.
|
|
|
|
return tool_call_result_msg
|
|
|
|
|
2024-06-10 19:51:51 -07:00
|
|
|
async def _generate_response(
|
|
|
|
self,
|
|
|
|
response_format: ResponseFormat,
|
|
|
|
cancellation_token: CancellationToken,
|
|
|
|
) -> TextMessage | FunctionCallMessage:
|
|
|
|
# Get a response from the model.
|
2024-06-11 00:46:52 -07:00
|
|
|
hisorical_messages = await self._memory.get_messages()
|
2024-06-10 19:51:51 -07:00
|
|
|
response = await self._client.create(
|
2024-06-11 00:46:52 -07:00
|
|
|
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.name),
|
2024-06-10 19:51:51 -07:00
|
|
|
tools=self._tools,
|
|
|
|
json_output=response_format == ResponseFormat.json_object,
|
|
|
|
)
|
|
|
|
|
|
|
|
# If the agent has function executor, and the response is a list of
|
|
|
|
# tool calls, iterate with itself until we get a response that is not a
|
|
|
|
# list of tool calls.
|
|
|
|
while (
|
|
|
|
len(self._tools) > 0
|
|
|
|
and isinstance(response.content, list)
|
|
|
|
and all(isinstance(x, FunctionCall) for x in response.content)
|
|
|
|
):
|
|
|
|
# Send a function call message to itself.
|
|
|
|
response = await self._send_message(
|
|
|
|
message=FunctionCallMessage(content=response.content, source=self.name),
|
|
|
|
recipient=self,
|
|
|
|
cancellation_token=cancellation_token,
|
|
|
|
)
|
|
|
|
# Make an assistant message from the response.
|
2024-06-11 00:46:52 -07:00
|
|
|
hisorical_messages = await self._memory.get_messages()
|
2024-06-10 19:51:51 -07:00
|
|
|
response = await self._client.create(
|
2024-06-11 00:46:52 -07:00
|
|
|
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.name),
|
2024-06-10 19:51:51 -07:00
|
|
|
tools=self._tools,
|
|
|
|
json_output=response_format == ResponseFormat.json_object,
|
|
|
|
)
|
|
|
|
|
|
|
|
final_response: Message
|
|
|
|
if isinstance(response.content, str):
|
|
|
|
# If the response is a string, return a text message.
|
|
|
|
final_response = TextMessage(content=response.content, source=self.name)
|
|
|
|
elif isinstance(response.content, list) and all(isinstance(x, FunctionCall) for x in response.content):
|
|
|
|
# If the response is a list of function calls, return a function call message.
|
|
|
|
final_response = FunctionCallMessage(content=response.content, source=self.name)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unexpected response: {response.content}")
|
|
|
|
|
|
|
|
# Add the response to the chat messages.
|
2024-06-11 00:46:52 -07:00
|
|
|
await self._memory.add_message(final_response)
|
2024-06-10 19:51:51 -07:00
|
|
|
|
|
|
|
return final_response
|
|
|
|
|
|
|
|
async def _execute_function(
|
2024-06-07 13:33:51 -07:00
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
args: Dict[str, Any],
|
|
|
|
call_id: str,
|
|
|
|
cancellation_token: CancellationToken,
|
2024-06-05 15:48:14 -04:00
|
|
|
) -> Tuple[str, str]:
|
|
|
|
# Find tool
|
|
|
|
tool = next((t for t in self._tools if t.name == name), None)
|
|
|
|
if tool is None:
|
2024-06-07 13:33:51 -07:00
|
|
|
return (f"Error: tool {name} not found.", call_id)
|
2024-05-30 09:01:35 -07:00
|
|
|
try:
|
2024-06-05 15:48:14 -04:00
|
|
|
result = await tool.run_json(args, cancellation_token)
|
2024-06-07 13:33:51 -07:00
|
|
|
result_as_str = tool.return_value_as_string(result)
|
2024-05-30 09:01:35 -07:00
|
|
|
except Exception as e:
|
2024-06-07 13:33:51 -07:00
|
|
|
result_as_str = f"Error: {str(e)}"
|
|
|
|
return (result_as_str, call_id)
|
2024-05-30 09:01:35 -07:00
|
|
|
|
2024-05-28 23:18:28 -07:00
|
|
|
def save_state(self) -> Mapping[str, Any]:
|
|
|
|
return {
|
|
|
|
"description": self.description,
|
2024-06-08 16:29:27 -07:00
|
|
|
"memory": self._memory.save_state(),
|
2024-05-28 23:18:28 -07:00
|
|
|
"system_messages": self._system_messages,
|
|
|
|
}
|
|
|
|
|
|
|
|
def load_state(self, state: Mapping[str, Any]) -> None:
|
2024-06-08 16:29:27 -07:00
|
|
|
self._memory.load_state(state["memory"])
|
2024-05-28 23:18:28 -07:00
|
|
|
self._system_messages = state["system_messages"]
|
|
|
|
self._description = state["description"]
|