fix: update SK adapter stream tool call processing. (#5449)

<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

<!-- Please give a short summary of the change and the problem this
solves. -->

The current stream processing of SK model adapter returns on the first
function call chunk but this behavior is incorrect end ends up returning
with an incomplete function call. The observed behavior is that the
function name and arguments are split into different chunks and this
update correctly processes the chunks in this way.

## Related issue number

<!-- For example: "Closes #1234" -->

Fixes the reply in #5420 

## Checks

- [ ] I've included any doc changes needed for
https://microsoft.github.io/autogen/. See
https://microsoft.github.io/autogen/docs/Contribute#documentation to
build and test documentation locally.
- [ ] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [ ] I've made sure all auto checks have passed.

---------

Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
This commit is contained in:
Leonardo Pinheiro 2025-02-09 14:39:19 +10:00 committed by GitHub
parent b5eaab8501
commit b868e32b05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 397 additions and 44 deletions

View File

@ -1,5 +1,6 @@
import json
from typing import Any, Literal, Mapping, Optional, Sequence
import warnings
from autogen_core import FunctionCall
from autogen_core._cancellation_token import CancellationToken
@ -18,7 +19,6 @@ from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecut
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.functions.kernel_plugin import KernelPlugin
from semantic_kernel.kernel import Kernel
from typing_extensions import AsyncGenerator, Union
@ -427,6 +427,28 @@ class SKChatCompletionAdapter(ChatCompletionClient):
thought=thought,
)
@staticmethod
def _merge_function_call_content(existing_call: FunctionCallContent, new_chunk: FunctionCallContent) -> None:
"""Helper to merge partial argument chunks from new_chunk into existing_call."""
if isinstance(existing_call.arguments, str) and isinstance(new_chunk.arguments, str):
existing_call.arguments += new_chunk.arguments
elif isinstance(existing_call.arguments, dict) and isinstance(new_chunk.arguments, dict):
existing_call.arguments.update(new_chunk.arguments)
elif not existing_call.arguments or existing_call.arguments in ("{}", ""):
# If existing had no arguments yet, just take the new one
existing_call.arguments = new_chunk.arguments
else:
# If there's a mismatch (str vs dict), handle as needed
warnings.warn("Mismatch in argument types during merge. Existing arguments retained.", stacklevel=2)
# Optionally update name/function_name if newly provided
if new_chunk.name:
existing_call.name = new_chunk.name
if new_chunk.plugin_name:
existing_call.plugin_name = new_chunk.plugin_name
if new_chunk.function_name:
existing_call.function_name = new_chunk.function_name
async def create_stream(
self,
messages: Sequence[LLMMessage],
@ -460,6 +482,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
Yields:
Union[str, CreateResult]: Either a string chunk of the response or a CreateResult containing function calls.
"""
kernel = self._get_kernel(extra_create_args)
chat_history = self._convert_to_chat_history(messages)
user_settings = self._get_prompt_settings(extra_create_args)
@ -468,54 +491,105 @@ class SKChatCompletionAdapter(ChatCompletionClient):
prompt_tokens = 0
completion_tokens = 0
accumulated_content = ""
accumulated_text = ""
# Keep track of in-progress function calls. Keyed by ID
# because partial chunks for the same function call might arrive separately.
function_calls_in_progress: dict[str, FunctionCallContent] = {}
# Track the ID of the last function call we saw so we can continue
# accumulating chunk arguments for that call if new items have id=None
last_function_call_id: Optional[str] = None
async for streaming_messages in self._sk_client.get_streaming_chat_message_contents(
chat_history, settings=settings, kernel=kernel
):
for msg in streaming_messages:
if not isinstance(msg, StreamingChatMessageContent):
continue
# Track token usage
if msg.metadata and "usage" in msg.metadata:
usage = msg.metadata["usage"]
prompt_tokens = getattr(usage, "prompt_tokens", 0)
completion_tokens = getattr(usage, "completion_tokens", 0)
# Check for function calls
if any(isinstance(item, FunctionCallContent) for item in msg.items):
function_calls = self._process_tool_calls(msg)
# Process function call deltas
for item in msg.items:
if isinstance(item, FunctionCallContent):
# If the chunk has a valid ID, we start or continue that ID explicitly
if item.id:
last_function_call_id = item.id
if last_function_call_id not in function_calls_in_progress:
function_calls_in_progress[last_function_call_id] = item
else:
# Merge partial arguments into existing call
existing_call = function_calls_in_progress[last_function_call_id]
self._merge_function_call_content(existing_call, item)
else:
# item.id is None, so we assume it belongs to the last known ID
if not last_function_call_id:
# No call in progress means we can't merge
# You could either skip or raise an error here
warnings.warn(
"Received function call chunk with no ID and no call in progress.", stacklevel=2
)
continue
existing_call = function_calls_in_progress[last_function_call_id]
# Merge partial chunk
self._merge_function_call_content(existing_call, item)
# Check if the model signaled tool_calls finished
if msg.finish_reason == "tool_calls" and function_calls_in_progress:
calls_to_yield: list[FunctionCall] = []
for _, call_content in function_calls_in_progress.items():
plugin_name = call_content.plugin_name or ""
function_name = call_content.function_name
if plugin_name:
full_name = f"{plugin_name}-{function_name}"
else:
full_name = function_name
if isinstance(call_content.arguments, dict):
arguments = json.dumps(call_content.arguments)
else:
assert isinstance(call_content.arguments, str)
arguments = call_content.arguments or "{}"
calls_to_yield.append(
FunctionCall(
id=call_content.id or "unknown_id",
name=full_name,
arguments=arguments,
)
)
# Yield all function calls in progress
yield CreateResult(
content=function_calls,
content=calls_to_yield,
finish_reason="function_calls",
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
)
return
# Handle text content
# Handle any plain text in the message
if msg.content:
accumulated_content += msg.content
accumulated_text += msg.content
yield msg.content
# Final yield if there was text content
if accumulated_content:
self._total_prompt_tokens += prompt_tokens
self._total_completion_tokens += completion_tokens
# If we exit the loop without tool calls finishing, yield whatever text was accumulated
self._total_prompt_tokens += prompt_tokens
self._total_completion_tokens += completion_tokens
if isinstance(accumulated_content, str) and self._model_info["family"] == ModelFamily.R1:
thought, accumulated_content = parse_r1_content(accumulated_content)
else:
thought = None
thought = None
if isinstance(accumulated_text, str) and self._model_info["family"] == ModelFamily.R1:
thought, accumulated_text = parse_r1_content(accumulated_text)
yield CreateResult(
content=accumulated_content,
finish_reason="stop",
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
thought=thought,
)
yield CreateResult(
content=accumulated_text,
finish_reason="stop",
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
thought=thought,
)
def actual_usage(self) -> RequestUsage:
return RequestUsage(prompt_tokens=self._total_prompt_tokens, completion_tokens=self._total_completion_tokens)

View File

@ -7,7 +7,13 @@ from autogen_core import CancellationToken
from autogen_core.models import CreateResult, LLMMessage, ModelFamily, ModelInfo, SystemMessage, UserMessage
from autogen_core.tools import BaseTool
from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
Choice,
ChoiceDelta,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion
@ -72,7 +78,7 @@ def sk_client() -> AzureChatCompletion:
id="call_UwVVI0iGEmcPwmKUigJcuuuF",
function_name="calculator",
plugin_name=None,
arguments="{}",
arguments='{"a": 2, "b": 2}',
)
],
finish_reason=FinishReason.TOOL_CALLS,
@ -96,30 +102,89 @@ def sk_client() -> AzureChatCompletion:
**kwargs: Any,
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
if "What is 2 + 2?" in str(chat_history):
# Mock response for calculator tool test - single message with function call
# Initial chunk with function call setup
yield [
StreamingChatMessageContent(
choice_index=0,
inner_content=None,
inner_content=ChatCompletionChunk(
id="chatcmpl-123",
choices=[
Choice(
delta=ChoiceDelta(
role="assistant",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="call_UwVVI0iGEmcPwmKUigJcuuuF",
function=ChoiceDeltaToolCallFunction(name="calculator", arguments=""),
type="function",
)
],
),
finish_reason=None,
index=0,
)
],
created=1736673679,
model="gpt-4o-mini",
object="chat.completion.chunk",
),
ai_model_id="gpt-4o-mini",
metadata={
"logprobs": None,
"id": "chatcmpl-AooRjGxKtdTke46keWkBQBKg033XW",
"created": 1736673679,
"usage": {"prompt_tokens": 53, "completion_tokens": 13},
},
role=AuthorRole.ASSISTANT,
items=[ # type: ignore
items=[
FunctionCallContent(
id="call_n8135GXc2kbiaaDdpImsB1VW",
function_name="calculator",
plugin_name=None,
arguments="",
content_type="function_call", # type: ignore
id="call_UwVVI0iGEmcPwmKUigJcuuuF", function_name="calculator", arguments=""
)
],
finish_reason=None,
function_invoke_attempt=0,
)
]
# Arguments chunks
for arg_chunk in ["{", '"a"', ":", " ", "2", ",", " ", '"b"', ":", " ", "2", "}"]:
yield [
StreamingChatMessageContent(
choice_index=0,
inner_content=ChatCompletionChunk(
id="chatcmpl-123",
choices=[
Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0, function=ChoiceDeltaToolCallFunction(arguments=arg_chunk)
)
]
),
finish_reason=None,
index=0,
)
],
created=1736673679,
model="gpt-4o-mini",
object="chat.completion.chunk",
),
ai_model_id="gpt-4o-mini",
role=AuthorRole.ASSISTANT,
items=[FunctionCallContent(function_name="calculator", arguments=arg_chunk)],
)
]
# Final chunk with finish reason
yield [
StreamingChatMessageContent( # type: ignore
choice_index=0,
inner_content=ChatCompletionChunk(
id="chatcmpl-123",
choices=[Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)],
created=1736673679,
model="gpt-4o-mini",
object="chat.completion.chunk",
usage=CompletionUsage(prompt_tokens=53, completion_tokens=13, total_tokens=66),
),
ai_model_id="gpt-4o-mini",
role=AuthorRole.ASSISTANT,
finish_reason=FinishReason.TOOL_CALLS,
metadata={"usage": {"prompt_tokens": 53, "completion_tokens": 13}},
)
]
else:
@ -449,3 +514,217 @@ async def test_sk_chat_completion_r1_content() -> None:
assert response_chunks[-1].finish_reason == "stop"
assert response_chunks[-1].content == "Hello!"
assert response_chunks[-1].thought == "Reasoning..."
@pytest.mark.asyncio
async def test_sk_chat_completion_stream_with_multiple_function_calls() -> None:
"""
This test returns two distinct function calls via streaming, each one arriving in pieces.
We intentionally set name, plugin_name, and function_name in the later partial chunks so
that _merge_function_call_content is triggered to update them.
"""
async def mock_get_streaming_chat_message_contents(
chat_history: ChatHistory,
settings: PromptExecutionSettings,
**kwargs: Any,
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
# First partial chunk for call_1
yield [
StreamingChatMessageContent(
choice_index=0,
inner_content=ChatCompletionChunk(
id="chunk-id-1",
choices=[
Choice(
delta=ChoiceDelta(
role="assistant",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="call_1",
function=ChoiceDeltaToolCallFunction(name=None, arguments='{"arg1":'),
type="function",
)
],
),
finish_reason=None,
index=0,
)
],
created=1736679999,
model="gpt-4o-mini",
object="chat.completion.chunk",
),
ai_model_id="gpt-4o-mini",
role=AuthorRole.ASSISTANT,
items=[
FunctionCallContent(
id="call_1",
# no plugin_name/function_name yet
name=None,
arguments='{"arg1":',
)
],
)
]
# Second partial chunk for call_1 (updates plugin_name/function_name)
yield [
StreamingChatMessageContent(
choice_index=0,
inner_content=ChatCompletionChunk(
id="chunk-id-2",
choices=[
Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(
# Provide the rest of the arguments
arguments='"value1"}',
name="firstFunction",
),
)
]
),
finish_reason=None,
index=0,
)
],
created=1736679999,
model="gpt-4o-mini",
object="chat.completion.chunk",
),
ai_model_id="gpt-4o-mini",
role=AuthorRole.ASSISTANT,
items=[
FunctionCallContent(
id="call_1", plugin_name="myPlugin", function_name="firstFunction", arguments='"value1"}'
)
],
)
]
# Now partial chunk for a second call, call_2
yield [
StreamingChatMessageContent(
choice_index=0,
inner_content=ChatCompletionChunk(
id="chunk-id-3",
choices=[
Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="call_2",
function=ChoiceDeltaToolCallFunction(name=None, arguments='{"arg2":"another"}'),
type="function",
)
],
),
finish_reason=None,
index=0,
)
],
created=1736679999,
model="gpt-4o-mini",
object="chat.completion.chunk",
),
ai_model_id="gpt-4o-mini",
role=AuthorRole.ASSISTANT,
items=[FunctionCallContent(id="call_2", arguments='{"arg2":"another"}')],
)
]
# Next partial chunk updates name, plugin_name, function_name for call_2
yield [
StreamingChatMessageContent(
choice_index=0,
inner_content=ChatCompletionChunk(
id="chunk-id-4",
choices=[
Choice(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0, function=ChoiceDeltaToolCallFunction(name="secondFunction")
)
]
),
finish_reason=None,
index=0,
)
],
created=1736679999,
model="gpt-4o-mini",
object="chat.completion.chunk",
),
ai_model_id="gpt-4o-mini",
role=AuthorRole.ASSISTANT,
items=[
FunctionCallContent(
id="call_2",
name="someFancyName",
plugin_name="anotherPlugin",
function_name="secondFunction",
)
],
)
]
# Final chunk signals finish with tool_calls
yield [
StreamingChatMessageContent( # type: ignore
choice_index=0,
inner_content=ChatCompletionChunk(
id="chunk-id-5",
choices=[Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)],
created=1736679999,
model="gpt-4o-mini",
object="chat.completion.chunk",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
),
ai_model_id="gpt-4o-mini",
role=AuthorRole.ASSISTANT,
finish_reason=FinishReason.TOOL_CALLS,
metadata={"usage": {"prompt_tokens": 10, "completion_tokens": 5}},
)
]
# Mock SK client
mock_client = AsyncMock(spec=AzureChatCompletion)
mock_client.get_streaming_chat_message_contents = mock_get_streaming_chat_message_contents
# Create adapter and kernel
kernel = Kernel(memory=NullMemory())
adapter = SKChatCompletionAdapter(mock_client, kernel=kernel)
# Call create_stream with no actual tools (we just test the multiple calls)
messages: list[LLMMessage] = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Call two different plugin functions", source="user"),
]
# Collect streaming outputs
response_chunks: list[CreateResult | str] = []
async for chunk in adapter.create_stream(messages=messages):
response_chunks.append(chunk)
# The final chunk should be a CreateResult with function_calls
assert len(response_chunks) > 0
final_chunk = response_chunks[-1]
assert isinstance(final_chunk, CreateResult)
assert final_chunk.finish_reason == "function_calls"
assert isinstance(final_chunk.content, list)
assert len(final_chunk.content) == 2 # We expect 2 calls
# Verify first call merged name + arguments
first_call = final_chunk.content[0]
assert first_call.id == "call_1"
assert first_call.name == "myPlugin-firstFunction" # pluginName-functionName
assert '{"arg1":"value1"}' in first_call.arguments
# Verify second call also merged everything
second_call = final_chunk.content[1]
assert second_call.id == "call_2"
assert second_call.name == "anotherPlugin-secondFunction"
assert '{"arg2":"another"}' in second_call.arguments