Add support for thought field in AzureAIChatCompletionClient (#6062)

added support for the thought process in tool calls for
`OpenAIChatCompletionClient`, allowing additional text produced by a
model alongside tool calls to be preserved in the thought field of
`CreateResult`. This PR extends the same functionality to
`AzureAIChatCompletionClient` for consistency across model clients.

#5650
Co-authored-by: Jay Prakash Thakur <jathakur@microsoft.com>
This commit is contained in:
Jay Prakash Thakur 2025-03-24 17:33:10 -07:00 committed by GitHub
parent 47ffaccba1
commit 7047fb8b8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 176 additions and 6 deletions

View File

@ -408,9 +408,10 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
)
choice = result.choices[0]
thought = None
if choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
assert choice.message.tool_calls is not None
content: Union[str, List[FunctionCall]] = [
FunctionCall(
id=x.id,
@ -420,6 +421,9 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
for x in choice.message.tool_calls
]
finish_reason = "function_calls"
if choice.message.content:
thought = choice.message.content
else:
if isinstance(choice.finish_reason, CompletionsFinishReason):
finish_reason = choice.finish_reason.value
@ -429,8 +433,6 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None
response = CreateResult(
finish_reason=finish_reason, # type: ignore
@ -486,6 +488,8 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
chunk: Optional[StreamingChatCompletionsUpdate] = None
choice: Optional[StreamingChatChoiceUpdate] = None
first_chunk = True
thought = None
async for chunk in await task: # type: ignore
if first_chunk:
first_chunk = False
@ -545,6 +549,9 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
else:
content = list(full_tool_calls.values())
if len(content_deltas) > 0:
thought = "".join(content_deltas)
usage = RequestUsage(
completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
@ -552,8 +559,6 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None
result = CreateResult(
finish_reason=finish_reason,

View File

@ -2,7 +2,8 @@ import asyncio
import logging
import os
from datetime import datetime
from typing import Any, AsyncGenerator, List
from typing import Any, AsyncGenerator, List, Type, Union
from unittest.mock import MagicMock
import pytest
from autogen_core import CancellationToken, FunctionCall, Image
@ -458,3 +459,167 @@ async def test_r1_content(monkeypatch: pytest.MonkeyPatch) -> None:
assert isinstance(chunks[-1], CreateResult)
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
assert chunks[-1].thought == "Thought"
@pytest.fixture
def thought_with_tool_call_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
"""
Returns a client that simulates a response with both tool calls and thought content.
"""
async def _mock_thought_with_tool_call(*args: Any, **kwargs: Any) -> ChatCompletions:
await asyncio.sleep(0.01)
return ChatCompletions(
id="id",
created=datetime.now(),
model="model",
choices=[
ChatChoice(
index=0,
finish_reason=CompletionsFinishReason.TOOL_CALLS,
message=ChatResponseMessage(
role="assistant",
content="Let me think about what function to call.",
tool_calls=[
ChatCompletionsToolCall(
id="tool_call_id",
function=AzureFunctionCall(name="some_function", arguments='{"foo": "bar"}'),
)
],
),
)
],
usage=CompletionsUsage(prompt_tokens=8, completion_tokens=5, total_tokens=13),
)
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_thought_with_tool_call)
return AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": True,
"vision": False,
"family": "function_calling_model",
"structured_output": False,
},
model="model",
)
@pytest.mark.asyncio
async def test_thought_field_with_tool_calls(thought_with_tool_call_client: AzureAIChatCompletionClient) -> None:
"""
Tests that when a model returns both tool calls and text content, the text content is
preserved in the thought field of the CreateResult.
"""
result = await thought_with_tool_call_client.create(
messages=[UserMessage(content="Please call a function", source="user")],
tools=[{"name": "test_tool"}],
)
assert result.finish_reason == "function_calls"
assert isinstance(result.content, list)
assert isinstance(result.content[0], FunctionCall)
assert result.content[0].name == "some_function"
assert result.content[0].arguments == '{"foo": "bar"}'
assert result.thought == "Let me think about what function to call."
@pytest.fixture
def thought_with_tool_call_stream_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
"""
Returns a client that simulates a streaming response with both tool calls and thought content.
"""
first_choice = MagicMock()
first_choice.delta = MagicMock()
first_choice.delta.content = "Let me think about what function to call."
first_choice.finish_reason = None
mock_tool_call = MagicMock()
mock_tool_call.id = "tool_call_id"
mock_tool_call.function = MagicMock()
mock_tool_call.function.name = "some_function"
mock_tool_call.function.arguments = '{"foo": "bar"}'
tool_call_choice = MagicMock()
tool_call_choice.delta = MagicMock()
tool_call_choice.delta.content = None
tool_call_choice.delta.tool_calls = [mock_tool_call]
tool_call_choice.finish_reason = "function_calls"
async def _mock_thought_with_tool_call_stream(
*args: Any, **kwargs: Any
) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
yield StreamingChatCompletionsUpdate(
id="id",
choices=[first_choice],
created=datetime.now(),
model="model",
)
await asyncio.sleep(0.01)
yield StreamingChatCompletionsUpdate(
id="id",
choices=[tool_call_choice],
created=datetime.now(),
model="model",
usage=CompletionsUsage(prompt_tokens=8, completion_tokens=5, total_tokens=13),
)
mock_client = MagicMock()
mock_client.close = MagicMock()
async def mock_complete(*args: Any, **kwargs: Any) -> Any:
if kwargs.get("stream", False):
return _mock_thought_with_tool_call_stream(*args, **kwargs)
return None
mock_client.complete = mock_complete
def mock_new(cls: Type[ChatCompletionsClient], *args: Any, **kwargs: Any) -> MagicMock:
return mock_client
monkeypatch.setattr(ChatCompletionsClient, "__new__", mock_new)
return AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": True,
"vision": False,
"family": "function_calling_model",
"structured_output": False,
},
model="model",
)
@pytest.mark.asyncio
async def test_thought_field_with_tool_calls_streaming(
thought_with_tool_call_stream_client: AzureAIChatCompletionClient,
) -> None:
"""
Tests that when a model returns both tool calls and text content in a streaming response,
the text content is preserved in the thought field of the final CreateResult.
"""
chunks: List[Union[str, CreateResult]] = []
async for chunk in thought_with_tool_call_stream_client.create_stream(
messages=[UserMessage(content="Please call a function", source="user")],
tools=[{"name": "test_tool"}],
):
chunks.append(chunk)
final_result = chunks[-1]
assert isinstance(final_result, CreateResult)
assert final_result.finish_reason == "function_calls"
assert isinstance(final_result.content, list)
assert isinstance(final_result.content[0], FunctionCall)
assert final_result.content[0].name == "some_function"
assert final_result.content[0].arguments == '{"foo": "bar"}'
assert final_result.thought == "Let me think about what function to call."