mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-15 12:11:30 +00:00
Add examples to showcase patterns (#55)
* add chess example * wip * wip * fix tool schema generation * fixes * Agent handle exception Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com> * format * mypy * fix test for annotated --------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
c6360feeb6
commit
b4ade8b735
119
examples/chess_game.py
Normal file
119
examples/chess_game.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""This is an example of simulating a chess game with two agents
|
||||
that play against each other, using tools to reason about the game state
|
||||
and make moves."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
|
||||
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
|
||||
from agnext.chat.types import TextMessage
|
||||
from agnext.components.models import OpenAI, SystemMessage
|
||||
from agnext.components.tools import FunctionTool
|
||||
from agnext.core import AgentRuntime
|
||||
from chess import SQUARE_NAMES, Board, Move
|
||||
from chess import piece_name as get_piece_name
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class ChessGameOutput(GroupChatOutput): # type: ignore
|
||||
def on_message_received(self, message: TextMessage) -> None: # type: ignore
|
||||
pass
|
||||
|
||||
def get_output(self) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def chess_game(runtime: AgentRuntime) -> GroupChat: # type: ignore
|
||||
"""Create agents for a chess game and return the group chat."""
|
||||
|
||||
# Create the board.
|
||||
board = Board()
|
||||
|
||||
# Create shared tools.
|
||||
def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format."]:
|
||||
return "Possible moves are: " + ", ".join([str(move) for move in board.legal_moves])
|
||||
|
||||
get_legal_moves_tool = FunctionTool(get_legal_moves, description="Get legal moves.")
|
||||
|
||||
def make_move(input: Annotated[str, "A move in UCI format."]) -> Annotated[str, "Result of the move."]:
|
||||
move = Move.from_uci(input)
|
||||
board.push(move)
|
||||
print(board.unicode(borders=True))
|
||||
# Get the piece name.
|
||||
piece = board.piece_at(move.to_square)
|
||||
assert piece is not None
|
||||
piece_symbol = piece.unicode_symbol()
|
||||
piece_name = get_piece_name(piece.piece_type)
|
||||
if piece_symbol.isupper():
|
||||
piece_name = piece_name.capitalize()
|
||||
return f"Moved {piece_name} ({piece_symbol}) from {SQUARE_NAMES[move.from_square]} to {SQUARE_NAMES[move.to_square]}."
|
||||
|
||||
make_move_tool = FunctionTool(make_move, description="Call this tool to make a move.")
|
||||
|
||||
tools = [get_legal_moves_tool, make_move_tool]
|
||||
|
||||
black = ChatCompletionAgent(
|
||||
name="PlayerBlack",
|
||||
description="Player playing black.",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
content="You are a chess player and you play as black. "
|
||||
"First call get_legal_moves() first, to get list of legal moves. "
|
||||
"Then call make_move(move) to make a move."
|
||||
),
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
tools=tools,
|
||||
)
|
||||
white = ChatCompletionAgent(
|
||||
name="PlayerWhite",
|
||||
description="Player playing white.",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
content="You are a chess player and you play as white. "
|
||||
"First call get_legal_moves() first, to get list of legal moves. "
|
||||
"Then call make_move(move) to make a move."
|
||||
),
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
tools=tools,
|
||||
)
|
||||
game_chat = GroupChat(
|
||||
name="ChessGame",
|
||||
description="A chess game between two agents.",
|
||||
runtime=runtime,
|
||||
agents=[white, black],
|
||||
num_rounds=10,
|
||||
output=ChessGameOutput(),
|
||||
)
|
||||
return game_chat
|
||||
|
||||
|
||||
async def main(message: str) -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
game_chat = chess_game(runtime)
|
||||
future = runtime.send_message(TextMessage(content=message, source="Human"), game_chat)
|
||||
while not future.done():
|
||||
await runtime.process_next()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run a chess game between two agents.")
|
||||
parser.add_argument(
|
||||
"--initial-message",
|
||||
default="Please make a move.",
|
||||
help="The initial message to send to the agent playing white.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(args.initial_message))
|
@ -31,6 +31,9 @@ dev = [
|
||||
"pytest-xdist",
|
||||
"types-Pillow",
|
||||
"polars",
|
||||
# Dependencies for the examples.
|
||||
"chess",
|
||||
"tavily-python",
|
||||
]
|
||||
docs = ["sphinx", "furo", "sphinxcontrib-apidoc", "myst-parser"]
|
||||
|
||||
|
@ -10,6 +10,7 @@ from ..core.exceptions import MessageDroppedException
|
||||
from ..core.intervention import DropMessage, InterventionHandler
|
||||
|
||||
logger = logging.getLogger("agnext")
|
||||
event_logger = logging.getLogger("agnext.events")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -67,7 +68,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
return list(self._agents)
|
||||
|
||||
@property
|
||||
def unprocessed_messages(self) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
|
||||
def unprocessed_messages(
|
||||
self,
|
||||
) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
|
||||
return self._message_queue
|
||||
|
||||
# Returns the response of the message
|
||||
@ -82,6 +85,18 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {message.__dict__}")
|
||||
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message,
|
||||
# sender=sender,
|
||||
# receiver=recipient,
|
||||
# kind=MessageKind.DIRECT,
|
||||
# delivery_stage=DeliveryStage.SEND,
|
||||
# )
|
||||
# )
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
if recipient not in self._agents:
|
||||
future.set_exception(Exception("Recipient not found"))
|
||||
@ -108,6 +123,18 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {message.__dict__}")
|
||||
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message,
|
||||
# sender=sender,
|
||||
# receiver=None,
|
||||
# kind=MessageKind.PUBLISH,
|
||||
# delivery_stage=DeliveryStage.SEND,
|
||||
# )
|
||||
# )
|
||||
|
||||
self._message_queue.append(
|
||||
PublishMessageEnvelope(
|
||||
message=message,
|
||||
@ -137,8 +164,17 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
try:
|
||||
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} from {sender_name}"
|
||||
f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=recipient,
|
||||
# kind=MessageKind.DIRECT,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
response = await recipient.on_message(
|
||||
message_envelope.message,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
@ -162,9 +198,19 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
if message_envelope.sender is not None and agent.name == message_envelope.sender.name:
|
||||
continue
|
||||
|
||||
sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {message_envelope.sender.name if message_envelope.sender is not None else 'Unknown'}"
|
||||
f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {sender_name}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=agent,
|
||||
# kind=MessageKind.PUBLISH,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
|
||||
future = agent.on_message(
|
||||
message_envelope.message,
|
||||
@ -182,9 +228,23 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
||||
recipient_name = message_envelope.recipient.name if message_envelope.recipient is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Resolving response for recipient {recipient_name} from {message_envelope.sender.name} with message type {type(message_envelope.message).__name__}"
|
||||
content = (
|
||||
message_envelope.message.__dict__
|
||||
if hasattr(message_envelope.message, "__dict__")
|
||||
else message_envelope.message
|
||||
)
|
||||
logger.info(
|
||||
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {recipient_name} from {message_envelope.sender.name}: {content}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=message_envelope.recipient,
|
||||
# kind=MessageKind.RESPOND,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
message_envelope.future.set_result(message_envelope.message)
|
||||
|
||||
async def process_next(self) -> None:
|
||||
|
@ -1,6 +1,13 @@
|
||||
from ._events import LLMCallEvent
|
||||
from ._events import DeliveryStage, LLMCallEvent, MessageEvent, MessageKind
|
||||
from ._llm_usage import LLMUsageTracker
|
||||
|
||||
EVENT_LOGGER_NAME = "agnext.events"
|
||||
|
||||
__all__ = ["LLMCallEvent", "EVENT_LOGGER_NAME", "LLMUsageTracker"]
|
||||
__all__ = [
|
||||
"LLMCallEvent",
|
||||
"EVENT_LOGGER_NAME",
|
||||
"LLMUsageTracker",
|
||||
"MessageEvent",
|
||||
"MessageKind",
|
||||
"DeliveryStage",
|
||||
]
|
||||
|
@ -1,6 +1,9 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
|
||||
from ...core import Agent
|
||||
|
||||
|
||||
class LLMCallEvent:
|
||||
def __init__(self, *, prompt_tokens: int, completion_tokens: int, **kwargs: Any) -> None:
|
||||
@ -23,6 +26,50 @@ class LLMCallEvent:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["prompt_tokens"] = prompt_tokens
|
||||
self.kwargs["completion_tokens"] = completion_tokens
|
||||
self.kwargs["type"] = "LLMCall"
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["prompt_tokens"])
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["completion_tokens"])
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class MessageKind(Enum):
|
||||
DIRECT = 1
|
||||
PUBLISH = 2
|
||||
RESPOND = 3
|
||||
|
||||
|
||||
class DeliveryStage(Enum):
|
||||
SEND = 1
|
||||
DELIVER = 2
|
||||
|
||||
|
||||
class MessageEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: Any,
|
||||
sender: Agent | None,
|
||||
receiver: Agent | None,
|
||||
kind: MessageKind,
|
||||
delivery_stage: DeliveryStage,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["sender"] = None if sender is None else sender.name
|
||||
self.kwargs["receiver"] = None if receiver is None else receiver.name
|
||||
self.kwargs["kind"] = kind
|
||||
self.kwargs["delivery_stage"] = delivery_stage
|
||||
self.kwargs["type"] = "Message"
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
|
@ -128,7 +128,10 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
continue
|
||||
# Execute the function.
|
||||
future = self.execute_function(
|
||||
function_call.name, arguments, function_call.id, cancellation_token=cancellation_token
|
||||
function_call.name,
|
||||
arguments,
|
||||
function_call.id,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
# Append the async result.
|
||||
execution_futures.append(future)
|
||||
@ -149,24 +152,22 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
return tool_call_result_msg
|
||||
|
||||
async def execute_function(
|
||||
self, name: str, args: Dict[str, Any], call_id: str, cancellation_token: CancellationToken
|
||||
self,
|
||||
name: str,
|
||||
args: Dict[str, Any],
|
||||
call_id: str,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Tuple[str, str]:
|
||||
# Find tool
|
||||
tool = next((t for t in self._tools if t.name == name), None)
|
||||
if tool is None:
|
||||
raise ValueError(f"Tool {name} not found.")
|
||||
return (f"Error: tool {name} not found.", call_id)
|
||||
try:
|
||||
result = await tool.run_json(args, cancellation_token)
|
||||
result_json_or_str = result.model_dump()
|
||||
if isinstance(result, dict):
|
||||
result_str = json.dumps(result_json_or_str)
|
||||
elif isinstance(result_json_or_str, str):
|
||||
result_str = result_json_or_str
|
||||
else:
|
||||
raise ValueError(f"Unexpected result type: {type(result)}")
|
||||
result_as_str = tool.return_value_as_string(result)
|
||||
except Exception as e:
|
||||
result_str = f"Error: {str(e)}"
|
||||
return (result_str, call_id)
|
||||
result_as_str = f"Error: {str(e)}"
|
||||
return (result_as_str, call_id)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
|
@ -73,17 +73,17 @@ def convert_messages_to_llm_messages(
|
||||
for message in messages:
|
||||
match message:
|
||||
case (
|
||||
TextMessage(_, source=source)
|
||||
| MultiModalMessage(_, source=source)
|
||||
| FunctionCallMessage(_, source=source)
|
||||
TextMessage(content=_, source=source)
|
||||
| MultiModalMessage(content=_, source=source)
|
||||
| FunctionCallMessage(content=_, source=source)
|
||||
) if source == self_name:
|
||||
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable)
|
||||
if converted_message_1 is not None:
|
||||
result.append(converted_message_1)
|
||||
case (
|
||||
TextMessage(_, source=source)
|
||||
| MultiModalMessage(_, source=source)
|
||||
| FunctionCallMessage(_, source=source)
|
||||
TextMessage(content=_, source=source)
|
||||
| MultiModalMessage(content=_, source=source)
|
||||
| FunctionCallMessage(content=_, source=source)
|
||||
) if source != self_name:
|
||||
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
|
||||
if converted_message_2 is not None:
|
||||
|
@ -105,18 +105,18 @@ def message_handler(
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
|
||||
if strict:
|
||||
if type(message) not in target_types:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, cancellation_token)
|
||||
|
||||
if strict:
|
||||
if return_value is not AnyType and type(return_value) not in return_types:
|
||||
if AnyType not in return_types and type(return_value) not in return_types:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
elif return_value is not AnyType:
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
|
||||
return return_value
|
||||
@ -134,7 +134,10 @@ def message_handler(
|
||||
class TypeRoutedAgent(BaseAgent):
|
||||
def __init__(self, name: str, router: AgentRuntime) -> None:
|
||||
# Self is already bound to the handlers
|
||||
self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {}
|
||||
self._handlers: Dict[
|
||||
Type[Any],
|
||||
Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]],
|
||||
] = {}
|
||||
|
||||
for attr in dir(self):
|
||||
if callable(getattr(self, attr, None)):
|
||||
|
@ -1,5 +1,6 @@
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
@ -11,6 +12,7 @@ from typing import (
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
@ -27,6 +29,7 @@ from openai.types.chat import (
|
||||
ChatCompletionUserMessageParam,
|
||||
completion_create_params,
|
||||
)
|
||||
from openai.types.shared_params import FunctionDefinition, FunctionParameters
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
|
||||
@ -205,15 +208,47 @@ def convert_tools(
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
result: List[ChatCompletionToolParam] = []
|
||||
for tool in tools:
|
||||
tool_schema = tool.schema
|
||||
result.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": tool.schema, # type: ignore
|
||||
}
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=tool_schema["name"],
|
||||
description=tool_schema["description"] if "description" in tool_schema else "",
|
||||
parameters=cast(FunctionParameters, tool_schema["parameters"])
|
||||
if "parameters" in tool_schema
|
||||
else {},
|
||||
),
|
||||
)
|
||||
)
|
||||
# Check if all tools have valid names.
|
||||
for tool_param in result:
|
||||
assert_valid_name(tool_param["function"]["name"])
|
||||
return result
|
||||
|
||||
|
||||
def normalize_name(name: str) -> str:
|
||||
"""
|
||||
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
|
||||
|
||||
Prefer _assert_valid_name for validating user configuration or input
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
|
||||
|
||||
|
||||
def assert_valid_name(name: str) -> str:
|
||||
"""
|
||||
Ensure that configured names are valid, raises ValueError if not.
|
||||
|
||||
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
|
||||
"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
||||
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
|
||||
if len(name) > 64:
|
||||
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
|
||||
return name
|
||||
|
||||
|
||||
class BaseOpenAI(ChatCompletionClient):
|
||||
def __init__(
|
||||
self,
|
||||
@ -293,7 +328,10 @@ class BaseOpenAI(ChatCompletionClient):
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
result = await self._client.chat.completions.create(
|
||||
messages=oai_messages, stream=False, tools=converted_tools, **create_args
|
||||
messages=oai_messages,
|
||||
stream=False,
|
||||
tools=converted_tools,
|
||||
**create_args,
|
||||
)
|
||||
else:
|
||||
result = await self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args)
|
||||
@ -331,7 +369,11 @@ class BaseOpenAI(ChatCompletionClient):
|
||||
|
||||
# NOTE: If OAI response type changes, this will need to be updated
|
||||
content = [
|
||||
FunctionCall(id=x.id, arguments=x.function.arguments, name=x.function.name)
|
||||
FunctionCall(
|
||||
id=x.id,
|
||||
arguments=x.function.arguments,
|
||||
name=normalize_name(x.function.name),
|
||||
)
|
||||
for x in choice.message.tool_calls
|
||||
]
|
||||
finish_reason = "function_calls"
|
||||
|
@ -1,14 +1,29 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from ...core import CancellationToken
|
||||
from .._function_utils import normalize_annotated_type
|
||||
|
||||
T = TypeVar("T", bound=BaseModel, contravariant=True)
|
||||
|
||||
|
||||
class ParametersSchema(TypedDict):
|
||||
type: str
|
||||
properties: Dict[str, Any]
|
||||
required: NotRequired[Sequence[str]]
|
||||
|
||||
|
||||
class ToolSchema(TypedDict):
|
||||
parameters: NotRequired[ParametersSchema]
|
||||
name: str
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
class Tool(Protocol):
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@ -17,7 +32,7 @@ class Tool(Protocol):
|
||||
def description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def schema(self) -> Mapping[str, Any]: ...
|
||||
def schema(self) -> ToolSchema: ...
|
||||
|
||||
def args_type(self) -> Type[BaseModel]: ...
|
||||
|
||||
@ -40,20 +55,36 @@ StateT = TypeVar("StateT", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
|
||||
def __init__(self, args_type: Type[ArgsT], return_type: Type[ReturnT], name: str, description: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
name: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
self._args_type = args_type
|
||||
self._return_type = return_type
|
||||
# Normalize Annotated to the base type.
|
||||
self._return_type = normalize_annotated_type(return_type)
|
||||
self._name = name
|
||||
self._description = description
|
||||
|
||||
@property
|
||||
def schema(self) -> Mapping[str, Any]:
|
||||
def schema(self) -> ToolSchema:
|
||||
model_schema = self._args_type.model_json_schema()
|
||||
parameter_schema: Dict[str, Any] = dict()
|
||||
parameter_schema["parameters"] = model_schema["properties"]
|
||||
parameter_schema["name"] = self._name
|
||||
parameter_schema["description"] = self._description
|
||||
return parameter_schema
|
||||
|
||||
tool_schema = ToolSchema(
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
parameters=ParametersSchema(
|
||||
type="object",
|
||||
properties=model_schema["properties"],
|
||||
),
|
||||
)
|
||||
if "required" in model_schema:
|
||||
assert "parameters" in tool_schema
|
||||
tool_schema["parameters"]["required"] = model_schema["required"]
|
||||
|
||||
return tool_schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -97,7 +128,12 @@ class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
|
||||
|
||||
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
|
||||
def __init__(
|
||||
self, args_type: Type[ArgsT], return_type: Type[ReturnT], state_type: Type[StateT], name: str, description: str
|
||||
self,
|
||||
args_type: Type[ArgsT],
|
||||
return_type: Type[ReturnT],
|
||||
state_type: Type[StateT],
|
||||
name: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
super().__init__(args_type, return_type, name, description)
|
||||
self._state_type = state_type
|
||||
|
@ -32,7 +32,12 @@ class FunctionTool(BaseTool[BaseModel, BaseModel]):
|
||||
else:
|
||||
if self._has_cancellation_support:
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(self._func, **args.model_dump(), cancellation_token=cancellation_token)
|
||||
None,
|
||||
functools.partial(
|
||||
self._func,
|
||||
**args.model_dump(),
|
||||
cancellation_token=cancellation_token,
|
||||
),
|
||||
)
|
||||
else:
|
||||
future = asyncio.get_event_loop().run_in_executor(
|
||||
|
@ -36,9 +36,54 @@ def test_tool_schema_generation() -> None:
|
||||
schema = MyTool().schema
|
||||
|
||||
assert schema["name"] == "TestTool"
|
||||
assert "description" in schema
|
||||
assert schema["description"] == "Description of test tool."
|
||||
assert schema["parameters"]["query"]["description"] == "The description."
|
||||
assert len(schema["parameters"]) == 1
|
||||
assert "parameters" in schema
|
||||
assert schema["parameters"]["type"] == "object"
|
||||
assert "properties" in schema["parameters"]
|
||||
assert schema["parameters"]["properties"]["query"]["description"] == "The description."
|
||||
assert schema["parameters"]["properties"]["query"]["type"] == "string"
|
||||
assert "required" in schema["parameters"]
|
||||
assert schema["parameters"]["required"] == ["query"]
|
||||
assert len(schema["parameters"]["properties"]) == 1
|
||||
|
||||
def test_func_tool_schema_generation() -> None:
|
||||
def my_function(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult:
|
||||
return MyResult(result="test")
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
schema = tool.schema
|
||||
|
||||
assert schema["name"] == "my_function"
|
||||
assert "description" in schema
|
||||
assert schema["description"] == "Function tool."
|
||||
assert "parameters" in schema
|
||||
assert schema["parameters"]["type"] == "object"
|
||||
assert schema["parameters"]["properties"].keys() == {"arg", "other", "nonrequired"}
|
||||
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
|
||||
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
|
||||
assert schema["parameters"]["properties"]["other"]["type"] == "integer"
|
||||
assert schema["parameters"]["properties"]["other"]["description"] == "int arg"
|
||||
assert schema["parameters"]["properties"]["nonrequired"]["type"] == "integer"
|
||||
assert schema["parameters"]["properties"]["nonrequired"]["description"] == "nonrequired"
|
||||
assert "required" in schema["parameters"]
|
||||
assert schema["parameters"]["required"] == ["arg", "other"]
|
||||
assert len(schema["parameters"]["properties"]) == 3
|
||||
|
||||
def test_func_tool_schema_generation_only_default_arg() -> None:
|
||||
def my_function(arg: str = "default") -> MyResult:
|
||||
return MyResult(result="test")
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
schema = tool.schema
|
||||
|
||||
assert schema["name"] == "my_function"
|
||||
assert "description" in schema
|
||||
assert schema["description"] == "Function tool."
|
||||
assert "parameters" in schema
|
||||
assert len(schema["parameters"]["properties"]) == 1
|
||||
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
|
||||
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
|
||||
assert "required" not in schema["parameters"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_run()-> None:
|
||||
@ -128,7 +173,7 @@ def test_func_tool_return_annotated()-> None:
|
||||
assert tool.name == "my_function"
|
||||
assert tool.description == "Function tool."
|
||||
assert issubclass(tool.args_type(), BaseModel)
|
||||
assert tool.return_type() == Annotated[str, "test description"]
|
||||
assert tool.return_type() == str
|
||||
assert tool.state_type() is None
|
||||
|
||||
def test_func_tool_no_args()-> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user