autogen/python/samples/common/agents/_chat_completion_agent.py
Jack Gerrits 853b00b0f0 Add message context to message handler (#367)
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
2024-08-17 03:14:09 +00:00

262 lines
10 KiB
Python

import asyncio
import json
from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple
from agnext.components import (
FunctionCall,
TypeRoutedAgent,
message_handler,
)
from agnext.components.memory import ChatMemory
from agnext.components.models import (
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
SystemMessage,
)
from agnext.components.tools import Tool
from agnext.core import AgentId, CancellationToken, MessageContext
from ..types import (
FunctionCallMessage,
Message,
MultiModalMessage,
PublishNow,
Reset,
RespondNow,
ResponseFormat,
TextMessage,
ToolApprovalRequest,
ToolApprovalResponse,
)
from ..utils import convert_messages_to_llm_messages
class ChatCompletionAgent(TypeRoutedAgent):
"""An agent implementation that uses the ChatCompletion API to gnenerate
responses and execute tools.
Args:
description (str): The description of the agent.
system_messages (List[SystemMessage]): The system messages to use for
the ChatCompletion API.
memory (ChatMemory[Message]): The memory to store and retrieve messages.
model_client (ChatCompletionClient): The client to use for the
ChatCompletion API.
tools (Sequence[Tool], optional): The tools used by the agent. Defaults
to []. If no tools are provided, the agent cannot handle tool calls.
If tools are provided, and the response from the model is a list of
tool calls, the agent will call itselfs with the tool calls until it
gets a response that is not a list of tool calls, and then use that
response as the final response.
tool_approver (Agent | None, optional): The agent that approves tool
calls. Defaults to None. If no tool approver is provided, the agent
will execute the tools without approval. If a tool approver is
provided, the agent will send a request to the tool approver before
executing the tools.
"""
def __init__(
self,
description: str,
system_messages: List[SystemMessage],
memory: ChatMemory[Message],
model_client: ChatCompletionClient,
tools: Sequence[Tool] = [],
tool_approver: AgentId | None = None,
) -> None:
super().__init__(description)
self._description = description
self._system_messages = system_messages
self._client = model_client
self._memory = memory
self._tools = tools
self._tool_approver = tool_approver
@message_handler()
async def on_text_message(self, message: TextMessage, ctx: MessageContext) -> None:
"""Handle a text message. This method adds the message to the memory and
does not generate any message."""
# Add a user message.
await self._memory.add_message(message)
@message_handler()
async def on_multi_modal_message(self, message: MultiModalMessage, ctx: MessageContext) -> None:
"""Handle a multimodal message. This method adds the message to the memory
and does not generate any message."""
# Add a user message.
await self._memory.add_message(message)
@message_handler()
async def on_reset(self, message: Reset, ctx: MessageContext) -> None:
"""Handle a reset message. This method clears the memory."""
# Reset the chat messages.
await self._memory.clear()
@message_handler()
async def on_respond_now(self, message: RespondNow, ctx: MessageContext) -> TextMessage | FunctionCallMessage:
"""Handle a respond now message. This method generates a response and
returns it to the sender."""
# Generate a response.
response = await self._generate_response(message.response_format, ctx)
# Return the response.
return response
@message_handler()
async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None:
"""Handle a publish now message. This method generates a response and
publishes it."""
# Generate a response.
response = await self._generate_response(message.response_format, ctx)
# Publish the response.
await self.publish_message(response)
@message_handler()
async def on_tool_call_message(
self, message: FunctionCallMessage, ctx: MessageContext
) -> FunctionExecutionResultMessage:
"""Handle a tool call message. This method executes the tools and
returns the results."""
if len(self._tools) == 0:
raise ValueError("No tools available")
# Add a tool call message.
await self._memory.add_message(message)
# Execute the tool calls.
results: List[FunctionExecutionResult] = []
execution_futures: List[Coroutine[Any, Any, Tuple[str, str]]] = []
for function_call in message.content:
# Parse the arguments.
try:
arguments = json.loads(function_call.arguments)
except json.JSONDecodeError:
results.append(
FunctionExecutionResult(
content=f"Error: Could not parse arguments for function {function_call.name}.",
call_id=function_call.id,
)
)
continue
# Execute the function.
future = self._execute_function(
function_call.name,
arguments,
function_call.id,
cancellation_token=ctx.cancellation_token,
)
# Append the async result.
execution_futures.append(future)
if execution_futures:
# Wait for all async results.
execution_results = await asyncio.gather(*execution_futures)
# Add the results.
for execution_result, call_id in execution_results:
results.append(FunctionExecutionResult(content=execution_result, call_id=call_id))
# Create a tool call result message.
tool_call_result_msg = FunctionExecutionResultMessage(content=results)
# Add tool call result message.
await self._memory.add_message(tool_call_result_msg)
# Return the results.
return tool_call_result_msg
async def _generate_response(
self,
response_format: ResponseFormat,
ctx: MessageContext,
) -> TextMessage | FunctionCallMessage:
# Get a response from the model.
hisorical_messages = await self._memory.get_messages()
response = await self._client.create(
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["type"]),
tools=self._tools,
json_output=response_format == ResponseFormat.json_object,
)
# If the agent has function executor, and the response is a list of
# tool calls, iterate with itself until we get a response that is not a
# list of tool calls.
while (
len(self._tools) > 0
and isinstance(response.content, list)
and all(isinstance(x, FunctionCall) for x in response.content)
):
# Send a function call message to itself.
response = await self.send_message(
message=FunctionCallMessage(content=response.content, source=self.metadata["type"]),
recipient=self.id,
cancellation_token=ctx.cancellation_token,
)
# Make an assistant message from the response.
hisorical_messages = await self._memory.get_messages()
response = await self._client.create(
self._system_messages + convert_messages_to_llm_messages(hisorical_messages, self.metadata["type"]),
tools=self._tools,
json_output=response_format == ResponseFormat.json_object,
)
final_response: Message
if isinstance(response.content, str):
# If the response is a string, return a text message.
final_response = TextMessage(content=response.content, source=self.metadata["type"])
elif isinstance(response.content, list) and all(isinstance(x, FunctionCall) for x in response.content):
# If the response is a list of function calls, return a function call message.
final_response = FunctionCallMessage(content=response.content, source=self.metadata["type"])
else:
raise ValueError(f"Unexpected response: {response.content}")
# Add the response to the chat messages.
await self._memory.add_message(final_response)
return final_response
async def _execute_function(
self,
name: str,
args: Dict[str, Any],
call_id: str,
cancellation_token: CancellationToken,
) -> Tuple[str, str]:
# Find tool
tool = next((t for t in self._tools if t.name == name), None)
if tool is None:
return (f"Error: tool {name} not found.", call_id)
# Check if the tool needs approval
if self._tool_approver is not None:
# Send a tool approval request.
approval_request = ToolApprovalRequest(
tool_call=FunctionCall(id=call_id, arguments=json.dumps(args), name=name)
)
approval_response = await self.send_message(
message=approval_request,
recipient=self._tool_approver,
cancellation_token=cancellation_token,
)
if not isinstance(approval_response, ToolApprovalResponse):
raise ValueError(f"Expecting {ToolApprovalResponse.__name__}, received: {type(approval_response)}")
if not approval_response.approved:
return (f"Error: tool {name} approved, reason: {approval_response.reason}", call_id)
try:
result = await tool.run_json(args, cancellation_token)
result_as_str = tool.return_value_as_string(result)
except Exception as e:
result_as_str = f"Error: {str(e)}"
return (result_as_str, call_id)
def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._memory.save_state(),
"system_messages": self._system_messages,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._memory.load_state(state["memory"])
self._system_messages = state["system_messages"]