Muhammad Junaid 0b9a622b56
Fix: Auto-Convert Pydantic and Dataclass Arguments in AutoGen Tool Calls (#5737)
AutoGen was passing raw dictionaries to functions instead of
constructing Pydantic model or dataclass instances. If a tool function’s
parameter was a Pydantic BaseModel or a dataclass, the function would
receive a dict and likely throw an error or behave incorrectly (since it
expected an object of that type).

This PR addresses problem in AutoGen where tool functions expecting
structured inputs (Pydantic models or dataclasses) were receiving raw
dictionaries. It ensures that structured inputs are automatically
validated and instantiated before function calls. Complete details are
in Issue #5736

[Reproducible Example Code - Failing
Case](https://colab.research.google.com/drive/1hgoP-cGdSZ1-OqQLpwYmlmcExgftDqlO?usp=sharing)
 
<!-- Please give a short summary of the change and the problem this
solves. -->
## Changes Made:
- Inspect function signatures for Pydantic BaseModel and dataclass
annotations.
- Convert input dictionaries into properly instantiated objects using
BaseModel.model_validate() for Pydantic models or standard instantiation
for dataclasses.
  - Raise descriptive errors when validation or instantiation fails.
  - Unit tests have been added to cover all scenarios

Now structured inputs are automatically validated and instantiated
before function calls.

- **Updated Conversion Logic:**  
In the `run()` method, we now inspect the function’s signature and
convert input dictionaries to structured objects. For parameters
annotated with a Pydantic model, we use `model_validate()` to create an
instance; for those annotated with a dataclass, we instantiate the
object using the dataclass constructor. For example:

  ```python
  # Get the function signature.
  sig = inspect.signature(self._func)
  raw_kwargs = args.model_dump()
  kwargs = {}

  # Iterate over the parameters expected by the function.
  for name, param in sig.parameters.items():
      if name in raw_kwargs:
          expected_type = param.annotation
          value = raw_kwargs[name]
# If expected type is a subclass of BaseModel, perform conversion.
if inspect.isclass(expected_type) and issubclass(expected_type,
BaseModel):
              try:
                  kwargs[name] = expected_type.model_validate(value)
              except ValidationError as e:
                  raise ValueError(
f"Error validating parameter '{name}' for function
'{self._func.__name__}': {e}"
                  ) from e
          # If it's a dataclass, instantiate it.
          elif is_dataclass(expected_type):
              try:
cls = expected_type if isinstance(expected_type, type) else
type(expected_type)
                  kwargs[name] = cls(**value)
              except Exception as e:
                  raise ValueError(
f"Error instantiating dataclass parameter '{name}' for function
'{self._func.__name__}': {e}"
                  ) from e
          else:
              kwargs[name] = value
  ```

- **Error Handling Improvements:**  
Conversion steps are wrapped in try/except blocks to raise descriptive
errors when instantiation fails, aiding in debugging invalid inputs.

- **Testing:**  
Unit tests have been added to simulate tool calls (e.g., an `add` tool)
to ensure that with input like:
  ```json
  {"input": {"x": 2, "y": 3}}
  ```
The tool function receives an instance of the expected type and returns
the correct result.


## Related issue number
Closes #5736
 
## Checks
- [x] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.
2025-03-03 16:35:27 -08:00

592 lines
21 KiB
Python

import inspect
from dataclasses import dataclass
from functools import partial
from typing import Annotated, List
import pytest
from autogen_core import CancellationToken
from autogen_core._function_utils import get_typed_signature
from autogen_core.tools import BaseTool, FunctionTool
from autogen_core.tools._base import ToolSchema
from pydantic import BaseModel, Field, ValidationError, model_serializer
from pydantic_core import PydanticUndefined
class MyArgs(BaseModel):
query: str = Field(description="The description.")
class MyNestedArgs(BaseModel):
arg: MyArgs = Field(description="The nested description.")
class MyResult(BaseModel):
result: str = Field(description="The other description.")
class MyTool(BaseTool[MyArgs, MyResult]):
def __init__(self) -> None:
super().__init__(
args_type=MyArgs,
return_type=MyResult,
name="TestTool",
description="Description of test tool.",
)
self.called_count = 0
async def run(self, args: MyArgs, cancellation_token: CancellationToken) -> MyResult:
self.called_count += 1
return MyResult(result="value")
class MyNestedTool(BaseTool[MyNestedArgs, MyResult]):
def __init__(self) -> None:
super().__init__(
args_type=MyNestedArgs,
return_type=MyResult,
name="TestNestedTool",
description="Description of test nested tool.",
)
self.called_count = 0
async def run(self, args: MyNestedArgs, cancellation_token: CancellationToken) -> MyResult:
self.called_count += 1
return MyResult(result="value")
def test_tool_schema_generation() -> None:
schema = MyTool().schema
assert schema["name"] == "TestTool"
assert "description" in schema
assert schema["description"] == "Description of test tool."
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert "properties" in schema["parameters"]
assert schema["parameters"]["properties"]["query"]["description"] == "The description."
assert schema["parameters"]["properties"]["query"]["type"] == "string"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["query"]
assert len(schema["parameters"]["properties"]) == 1
def test_func_tool_schema_generation() -> None:
def my_function(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult:
return MyResult(result="test")
tool = FunctionTool(my_function, description="Function tool.")
schema = tool.schema
assert schema["name"] == "my_function"
assert "description" in schema
assert schema["description"] == "Function tool."
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"].keys() == {"arg", "other", "nonrequired"}
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
assert schema["parameters"]["properties"]["other"]["type"] == "integer"
assert schema["parameters"]["properties"]["other"]["description"] == "int arg"
assert schema["parameters"]["properties"]["nonrequired"]["type"] == "integer"
assert schema["parameters"]["properties"]["nonrequired"]["description"] == "nonrequired"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["arg", "other"]
assert len(schema["parameters"]["properties"]) == 3
def test_func_tool_schema_generation_strict() -> None:
def my_function1(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult:
return MyResult(result="test")
with pytest.raises(ValueError, match="Strict mode is enabled"):
tool = FunctionTool(my_function1, description="Function tool.", strict=True)
schema = tool.schema
def my_function2(arg: str, other: Annotated[int, "int arg"]) -> MyResult:
return MyResult(result="test")
tool = FunctionTool(my_function2, description="Function tool.", strict=True)
schema = tool.schema
assert schema["name"] == "my_function2"
assert "description" in schema
assert schema["description"] == "Function tool."
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"].keys() == {"arg", "other"}
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
assert schema["parameters"]["properties"]["other"]["type"] == "integer"
assert schema["parameters"]["properties"]["other"]["description"] == "int arg"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["arg", "other"]
assert len(schema["parameters"]["properties"]) == 2
assert "additionalProperties" in schema["parameters"]
assert schema["parameters"]["additionalProperties"] is False
def test_func_tool_schema_generation_only_default_arg() -> None:
def my_function(arg: str = "default") -> MyResult:
return MyResult(result="test")
tool = FunctionTool(my_function, description="Function tool.")
schema = tool.schema
assert schema["name"] == "my_function"
assert "description" in schema
assert schema["description"] == "Function tool."
assert "parameters" in schema
assert len(schema["parameters"]["properties"]) == 1
assert schema["parameters"]["properties"]["arg"]["type"] == "string"
assert schema["parameters"]["properties"]["arg"]["description"] == "arg"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == []
def test_func_tool_schema_generation_only_default_arg_strict() -> None:
def my_function(arg: str = "default") -> MyResult:
return MyResult(result="test")
with pytest.raises(ValueError, match="Strict mode is enabled"):
tool = FunctionTool(my_function, description="Function tool.", strict=True)
_ = tool.schema
def test_func_tool_with_partial_positional_arguments_schema_generation() -> None:
"""Test correct schema generation for a partial function with positional arguments."""
def get_weather(country: str, city: str) -> str:
return f"The temperature in {city}, {country} is 75°"
partial_function = partial(get_weather, "Germany")
tool = FunctionTool(partial_function, description="Partial function tool.")
schema = tool.schema
assert schema["name"] == "get_weather"
assert "description" in schema
assert schema["description"] == "Partial function tool."
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"].keys() == {"city"}
assert schema["parameters"]["properties"]["city"]["type"] == "string"
assert schema["parameters"]["properties"]["city"]["description"] == "city"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["city"]
assert "country" not in schema["parameters"]["properties"] # check country not in schema params
assert len(schema["parameters"]["properties"]) == 1
def test_func_call_tool_with_kwargs_schema_generation() -> None:
"""Test correct schema generation for a partial function with kwargs."""
def get_weather(country: str, city: str) -> str:
return f"The temperature in {city}, {country} is 75°"
partial_function = partial(get_weather, country="Germany")
tool = FunctionTool(partial_function, description="Partial function tool.")
schema = tool.schema
assert schema["name"] == "get_weather"
assert "description" in schema
assert schema["description"] == "Partial function tool."
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"].keys() == {"country", "city"}
assert schema["parameters"]["properties"]["city"]["type"] == "string"
assert schema["parameters"]["properties"]["country"]["type"] == "string"
assert "required" in schema["parameters"]
assert schema["parameters"]["required"] == ["city"] # only city is required
assert len(schema["parameters"]["properties"]) == 2
@pytest.mark.asyncio
async def test_run_func_call_tool_with_kwargs_and_args() -> None:
"""Test run partial function with kwargs and args."""
def get_weather(country: str, city: str, unit: str = "Celsius") -> str:
return f"The temperature in {city}, {country} is 75° {unit}"
partial_function = partial(get_weather, "Germany", unit="Fahrenheit")
tool = FunctionTool(partial_function, description="Partial function tool.")
result = await tool.run_json({"city": "Berlin"}, CancellationToken())
assert isinstance(result, str)
assert result == "The temperature in Berlin, Germany is 75° Fahrenheit"
@pytest.mark.asyncio
async def test_tool_run() -> None:
tool = MyTool()
result = await tool.run_json({"query": "test"}, CancellationToken())
assert isinstance(result, MyResult)
assert result.result == "value"
assert tool.called_count == 1
result = await tool.run_json({"query": "test"}, CancellationToken())
result = await tool.run_json({"query": "test"}, CancellationToken())
assert tool.called_count == 3
def test_tool_properties() -> None:
tool = MyTool()
assert tool.name == "TestTool"
assert tool.description == "Description of test tool."
assert tool.args_type() == MyArgs
assert tool.return_type() == MyResult
assert tool.state_type() is None
def test_get_typed_signature() -> None:
def my_function() -> str:
return "result"
sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 0
assert sig.return_annotation is str
def test_get_typed_signature_annotated() -> None:
def my_function() -> Annotated[str, "The return type"]:
return "result"
sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 0
assert sig.return_annotation == Annotated[str, "The return type"]
def test_get_typed_signature_string() -> None:
def my_function() -> "str":
return "result"
sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 0
assert sig.return_annotation is str
def test_get_typed_signature_params() -> None:
def my_function(arg: str) -> None:
return None
sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert sig.return_annotation is type(None)
assert len(sig.parameters) == 1
assert sig.parameters["arg"].annotation is str
def test_get_typed_signature_two_params() -> None:
def my_function(arg: str, arg2: int) -> None:
return None
sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 2
assert sig.parameters["arg"].annotation is str
assert sig.parameters["arg2"].annotation is int
def test_get_typed_signature_param_str() -> None:
def my_function(arg: "str") -> None:
return None
sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 1
assert sig.parameters["arg"].annotation is str
def test_get_typed_signature_param_annotated() -> None:
def my_function(arg: Annotated[str, "An arg"]) -> None:
return None
sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 1
assert sig.parameters["arg"].annotation == Annotated[str, "An arg"]
def test_func_tool() -> None:
def my_function() -> str:
return "result"
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert issubclass(tool.return_type(), str)
assert tool.state_type() is None
def test_func_tool_annotated_arg() -> None:
def my_function(my_arg: Annotated[str, "test description"]) -> str:
return "result"
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert issubclass(tool.return_type(), str)
assert tool.args_type().model_fields["my_arg"].description == "test description"
assert tool.args_type().model_fields["my_arg"].annotation is str
assert tool.args_type().model_fields["my_arg"].is_required() is True
assert tool.args_type().model_fields["my_arg"].default is PydanticUndefined
assert len(tool.args_type().model_fields) == 1
assert tool.return_type() is str
assert tool.state_type() is None
def test_func_tool_return_annotated() -> None:
def my_function() -> Annotated[str, "test description"]:
return "result"
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert tool.return_type() is str
assert tool.state_type() is None
def test_func_tool_no_args() -> None:
def my_function() -> str:
return "result"
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert len(tool.args_type().model_fields) == 0
assert tool.return_type() is str
assert tool.state_type() is None
def test_func_tool_return_none() -> None:
def my_function() -> None:
return None
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert tool.return_type() is type(None)
assert tool.state_type() is None
def test_func_tool_return_base_model() -> None:
def my_function() -> MyResult:
return MyResult(result="value")
tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert tool.return_type() is MyResult
assert tool.state_type() is None
@pytest.mark.asyncio
async def test_func_call_tool() -> None:
def my_function() -> str:
return "result"
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({}, CancellationToken())
assert result == "result"
@pytest.mark.asyncio
async def test_func_call_tool_base_model() -> None:
def my_function() -> MyResult:
return MyResult(result="value")
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({}, CancellationToken())
assert isinstance(result, MyResult)
assert result.result == "value"
@pytest.mark.asyncio
async def test_func_call_tool_with_arg_base_model() -> None:
def my_function(arg: str) -> MyResult:
return MyResult(result="value")
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert isinstance(result, MyResult)
assert result.result == "value"
@pytest.mark.asyncio
async def test_func_str_res() -> None:
def my_function(arg: str) -> str:
return "test"
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert tool.return_value_as_string(result) == "test"
@pytest.mark.asyncio
async def test_func_base_model_res() -> None:
def my_function(arg: str) -> MyResult:
return MyResult(result="test")
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert tool.return_value_as_string(result) == '{"result": "test"}'
@pytest.mark.asyncio
async def test_func_base_model_custom_dump_res() -> None:
class MyResultCustomDump(BaseModel):
result: str = Field(description="The other description.")
@model_serializer
def ser_model(self) -> str:
return "custom: " + self.result
def my_function(arg: str) -> MyResultCustomDump:
return MyResultCustomDump(result="test")
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert tool.return_value_as_string(result) == "custom: test"
@pytest.mark.asyncio
async def test_func_int_res() -> None:
def my_function(arg: int) -> int:
return arg
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": 5}, CancellationToken())
assert tool.return_value_as_string(result) == "5"
@pytest.mark.asyncio
async def test_func_tool_return_list() -> None:
def my_function() -> List[int]:
return [1, 2]
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({}, CancellationToken())
assert isinstance(result, list)
assert result == [1, 2]
assert tool.return_value_as_string(result) == "[1, 2]"
def test_nested_tool_schema_generation() -> None:
schema: ToolSchema = MyNestedTool().schema
assert "description" in schema
assert "parameters" in schema
assert "type" in schema["parameters"]
assert "arg" in schema["parameters"]["properties"]
assert "type" in schema["parameters"]["properties"]["arg"]
assert "title" in schema["parameters"]["properties"]["arg"]
assert "properties" in schema["parameters"]["properties"]["arg"]
assert "query" in schema["parameters"]["properties"]["arg"]["properties"]
assert "type" in schema["parameters"]["properties"]["arg"]["properties"]["query"]
assert "description" in schema["parameters"]["properties"]["arg"]["properties"]["query"]
assert "required" in schema["parameters"]
assert schema["description"] == "Description of test nested tool."
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"]["arg"]["type"] == "object"
assert schema["parameters"]["properties"]["arg"]["title"] == "MyArgs"
assert schema["parameters"]["properties"]["arg"]["properties"]["query"]["type"] == "string"
assert schema["parameters"]["properties"]["arg"]["properties"]["query"]["description"] == "The description."
assert schema["parameters"]["properties"]["arg"]["required"] == ["query"]
assert schema["parameters"]["required"] == ["arg"]
assert len(schema["parameters"]["properties"]) == 1
@pytest.mark.asyncio
async def test_nested_tool_run() -> None:
tool = MyNestedTool()
result = await tool.run_json({"arg": {"query": "test"}}, CancellationToken())
assert isinstance(result, MyResult)
assert result.result == "value"
assert tool.called_count == 1
result = await tool.run_json({"arg": {"query": "test"}}, CancellationToken())
result = await tool.run_json({"arg": {"query": "test"}}, CancellationToken())
assert tool.called_count == 3
def test_nested_tool_properties() -> None:
tool = MyNestedTool()
assert tool.name == "TestNestedTool"
assert tool.description == "Description of test nested tool."
assert tool.args_type() == MyNestedArgs
assert tool.return_type() == MyResult
assert tool.state_type() is None
# --- Define a sample Pydantic model and tool function ---
class AddInput(BaseModel):
x: int
y: int
def add_tool(input: AddInput) -> int:
return input.x + input.y
@pytest.mark.asyncio
async def test_func_tool_with_pydantic_model_conversion_success() -> None:
tool = FunctionTool(add_tool, description="Tool to add two numbers.")
test_input = {"input": {"x": 2, "y": 3}}
result = await tool.run_json(test_input, CancellationToken())
assert result == 5
assert tool.return_value_as_string(result) == "5"
@pytest.mark.asyncio
async def test_func_tool_with_pydantic_model_conversion_failure() -> None:
tool = FunctionTool(add_tool, description="Tool to add two numbers.")
test_input = {"input": {"x": 2}}
with pytest.raises(ValidationError, match="Field required"):
await tool.run_json(test_input, CancellationToken())
# --- Additional test using a dataclass ---
@dataclass
class MultiplyInput:
a: int
b: int
def multiply_tool(input: MultiplyInput) -> int:
return input.a * input.b
@pytest.mark.asyncio
async def test_func_tool_with_dataclass_conversion_success() -> None:
tool = FunctionTool(multiply_tool, description="Tool to multiply two numbers.")
test_input = {"input": {"a": 4, "b": 5}}
result = await tool.run_json(test_input, CancellationToken())
assert result == 20
assert tool.return_value_as_string(result) == "20"
@pytest.mark.asyncio
async def test_func_tool_with_dataclass_conversion_failure() -> None:
tool = FunctionTool(multiply_tool, description="Tool to multiply two numbers.")
# Missing field 'b'
test_input = {"input": {"a": 4}}
with pytest.raises(ValidationError, match="Field required"):
await tool.run_json(test_input, CancellationToken())