mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-15 20:21:10 +00:00
Add examples to showcase patterns (#55)
* add chess example * wip * wip * fix tool schema generation * fixes * Agent handle exception Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com> * format * mypy * fix test for annotated --------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
c6360feeb6
commit
b4ade8b735
119
examples/chess_game.py
Normal file
119
examples/chess_game.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
"""This is an example of simulating a chess game with two agents
|
||||||
|
that play against each other, using tools to reason about the game state
|
||||||
|
and make moves."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from agnext.application import SingleThreadedAgentRuntime
|
||||||
|
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
|
||||||
|
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
|
||||||
|
from agnext.chat.types import TextMessage
|
||||||
|
from agnext.components.models import OpenAI, SystemMessage
|
||||||
|
from agnext.components.tools import FunctionTool
|
||||||
|
from agnext.core import AgentRuntime
|
||||||
|
from chess import SQUARE_NAMES, Board, Move
|
||||||
|
from chess import piece_name as get_piece_name
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
class ChessGameOutput(GroupChatOutput): # type: ignore
|
||||||
|
def on_message_received(self, message: TextMessage) -> None: # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_output(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def chess_game(runtime: AgentRuntime) -> GroupChat: # type: ignore
|
||||||
|
"""Create agents for a chess game and return the group chat."""
|
||||||
|
|
||||||
|
# Create the board.
|
||||||
|
board = Board()
|
||||||
|
|
||||||
|
# Create shared tools.
|
||||||
|
def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format."]:
|
||||||
|
return "Possible moves are: " + ", ".join([str(move) for move in board.legal_moves])
|
||||||
|
|
||||||
|
get_legal_moves_tool = FunctionTool(get_legal_moves, description="Get legal moves.")
|
||||||
|
|
||||||
|
def make_move(input: Annotated[str, "A move in UCI format."]) -> Annotated[str, "Result of the move."]:
|
||||||
|
move = Move.from_uci(input)
|
||||||
|
board.push(move)
|
||||||
|
print(board.unicode(borders=True))
|
||||||
|
# Get the piece name.
|
||||||
|
piece = board.piece_at(move.to_square)
|
||||||
|
assert piece is not None
|
||||||
|
piece_symbol = piece.unicode_symbol()
|
||||||
|
piece_name = get_piece_name(piece.piece_type)
|
||||||
|
if piece_symbol.isupper():
|
||||||
|
piece_name = piece_name.capitalize()
|
||||||
|
return f"Moved {piece_name} ({piece_symbol}) from {SQUARE_NAMES[move.from_square]} to {SQUARE_NAMES[move.to_square]}."
|
||||||
|
|
||||||
|
make_move_tool = FunctionTool(make_move, description="Call this tool to make a move.")
|
||||||
|
|
||||||
|
tools = [get_legal_moves_tool, make_move_tool]
|
||||||
|
|
||||||
|
black = ChatCompletionAgent(
|
||||||
|
name="PlayerBlack",
|
||||||
|
description="Player playing black.",
|
||||||
|
runtime=runtime,
|
||||||
|
system_messages=[
|
||||||
|
SystemMessage(
|
||||||
|
content="You are a chess player and you play as black. "
|
||||||
|
"First call get_legal_moves() first, to get list of legal moves. "
|
||||||
|
"Then call make_move(move) to make a move."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
model_client=OpenAI(model="gpt-4-turbo"),
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
white = ChatCompletionAgent(
|
||||||
|
name="PlayerWhite",
|
||||||
|
description="Player playing white.",
|
||||||
|
runtime=runtime,
|
||||||
|
system_messages=[
|
||||||
|
SystemMessage(
|
||||||
|
content="You are a chess player and you play as white. "
|
||||||
|
"First call get_legal_moves() first, to get list of legal moves. "
|
||||||
|
"Then call make_move(move) to make a move."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
model_client=OpenAI(model="gpt-4-turbo"),
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
game_chat = GroupChat(
|
||||||
|
name="ChessGame",
|
||||||
|
description="A chess game between two agents.",
|
||||||
|
runtime=runtime,
|
||||||
|
agents=[white, black],
|
||||||
|
num_rounds=10,
|
||||||
|
output=ChessGameOutput(),
|
||||||
|
)
|
||||||
|
return game_chat
|
||||||
|
|
||||||
|
|
||||||
|
async def main(message: str) -> None:
|
||||||
|
runtime = SingleThreadedAgentRuntime()
|
||||||
|
game_chat = chess_game(runtime)
|
||||||
|
future = runtime.send_message(TextMessage(content=message, source="Human"), game_chat)
|
||||||
|
while not future.done():
|
||||||
|
await runtime.process_next()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run a chess game between two agents.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-message",
|
||||||
|
default="Please make a move.",
|
||||||
|
help="The initial message to send to the agent playing white.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
asyncio.run(main(args.initial_message))
|
@ -31,6 +31,9 @@ dev = [
|
|||||||
"pytest-xdist",
|
"pytest-xdist",
|
||||||
"types-Pillow",
|
"types-Pillow",
|
||||||
"polars",
|
"polars",
|
||||||
|
# Dependencies for the examples.
|
||||||
|
"chess",
|
||||||
|
"tavily-python",
|
||||||
]
|
]
|
||||||
docs = ["sphinx", "furo", "sphinxcontrib-apidoc", "myst-parser"]
|
docs = ["sphinx", "furo", "sphinxcontrib-apidoc", "myst-parser"]
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from ..core.exceptions import MessageDroppedException
|
|||||||
from ..core.intervention import DropMessage, InterventionHandler
|
from ..core.intervention import DropMessage, InterventionHandler
|
||||||
|
|
||||||
logger = logging.getLogger("agnext")
|
logger = logging.getLogger("agnext")
|
||||||
|
event_logger = logging.getLogger("agnext.events")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
@ -67,7 +68,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
return list(self._agents)
|
return list(self._agents)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unprocessed_messages(self) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
|
def unprocessed_messages(
|
||||||
|
self,
|
||||||
|
) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
|
||||||
return self._message_queue
|
return self._message_queue
|
||||||
|
|
||||||
# Returns the response of the message
|
# Returns the response of the message
|
||||||
@ -82,6 +85,18 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
if cancellation_token is None:
|
if cancellation_token is None:
|
||||||
cancellation_token = CancellationToken()
|
cancellation_token = CancellationToken()
|
||||||
|
|
||||||
|
logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {message.__dict__}")
|
||||||
|
|
||||||
|
# event_logger.info(
|
||||||
|
# MessageEvent(
|
||||||
|
# payload=message,
|
||||||
|
# sender=sender,
|
||||||
|
# receiver=recipient,
|
||||||
|
# kind=MessageKind.DIRECT,
|
||||||
|
# delivery_stage=DeliveryStage.SEND,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
future = asyncio.get_event_loop().create_future()
|
||||||
if recipient not in self._agents:
|
if recipient not in self._agents:
|
||||||
future.set_exception(Exception("Recipient not found"))
|
future.set_exception(Exception("Recipient not found"))
|
||||||
@ -108,6 +123,18 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
if cancellation_token is None:
|
if cancellation_token is None:
|
||||||
cancellation_token = CancellationToken()
|
cancellation_token = CancellationToken()
|
||||||
|
|
||||||
|
logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {message.__dict__}")
|
||||||
|
|
||||||
|
# event_logger.info(
|
||||||
|
# MessageEvent(
|
||||||
|
# payload=message,
|
||||||
|
# sender=sender,
|
||||||
|
# receiver=None,
|
||||||
|
# kind=MessageKind.PUBLISH,
|
||||||
|
# delivery_stage=DeliveryStage.SEND,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
self._message_queue.append(
|
self._message_queue.append(
|
||||||
PublishMessageEnvelope(
|
PublishMessageEnvelope(
|
||||||
message=message,
|
message=message,
|
||||||
@ -137,8 +164,17 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
try:
|
try:
|
||||||
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
|
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} from {sender_name}"
|
f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
|
||||||
)
|
)
|
||||||
|
# event_logger.info(
|
||||||
|
# MessageEvent(
|
||||||
|
# payload=message_envelope.message,
|
||||||
|
# sender=message_envelope.sender,
|
||||||
|
# receiver=recipient,
|
||||||
|
# kind=MessageKind.DIRECT,
|
||||||
|
# delivery_stage=DeliveryStage.DELIVER,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
response = await recipient.on_message(
|
response = await recipient.on_message(
|
||||||
message_envelope.message,
|
message_envelope.message,
|
||||||
cancellation_token=message_envelope.cancellation_token,
|
cancellation_token=message_envelope.cancellation_token,
|
||||||
@ -162,9 +198,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
if message_envelope.sender is not None and agent.name == message_envelope.sender.name:
|
if message_envelope.sender is not None and agent.name == message_envelope.sender.name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {message_envelope.sender.name if message_envelope.sender is not None else 'Unknown'}"
|
f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {sender_name}"
|
||||||
)
|
)
|
||||||
|
# event_logger.info(
|
||||||
|
# MessageEvent(
|
||||||
|
# payload=message_envelope.message,
|
||||||
|
# sender=message_envelope.sender,
|
||||||
|
# receiver=agent,
|
||||||
|
# kind=MessageKind.PUBLISH,
|
||||||
|
# delivery_stage=DeliveryStage.DELIVER,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
future = agent.on_message(
|
future = agent.on_message(
|
||||||
message_envelope.message,
|
message_envelope.message,
|
||||||
@ -182,9 +228,23 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
|
|
||||||
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
||||||
recipient_name = message_envelope.recipient.name if message_envelope.recipient is not None else "Unknown"
|
recipient_name = message_envelope.recipient.name if message_envelope.recipient is not None else "Unknown"
|
||||||
logger.info(
|
content = (
|
||||||
f"Resolving response for recipient {recipient_name} from {message_envelope.sender.name} with message type {type(message_envelope.message).__name__}"
|
message_envelope.message.__dict__
|
||||||
|
if hasattr(message_envelope.message, "__dict__")
|
||||||
|
else message_envelope.message
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {recipient_name} from {message_envelope.sender.name}: {content}"
|
||||||
|
)
|
||||||
|
# event_logger.info(
|
||||||
|
# MessageEvent(
|
||||||
|
# payload=message_envelope.message,
|
||||||
|
# sender=message_envelope.sender,
|
||||||
|
# receiver=message_envelope.recipient,
|
||||||
|
# kind=MessageKind.RESPOND,
|
||||||
|
# delivery_stage=DeliveryStage.DELIVER,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
message_envelope.future.set_result(message_envelope.message)
|
message_envelope.future.set_result(message_envelope.message)
|
||||||
|
|
||||||
async def process_next(self) -> None:
|
async def process_next(self) -> None:
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
from ._events import LLMCallEvent
|
from ._events import DeliveryStage, LLMCallEvent, MessageEvent, MessageKind
|
||||||
from ._llm_usage import LLMUsageTracker
|
from ._llm_usage import LLMUsageTracker
|
||||||
|
|
||||||
EVENT_LOGGER_NAME = "agnext.events"
|
EVENT_LOGGER_NAME = "agnext.events"
|
||||||
|
|
||||||
__all__ = ["LLMCallEvent", "EVENT_LOGGER_NAME", "LLMUsageTracker"]
|
__all__ = [
|
||||||
|
"LLMCallEvent",
|
||||||
|
"EVENT_LOGGER_NAME",
|
||||||
|
"LLMUsageTracker",
|
||||||
|
"MessageEvent",
|
||||||
|
"MessageKind",
|
||||||
|
"DeliveryStage",
|
||||||
|
]
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from ...core import Agent
|
||||||
|
|
||||||
|
|
||||||
class LLMCallEvent:
|
class LLMCallEvent:
|
||||||
def __init__(self, *, prompt_tokens: int, completion_tokens: int, **kwargs: Any) -> None:
|
def __init__(self, *, prompt_tokens: int, completion_tokens: int, **kwargs: Any) -> None:
|
||||||
@ -23,6 +26,50 @@ class LLMCallEvent:
|
|||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.kwargs["prompt_tokens"] = prompt_tokens
|
self.kwargs["prompt_tokens"] = prompt_tokens
|
||||||
self.kwargs["completion_tokens"] = completion_tokens
|
self.kwargs["completion_tokens"] = completion_tokens
|
||||||
|
self.kwargs["type"] = "LLMCall"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompt_tokens(self) -> int:
|
||||||
|
return cast(int, self.kwargs["prompt_tokens"])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completion_tokens(self) -> int:
|
||||||
|
return cast(int, self.kwargs["completion_tokens"])
|
||||||
|
|
||||||
|
# This must output the event in a json serializable format
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return json.dumps(self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageKind(Enum):
|
||||||
|
DIRECT = 1
|
||||||
|
PUBLISH = 2
|
||||||
|
RESPOND = 3
|
||||||
|
|
||||||
|
|
||||||
|
class DeliveryStage(Enum):
|
||||||
|
SEND = 1
|
||||||
|
DELIVER = 2
|
||||||
|
|
||||||
|
|
||||||
|
class MessageEvent:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
payload: Any,
|
||||||
|
sender: Agent | None,
|
||||||
|
receiver: Agent | None,
|
||||||
|
kind: MessageKind,
|
||||||
|
delivery_stage: DeliveryStage,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.kwargs["payload"] = payload
|
||||||
|
self.kwargs["sender"] = None if sender is None else sender.name
|
||||||
|
self.kwargs["receiver"] = None if receiver is None else receiver.name
|
||||||
|
self.kwargs["kind"] = kind
|
||||||
|
self.kwargs["delivery_stage"] = delivery_stage
|
||||||
|
self.kwargs["type"] = "Message"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt_tokens(self) -> int:
|
def prompt_tokens(self) -> int:
|
||||||
|
@ -128,7 +128,10 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
|
|||||||
continue
|
continue
|
||||||
# Execute the function.
|
# Execute the function.
|
||||||
future = self.execute_function(
|
future = self.execute_function(
|
||||||
function_call.name, arguments, function_call.id, cancellation_token=cancellation_token
|
function_call.name,
|
||||||
|
arguments,
|
||||||
|
function_call.id,
|
||||||
|
cancellation_token=cancellation_token,
|
||||||
)
|
)
|
||||||
# Append the async result.
|
# Append the async result.
|
||||||
execution_futures.append(future)
|
execution_futures.append(future)
|
||||||
@ -149,24 +152,22 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
|
|||||||
return tool_call_result_msg
|
return tool_call_result_msg
|
||||||
|
|
||||||
async def execute_function(
|
async def execute_function(
|
||||||
self, name: str, args: Dict[str, Any], call_id: str, cancellation_token: CancellationToken
|
self,
|
||||||
|
name: str,
|
||||||
|
args: Dict[str, Any],
|
||||||
|
call_id: str,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
# Find tool
|
# Find tool
|
||||||
tool = next((t for t in self._tools if t.name == name), None)
|
tool = next((t for t in self._tools if t.name == name), None)
|
||||||
if tool is None:
|
if tool is None:
|
||||||
raise ValueError(f"Tool {name} not found.")
|
return (f"Error: tool {name} not found.", call_id)
|
||||||
try:
|
try:
|
||||||
result = await tool.run_json(args, cancellation_token)
|
result = await tool.run_json(args, cancellation_token)
|
||||||
result_json_or_str = result.model_dump()
|
result_as_str = tool.return_value_as_string(result)
|
||||||
if isinstance(result, dict):
|
|
||||||
result_str = json.dumps(result_json_or_str)
|
|
||||||
elif isinstance(result_json_or_str, str):
|
|
||||||
result_str = result_json_or_str
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected result type: {type(result)}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result_str = f"Error: {str(e)}"
|
result_as_str = f"Error: {str(e)}"
|
||||||
return (result_str, call_id)
|
return (result_as_str, call_id)
|
||||||
|
|
||||||
def save_state(self) -> Mapping[str, Any]:
|
def save_state(self) -> Mapping[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
@ -73,17 +73,17 @@ def convert_messages_to_llm_messages(
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
match message:
|
match message:
|
||||||
case (
|
case (
|
||||||
TextMessage(_, source=source)
|
TextMessage(content=_, source=source)
|
||||||
| MultiModalMessage(_, source=source)
|
| MultiModalMessage(content=_, source=source)
|
||||||
| FunctionCallMessage(_, source=source)
|
| FunctionCallMessage(content=_, source=source)
|
||||||
) if source == self_name:
|
) if source == self_name:
|
||||||
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable)
|
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable)
|
||||||
if converted_message_1 is not None:
|
if converted_message_1 is not None:
|
||||||
result.append(converted_message_1)
|
result.append(converted_message_1)
|
||||||
case (
|
case (
|
||||||
TextMessage(_, source=source)
|
TextMessage(content=_, source=source)
|
||||||
| MultiModalMessage(_, source=source)
|
| MultiModalMessage(content=_, source=source)
|
||||||
| FunctionCallMessage(_, source=source)
|
| FunctionCallMessage(content=_, source=source)
|
||||||
) if source != self_name:
|
) if source != self_name:
|
||||||
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
|
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
|
||||||
if converted_message_2 is not None:
|
if converted_message_2 is not None:
|
||||||
|
@ -105,18 +105,18 @@ def message_handler(
|
|||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
|
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
|
||||||
if strict:
|
|
||||||
if type(message) not in target_types:
|
if type(message) not in target_types:
|
||||||
|
if strict:
|
||||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||||
|
|
||||||
return_value = await func(self, message, cancellation_token)
|
return_value = await func(self, message, cancellation_token)
|
||||||
|
|
||||||
|
if AnyType not in return_types and type(return_value) not in return_types:
|
||||||
if strict:
|
if strict:
|
||||||
if return_value is not AnyType and type(return_value) not in return_types:
|
|
||||||
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
||||||
elif return_value is not AnyType:
|
else:
|
||||||
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
||||||
|
|
||||||
return return_value
|
return return_value
|
||||||
@ -134,7 +134,10 @@ def message_handler(
|
|||||||
class TypeRoutedAgent(BaseAgent):
|
class TypeRoutedAgent(BaseAgent):
|
||||||
def __init__(self, name: str, router: AgentRuntime) -> None:
|
def __init__(self, name: str, router: AgentRuntime) -> None:
|
||||||
# Self is already bound to the handlers
|
# Self is already bound to the handlers
|
||||||
self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
|
self._handlers: Dict[
|
||||||
|
Type[Any],
|
||||||
|
Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]],
|
||||||
|
] = {}
|
||||||
|
|
||||||
for attr in dir(self):
|
for attr in dir(self):
|
||||||
if callable(getattr(self, attr, None)):
|
if callable(getattr(self, attr, None)):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -11,6 +12,7 @@ from typing import (
|
|||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||||
@ -27,6 +29,7 @@ from openai.types.chat import (
|
|||||||
ChatCompletionUserMessageParam,
|
ChatCompletionUserMessageParam,
|
||||||
completion_create_params,
|
completion_create_params,
|
||||||
)
|
)
|
||||||
|
from openai.types.shared_params import FunctionDefinition, FunctionParameters
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
|
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
|
||||||
@ -205,15 +208,47 @@ def convert_tools(
|
|||||||
) -> List[ChatCompletionToolParam]:
|
) -> List[ChatCompletionToolParam]:
|
||||||
result: List[ChatCompletionToolParam] = []
|
result: List[ChatCompletionToolParam] = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
|
tool_schema = tool.schema
|
||||||
result.append(
|
result.append(
|
||||||
{
|
ChatCompletionToolParam(
|
||||||
"type": "function",
|
type="function",
|
||||||
"function": tool.schema, # type: ignore
|
function=FunctionDefinition(
|
||||||
}
|
name=tool_schema["name"],
|
||||||
|
description=tool_schema["description"] if "description" in tool_schema else "",
|
||||||
|
parameters=cast(FunctionParameters, tool_schema["parameters"])
|
||||||
|
if "parameters" in tool_schema
|
||||||
|
else {},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
# Check if all tools have valid names.
|
||||||
|
for tool_param in result:
|
||||||
|
assert_valid_name(tool_param["function"]["name"])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_name(name: str) -> str:
|
||||||
|
"""
|
||||||
|
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
|
||||||
|
|
||||||
|
Prefer _assert_valid_name for validating user configuration or input
|
||||||
|
"""
|
||||||
|
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
|
||||||
|
|
||||||
|
|
||||||
|
def assert_valid_name(name: str) -> str:
|
||||||
|
"""
|
||||||
|
Ensure that configured names are valid, raises ValueError if not.
|
||||||
|
|
||||||
|
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
|
||||||
|
"""
|
||||||
|
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
||||||
|
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
|
||||||
|
if len(name) > 64:
|
||||||
|
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
class BaseOpenAI(ChatCompletionClient):
|
class BaseOpenAI(ChatCompletionClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -293,7 +328,10 @@ class BaseOpenAI(ChatCompletionClient):
|
|||||||
if len(tools) > 0:
|
if len(tools) > 0:
|
||||||
converted_tools = convert_tools(tools)
|
converted_tools = convert_tools(tools)
|
||||||
result = await self._client.chat.completions.create(
|
result = await self._client.chat.completions.create(
|
||||||
messages=oai_messages, stream=False, tools=converted_tools, **create_args
|
messages=oai_messages,
|
||||||
|
stream=False,
|
||||||
|
tools=converted_tools,
|
||||||
|
**create_args,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args)
|
result = await self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args)
|
||||||
@ -331,7 +369,11 @@ class BaseOpenAI(ChatCompletionClient):
|
|||||||
|
|
||||||
# NOTE: If OAI response type changes, this will need to be updated
|
# NOTE: If OAI response type changes, this will need to be updated
|
||||||
content = [
|
content = [
|
||||||
FunctionCall(id=x.id, arguments=x.function.arguments, name=x.function.name)
|
FunctionCall(
|
||||||
|
id=x.id,
|
||||||
|
arguments=x.function.arguments,
|
||||||
|
name=normalize_name(x.function.name),
|
||||||
|
)
|
||||||
for x in choice.message.tool_calls
|
for x in choice.message.tool_calls
|
||||||
]
|
]
|
||||||
finish_reason = "function_calls"
|
finish_reason = "function_calls"
|
||||||
|
@ -1,14 +1,29 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from ...core import CancellationToken
|
from ...core import CancellationToken
|
||||||
|
from .._function_utils import normalize_annotated_type
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel, contravariant=True)
|
T = TypeVar("T", bound=BaseModel, contravariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ParametersSchema(TypedDict):
|
||||||
|
type: str
|
||||||
|
properties: Dict[str, Any]
|
||||||
|
required: NotRequired[Sequence[str]]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolSchema(TypedDict):
|
||||||
|
parameters: NotRequired[ParametersSchema]
|
||||||
|
name: str
|
||||||
|
description: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
class Tool(Protocol):
|
class Tool(Protocol):
|
||||||
@property
|
@property
|
||||||
def name(self) -> str: ...
|
def name(self) -> str: ...
|
||||||
@ -17,7 +32,7 @@ class Tool(Protocol):
|
|||||||
def description(self) -> str: ...
|
def description(self) -> str: ...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schema(self) -> Mapping[str, Any]: ...
|
def schema(self) -> ToolSchema: ...
|
||||||
|
|
||||||
def args_type(self) -> Type[BaseModel]: ...
|
def args_type(self) -> Type[BaseModel]: ...
|
||||||
|
|
||||||
@ -40,20 +55,36 @@ StateT = TypeVar("StateT", bound=BaseModel)
|
|||||||
|
|
||||||
|
|
||||||
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
|
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
|
||||||
def __init__(self, args_type: Type[ArgsT], return_type: Type[ReturnT], name: str, description: str) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
args_type: Type[ArgsT],
|
||||||
|
return_type: Type[ReturnT],
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
) -> None:
|
||||||
self._args_type = args_type
|
self._args_type = args_type
|
||||||
self._return_type = return_type
|
# Normalize Annotated to the base type.
|
||||||
|
self._return_type = normalize_annotated_type(return_type)
|
||||||
self._name = name
|
self._name = name
|
||||||
self._description = description
|
self._description = description
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schema(self) -> Mapping[str, Any]:
|
def schema(self) -> ToolSchema:
|
||||||
model_schema = self._args_type.model_json_schema()
|
model_schema = self._args_type.model_json_schema()
|
||||||
parameter_schema: Dict[str, Any] = dict()
|
|
||||||
parameter_schema["parameters"] = model_schema["properties"]
|
tool_schema = ToolSchema(
|
||||||
parameter_schema["name"] = self._name
|
name=self._name,
|
||||||
parameter_schema["description"] = self._description
|
description=self._description,
|
||||||
return parameter_schema
|
parameters=ParametersSchema(
|
||||||
|
type="object",
|
||||||
|
properties=model_schema["properties"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if "required" in model_schema:
|
||||||
|
assert "parameters" in tool_schema
|
||||||
|
tool_schema["parameters"]["required"] = model_schema["required"]
|
||||||
|
|
||||||
|
return tool_schema
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@ -97,7 +128,12 @@ class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
|
|||||||
|
|
||||||
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
|
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, args_type: Type[ArgsT], return_type: Type[ReturnT], state_type: Type[StateT], name: str, description: str
|
self,
|
||||||
|
args_type: Type[ArgsT],
|
||||||
|
return_type: Type[ReturnT],
|
||||||
|
state_type: Type[StateT],
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(args_type, return_type, name, description)
|
super().__init__(args_type, return_type, name, description)
|
||||||
self._state_type = state_type
|
self._state_type = state_type
|
||||||
|
@ -32,7 +32,12 @@ class FunctionTool(BaseTool[BaseModel, BaseModel]):
|
|||||||
else:
|
else:
|
||||||
if self._has_cancellation_support:
|
if self._has_cancellation_support:
|
||||||
result = await asyncio.get_event_loop().run_in_executor(
|
result = await asyncio.get_event_loop().run_in_executor(
|
||||||
None, functools.partial(self._func, **args.model_dump(), cancellation_token=cancellation_token)
|
None,
|
||||||
|
functools.partial(
|
||||||
|
self._func,
|
||||||
|
**args.model_dump(),
|
||||||
|
cancellation_token=cancellation_token,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
future = asyncio.get_event_loop().run_in_executor(
|
future = asyncio.get_event_loop().run_in_executor(
|
||||||
|
@ -36,9 +36,54 @@ def test_tool_schema_generation() -> None:
|
|||||||
schema = MyTool().schema
|
schema = MyTool().schema
|
||||||
|
|
||||||
assert schema["name"] == "TestTool"
|
assert schema["name"] == "TestTool"
|
||||||
|
assert "description" in schema
|
||||||
assert schema["description"] == "Description of test tool."
|
assert schema["description"] == "Description of test tool."
|
||||||
assert schema["parameters"]["query"]["description"] == "The description."
|
assert "parameters" in schema
|
||||||
assert len(schema["parameters"]) == 1
|
assert schema["parameters"]["type"] == "object"
|
||||||
|
assert "properties" in schema["parameters"]
|
||||||
|
assert schema["parameters"]["properties"]["query"]["description"] == "The description."
|
||||||
|
assert schema["parameters"]["properties"]["query"]["type"] == "string"
|
||||||
|
assert "required" in schema["parameters"]
|
||||||
|
assert schema["parameters"]["required"] == ["query"]
|
||||||
|
assert len(schema["parameters"]["properties"]) == 1
|
||||||
|
|
||||||
|
def test_func_tool_schema_generation() -> None:
|
||||||
|
def my_function(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult:
|
||||||
|
return MyResult(result="test")
|
||||||
|
tool = FunctionTool(my_function, description="Function tool.")
|
||||||
|
schema = tool.schema
|
||||||
|
|
||||||
|
assert schema["name"] == "my_function"
|
||||||
|
assert "description" in schema
|
||||||
|
assert schema["description"] == "Function tool."
|
||||||
|
assert "parameters" in schema
|
||||||
|
assert schema["parameters"]["type"] == "object"
|
||||||
|
assert schema["parameters"]["properties"].keys() == {"arg", "other", "nonrequired"}
|
||||||
|
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
|
||||||
|
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
|
||||||
|
assert schema["parameters"]["properties"]["other"]["type"] == "integer"
|
||||||
|
assert schema["parameters"]["properties"]["other"]["description"] == "int arg"
|
||||||
|
assert schema["parameters"]["properties"]["nonrequired"]["type"] == "integer"
|
||||||
|
assert schema["parameters"]["properties"]["nonrequired"]["description"] == "nonrequired"
|
||||||
|
assert "required" in schema["parameters"]
|
||||||
|
assert schema["parameters"]["required"] == ["arg", "other"]
|
||||||
|
assert len(schema["parameters"]["properties"]) == 3
|
||||||
|
|
||||||
|
def test_func_tool_schema_generation_only_default_arg() -> None:
|
||||||
|
def my_function(arg: str = "default") -> MyResult:
|
||||||
|
return MyResult(result="test")
|
||||||
|
tool = FunctionTool(my_function, description="Function tool.")
|
||||||
|
schema = tool.schema
|
||||||
|
|
||||||
|
assert schema["name"] == "my_function"
|
||||||
|
assert "description" in schema
|
||||||
|
assert schema["description"] == "Function tool."
|
||||||
|
assert "parameters" in schema
|
||||||
|
assert len(schema["parameters"]["properties"]) == 1
|
||||||
|
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
|
||||||
|
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
|
||||||
|
assert "required" not in schema["parameters"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_run()-> None:
|
async def test_tool_run()-> None:
|
||||||
@ -128,7 +173,7 @@ def test_func_tool_return_annotated()-> None:
|
|||||||
assert tool.name == "my_function"
|
assert tool.name == "my_function"
|
||||||
assert tool.description == "Function tool."
|
assert tool.description == "Function tool."
|
||||||
assert issubclass(tool.args_type(), BaseModel)
|
assert issubclass(tool.args_type(), BaseModel)
|
||||||
assert tool.return_type() == Annotated[str, "test description"]
|
assert tool.return_type() == str
|
||||||
assert tool.state_type() is None
|
assert tool.state_type() is None
|
||||||
|
|
||||||
def test_func_tool_no_args()-> None:
|
def test_func_tool_no_args()-> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user