mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-11-04 03:39:52 +00:00 
			
		
		
		
	fix: handle non-string function arguments in tool calls and add corresponding warnings (#5260)
This commit is contained in:
		
							parent
							
								
									aa23093f36
								
							
						
					
					
						commit
						44db2cc1fb
					
				@ -571,14 +571,24 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
 | 
				
			|||||||
                    stacklevel=2,
 | 
					                    stacklevel=2,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            # NOTE: If OAI response type changes, this will need to be updated
 | 
					            # NOTE: If OAI response type changes, this will need to be updated
 | 
				
			||||||
            content = [
 | 
					            content = []
 | 
				
			||||||
                FunctionCall(
 | 
					            for tool_call in choice.message.tool_calls:
 | 
				
			||||||
                    id=x.id,
 | 
					                if not isinstance(tool_call.function.arguments, str):
 | 
				
			||||||
                    arguments=x.function.arguments,
 | 
					                    warnings.warn(
 | 
				
			||||||
                    name=normalize_name(x.function.name),
 | 
					                        f"Tool call function arguments field is not a string: {tool_call.function.arguments}."
 | 
				
			||||||
 | 
					                        "This is unexpected and may due to the API used not returning the correct type. "
 | 
				
			||||||
 | 
					                        "Attempting to convert it to string.",
 | 
				
			||||||
 | 
					                        stacklevel=2,
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    if isinstance(tool_call.function.arguments, dict):
 | 
				
			||||||
 | 
					                        tool_call.function.arguments = json.dumps(tool_call.function.arguments)
 | 
				
			||||||
 | 
					                content.append(
 | 
				
			||||||
 | 
					                    FunctionCall(
 | 
				
			||||||
 | 
					                        id=tool_call.id,
 | 
				
			||||||
 | 
					                        arguments=tool_call.function.arguments,
 | 
				
			||||||
 | 
					                        name=normalize_name(tool_call.function.name),
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                for x in choice.message.tool_calls
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
            finish_reason = "tool_calls"
 | 
					            finish_reason = "tool_calls"
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            finish_reason = choice.finish_reason
 | 
					            finish_reason = choice.finish_reason
 | 
				
			||||||
 | 
				
			|||||||
@ -619,6 +619,31 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
 | 
				
			|||||||
            object="chat.completion",
 | 
					            object="chat.completion",
 | 
				
			||||||
            usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
 | 
					            usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
 | 
				
			||||||
        ),
 | 
					        ),
 | 
				
			||||||
 | 
					        # Should raise warning when function arguments is not a string.
 | 
				
			||||||
 | 
					        ChatCompletion(
 | 
				
			||||||
 | 
					            id="id6",
 | 
				
			||||||
 | 
					            choices=[
 | 
				
			||||||
 | 
					                Choice(
 | 
				
			||||||
 | 
					                    finish_reason="tool_calls",
 | 
				
			||||||
 | 
					                    index=0,
 | 
				
			||||||
 | 
					                    message=ChatCompletionMessage(
 | 
				
			||||||
 | 
					                        content=None,
 | 
				
			||||||
 | 
					                        tool_calls=[
 | 
				
			||||||
 | 
					                            ChatCompletionMessageToolCall(
 | 
				
			||||||
 | 
					                                id="1",
 | 
				
			||||||
 | 
					                                type="function",
 | 
				
			||||||
 | 
					                                function=Function.construct(name="_pass_function", arguments={"input": "task"}),  # type: ignore
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					                        ],
 | 
				
			||||||
 | 
					                        role="assistant",
 | 
				
			||||||
 | 
					                    ),
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					            created=0,
 | 
				
			||||||
 | 
					            model=model,
 | 
				
			||||||
 | 
					            object="chat.completion",
 | 
				
			||||||
 | 
					            usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
    mock = _MockChatCompletion(chat_completions)
 | 
					    mock = _MockChatCompletion(chat_completions)
 | 
				
			||||||
    monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
 | 
					    monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
 | 
				
			||||||
@ -676,8 +701,16 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
 | 
				
			|||||||
    assert create_result.content == "I should make a tool call."
 | 
					    assert create_result.content == "I should make a tool call."
 | 
				
			||||||
    assert create_result.finish_reason == "stop"
 | 
					    assert create_result.finish_reason == "stop"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Should raise warning when function arguments is not a string.
 | 
				
			||||||
 | 
					    with pytest.warns(UserWarning, match="Tool call function arguments field is not a string"):
 | 
				
			||||||
 | 
					        create_result = await model_client.create(
 | 
				
			||||||
 | 
					            messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
 | 
				
			||||||
 | 
					        assert create_result.finish_reason == "function_calls"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
 | 
					
 | 
				
			||||||
 | 
					async def _test_model_client_basic_completion(model_client: OpenAIChatCompletionClient) -> None:
 | 
				
			||||||
    # Test basic completion
 | 
					    # Test basic completion
 | 
				
			||||||
    create_result = await model_client.create(
 | 
					    create_result = await model_client.create(
 | 
				
			||||||
        messages=[
 | 
					        messages=[
 | 
				
			||||||
@ -688,6 +721,8 @@ async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
 | 
				
			|||||||
    assert isinstance(create_result.content, str)
 | 
					    assert isinstance(create_result.content, str)
 | 
				
			||||||
    assert len(create_result.content) > 0
 | 
					    assert len(create_result.content) > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def _test_model_client_with_function_calling(model_client: OpenAIChatCompletionClient) -> None:
 | 
				
			||||||
    # Test tool calling
 | 
					    # Test tool calling
 | 
				
			||||||
    pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
 | 
					    pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
 | 
				
			||||||
    fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
 | 
					    fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
 | 
				
			||||||
@ -755,7 +790,8 @@ async def test_openai() -> None:
 | 
				
			|||||||
        model="gpt-4o-mini",
 | 
					        model="gpt-4o-mini",
 | 
				
			||||||
        api_key=api_key,
 | 
					        api_key=api_key,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    await _test_model_client(model_client)
 | 
					    await _test_model_client_basic_completion(model_client)
 | 
				
			||||||
 | 
					    await _test_model_client_with_function_calling(model_client)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.asyncio
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
@ -775,7 +811,29 @@ async def test_gemini() -> None:
 | 
				
			|||||||
            "family": ModelFamily.UNKNOWN,
 | 
					            "family": ModelFamily.UNKNOWN,
 | 
				
			||||||
        },
 | 
					        },
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    await _test_model_client(model_client)
 | 
					    await _test_model_client_basic_completion(model_client)
 | 
				
			||||||
 | 
					    await _test_model_client_with_function_calling(model_client)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_hugging_face() -> None:
 | 
				
			||||||
 | 
					    api_key = os.getenv("HF_TOKEN")
 | 
				
			||||||
 | 
					    if not api_key:
 | 
				
			||||||
 | 
					        pytest.skip("HF_TOKEN not found in environment variables")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_client = OpenAIChatCompletionClient(
 | 
				
			||||||
 | 
					        model="microsoft/Phi-3.5-mini-instruct",
 | 
				
			||||||
 | 
					        api_key=api_key,
 | 
				
			||||||
 | 
					        base_url="https://api-inference.huggingface.co/v1/",
 | 
				
			||||||
 | 
					        model_info={
 | 
				
			||||||
 | 
					            "function_calling": False,
 | 
				
			||||||
 | 
					            "json_output": False,
 | 
				
			||||||
 | 
					            "vision": False,
 | 
				
			||||||
 | 
					            "family": ModelFamily.UNKNOWN,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await _test_model_client_basic_completion(model_client)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: add integration tests for Azure OpenAI using AAD token.
 | 
					# TODO: add integration tests for Azure OpenAI using AAD token.
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user