mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 15:09:41 +00:00
Add support for thought field in AzureAIChatCompletionClient (#6062)
added support for the thought process in tool calls for `OpenAIChatCompletionClient`, allowing additional text produced by a model alongside tool calls to be preserved in the thought field of `CreateResult`. This PR extends the same functionality to `AzureAIChatCompletionClient` for consistency across model clients. #5650 Co-authored-by: Jay Prakash Thakur <jathakur@microsoft.com>
This commit is contained in:
parent
47ffaccba1
commit
7047fb8b8d
@ -408,9 +408,10 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
)
|
||||
|
||||
choice = result.choices[0]
|
||||
thought = None
|
||||
|
||||
if choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
|
||||
assert choice.message.tool_calls is not None
|
||||
|
||||
content: Union[str, List[FunctionCall]] = [
|
||||
FunctionCall(
|
||||
id=x.id,
|
||||
@ -420,6 +421,9 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
for x in choice.message.tool_calls
|
||||
]
|
||||
finish_reason = "function_calls"
|
||||
|
||||
if choice.message.content:
|
||||
thought = choice.message.content
|
||||
else:
|
||||
if isinstance(choice.finish_reason, CompletionsFinishReason):
|
||||
finish_reason = choice.finish_reason.value
|
||||
@ -429,8 +433,6 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, content = parse_r1_content(content)
|
||||
else:
|
||||
thought = None
|
||||
|
||||
response = CreateResult(
|
||||
finish_reason=finish_reason, # type: ignore
|
||||
@ -486,6 +488,8 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
chunk: Optional[StreamingChatCompletionsUpdate] = None
|
||||
choice: Optional[StreamingChatChoiceUpdate] = None
|
||||
first_chunk = True
|
||||
thought = None
|
||||
|
||||
async for chunk in await task: # type: ignore
|
||||
if first_chunk:
|
||||
first_chunk = False
|
||||
@ -545,6 +549,9 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
else:
|
||||
content = list(full_tool_calls.values())
|
||||
|
||||
if len(content_deltas) > 0:
|
||||
thought = "".join(content_deltas)
|
||||
|
||||
usage = RequestUsage(
|
||||
completion_tokens=completion_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
@ -552,8 +559,6 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, content = parse_r1_content(content)
|
||||
else:
|
||||
thought = None
|
||||
|
||||
result = CreateResult(
|
||||
finish_reason=finish_reason,
|
||||
|
||||
@ -2,7 +2,8 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, List
|
||||
from typing import Any, AsyncGenerator, List, Type, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from autogen_core import CancellationToken, FunctionCall, Image
|
||||
@ -458,3 +459,167 @@ async def test_r1_content(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
|
||||
assert chunks[-1].thought == "Thought"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thought_with_tool_call_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
|
||||
"""
|
||||
Returns a client that simulates a response with both tool calls and thought content.
|
||||
"""
|
||||
|
||||
async def _mock_thought_with_tool_call(*args: Any, **kwargs: Any) -> ChatCompletions:
|
||||
await asyncio.sleep(0.01)
|
||||
return ChatCompletions(
|
||||
id="id",
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
finish_reason=CompletionsFinishReason.TOOL_CALLS,
|
||||
message=ChatResponseMessage(
|
||||
role="assistant",
|
||||
content="Let me think about what function to call.",
|
||||
tool_calls=[
|
||||
ChatCompletionsToolCall(
|
||||
id="tool_call_id",
|
||||
function=AzureFunctionCall(name="some_function", arguments='{"foo": "bar"}'),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=CompletionsUsage(prompt_tokens=8, completion_tokens=5, total_tokens=13),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_thought_with_tool_call)
|
||||
return AzureAIChatCompletionClient(
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": True,
|
||||
"vision": False,
|
||||
"family": "function_calling_model",
|
||||
"structured_output": False,
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thought_field_with_tool_calls(thought_with_tool_call_client: AzureAIChatCompletionClient) -> None:
|
||||
"""
|
||||
Tests that when a model returns both tool calls and text content, the text content is
|
||||
preserved in the thought field of the CreateResult.
|
||||
"""
|
||||
result = await thought_with_tool_call_client.create(
|
||||
messages=[UserMessage(content="Please call a function", source="user")],
|
||||
tools=[{"name": "test_tool"}],
|
||||
)
|
||||
|
||||
assert result.finish_reason == "function_calls"
|
||||
assert isinstance(result.content, list)
|
||||
assert isinstance(result.content[0], FunctionCall)
|
||||
assert result.content[0].name == "some_function"
|
||||
assert result.content[0].arguments == '{"foo": "bar"}'
|
||||
|
||||
assert result.thought == "Let me think about what function to call."
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thought_with_tool_call_stream_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
|
||||
"""
|
||||
Returns a client that simulates a streaming response with both tool calls and thought content.
|
||||
"""
|
||||
first_choice = MagicMock()
|
||||
first_choice.delta = MagicMock()
|
||||
first_choice.delta.content = "Let me think about what function to call."
|
||||
first_choice.finish_reason = None
|
||||
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "tool_call_id"
|
||||
mock_tool_call.function = MagicMock()
|
||||
mock_tool_call.function.name = "some_function"
|
||||
mock_tool_call.function.arguments = '{"foo": "bar"}'
|
||||
|
||||
tool_call_choice = MagicMock()
|
||||
tool_call_choice.delta = MagicMock()
|
||||
tool_call_choice.delta.content = None
|
||||
tool_call_choice.delta.tool_calls = [mock_tool_call]
|
||||
tool_call_choice.finish_reason = "function_calls"
|
||||
|
||||
async def _mock_thought_with_tool_call_stream(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
|
||||
yield StreamingChatCompletionsUpdate(
|
||||
id="id",
|
||||
choices=[first_choice],
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
yield StreamingChatCompletionsUpdate(
|
||||
id="id",
|
||||
choices=[tool_call_choice],
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
usage=CompletionsUsage(prompt_tokens=8, completion_tokens=5, total_tokens=13),
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.close = MagicMock()
|
||||
|
||||
async def mock_complete(*args: Any, **kwargs: Any) -> Any:
|
||||
if kwargs.get("stream", False):
|
||||
return _mock_thought_with_tool_call_stream(*args, **kwargs)
|
||||
return None
|
||||
|
||||
mock_client.complete = mock_complete
|
||||
|
||||
def mock_new(cls: Type[ChatCompletionsClient], *args: Any, **kwargs: Any) -> MagicMock:
|
||||
return mock_client
|
||||
|
||||
monkeypatch.setattr(ChatCompletionsClient, "__new__", mock_new)
|
||||
|
||||
return AzureAIChatCompletionClient(
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": True,
|
||||
"vision": False,
|
||||
"family": "function_calling_model",
|
||||
"structured_output": False,
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thought_field_with_tool_calls_streaming(
|
||||
thought_with_tool_call_stream_client: AzureAIChatCompletionClient,
|
||||
) -> None:
|
||||
"""
|
||||
Tests that when a model returns both tool calls and text content in a streaming response,
|
||||
the text content is preserved in the thought field of the final CreateResult.
|
||||
"""
|
||||
chunks: List[Union[str, CreateResult]] = []
|
||||
async for chunk in thought_with_tool_call_stream_client.create_stream(
|
||||
messages=[UserMessage(content="Please call a function", source="user")],
|
||||
tools=[{"name": "test_tool"}],
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
final_result = chunks[-1]
|
||||
assert isinstance(final_result, CreateResult)
|
||||
|
||||
assert final_result.finish_reason == "function_calls"
|
||||
assert isinstance(final_result.content, list)
|
||||
assert isinstance(final_result.content[0], FunctionCall)
|
||||
assert final_result.content[0].name == "some_function"
|
||||
assert final_result.content[0].arguments == '{"foo": "bar"}'
|
||||
|
||||
assert final_result.thought == "Let me think about what function to call."
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user