mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-02 10:50:03 +00:00
fix: handle non-string function arguments in tool calls and add corresponding warnings (#5260)
This commit is contained in:
parent
aa23093f36
commit
44db2cc1fb
@ -571,14 +571,24 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
stacklevel=2,
|
||||
)
|
||||
# NOTE: If OAI response type changes, this will need to be updated
|
||||
content = [
|
||||
FunctionCall(
|
||||
id=x.id,
|
||||
arguments=x.function.arguments,
|
||||
name=normalize_name(x.function.name),
|
||||
content = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if not isinstance(tool_call.function.arguments, str):
|
||||
warnings.warn(
|
||||
f"Tool call function arguments field is not a string: {tool_call.function.arguments}."
|
||||
"This is unexpected and may due to the API used not returning the correct type. "
|
||||
"Attempting to convert it to string.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if isinstance(tool_call.function.arguments, dict):
|
||||
tool_call.function.arguments = json.dumps(tool_call.function.arguments)
|
||||
content.append(
|
||||
FunctionCall(
|
||||
id=tool_call.id,
|
||||
arguments=tool_call.function.arguments,
|
||||
name=normalize_name(tool_call.function.name),
|
||||
)
|
||||
)
|
||||
for x in choice.message.tool_calls
|
||||
]
|
||||
finish_reason = "tool_calls"
|
||||
else:
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
@ -619,6 +619,31 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
# Should raise warning when function arguments is not a string.
|
||||
ChatCompletion(
|
||||
id="id6",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function.construct(name="_pass_function", arguments={"input": "task"}), # type: ignore
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
@ -676,8 +701,16 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert create_result.content == "I should make a tool call."
|
||||
assert create_result.finish_reason == "stop"
|
||||
|
||||
# Should raise warning when function arguments is not a string.
|
||||
with pytest.warns(UserWarning, match="Tool call function arguments field is not a string"):
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
|
||||
)
|
||||
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
|
||||
|
||||
async def _test_model_client_basic_completion(model_client: OpenAIChatCompletionClient) -> None:
|
||||
# Test basic completion
|
||||
create_result = await model_client.create(
|
||||
messages=[
|
||||
@ -688,6 +721,8 @@ async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
|
||||
|
||||
async def _test_model_client_with_function_calling(model_client: OpenAIChatCompletionClient) -> None:
|
||||
# Test tool calling
|
||||
pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
|
||||
fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
|
||||
@ -755,7 +790,8 @@ async def test_openai() -> None:
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
)
|
||||
await _test_model_client(model_client)
|
||||
await _test_model_client_basic_completion(model_client)
|
||||
await _test_model_client_with_function_calling(model_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -775,7 +811,29 @@ async def test_gemini() -> None:
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
},
|
||||
)
|
||||
await _test_model_client(model_client)
|
||||
await _test_model_client_basic_completion(model_client)
|
||||
await _test_model_client_with_function_calling(model_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hugging_face() -> None:
|
||||
api_key = os.getenv("HF_TOKEN")
|
||||
if not api_key:
|
||||
pytest.skip("HF_TOKEN not found in environment variables")
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="microsoft/Phi-3.5-mini-instruct",
|
||||
api_key=api_key,
|
||||
base_url="https://api-inference.huggingface.co/v1/",
|
||||
model_info={
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"vision": False,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
},
|
||||
)
|
||||
|
||||
await _test_model_client_basic_completion(model_client)
|
||||
|
||||
|
||||
# TODO: add integration tests for Azure OpenAI using AAD token.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user