mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-03 11:20:35 +00:00
fix: handle non-string function arguments in tool calls and add corresponding warnings (#5260)
This commit is contained in:
parent
aa23093f36
commit
44db2cc1fb
@ -571,14 +571,24 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
# NOTE: If OAI response type changes, this will need to be updated
|
# NOTE: If OAI response type changes, this will need to be updated
|
||||||
content = [
|
content = []
|
||||||
FunctionCall(
|
for tool_call in choice.message.tool_calls:
|
||||||
id=x.id,
|
if not isinstance(tool_call.function.arguments, str):
|
||||||
arguments=x.function.arguments,
|
warnings.warn(
|
||||||
name=normalize_name(x.function.name),
|
f"Tool call function arguments field is not a string: {tool_call.function.arguments}."
|
||||||
|
"This is unexpected and may due to the API used not returning the correct type. "
|
||||||
|
"Attempting to convert it to string.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
if isinstance(tool_call.function.arguments, dict):
|
||||||
|
tool_call.function.arguments = json.dumps(tool_call.function.arguments)
|
||||||
|
content.append(
|
||||||
|
FunctionCall(
|
||||||
|
id=tool_call.id,
|
||||||
|
arguments=tool_call.function.arguments,
|
||||||
|
name=normalize_name(tool_call.function.name),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for x in choice.message.tool_calls
|
|
||||||
]
|
|
||||||
finish_reason = "tool_calls"
|
finish_reason = "tool_calls"
|
||||||
else:
|
else:
|
||||||
finish_reason = choice.finish_reason
|
finish_reason = choice.finish_reason
|
||||||
|
|||||||
@ -619,6 +619,31 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||||
),
|
),
|
||||||
|
# Should raise warning when function arguments is not a string.
|
||||||
|
ChatCompletion(
|
||||||
|
id="id6",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id="1",
|
||||||
|
type="function",
|
||||||
|
function=Function.construct(name="_pass_function", arguments={"input": "task"}), # type: ignore
|
||||||
|
)
|
||||||
|
],
|
||||||
|
role="assistant",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=0,
|
||||||
|
model=model,
|
||||||
|
object="chat.completion",
|
||||||
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
mock = _MockChatCompletion(chat_completions)
|
mock = _MockChatCompletion(chat_completions)
|
||||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||||
@ -676,8 +701,16 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert create_result.content == "I should make a tool call."
|
assert create_result.content == "I should make a tool call."
|
||||||
assert create_result.finish_reason == "stop"
|
assert create_result.finish_reason == "stop"
|
||||||
|
|
||||||
|
# Should raise warning when function arguments is not a string.
|
||||||
|
with pytest.warns(UserWarning, match="Tool call function arguments field is not a string"):
|
||||||
|
create_result = await model_client.create(
|
||||||
|
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
|
||||||
|
)
|
||||||
|
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
||||||
|
assert create_result.finish_reason == "function_calls"
|
||||||
|
|
||||||
async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
|
|
||||||
|
async def _test_model_client_basic_completion(model_client: OpenAIChatCompletionClient) -> None:
|
||||||
# Test basic completion
|
# Test basic completion
|
||||||
create_result = await model_client.create(
|
create_result = await model_client.create(
|
||||||
messages=[
|
messages=[
|
||||||
@ -688,6 +721,8 @@ async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
|
|||||||
assert isinstance(create_result.content, str)
|
assert isinstance(create_result.content, str)
|
||||||
assert len(create_result.content) > 0
|
assert len(create_result.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
async def _test_model_client_with_function_calling(model_client: OpenAIChatCompletionClient) -> None:
|
||||||
# Test tool calling
|
# Test tool calling
|
||||||
pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
|
pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
|
||||||
fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
|
fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
|
||||||
@ -755,7 +790,8 @@ async def test_openai() -> None:
|
|||||||
model="gpt-4o-mini",
|
model="gpt-4o-mini",
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
await _test_model_client(model_client)
|
await _test_model_client_basic_completion(model_client)
|
||||||
|
await _test_model_client_with_function_calling(model_client)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -775,7 +811,29 @@ async def test_gemini() -> None:
|
|||||||
"family": ModelFamily.UNKNOWN,
|
"family": ModelFamily.UNKNOWN,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await _test_model_client(model_client)
|
await _test_model_client_basic_completion(model_client)
|
||||||
|
await _test_model_client_with_function_calling(model_client)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hugging_face() -> None:
|
||||||
|
api_key = os.getenv("HF_TOKEN")
|
||||||
|
if not api_key:
|
||||||
|
pytest.skip("HF_TOKEN not found in environment variables")
|
||||||
|
|
||||||
|
model_client = OpenAIChatCompletionClient(
|
||||||
|
model="microsoft/Phi-3.5-mini-instruct",
|
||||||
|
api_key=api_key,
|
||||||
|
base_url="https://api-inference.huggingface.co/v1/",
|
||||||
|
model_info={
|
||||||
|
"function_calling": False,
|
||||||
|
"json_output": False,
|
||||||
|
"vision": False,
|
||||||
|
"family": ModelFamily.UNKNOWN,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await _test_model_client_basic_completion(model_client)
|
||||||
|
|
||||||
|
|
||||||
# TODO: add integration tests for Azure OpenAI using AAD token.
|
# TODO: add integration tests for Azure OpenAI using AAD token.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user