From b4ade8b735bca17f0794ccecb16e3db07905ebbf Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Fri, 7 Jun 2024 13:33:51 -0700 Subject: [PATCH] Add examples to showcase patterns (#55) * add chess example * wip * wip * fix tool schema generation * fixes * Agent handle exception Co-authored-by: Jack Gerrits * format * mypy * fix test for annotated --------- Co-authored-by: Jack Gerrits --- examples/chess_game.py | 119 ++++++++++++++++++ pyproject.toml | 3 + .../_single_threaded_agent_runtime.py | 70 ++++++++++- src/agnext/application/logging/__init__.py | 11 +- src/agnext/application/logging/_events.py | 47 +++++++ .../chat/agents/chat_completion_agent.py | 25 ++-- src/agnext/chat/utils.py | 12 +- src/agnext/components/_type_routed_agent.py | 15 ++- .../components/models/_openai_client.py | 54 +++++++- src/agnext/components/tools/_base.py | 58 +++++++-- src/agnext/components/tools/_function_tool.py | 7 +- tests/test_tools.py | 51 +++++++- 12 files changed, 420 insertions(+), 52 deletions(-) create mode 100644 examples/chess_game.py diff --git a/examples/chess_game.py b/examples/chess_game.py new file mode 100644 index 000000000..119c1abb0 --- /dev/null +++ b/examples/chess_game.py @@ -0,0 +1,119 @@ +"""This is an example of simulating a chess game with two agents +that play against each other, using tools to reason about the game state +and make moves.""" + +import argparse +import asyncio +import logging +from typing import Annotated + +from agnext.application import SingleThreadedAgentRuntime +from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent +from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput +from agnext.chat.types import TextMessage +from agnext.components.models import OpenAI, SystemMessage +from agnext.components.tools import FunctionTool +from agnext.core import AgentRuntime +from chess import SQUARE_NAMES, Board, Move +from chess import piece_name as get_piece_name + +logging.basicConfig(level=logging.WARNING) +logging.getLogger("agnext").setLevel(logging.DEBUG) + + +class ChessGameOutput(GroupChatOutput): # type: ignore + def on_message_received(self, message: TextMessage) -> None: # type: ignore + pass + + def get_output(self) -> None: + pass + + def reset(self) -> None: + pass + + +def chess_game(runtime: AgentRuntime) -> GroupChat: # type: ignore + """Create agents for a chess game and return the group chat.""" + + # Create the board. + board = Board() + + # Create shared tools. + def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format."]: + return "Possible moves are: " + ", ".join([str(move) for move in board.legal_moves]) + + get_legal_moves_tool = FunctionTool(get_legal_moves, description="Get legal moves.") + + def make_move(input: Annotated[str, "A move in UCI format."]) -> Annotated[str, "Result of the move."]: + move = Move.from_uci(input) + board.push(move) + print(board.unicode(borders=True)) + # Get the piece name. + piece = board.piece_at(move.to_square) + assert piece is not None + piece_symbol = piece.unicode_symbol() + piece_name = get_piece_name(piece.piece_type) + if piece_symbol.isupper(): + piece_name = piece_name.capitalize() + return f"Moved {piece_name} ({piece_symbol}) from {SQUARE_NAMES[move.from_square]} to {SQUARE_NAMES[move.to_square]}." + + make_move_tool = FunctionTool(make_move, description="Call this tool to make a move.") + + tools = [get_legal_moves_tool, make_move_tool] + + black = ChatCompletionAgent( + name="PlayerBlack", + description="Player playing black.", + runtime=runtime, + system_messages=[ + SystemMessage( + content="You are a chess player and you play as black. " + "First call get_legal_moves() first, to get list of legal moves. " + "Then call make_move(move) to make a move." + ), + ], + model_client=OpenAI(model="gpt-4-turbo"), + tools=tools, + ) + white = ChatCompletionAgent( + name="PlayerWhite", + description="Player playing white.", + runtime=runtime, + system_messages=[ + SystemMessage( + content="You are a chess player and you play as white. " + "First call get_legal_moves() first, to get list of legal moves. " + "Then call make_move(move) to make a move." + ), + ], + model_client=OpenAI(model="gpt-4-turbo"), + tools=tools, + ) + game_chat = GroupChat( + name="ChessGame", + description="A chess game between two agents.", + runtime=runtime, + agents=[white, black], + num_rounds=10, + output=ChessGameOutput(), + ) + return game_chat + + +async def main(message: str) -> None: + runtime = SingleThreadedAgentRuntime() + game_chat = chess_game(runtime) + future = runtime.send_message(TextMessage(content=message, source="Human"), game_chat) + while not future.done(): + await runtime.process_next() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run a chess game between two agents.") + parser.add_argument( + "--initial-message", + default="Please make a move.", + help="The initial message to send to the agent playing white.", + ) + args = parser.parse_args() + asyncio.run(main(args.initial_message)) diff --git a/pyproject.toml b/pyproject.toml index 0a6dc535a..e046e6dda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,9 @@ dev = [ "pytest-xdist", "types-Pillow", "polars", + # Dependencies for the examples. + "chess", + "tavily-python", ] docs = ["sphinx", "furo", "sphinxcontrib-apidoc", "myst-parser"] diff --git a/src/agnext/application/_single_threaded_agent_runtime.py b/src/agnext/application/_single_threaded_agent_runtime.py index 4f4260905..2ce46df2e 100644 --- a/src/agnext/application/_single_threaded_agent_runtime.py +++ b/src/agnext/application/_single_threaded_agent_runtime.py @@ -10,6 +10,7 @@ from ..core.exceptions import MessageDroppedException from ..core.intervention import DropMessage, InterventionHandler logger = logging.getLogger("agnext") +event_logger = logging.getLogger("agnext.events") @dataclass(kw_only=True) @@ -67,7 +68,9 @@ class SingleThreadedAgentRuntime(AgentRuntime): return list(self._agents) @property - def unprocessed_messages(self) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]: + def unprocessed_messages( + self, + ) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]: return self._message_queue # Returns the response of the message @@ -82,6 +85,18 @@ class SingleThreadedAgentRuntime(AgentRuntime): if cancellation_token is None: cancellation_token = CancellationToken() + logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {message.__dict__}") + + # event_logger.info( + # MessageEvent( + # payload=message, + # sender=sender, + # receiver=recipient, + # kind=MessageKind.DIRECT, + # delivery_stage=DeliveryStage.SEND, + # ) + # ) + future = asyncio.get_event_loop().create_future() if recipient not in self._agents: future.set_exception(Exception("Recipient not found")) @@ -108,6 +123,18 @@ class SingleThreadedAgentRuntime(AgentRuntime): if cancellation_token is None: cancellation_token = CancellationToken() + logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {message.__dict__}") + + # event_logger.info( + # MessageEvent( + # payload=message, + # sender=sender, + # receiver=None, + # kind=MessageKind.PUBLISH, + # delivery_stage=DeliveryStage.SEND, + # ) + # ) + self._message_queue.append( PublishMessageEnvelope( message=message, @@ -137,8 +164,17 @@ class SingleThreadedAgentRuntime(AgentRuntime): try: sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown" logger.info( - f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} from {sender_name}" + f"Calling message handler for {recipient.name} with message type {type(message_envelope.message).__name__} sent by {sender_name}" ) + # event_logger.info( + # MessageEvent( + # payload=message_envelope.message, + # sender=message_envelope.sender, + # receiver=recipient, + # kind=MessageKind.DIRECT, + # delivery_stage=DeliveryStage.DELIVER, + # ) + # ) response = await recipient.on_message( message_envelope.message, cancellation_token=message_envelope.cancellation_token, @@ -162,9 +198,19 @@ class SingleThreadedAgentRuntime(AgentRuntime): if message_envelope.sender is not None and agent.name == message_envelope.sender.name: continue + sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown" logger.info( - f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {message_envelope.sender.name if message_envelope.sender is not None else 'Unknown'}" + f"Calling message handler for {agent.name} with message type {type(message_envelope.message).__name__} published by {sender_name}" ) + # event_logger.info( + # MessageEvent( + # payload=message_envelope.message, + # sender=message_envelope.sender, + # receiver=agent, + # kind=MessageKind.PUBLISH, + # delivery_stage=DeliveryStage.DELIVER, + # ) + # ) future = agent.on_message( message_envelope.message, @@ -182,9 +228,23 @@ class SingleThreadedAgentRuntime(AgentRuntime): async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: recipient_name = message_envelope.recipient.name if message_envelope.recipient is not None else "Unknown" - logger.info( - f"Resolving response for recipient {recipient_name} from {message_envelope.sender.name} with message type {type(message_envelope.message).__name__}" + content = ( + message_envelope.message.__dict__ + if hasattr(message_envelope.message, "__dict__") + else message_envelope.message ) + logger.info( + f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {recipient_name} from {message_envelope.sender.name}: {content}" + ) + # event_logger.info( + # MessageEvent( + # payload=message_envelope.message, + # sender=message_envelope.sender, + # receiver=message_envelope.recipient, + # kind=MessageKind.RESPOND, + # delivery_stage=DeliveryStage.DELIVER, + # ) + # ) message_envelope.future.set_result(message_envelope.message) async def process_next(self) -> None: diff --git a/src/agnext/application/logging/__init__.py b/src/agnext/application/logging/__init__.py index e9940925e..1696599c5 100644 --- a/src/agnext/application/logging/__init__.py +++ b/src/agnext/application/logging/__init__.py @@ -1,6 +1,13 @@ -from ._events import LLMCallEvent +from ._events import DeliveryStage, LLMCallEvent, MessageEvent, MessageKind from ._llm_usage import LLMUsageTracker EVENT_LOGGER_NAME = "agnext.events" -__all__ = ["LLMCallEvent", "EVENT_LOGGER_NAME", "LLMUsageTracker"] +__all__ = [ + "LLMCallEvent", + "EVENT_LOGGER_NAME", + "LLMUsageTracker", + "MessageEvent", + "MessageKind", + "DeliveryStage", +] diff --git a/src/agnext/application/logging/_events.py b/src/agnext/application/logging/_events.py index ae473b1f2..cec85b43a 100644 --- a/src/agnext/application/logging/_events.py +++ b/src/agnext/application/logging/_events.py @@ -1,6 +1,9 @@ import json +from enum import Enum from typing import Any, cast +from ...core import Agent + class LLMCallEvent: def __init__(self, *, prompt_tokens: int, completion_tokens: int, **kwargs: Any) -> None: @@ -23,6 +26,50 @@ class LLMCallEvent: self.kwargs = kwargs self.kwargs["prompt_tokens"] = prompt_tokens self.kwargs["completion_tokens"] = completion_tokens + self.kwargs["type"] = "LLMCall" + + @property + def prompt_tokens(self) -> int: + return cast(int, self.kwargs["prompt_tokens"]) + + @property + def completion_tokens(self) -> int: + return cast(int, self.kwargs["completion_tokens"]) + + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) + + +class MessageKind(Enum): + DIRECT = 1 + PUBLISH = 2 + RESPOND = 3 + + +class DeliveryStage(Enum): + SEND = 1 + DELIVER = 2 + + +class MessageEvent: + def __init__( + self, + *, + payload: Any, + sender: Agent | None, + receiver: Agent | None, + kind: MessageKind, + delivery_stage: DeliveryStage, + **kwargs: Any, + ) -> None: + self.kwargs = kwargs + self.kwargs["payload"] = payload + self.kwargs["sender"] = None if sender is None else sender.name + self.kwargs["receiver"] = None if receiver is None else receiver.name + self.kwargs["kind"] = kind + self.kwargs["delivery_stage"] = delivery_stage + self.kwargs["type"] = "Message" @property def prompt_tokens(self) -> int: diff --git a/src/agnext/chat/agents/chat_completion_agent.py b/src/agnext/chat/agents/chat_completion_agent.py index 22ae7c754..c6a1fcda3 100644 --- a/src/agnext/chat/agents/chat_completion_agent.py +++ b/src/agnext/chat/agents/chat_completion_agent.py @@ -128,7 +128,10 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent): continue # Execute the function. future = self.execute_function( - function_call.name, arguments, function_call.id, cancellation_token=cancellation_token + function_call.name, + arguments, + function_call.id, + cancellation_token=cancellation_token, ) # Append the async result. execution_futures.append(future) @@ -149,24 +152,22 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent): return tool_call_result_msg async def execute_function( - self, name: str, args: Dict[str, Any], call_id: str, cancellation_token: CancellationToken + self, + name: str, + args: Dict[str, Any], + call_id: str, + cancellation_token: CancellationToken, ) -> Tuple[str, str]: # Find tool tool = next((t for t in self._tools if t.name == name), None) if tool is None: - raise ValueError(f"Tool {name} not found.") + return (f"Error: tool {name} not found.", call_id) try: result = await tool.run_json(args, cancellation_token) - result_json_or_str = result.model_dump() - if isinstance(result, dict): - result_str = json.dumps(result_json_or_str) - elif isinstance(result_json_or_str, str): - result_str = result_json_or_str - else: - raise ValueError(f"Unexpected result type: {type(result)}") + result_as_str = tool.return_value_as_string(result) except Exception as e: - result_str = f"Error: {str(e)}" - return (result_str, call_id) + result_as_str = f"Error: {str(e)}" + return (result_as_str, call_id) def save_state(self) -> Mapping[str, Any]: return { diff --git a/src/agnext/chat/utils.py b/src/agnext/chat/utils.py index 19cba873b..cfee683a7 100644 --- a/src/agnext/chat/utils.py +++ b/src/agnext/chat/utils.py @@ -73,17 +73,17 @@ def convert_messages_to_llm_messages( for message in messages: match message: case ( - TextMessage(_, source=source) - | MultiModalMessage(_, source=source) - | FunctionCallMessage(_, source=source) + TextMessage(content=_, source=source) + | MultiModalMessage(content=_, source=source) + | FunctionCallMessage(content=_, source=source) ) if source == self_name: converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable) if converted_message_1 is not None: result.append(converted_message_1) case ( - TextMessage(_, source=source) - | MultiModalMessage(_, source=source) - | FunctionCallMessage(_, source=source) + TextMessage(content=_, source=source) + | MultiModalMessage(content=_, source=source) + | FunctionCallMessage(content=_, source=source) ) if source != self_name: converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable) if converted_message_2 is not None: diff --git a/src/agnext/components/_type_routed_agent.py b/src/agnext/components/_type_routed_agent.py index 74fc0360c..88bbdb35f 100644 --- a/src/agnext/components/_type_routed_agent.py +++ b/src/agnext/components/_type_routed_agent.py @@ -105,18 +105,18 @@ def message_handler( @wraps(func) async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: - if strict: - if type(message) not in target_types: + if type(message) not in target_types: + if strict: raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") else: logger.warning(f"Message type {type(message)} not in target types {target_types}") return_value = await func(self, message, cancellation_token) - if strict: - if return_value is not AnyType and type(return_value) not in return_types: + if AnyType not in return_types and type(return_value) not in return_types: + if strict: raise ValueError(f"Return type {type(return_value)} not in return types {return_types}") - elif return_value is not AnyType: + else: logger.warning(f"Return type {type(return_value)} not in return types {return_types}") return return_value @@ -134,7 +134,10 @@ def message_handler( class TypeRoutedAgent(BaseAgent): def __init__(self, name: str, router: AgentRuntime) -> None: # Self is already bound to the handlers - self._handlers: Dict[Type[Any], Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]]] = {} + self._handlers: Dict[ + Type[Any], + Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]], + ] = {} for attr in dir(self): if callable(getattr(self, attr, None)): diff --git a/src/agnext/components/models/_openai_client.py b/src/agnext/components/models/_openai_client.py index 101aa8bdf..34cd5e933 100644 --- a/src/agnext/components/models/_openai_client.py +++ b/src/agnext/components/models/_openai_client.py @@ -1,5 +1,6 @@ import inspect import logging +import re import warnings from typing import ( Any, @@ -11,6 +12,7 @@ from typing import ( Sequence, Set, Union, + cast, ) from openai import AsyncAzureOpenAI, AsyncOpenAI @@ -27,6 +29,7 @@ from openai.types.chat import ( ChatCompletionUserMessageParam, completion_create_params, ) +from openai.types.shared_params import FunctionDefinition, FunctionParameters from typing_extensions import Unpack from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent @@ -205,15 +208,47 @@ def convert_tools( ) -> List[ChatCompletionToolParam]: result: List[ChatCompletionToolParam] = [] for tool in tools: + tool_schema = tool.schema result.append( - { - "type": "function", - "function": tool.schema, # type: ignore - } + ChatCompletionToolParam( + type="function", + function=FunctionDefinition( + name=tool_schema["name"], + description=tool_schema["description"] if "description" in tool_schema else "", + parameters=cast(FunctionParameters, tool_schema["parameters"]) + if "parameters" in tool_schema + else {}, + ), + ) ) + # Check if all tools have valid names. + for tool_param in result: + assert_valid_name(tool_param["function"]["name"]) return result +def normalize_name(name: str) -> str: + """ + LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_". + + Prefer _assert_valid_name for validating user configuration or input + """ + return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64] + + +def assert_valid_name(name: str) -> str: + """ + Ensure that configured names are valid, raises ValueError if not. + + For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API. + """ + if not re.match(r"^[a-zA-Z0-9_-]+$", name): + raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.") + if len(name) > 64: + raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.") + return name + + class BaseOpenAI(ChatCompletionClient): def __init__( self, @@ -293,7 +328,10 @@ class BaseOpenAI(ChatCompletionClient): if len(tools) > 0: converted_tools = convert_tools(tools) result = await self._client.chat.completions.create( - messages=oai_messages, stream=False, tools=converted_tools, **create_args + messages=oai_messages, + stream=False, + tools=converted_tools, + **create_args, ) else: result = await self._client.chat.completions.create(messages=oai_messages, stream=False, **create_args) @@ -331,7 +369,11 @@ class BaseOpenAI(ChatCompletionClient): # NOTE: If OAI response type changes, this will need to be updated content = [ - FunctionCall(id=x.id, arguments=x.function.arguments, name=x.function.name) + FunctionCall( + id=x.id, + arguments=x.function.arguments, + name=normalize_name(x.function.name), + ) for x in choice.message.tool_calls ] finish_reason = "function_calls" diff --git a/src/agnext/components/tools/_base.py b/src/agnext/components/tools/_base.py index ff922c51b..36d7c83c4 100644 --- a/src/agnext/components/tools/_base.py +++ b/src/agnext/components/tools/_base.py @@ -1,14 +1,29 @@ import json from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar +from collections.abc import Sequence +from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar from pydantic import BaseModel +from typing_extensions import NotRequired from ...core import CancellationToken +from .._function_utils import normalize_annotated_type T = TypeVar("T", bound=BaseModel, contravariant=True) +class ParametersSchema(TypedDict): + type: str + properties: Dict[str, Any] + required: NotRequired[Sequence[str]] + + +class ToolSchema(TypedDict): + parameters: NotRequired[ParametersSchema] + name: str + description: NotRequired[str] + + class Tool(Protocol): @property def name(self) -> str: ... @@ -17,7 +32,7 @@ class Tool(Protocol): def description(self) -> str: ... @property - def schema(self) -> Mapping[str, Any]: ... + def schema(self) -> ToolSchema: ... def args_type(self) -> Type[BaseModel]: ... @@ -40,20 +55,36 @@ StateT = TypeVar("StateT", bound=BaseModel) class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]): - def __init__(self, args_type: Type[ArgsT], return_type: Type[ReturnT], name: str, description: str) -> None: + def __init__( + self, + args_type: Type[ArgsT], + return_type: Type[ReturnT], + name: str, + description: str, + ) -> None: self._args_type = args_type - self._return_type = return_type + # Normalize Annotated to the base type. + self._return_type = normalize_annotated_type(return_type) self._name = name self._description = description @property - def schema(self) -> Mapping[str, Any]: + def schema(self) -> ToolSchema: model_schema = self._args_type.model_json_schema() - parameter_schema: Dict[str, Any] = dict() - parameter_schema["parameters"] = model_schema["properties"] - parameter_schema["name"] = self._name - parameter_schema["description"] = self._description - return parameter_schema + + tool_schema = ToolSchema( + name=self._name, + description=self._description, + parameters=ParametersSchema( + type="object", + properties=model_schema["properties"], + ), + ) + if "required" in model_schema: + assert "parameters" in tool_schema + tool_schema["parameters"]["required"] = model_schema["required"] + + return tool_schema @property def name(self) -> str: @@ -97,7 +128,12 @@ class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]): class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]): def __init__( - self, args_type: Type[ArgsT], return_type: Type[ReturnT], state_type: Type[StateT], name: str, description: str + self, + args_type: Type[ArgsT], + return_type: Type[ReturnT], + state_type: Type[StateT], + name: str, + description: str, ) -> None: super().__init__(args_type, return_type, name, description) self._state_type = state_type diff --git a/src/agnext/components/tools/_function_tool.py b/src/agnext/components/tools/_function_tool.py index caf4575de..f0536b9bb 100644 --- a/src/agnext/components/tools/_function_tool.py +++ b/src/agnext/components/tools/_function_tool.py @@ -32,7 +32,12 @@ class FunctionTool(BaseTool[BaseModel, BaseModel]): else: if self._has_cancellation_support: result = await asyncio.get_event_loop().run_in_executor( - None, functools.partial(self._func, **args.model_dump(), cancellation_token=cancellation_token) + None, + functools.partial( + self._func, + **args.model_dump(), + cancellation_token=cancellation_token, + ), ) else: future = asyncio.get_event_loop().run_in_executor( diff --git a/tests/test_tools.py b/tests/test_tools.py index 6f7c76309..c5e52f19a 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -36,9 +36,54 @@ def test_tool_schema_generation() -> None: schema = MyTool().schema assert schema["name"] == "TestTool" + assert "description" in schema assert schema["description"] == "Description of test tool." - assert schema["parameters"]["query"]["description"] == "The description." - assert len(schema["parameters"]) == 1 + assert "parameters" in schema + assert schema["parameters"]["type"] == "object" + assert "properties" in schema["parameters"] + assert schema["parameters"]["properties"]["query"]["description"] == "The description." + assert schema["parameters"]["properties"]["query"]["type"] == "string" + assert "required" in schema["parameters"] + assert schema["parameters"]["required"] == ["query"] + assert len(schema["parameters"]["properties"]) == 1 + +def test_func_tool_schema_generation() -> None: + def my_function(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult: + return MyResult(result="test") + tool = FunctionTool(my_function, description="Function tool.") + schema = tool.schema + + assert schema["name"] == "my_function" + assert "description" in schema + assert schema["description"] == "Function tool." + assert "parameters" in schema + assert schema["parameters"]["type"] == "object" + assert schema["parameters"]["properties"].keys() == {"arg", "other", "nonrequired"} + assert schema["parameters"]["properties"]["arg"]["type"] == "string" + assert schema["parameters"]["properties"]["arg"]["description"] == "arg" + assert schema["parameters"]["properties"]["other"]["type"] == "integer" + assert schema["parameters"]["properties"]["other"]["description"] == "int arg" + assert schema["parameters"]["properties"]["nonrequired"]["type"] == "integer" + assert schema["parameters"]["properties"]["nonrequired"]["description"] == "nonrequired" + assert "required" in schema["parameters"] + assert schema["parameters"]["required"] == ["arg", "other"] + assert len(schema["parameters"]["properties"]) == 3 + +def test_func_tool_schema_generation_only_default_arg() -> None: + def my_function(arg: str = "default") -> MyResult: + return MyResult(result="test") + tool = FunctionTool(my_function, description="Function tool.") + schema = tool.schema + + assert schema["name"] == "my_function" + assert "description" in schema + assert schema["description"] == "Function tool." + assert "parameters" in schema + assert len(schema["parameters"]["properties"]) == 1 + assert schema["parameters"]["properties"]["arg"]["type"] == "string" + assert schema["parameters"]["properties"]["arg"]["description"] == "arg" + assert "required" not in schema["parameters"] + @pytest.mark.asyncio async def test_tool_run()-> None: @@ -128,7 +173,7 @@ def test_func_tool_return_annotated()-> None: assert tool.name == "my_function" assert tool.description == "Function tool." assert issubclass(tool.args_type(), BaseModel) - assert tool.return_type() == Annotated[str, "test description"] + assert tool.return_type() == str assert tool.state_type() is None def test_func_tool_no_args()-> None: