mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 06:59:03 +00:00
fix: update SK adapter stream tool call processing. (#5449)
<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> The current stream processing of SK model adapter returns on the first function call chunk but this behavior is incorrect end ends up returning with an incomplete function call. The observed behavior is that the function name and arguments are split into different chunks and this update correctly processes the chunks in this way. ## Related issue number <!-- For example: "Closes #1234" --> Fixes the reply in #5420 ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --------- Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
This commit is contained in:
parent
b5eaab8501
commit
b868e32b05
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Any, Literal, Mapping, Optional, Sequence
|
||||
import warnings
|
||||
|
||||
from autogen_core import FunctionCall
|
||||
from autogen_core._cancellation_token import CancellationToken
|
||||
@ -18,7 +19,6 @@ from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecut
|
||||
from semantic_kernel.contents.chat_history import ChatHistory
|
||||
from semantic_kernel.contents.chat_message_content import ChatMessageContent
|
||||
from semantic_kernel.contents.function_call_content import FunctionCallContent
|
||||
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
|
||||
from semantic_kernel.functions.kernel_plugin import KernelPlugin
|
||||
from semantic_kernel.kernel import Kernel
|
||||
from typing_extensions import AsyncGenerator, Union
|
||||
@ -427,6 +427,28 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
thought=thought,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_function_call_content(existing_call: FunctionCallContent, new_chunk: FunctionCallContent) -> None:
|
||||
"""Helper to merge partial argument chunks from new_chunk into existing_call."""
|
||||
if isinstance(existing_call.arguments, str) and isinstance(new_chunk.arguments, str):
|
||||
existing_call.arguments += new_chunk.arguments
|
||||
elif isinstance(existing_call.arguments, dict) and isinstance(new_chunk.arguments, dict):
|
||||
existing_call.arguments.update(new_chunk.arguments)
|
||||
elif not existing_call.arguments or existing_call.arguments in ("{}", ""):
|
||||
# If existing had no arguments yet, just take the new one
|
||||
existing_call.arguments = new_chunk.arguments
|
||||
else:
|
||||
# If there's a mismatch (str vs dict), handle as needed
|
||||
warnings.warn("Mismatch in argument types during merge. Existing arguments retained.", stacklevel=2)
|
||||
|
||||
# Optionally update name/function_name if newly provided
|
||||
if new_chunk.name:
|
||||
existing_call.name = new_chunk.name
|
||||
if new_chunk.plugin_name:
|
||||
existing_call.plugin_name = new_chunk.plugin_name
|
||||
if new_chunk.function_name:
|
||||
existing_call.function_name = new_chunk.function_name
|
||||
|
||||
async def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
@ -460,6 +482,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
Yields:
|
||||
Union[str, CreateResult]: Either a string chunk of the response or a CreateResult containing function calls.
|
||||
"""
|
||||
|
||||
kernel = self._get_kernel(extra_create_args)
|
||||
chat_history = self._convert_to_chat_history(messages)
|
||||
user_settings = self._get_prompt_settings(extra_create_args)
|
||||
@ -468,54 +491,105 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
accumulated_content = ""
|
||||
accumulated_text = ""
|
||||
|
||||
# Keep track of in-progress function calls. Keyed by ID
|
||||
# because partial chunks for the same function call might arrive separately.
|
||||
function_calls_in_progress: dict[str, FunctionCallContent] = {}
|
||||
|
||||
# Track the ID of the last function call we saw so we can continue
|
||||
# accumulating chunk arguments for that call if new items have id=None
|
||||
last_function_call_id: Optional[str] = None
|
||||
|
||||
async for streaming_messages in self._sk_client.get_streaming_chat_message_contents(
|
||||
chat_history, settings=settings, kernel=kernel
|
||||
):
|
||||
for msg in streaming_messages:
|
||||
if not isinstance(msg, StreamingChatMessageContent):
|
||||
continue
|
||||
|
||||
# Track token usage
|
||||
if msg.metadata and "usage" in msg.metadata:
|
||||
usage = msg.metadata["usage"]
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0)
|
||||
|
||||
# Check for function calls
|
||||
if any(isinstance(item, FunctionCallContent) for item in msg.items):
|
||||
function_calls = self._process_tool_calls(msg)
|
||||
# Process function call deltas
|
||||
for item in msg.items:
|
||||
if isinstance(item, FunctionCallContent):
|
||||
# If the chunk has a valid ID, we start or continue that ID explicitly
|
||||
if item.id:
|
||||
last_function_call_id = item.id
|
||||
if last_function_call_id not in function_calls_in_progress:
|
||||
function_calls_in_progress[last_function_call_id] = item
|
||||
else:
|
||||
# Merge partial arguments into existing call
|
||||
existing_call = function_calls_in_progress[last_function_call_id]
|
||||
self._merge_function_call_content(existing_call, item)
|
||||
else:
|
||||
# item.id is None, so we assume it belongs to the last known ID
|
||||
if not last_function_call_id:
|
||||
# No call in progress means we can't merge
|
||||
# You could either skip or raise an error here
|
||||
warnings.warn(
|
||||
"Received function call chunk with no ID and no call in progress.", stacklevel=2
|
||||
)
|
||||
continue
|
||||
|
||||
existing_call = function_calls_in_progress[last_function_call_id]
|
||||
# Merge partial chunk
|
||||
self._merge_function_call_content(existing_call, item)
|
||||
|
||||
# Check if the model signaled tool_calls finished
|
||||
if msg.finish_reason == "tool_calls" and function_calls_in_progress:
|
||||
calls_to_yield: list[FunctionCall] = []
|
||||
for _, call_content in function_calls_in_progress.items():
|
||||
plugin_name = call_content.plugin_name or ""
|
||||
function_name = call_content.function_name
|
||||
if plugin_name:
|
||||
full_name = f"{plugin_name}-{function_name}"
|
||||
else:
|
||||
full_name = function_name
|
||||
|
||||
if isinstance(call_content.arguments, dict):
|
||||
arguments = json.dumps(call_content.arguments)
|
||||
else:
|
||||
assert isinstance(call_content.arguments, str)
|
||||
arguments = call_content.arguments or "{}"
|
||||
|
||||
calls_to_yield.append(
|
||||
FunctionCall(
|
||||
id=call_content.id or "unknown_id",
|
||||
name=full_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
# Yield all function calls in progress
|
||||
yield CreateResult(
|
||||
content=function_calls,
|
||||
content=calls_to_yield,
|
||||
finish_reason="function_calls",
|
||||
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
|
||||
cached=False,
|
||||
)
|
||||
return
|
||||
|
||||
# Handle text content
|
||||
# Handle any plain text in the message
|
||||
if msg.content:
|
||||
accumulated_content += msg.content
|
||||
accumulated_text += msg.content
|
||||
yield msg.content
|
||||
|
||||
# Final yield if there was text content
|
||||
if accumulated_content:
|
||||
self._total_prompt_tokens += prompt_tokens
|
||||
self._total_completion_tokens += completion_tokens
|
||||
# If we exit the loop without tool calls finishing, yield whatever text was accumulated
|
||||
self._total_prompt_tokens += prompt_tokens
|
||||
self._total_completion_tokens += completion_tokens
|
||||
|
||||
if isinstance(accumulated_content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, accumulated_content = parse_r1_content(accumulated_content)
|
||||
else:
|
||||
thought = None
|
||||
thought = None
|
||||
if isinstance(accumulated_text, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, accumulated_text = parse_r1_content(accumulated_text)
|
||||
|
||||
yield CreateResult(
|
||||
content=accumulated_content,
|
||||
finish_reason="stop",
|
||||
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
|
||||
cached=False,
|
||||
thought=thought,
|
||||
)
|
||||
yield CreateResult(
|
||||
content=accumulated_text,
|
||||
finish_reason="stop",
|
||||
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
|
||||
cached=False,
|
||||
thought=thought,
|
||||
)
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return RequestUsage(prompt_tokens=self._total_prompt_tokens, completion_tokens=self._total_completion_tokens)
|
||||
|
||||
@ -7,7 +7,13 @@ from autogen_core import CancellationToken
|
||||
from autogen_core.models import CreateResult, LLMMessage, ModelFamily, ModelInfo, SystemMessage, UserMessage
|
||||
from autogen_core.tools import BaseTool
|
||||
from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
Choice,
|
||||
ChoiceDelta,
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import BaseModel
|
||||
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion
|
||||
@ -72,7 +78,7 @@ def sk_client() -> AzureChatCompletion:
|
||||
id="call_UwVVI0iGEmcPwmKUigJcuuuF",
|
||||
function_name="calculator",
|
||||
plugin_name=None,
|
||||
arguments="{}",
|
||||
arguments='{"a": 2, "b": 2}',
|
||||
)
|
||||
],
|
||||
finish_reason=FinishReason.TOOL_CALLS,
|
||||
@ -96,30 +102,89 @@ def sk_client() -> AzureChatCompletion:
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
|
||||
if "What is 2 + 2?" in str(chat_history):
|
||||
# Mock response for calculator tool test - single message with function call
|
||||
# Initial chunk with function call setup
|
||||
yield [
|
||||
StreamingChatMessageContent(
|
||||
choice_index=0,
|
||||
inner_content=None,
|
||||
inner_content=ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="call_UwVVI0iGEmcPwmKUigJcuuuF",
|
||||
function=ChoiceDeltaToolCallFunction(name="calculator", arguments=""),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1736673679,
|
||||
model="gpt-4o-mini",
|
||||
object="chat.completion.chunk",
|
||||
),
|
||||
ai_model_id="gpt-4o-mini",
|
||||
metadata={
|
||||
"logprobs": None,
|
||||
"id": "chatcmpl-AooRjGxKtdTke46keWkBQBKg033XW",
|
||||
"created": 1736673679,
|
||||
"usage": {"prompt_tokens": 53, "completion_tokens": 13},
|
||||
},
|
||||
role=AuthorRole.ASSISTANT,
|
||||
items=[ # type: ignore
|
||||
items=[
|
||||
FunctionCallContent(
|
||||
id="call_n8135GXc2kbiaaDdpImsB1VW",
|
||||
function_name="calculator",
|
||||
plugin_name=None,
|
||||
arguments="",
|
||||
content_type="function_call", # type: ignore
|
||||
id="call_UwVVI0iGEmcPwmKUigJcuuuF", function_name="calculator", arguments=""
|
||||
)
|
||||
],
|
||||
finish_reason=None,
|
||||
function_invoke_attempt=0,
|
||||
)
|
||||
]
|
||||
|
||||
# Arguments chunks
|
||||
for arg_chunk in ["{", '"a"', ":", " ", "2", ",", " ", '"b"', ":", " ", "2", "}"]:
|
||||
yield [
|
||||
StreamingChatMessageContent(
|
||||
choice_index=0,
|
||||
inner_content=ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, function=ChoiceDeltaToolCallFunction(arguments=arg_chunk)
|
||||
)
|
||||
]
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1736673679,
|
||||
model="gpt-4o-mini",
|
||||
object="chat.completion.chunk",
|
||||
),
|
||||
ai_model_id="gpt-4o-mini",
|
||||
role=AuthorRole.ASSISTANT,
|
||||
items=[FunctionCallContent(function_name="calculator", arguments=arg_chunk)],
|
||||
)
|
||||
]
|
||||
|
||||
# Final chunk with finish reason
|
||||
yield [
|
||||
StreamingChatMessageContent( # type: ignore
|
||||
choice_index=0,
|
||||
inner_content=ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
choices=[Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)],
|
||||
created=1736673679,
|
||||
model="gpt-4o-mini",
|
||||
object="chat.completion.chunk",
|
||||
usage=CompletionUsage(prompt_tokens=53, completion_tokens=13, total_tokens=66),
|
||||
),
|
||||
ai_model_id="gpt-4o-mini",
|
||||
role=AuthorRole.ASSISTANT,
|
||||
finish_reason=FinishReason.TOOL_CALLS,
|
||||
metadata={"usage": {"prompt_tokens": 53, "completion_tokens": 13}},
|
||||
)
|
||||
]
|
||||
else:
|
||||
@ -449,3 +514,217 @@ async def test_sk_chat_completion_r1_content() -> None:
|
||||
assert response_chunks[-1].finish_reason == "stop"
|
||||
assert response_chunks[-1].content == "Hello!"
|
||||
assert response_chunks[-1].thought == "Reasoning..."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sk_chat_completion_stream_with_multiple_function_calls() -> None:
|
||||
"""
|
||||
This test returns two distinct function calls via streaming, each one arriving in pieces.
|
||||
We intentionally set name, plugin_name, and function_name in the later partial chunks so
|
||||
that _merge_function_call_content is triggered to update them.
|
||||
"""
|
||||
|
||||
async def mock_get_streaming_chat_message_contents(
|
||||
chat_history: ChatHistory,
|
||||
settings: PromptExecutionSettings,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
|
||||
# First partial chunk for call_1
|
||||
yield [
|
||||
StreamingChatMessageContent(
|
||||
choice_index=0,
|
||||
inner_content=ChatCompletionChunk(
|
||||
id="chunk-id-1",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="call_1",
|
||||
function=ChoiceDeltaToolCallFunction(name=None, arguments='{"arg1":'),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1736679999,
|
||||
model="gpt-4o-mini",
|
||||
object="chat.completion.chunk",
|
||||
),
|
||||
ai_model_id="gpt-4o-mini",
|
||||
role=AuthorRole.ASSISTANT,
|
||||
items=[
|
||||
FunctionCallContent(
|
||||
id="call_1",
|
||||
# no plugin_name/function_name yet
|
||||
name=None,
|
||||
arguments='{"arg1":',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
# Second partial chunk for call_1 (updates plugin_name/function_name)
|
||||
yield [
|
||||
StreamingChatMessageContent(
|
||||
choice_index=0,
|
||||
inner_content=ChatCompletionChunk(
|
||||
id="chunk-id-2",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
# Provide the rest of the arguments
|
||||
arguments='"value1"}',
|
||||
name="firstFunction",
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1736679999,
|
||||
model="gpt-4o-mini",
|
||||
object="chat.completion.chunk",
|
||||
),
|
||||
ai_model_id="gpt-4o-mini",
|
||||
role=AuthorRole.ASSISTANT,
|
||||
items=[
|
||||
FunctionCallContent(
|
||||
id="call_1", plugin_name="myPlugin", function_name="firstFunction", arguments='"value1"}'
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
# Now partial chunk for a second call, call_2
|
||||
yield [
|
||||
StreamingChatMessageContent(
|
||||
choice_index=0,
|
||||
inner_content=ChatCompletionChunk(
|
||||
id="chunk-id-3",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="call_2",
|
||||
function=ChoiceDeltaToolCallFunction(name=None, arguments='{"arg2":"another"}'),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1736679999,
|
||||
model="gpt-4o-mini",
|
||||
object="chat.completion.chunk",
|
||||
),
|
||||
ai_model_id="gpt-4o-mini",
|
||||
role=AuthorRole.ASSISTANT,
|
||||
items=[FunctionCallContent(id="call_2", arguments='{"arg2":"another"}')],
|
||||
)
|
||||
]
|
||||
# Next partial chunk updates name, plugin_name, function_name for call_2
|
||||
yield [
|
||||
StreamingChatMessageContent(
|
||||
choice_index=0,
|
||||
inner_content=ChatCompletionChunk(
|
||||
id="chunk-id-4",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, function=ChoiceDeltaToolCallFunction(name="secondFunction")
|
||||
)
|
||||
]
|
||||
),
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=1736679999,
|
||||
model="gpt-4o-mini",
|
||||
object="chat.completion.chunk",
|
||||
),
|
||||
ai_model_id="gpt-4o-mini",
|
||||
role=AuthorRole.ASSISTANT,
|
||||
items=[
|
||||
FunctionCallContent(
|
||||
id="call_2",
|
||||
name="someFancyName",
|
||||
plugin_name="anotherPlugin",
|
||||
function_name="secondFunction",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
# Final chunk signals finish with tool_calls
|
||||
yield [
|
||||
StreamingChatMessageContent( # type: ignore
|
||||
choice_index=0,
|
||||
inner_content=ChatCompletionChunk(
|
||||
id="chunk-id-5",
|
||||
choices=[Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)],
|
||||
created=1736679999,
|
||||
model="gpt-4o-mini",
|
||||
object="chat.completion.chunk",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
),
|
||||
ai_model_id="gpt-4o-mini",
|
||||
role=AuthorRole.ASSISTANT,
|
||||
finish_reason=FinishReason.TOOL_CALLS,
|
||||
metadata={"usage": {"prompt_tokens": 10, "completion_tokens": 5}},
|
||||
)
|
||||
]
|
||||
|
||||
# Mock SK client
|
||||
mock_client = AsyncMock(spec=AzureChatCompletion)
|
||||
mock_client.get_streaming_chat_message_contents = mock_get_streaming_chat_message_contents
|
||||
|
||||
# Create adapter and kernel
|
||||
kernel = Kernel(memory=NullMemory())
|
||||
adapter = SKChatCompletionAdapter(mock_client, kernel=kernel)
|
||||
|
||||
# Call create_stream with no actual tools (we just test the multiple calls)
|
||||
messages: list[LLMMessage] = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Call two different plugin functions", source="user"),
|
||||
]
|
||||
|
||||
# Collect streaming outputs
|
||||
response_chunks: list[CreateResult | str] = []
|
||||
async for chunk in adapter.create_stream(messages=messages):
|
||||
response_chunks.append(chunk)
|
||||
|
||||
# The final chunk should be a CreateResult with function_calls
|
||||
assert len(response_chunks) > 0
|
||||
final_chunk = response_chunks[-1]
|
||||
assert isinstance(final_chunk, CreateResult)
|
||||
assert final_chunk.finish_reason == "function_calls"
|
||||
assert isinstance(final_chunk.content, list)
|
||||
assert len(final_chunk.content) == 2 # We expect 2 calls
|
||||
|
||||
# Verify first call merged name + arguments
|
||||
first_call = final_chunk.content[0]
|
||||
assert first_call.id == "call_1"
|
||||
assert first_call.name == "myPlugin-firstFunction" # pluginName-functionName
|
||||
assert '{"arg1":"value1"}' in first_call.arguments
|
||||
|
||||
# Verify second call also merged everything
|
||||
second_call = final_chunk.content[1]
|
||||
assert second_call.id == "call_2"
|
||||
assert second_call.name == "anotherPlugin-secondFunction"
|
||||
assert '{"arg2":"another"}' in second_call.arguments
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user