Add examples to showcase patterns (#55)

* add chess example

* wip

* wip

* fix tool schema generation

* fixes

* Agent handle exception

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>

* format

* mypy

* fix test for annotated

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Eric Zhu 2024-06-07 13:33:51 -07:00 committed by GitHub
parent c6360feeb6
commit b4ade8b735
12 changed files with 420 additions and 52 deletions

119
examples/chess_game.py Normal file
View File

@ -0,0 +1,119 @@
"""This is an example of simulating a chess game with two agents
that play against each other, using tools to reason about the game state
and make moves."""
import argparse
import asyncio
import logging
from typing import Annotated
from agnext.application import SingleThreadedAgentRuntime
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
from agnext.chat.types import TextMessage
from agnext.components.models import OpenAI, SystemMessage
from agnext.components.tools import FunctionTool
from agnext.core import AgentRuntime
from chess import SQUARE_NAMES, Board, Move
from chess import piece_name as get_piece_name
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
class ChessGameOutput(GroupChatOutput): # type: ignore
def on_message_received(self, message: TextMessage) -> None: # type: ignore
pass
def get_output(self) -> None:
pass
def reset(self) -> None:
pass
def chess_game(runtime: AgentRuntime) -> GroupChat: # type: ignore
"""Create agents for a chess game and return the group chat."""
# Create the board.
board = Board()
# Create shared tools.
def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format."]:
return "Possible moves are: " + ", ".join([str(move) for move in board.legal_moves])
get_legal_moves_tool = FunctionTool(get_legal_moves, description="Get legal moves.")
def make_move(input: Annotated[str, "A move in UCI format."]) -> Annotated[str, "Result of the move."]:
move = Move.from_uci(input)
board.push(move)
print(board.unicode(borders=True))
# Get the piece name.
piece = board.piece_at(move.to_square)
assert piece is not None
piece_symbol = piece.unicode_symbol()
piece_name = get_piece_name(piece.piece_type)
if piece_symbol.isupper():
piece_name = piece_name.capitalize()
return f"Moved {piece_name} ({piece_symbol}) from {SQUARE_NAMES[move.from_square]} to {SQUARE_NAMES[move.to_square]}."
make_move_tool = FunctionTool(make_move, description="Call this tool to make a move.")
tools = [get_legal_moves_tool, make_move_tool]
black = ChatCompletionAgent(
name="PlayerBlack",
description="Player playing black.",
runtime=runtime,
system_messages=[
SystemMessage(
content="You are a chess player and you play as black. "
"First call get_legal_moves() first, to get list of legal moves. "
"Then call make_move(move) to make a move."
),
],
model_client=OpenAI(model="gpt-4-turbo"),
tools=tools,
)
white = ChatCompletionAgent(
name="PlayerWhite",
description="Player playing white.",
runtime=runtime,
system_messages=[
SystemMessage(
content="You are a chess player and you play as white. "
"First call get_legal_moves() first, to get list of legal moves. "
"Then call make_move(move) to make a move."
),
],
model_client=OpenAI(model="gpt-4-turbo"),
tools=tools,
)
game_chat = GroupChat(
name="ChessGame",
description="A chess game between two agents.",
runtime=runtime,
agents=[white, black],
num_rounds=10,
output=ChessGameOutput(),
)
return game_chat
async def main(message: str) -> None:
runtime = SingleThreadedAgentRuntime()
game_chat = chess_game(runtime)
future = runtime.send_message(TextMessage(content=message, source="Human"), game_chat)
while not future.done():
await runtime.process_next()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a chess game between two agents.")
parser.add_argument(
"--initial-message",
default="Please make a move.",
help="The initial message to send to the agent playing white.",
)
args = parser.parse_args()
asyncio.run(main(args.initial_message))

View File

@ -31,6 +31,9 @@ dev = [
"pytest-xdist",
"types-Pillow",
"polars",
# Dependencies for the examples.
"chess",
"tavily-python",
]
docs = ["sphinx", "furo", "sphinxcontrib-apidoc", "myst-parser"]

View File

@ -10,6 +10,7 @@ from ..core.exceptions import MessageDroppedException
from ..core.intervention import DropMessage, InterventionHandler
logger = logging.getLogger("agnext")
event_logger = logging.getLogger("agnext.events")
@dataclass(kw_only=True)
@ -67,7 +68,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
return list(self._agents)
@property
def unprocessed_messages(self) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
def unprocessed_messages(
self,
) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
return self._message_queue
# Returns the response of the message
@ -82,6 +85,18 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if cancellation_token is None:
cancellation_token = CancellationToken()
logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {message.__dict__}")
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.SEND,
# )
# )
future = asyncio.get_event_loop().create_future()
if recipient not in self._agents:
future.set_exception(Exception("Recipient not found"))
@ -108,6 +123,18 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if cancellation_token is None:
cancellation_token = CancellationToken()
logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {message.__dict__}")
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=None,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.SEND,
# )
# )
self._message_queue.append(
PublishMessageEnvelope(
message=message,
@ -137,8 +164,17 @@ class SingleThreadedAgentRuntime(AgentRuntime):
try:
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
logger.info(
f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} from {sender_name}"
f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
response = await recipient.on_message(
message_envelope.message,
cancellation_token=message_envelope.cancellation_token,
@ -162,9 +198,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if message_envelope.sender is not None and agent.name == message_envelope.sender.name:
continue
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
logger.info(
f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {message_envelope.sender.name if message_envelope.sender is not None else 'Unknown'}"
f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {sender_name}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=agent,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
future = agent.on_message(
message_envelope.message,
@ -182,9 +228,23 @@ class SingleThreadedAgentRuntime(AgentRuntime):
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
recipient_name = message_envelope.recipient.name if message_envelope.recipient is not None else "Unknown"
logger.info(
f"Resolving response for recipient {recipient_name} from {message_envelope.sender.name} with message type {type(message_envelope.message).__name__}"
content = (
message_envelope.message.__dict__
if hasattr(message_envelope.message, "__dict__")
else message_envelope.message
)
logger.info(
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {recipient_name} from {message_envelope.sender.name}: {content}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=message_envelope.recipient,
# kind=MessageKind.RESPOND,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
message_envelope.future.set_result(message_envelope.message)
async def process_next(self) -> None:

View File

@ -1,6 +1,13 @@
from ._events import LLMCallEvent
from ._events import DeliveryStage, LLMCallEvent, MessageEvent, MessageKind
from ._llm_usage import LLMUsageTracker
EVENT_LOGGER_NAME = "agnext.events"
__all__ = ["LLMCallEvent", "EVENT_LOGGER_NAME", "LLMUsageTracker"]
__all__ = [
"LLMCallEvent",
"EVENT_LOGGER_NAME",
"LLMUsageTracker",
"MessageEvent",
"MessageKind",
"DeliveryStage",
]

View File

@ -1,6 +1,9 @@
import json
from enum import Enum
from typing import Any, cast
from ...core import Agent
class LLMCallEvent:
def __init__(self, *, prompt_tokens: int, completion_tokens: int, **kwargs: Any) -> None:
@ -23,6 +26,50 @@ class LLMCallEvent:
self.kwargs = kwargs
self.kwargs["prompt_tokens"] = prompt_tokens
self.kwargs["completion_tokens"] = completion_tokens
self.kwargs["type"] = "LLMCall"
@property
def prompt_tokens(self) -> int:
return cast(int, self.kwargs["prompt_tokens"])
@property
def completion_tokens(self) -> int:
return cast(int, self.kwargs["completion_tokens"])
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)
class MessageKind(Enum):
DIRECT = 1
PUBLISH = 2
RESPOND = 3
class DeliveryStage(Enum):
SEND = 1
DELIVER = 2
class MessageEvent:
def __init__(
self,
*,
payload: Any,
sender: Agent | None,
receiver: Agent | None,
kind: MessageKind,
delivery_stage: DeliveryStage,
**kwargs: Any,
) -> None:
self.kwargs = kwargs
self.kwargs["payload"] = payload
self.kwargs["sender"] = None if sender is None else sender.name
self.kwargs["receiver"] = None if receiver is None else receiver.name
self.kwargs["kind"] = kind
self.kwargs["delivery_stage"] = delivery_stage
self.kwargs["type"] = "Message"
@property
def prompt_tokens(self) -> int:

View File

@ -128,7 +128,10 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
continue
# Execute the function.
future = self.execute_function(
function_call.name, arguments, function_call.id, cancellation_token=cancellation_token
function_call.name,
arguments,
function_call.id,
cancellation_token=cancellation_token,
)
# Append the async result.
execution_futures.append(future)
@ -149,24 +152,22 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
return tool_call_result_msg
async def execute_function(
self, name: str, args: Dict[str, Any], call_id: str, cancellation_token: CancellationToken
self,
name: str,
args: Dict[str, Any],
call_id: str,
cancellation_token: CancellationToken,
) -> Tuple[str, str]:
# Find tool
tool = next((t for t in self._tools if t.name == name), None)
if tool is None:
raise ValueError(f"Tool {name} not found.")
return (f"Error: tool {name} not found.", call_id)
try:
result = await tool.run_json(args, cancellation_token)
result_json_or_str = result.model_dump()
if isinstance(result, dict):
result_str = json.dumps(result_json_or_str)
elif isinstance(result_json_or_str, str):
result_str = result_json_or_str
else:
raise ValueError(f"Unexpected result type: {type(result)}")
result_as_str = tool.return_value_as_string(result)
except Exception as e:
result_str = f"Error: {str(e)}"
return (result_str, call_id)
result_as_str = f"Error: {str(e)}"
return (result_as_str, call_id)
def save_state(self) -> Mapping[str, Any]:
return {

View File

@ -73,17 +73,17 @@ def convert_messages_to_llm_messages(
for message in messages:
match message:
case (
TextMessage(_, source=source)
| MultiModalMessage(_, source=source)
| FunctionCallMessage(_, source=source)
TextMessage(content=_, source=source)
| MultiModalMessage(content=_, source=source)
| FunctionCallMessage(content=_, source=source)
) if source == self_name:
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable)
if converted_message_1 is not None:
result.append(converted_message_1)
case (
TextMessage(_, source=source)
| MultiModalMessage(_, source=source)
| FunctionCallMessage(_, source=source)
TextMessage(content=_, source=source)
| MultiModalMessage(content=_, source=source)
| FunctionCallMessage(content=_, source=source)
) if source != self_name:
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
if converted_message_2 is not None:

View File

@ -105,18 +105,18 @@ def message_handler(
@wraps(func)
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
if strict:
if type(message) not in target_types:
if type(message) not in target_types:
if strict:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
else:
logger.warning(f"Message type {type(message)} not in target types {target_types}")
return_value = await func(self, message, cancellation_token)
if strict:
if return_value is not AnyType and type(return_value) not in return_types:
if AnyType not in return_types and type(return_value) not in return_types:
if strict:
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
elif return_value is not AnyType:
else:
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
return return_value
@ -134,7 +134,10 @@ def message_handler(
class TypeRoutedAgent(BaseAgent):
def __init__(self, name: str, router: AgentRuntime) -> None:
# Self is already bound to the handlers
self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
self._handlers: Dict[
Type[Any],
Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]],
] = {}
for attr in dir(self):
if callable(getattr(self, attr, None)):

View File

@ -1,5 +1,6 @@
import inspect
import logging
import re
import warnings
from typing import (
Any,
@ -11,6 +12,7 @@ from typing import (
Sequence,
Set,
Union,
cast,
)
from openai import AsyncAzureOpenAI, AsyncOpenAI
@ -27,6 +29,7 @@ from openai.types.chat import (
ChatCompletionUserMessageParam,
completion_create_params,
)
from openai.types.shared_params import FunctionDefinition, FunctionParameters
from typing_extensions import Unpack
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
@ -205,15 +208,47 @@ def convert_tools(
) -> List[ChatCompletionToolParam]:
result: List[ChatCompletionToolParam] = []
for tool in tools:
tool_schema = tool.schema
result.append(
{
"type": "function",
"function": tool.schema, # type: ignore
}
ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=tool_schema["name"],
description=tool_schema["description"] if "description" in tool_schema else "",
parameters=cast(FunctionParameters, tool_schema["parameters"])
if "parameters" in tool_schema
else {},
),
)
)
# Check if all tools have valid names.
for tool_param in result:
assert_valid_name(tool_param["function"]["name"])
return result
def normalize_name(name: str) -> str:
"""
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
Prefer _assert_valid_name for validating user configuration or input
"""
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
def assert_valid_name(name: str) -> str:
"""
Ensure that configured names are valid, raises ValueError if not.
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
"""
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
if len(name) > 64:
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
return name
class BaseOpenAI(ChatCompletionClient):
def __init__(
self,
@ -293,7 +328,10 @@ class BaseOpenAI(ChatCompletionClient):
if len(tools) > 0:
converted_tools = convert_tools(tools)
result = await self._client.chat.completions.create(
messages=oai_messages, stream=False, tools=converted_tools, **create_args
messages=oai_messages,
stream=False,
tools=converted_tools,
**create_args,
)
else:
result = await self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args)
@ -331,7 +369,11 @@ class BaseOpenAI(ChatCompletionClient):
# NOTE: If OAI response type changes, this will need to be updated
content = [
FunctionCall(id=x.id, arguments=x.function.arguments, name=x.function.name)
FunctionCall(
id=x.id,
arguments=x.function.arguments,
name=normalize_name(x.function.name),
)
for x in choice.message.tool_calls
]
finish_reason = "function_calls"

View File

@ -1,14 +1,29 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar
from collections.abc import Sequence
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
from pydantic import BaseModel
from typing_extensions import NotRequired
from ...core import CancellationToken
from .._function_utils import normalize_annotated_type
T = TypeVar("T", bound=BaseModel, contravariant=True)
class ParametersSchema(TypedDict):
type: str
properties: Dict[str, Any]
required: NotRequired[Sequence[str]]
class ToolSchema(TypedDict):
parameters: NotRequired[ParametersSchema]
name: str
description: NotRequired[str]
class Tool(Protocol):
@property
def name(self) -> str: ...
@ -17,7 +32,7 @@ class Tool(Protocol):
def description(self) -> str: ...
@property
def schema(self) -> Mapping[str, Any]: ...
def schema(self) -> ToolSchema: ...
def args_type(self) -> Type[BaseModel]: ...
@ -40,20 +55,36 @@ StateT = TypeVar("StateT", bound=BaseModel)
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
def __init__(self, args_type: Type[ArgsT], return_type: Type[ReturnT], name: str, description: str) -> None:
def __init__(
self,
args_type: Type[ArgsT],
return_type: Type[ReturnT],
name: str,
description: str,
) -> None:
self._args_type = args_type
self._return_type = return_type
# Normalize Annotated to the base type.
self._return_type = normalize_annotated_type(return_type)
self._name = name
self._description = description
@property
def schema(self) -> Mapping[str, Any]:
def schema(self) -> ToolSchema:
model_schema = self._args_type.model_json_schema()
parameter_schema: Dict[str, Any] = dict()
parameter_schema["parameters"] = model_schema["properties"]
parameter_schema["name"] = self._name
parameter_schema["description"] = self._description
return parameter_schema
tool_schema = ToolSchema(
name=self._name,
description=self._description,
parameters=ParametersSchema(
type="object",
properties=model_schema["properties"],
),
)
if "required" in model_schema:
assert "parameters" in tool_schema
tool_schema["parameters"]["required"] = model_schema["required"]
return tool_schema
@property
def name(self) -> str:
@ -97,7 +128,12 @@ class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
def __init__(
self, args_type: Type[ArgsT], return_type: Type[ReturnT], state_type: Type[StateT], name: str, description: str
self,
args_type: Type[ArgsT],
return_type: Type[ReturnT],
state_type: Type[StateT],
name: str,
description: str,
) -> None:
super().__init__(args_type, return_type, name, description)
self._state_type = state_type

View File

@ -32,7 +32,12 @@ class FunctionTool(BaseTool[BaseModel, BaseModel]):
else:
if self._has_cancellation_support:
result = await asyncio.get_event_loop().run_in_executor(
None, functools.partial(self._func, **args.model_dump(), cancellation_token=cancellation_token)
None,
functools.partial(
self._func,
**args.model_dump(),
cancellation_token=cancellation_token,
),
)
else:
future = asyncio.get_event_loop().run_in_executor(

View File

@ -36,9 +36,54 @@ def test_tool_schema_generation() -> None:
schema = MyTool().schema
assert schema["name"] == "TestTool"
assert "description" in schema
assert schema["description"] == "Description of test tool."
assert schema["parameters"]["query"]["description"] == "The description."
assert len(schema["parameters"]) == 1
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert "properties" in schema["parameters"]
assert schema["parameters"]["properties"]["query"]["description"] == "The description."
assert schema["parameters"]["properties"]["query"]["type"] == "string"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["query"]
assert len(schema["parameters"]["properties"]) == 1
def test_func_tool_schema_generation() -> None:
def my_function(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult:
return MyResult(result="test")
tool = FunctionTool(my_function, description="Function tool.")
schema = tool.schema
assert schema["name"] == "my_function"
assert "description" in schema
assert schema["description"] == "Function tool."
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"].keys() == {"arg", "other", "nonrequired"}
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
assert schema["parameters"]["properties"]["other"]["type"] == "integer"
assert schema["parameters"]["properties"]["other"]["description"] == "int arg"
assert schema["parameters"]["properties"]["nonrequired"]["type"] == "integer"
assert schema["parameters"]["properties"]["nonrequired"]["description"] == "nonrequired"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["arg", "other"]
assert len(schema["parameters"]["properties"]) == 3
def test_func_tool_schema_generation_only_default_arg() -> None:
def my_function(arg: str = "default") -> MyResult:
return MyResult(result="test")
tool = FunctionTool(my_function, description="Function tool.")
schema = tool.schema
assert schema["name"] == "my_function"
assert "description" in schema
assert schema["description"] == "Function tool."
assert "parameters" in schema
assert len(schema["parameters"]["properties"]) == 1
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
assert "required" not in schema["parameters"]
@pytest.mark.asyncio
async def test_tool_run()-> None:
@ -128,7 +173,7 @@ def test_func_tool_return_annotated()-> None:
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert tool.return_type() == Annotated[str, "test description"]
assert tool.return_type() == str
assert tool.state_type() is None
def test_func_tool_no_args()-> None: