mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-24 13:39:24 +00:00
fix: Update SKChatCompletionAdapter message conversion (#5749)
<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> The PR introduces two changes. The first change is adding a name attribute to `FunctionExecutionResult`. The motivation is that semantic kernel requires it for their function result interface and it seemed like a easy modification as `FunctionExecutionResult` is always created in the context of a `FunctionCall` which will contain the name. I'm unsure if there was a motivation to keep it out but this change makes it easier to trace which tool the result refers to and also increases api compatibility with SK. The second change is an update to how messages are mapped from autogen to semantic kernel, which includes an update/fix in the processing of function results. ## Related issue number <!-- For example: "Closes #1234" --> Related to #5675 but wont fix the underlying issue of anthropic requiring tools during AssistantAgent reflection. ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --------- Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
This commit is contained in:
parent
7e01350d46
commit
906b09e451
@ -1078,6 +1078,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
content=result_as_str,
|
||||
call_id=tool_call.id,
|
||||
is_error=False,
|
||||
name=tool_call.name,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@ -1087,6 +1088,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
content=f"Error: {e}",
|
||||
call_id=tool_call.id,
|
||||
is_error=True,
|
||||
name=tool_call.name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -399,9 +399,9 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result.messages[2].models_usage.prompt_tokens == 10
|
||||
assert isinstance(result.messages[3], ToolCallExecutionEvent)
|
||||
expected_content = [
|
||||
FunctionExecutionResult(call_id="1", content="pass", is_error=False),
|
||||
FunctionExecutionResult(call_id="2", content="pass", is_error=False),
|
||||
FunctionExecutionResult(call_id="3", content="task3", is_error=False),
|
||||
FunctionExecutionResult(call_id="1", content="pass", is_error=False, name="_pass_function"),
|
||||
FunctionExecutionResult(call_id="2", content="pass", is_error=False, name="_pass_function"),
|
||||
FunctionExecutionResult(call_id="3", content="task3", is_error=False, name="_echo_function"),
|
||||
]
|
||||
for expected in expected_content:
|
||||
assert expected in result.messages[3].content
|
||||
@ -535,9 +535,9 @@ async def test_run_with_parallel_tools_with_empty_call_ids(monkeypatch: pytest.M
|
||||
assert result.messages[1].models_usage.prompt_tokens == 10
|
||||
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
||||
expected_content = [
|
||||
FunctionExecutionResult(call_id="", content="pass", is_error=False),
|
||||
FunctionExecutionResult(call_id="", content="pass", is_error=False),
|
||||
FunctionExecutionResult(call_id="", content="task3", is_error=False),
|
||||
FunctionExecutionResult(call_id="", content="pass", is_error=False, name="_pass_function"),
|
||||
FunctionExecutionResult(call_id="", content="pass", is_error=False, name="_pass_function"),
|
||||
FunctionExecutionResult(call_id="", content="task3", is_error=False, name="_echo_function"),
|
||||
]
|
||||
for expected in expected_content:
|
||||
assert expected in result.messages[2].content
|
||||
@ -1018,8 +1018,8 @@ async def test_model_client_stream_with_tool_calls() -> None:
|
||||
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
|
||||
]
|
||||
assert message.messages[2].content == [
|
||||
FunctionExecutionResult(call_id="1", content="pass", is_error=False),
|
||||
FunctionExecutionResult(call_id="3", content="task", is_error=False),
|
||||
FunctionExecutionResult(call_id="1", content="pass", is_error=False, name="_pass_function"),
|
||||
FunctionExecutionResult(call_id="3", content="task", is_error=False, name="_echo_function"),
|
||||
]
|
||||
elif isinstance(message, ModelClientStreamingChunkEvent):
|
||||
chunks.append(message.content)
|
||||
|
||||
@ -1172,8 +1172,8 @@ async def test_swarm_with_parallel_tool_calls(monkeypatch: pytest.MonkeyPatch) -
|
||||
),
|
||||
FunctionExecutionResultMessage(
|
||||
content=[
|
||||
FunctionExecutionResult(content="tool1", call_id="1", is_error=False),
|
||||
FunctionExecutionResult(content="tool2", call_id="2", is_error=False),
|
||||
FunctionExecutionResult(content="tool1", call_id="1", is_error=False, name="tool1"),
|
||||
FunctionExecutionResult(content="tool2", call_id="2", is_error=False, name="tool2"),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
@ -770,6 +770,7 @@ def convert_to_v04_message(message: Dict[str, Any]) -> AgentEvent | ChatMessage:
|
||||
call_id=tool_response["tool_call_id"],
|
||||
content=tool_response["content"],
|
||||
is_error=False,
|
||||
name=tool_response["name"],
|
||||
)
|
||||
)
|
||||
return ToolCallExecutionEvent(source="tools", content=tool_results)
|
||||
|
||||
@ -239,7 +239,12 @@
|
||||
],
|
||||
"source": [
|
||||
"# Create a function execution result\n",
|
||||
"exec_result = FunctionExecutionResult(call_id=create_result.content[0].id, content=tool_result_str, is_error=False) # type: ignore\n",
|
||||
"exec_result = FunctionExecutionResult(\n",
|
||||
" call_id=create_result.content[0].id, # type: ignore\n",
|
||||
" content=tool_result_str,\n",
|
||||
" is_error=False,\n",
|
||||
" name=stock_price_tool.name,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Make another chat completion with the history and function execution result message.\n",
|
||||
"messages = [\n",
|
||||
@ -353,9 +358,11 @@
|
||||
" try:\n",
|
||||
" arguments = json.loads(call.arguments)\n",
|
||||
" result = await tool.run_json(arguments, cancellation_token)\n",
|
||||
" return FunctionExecutionResult(call_id=call.id, content=tool.return_value_as_string(result), is_error=False)\n",
|
||||
" return FunctionExecutionResult(\n",
|
||||
" call_id=call.id, content=tool.return_value_as_string(result), is_error=False, name=tool.name\n",
|
||||
" )\n",
|
||||
" except Exception as e:\n",
|
||||
" return FunctionExecutionResult(call_id=call.id, content=str(e), is_error=True)"
|
||||
" return FunctionExecutionResult(call_id=call.id, content=str(e), is_error=True, name=tool.name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -140,7 +140,7 @@
|
||||
" f\"Function call: {message.name}\\nArguments: {message.arguments}\\nDo you want to execute the tool? (y/n): \"\n",
|
||||
" )\n",
|
||||
" if user_input.strip().lower() != \"y\":\n",
|
||||
" raise ToolException(content=\"User denied tool execution.\", call_id=message.id)\n",
|
||||
" raise ToolException(content=\"User denied tool execution.\", call_id=message.id, name=message.name)\n",
|
||||
" return message"
|
||||
]
|
||||
},
|
||||
|
||||
@ -184,7 +184,7 @@
|
||||
" result = await self._tools[call.name].run_json(arguments, ctx.cancellation_token)\n",
|
||||
" result_as_str = self._tools[call.name].return_value_as_string(result)\n",
|
||||
" tool_call_results.append(\n",
|
||||
" FunctionExecutionResult(call_id=call.id, content=result_as_str, is_error=False)\n",
|
||||
" FunctionExecutionResult(call_id=call.id, content=result_as_str, is_error=False, name=call.name)\n",
|
||||
" )\n",
|
||||
" elif call.name in self._delegate_tools:\n",
|
||||
" # Execute the tool to get the delegate agent's topic type.\n",
|
||||
@ -199,6 +199,7 @@
|
||||
" call_id=call.id,\n",
|
||||
" content=f\"Transferred to {topic_type}. Adopt persona immediately.\",\n",
|
||||
" is_error=False,\n",
|
||||
" name=call.name,\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
" ),\n",
|
||||
|
||||
@ -59,6 +59,9 @@ class FunctionExecutionResult(BaseModel):
|
||||
content: str
|
||||
"""The output of the function call."""
|
||||
|
||||
name: str
|
||||
"""The name of the function that was called."""
|
||||
|
||||
call_id: str
|
||||
"""The ID of the function call. Note this ID may be empty for some models."""
|
||||
|
||||
|
||||
@ -63,7 +63,9 @@ async def tool_agent_caller_loop(
|
||||
function_results.append(result)
|
||||
elif isinstance(result, ToolException):
|
||||
function_results.append(
|
||||
FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id, is_error=True)
|
||||
FunctionExecutionResult(
|
||||
content=f"Error: {result}", call_id=result.call_id, is_error=True, name=result.name
|
||||
)
|
||||
)
|
||||
elif isinstance(result, BaseException):
|
||||
raise result # Unexpected exception.
|
||||
|
||||
@ -19,6 +19,7 @@ __all__ = [
|
||||
class ToolException(BaseException):
|
||||
call_id: str
|
||||
content: str
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -76,7 +77,9 @@ class ToolAgent(RoutedAgent):
|
||||
"""
|
||||
tool = next((tool for tool in self._tools if tool.name == message.name), None)
|
||||
if tool is None:
|
||||
raise ToolNotFoundException(call_id=message.id, content=f"Error: Tool not found: {message.name}")
|
||||
raise ToolNotFoundException(
|
||||
call_id=message.id, content=f"Error: Tool not found: {message.name}", name=message.name
|
||||
)
|
||||
else:
|
||||
try:
|
||||
arguments = json.loads(message.arguments)
|
||||
@ -84,8 +87,8 @@ class ToolAgent(RoutedAgent):
|
||||
result_as_str = tool.return_value_as_string(result)
|
||||
except json.JSONDecodeError as e:
|
||||
raise InvalidToolArgumentsException(
|
||||
call_id=message.id, content=f"Error: Invalid arguments: {message.arguments}"
|
||||
call_id=message.id, content=f"Error: Invalid arguments: {message.arguments}", name=message.name
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ToolExecutionException(call_id=message.id, content=f"Error: {e}") from e
|
||||
return FunctionExecutionResult(content=result_as_str, call_id=message.id, is_error=False)
|
||||
raise ToolExecutionException(call_id=message.id, content=f"Error: {e}", name=message.name) from e
|
||||
return FunctionExecutionResult(content=result_as_str, call_id=message.id, is_error=False, name=message.name)
|
||||
|
||||
@ -61,7 +61,7 @@ async def test_tool_agent() -> None:
|
||||
result = await runtime.send_message(
|
||||
FunctionCall(id="1", arguments=json.dumps({"input": "pass"}), name="pass"), agent
|
||||
)
|
||||
assert result == FunctionExecutionResult(call_id="1", content="pass", is_error=False)
|
||||
assert result == FunctionExecutionResult(call_id="1", content="pass", is_error=False, name="pass")
|
||||
|
||||
# Test raise function
|
||||
with pytest.raises(ToolExecutionException):
|
||||
|
||||
@ -467,7 +467,9 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
||||
result = f"Error: {e}"
|
||||
is_error = True
|
||||
tool_outputs.append(
|
||||
FunctionExecutionResult(content=result, call_id=tool_call.id, is_error=is_error)
|
||||
FunctionExecutionResult(
|
||||
content=result, call_id=tool_call.id, is_error=is_error, name=tool_call.name
|
||||
)
|
||||
)
|
||||
|
||||
# Add tool result message to inner messages
|
||||
|
||||
@ -0,0 +1,22 @@
|
||||
from typing import Dict
|
||||
|
||||
from autogen_core.models import FinishReasons
|
||||
|
||||
|
||||
def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
|
||||
if stop_reason is None:
|
||||
return "unknown"
|
||||
|
||||
# Convert to lower case
|
||||
stop_reason = stop_reason.lower()
|
||||
|
||||
KNOWN_STOP_MAPPINGS: Dict[str, FinishReasons] = {
|
||||
"stop": "stop",
|
||||
"length": "length",
|
||||
"content_filter": "content_filter",
|
||||
"function_calls": "function_calls",
|
||||
"end_turn": "stop",
|
||||
"tool_calls": "function_calls",
|
||||
}
|
||||
|
||||
return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown")
|
||||
@ -37,7 +37,6 @@ from autogen_core.models import (
|
||||
ChatCompletionClient,
|
||||
ChatCompletionTokenLogprob,
|
||||
CreateResult,
|
||||
FinishReasons,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelCapabilities, # type: ignore
|
||||
@ -75,6 +74,7 @@ from openai.types.shared_params import FunctionDefinition, FunctionParameters
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self, Unpack
|
||||
|
||||
from .._utils.normalize_stop_reason import normalize_stop_reason
|
||||
from .._utils.parse_r1_content import parse_r1_content
|
||||
from . import _model_info
|
||||
from .config import (
|
||||
@ -349,25 +349,6 @@ def assert_valid_name(name: str) -> str:
|
||||
return name
|
||||
|
||||
|
||||
def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
|
||||
if stop_reason is None:
|
||||
return "unknown"
|
||||
|
||||
# Convert to lower case
|
||||
stop_reason = stop_reason.lower()
|
||||
|
||||
KNOWN_STOP_MAPPINGS: Dict[str, FinishReasons] = {
|
||||
"stop": "stop",
|
||||
"length": "length",
|
||||
"content_filter": "content_filter",
|
||||
"function_calls": "function_calls",
|
||||
"end_turn": "stop",
|
||||
"tool_calls": "function_calls",
|
||||
}
|
||||
|
||||
return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown")
|
||||
|
||||
|
||||
class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
def __init__(
|
||||
self,
|
||||
@ -1214,7 +1195,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
||||
UserMessage(content="I am happy.", source="user"),
|
||||
AssistantMessage(content=response1.content, source="assistant"),
|
||||
FunctionExecutionResultMessage(
|
||||
content=[FunctionExecutionResult(content="happy", call_id=response1.content[0].id, is_error=False)]
|
||||
content=[FunctionExecutionResult(content="happy", call_id=response1.content[0].id, is_error=False, name="sentiment_analysis")]
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@ -17,9 +17,13 @@ from autogen_core.tools import BaseTool, Tool, ToolSchema
|
||||
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
|
||||
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
|
||||
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
|
||||
from semantic_kernel.contents.chat_history import ChatHistory
|
||||
from semantic_kernel.contents.chat_message_content import ChatMessageContent
|
||||
from semantic_kernel.contents.function_call_content import FunctionCallContent
|
||||
from semantic_kernel.contents import (
|
||||
ChatHistory,
|
||||
ChatMessageContent,
|
||||
FinishReason,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
)
|
||||
from semantic_kernel.functions.kernel_plugin import KernelPlugin
|
||||
from semantic_kernel.kernel import Kernel
|
||||
from typing_extensions import AsyncGenerator, Union
|
||||
@ -265,7 +269,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
validate_model_info(self._model_info)
|
||||
self._total_prompt_tokens = 0
|
||||
self._total_completion_tokens = 0
|
||||
self._tools_plugin: Optional[KernelPlugin] = None
|
||||
self._tools_plugin: KernelPlugin = KernelPlugin(name="autogen_tools")
|
||||
|
||||
def _convert_to_chat_history(self, messages: Sequence[LLMMessage]) -> ChatHistory:
|
||||
"""Convert Autogen LLMMessages to SK ChatHistory"""
|
||||
@ -279,19 +283,51 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
if isinstance(msg.content, str):
|
||||
chat_history.add_user_message(msg.content)
|
||||
else:
|
||||
# Handle list of str/Image - would need to convert to SK content types
|
||||
# Handle list of str/Image - convert to string for now
|
||||
chat_history.add_user_message(str(msg.content))
|
||||
|
||||
elif msg.type == "AssistantMessage":
|
||||
if isinstance(msg.content, str):
|
||||
chat_history.add_assistant_message(msg.content)
|
||||
# Check if it's a function-call style message
|
||||
if isinstance(msg.content, list) and all(isinstance(fc, FunctionCall) for fc in msg.content):
|
||||
# If there's a 'thought' field, you can add that as plain assistant text
|
||||
if msg.thought:
|
||||
chat_history.add_assistant_message(msg.thought)
|
||||
|
||||
function_call_contents: list[FunctionCallContent] = []
|
||||
for fc in msg.content:
|
||||
function_call_contents.append(
|
||||
FunctionCallContent(
|
||||
id=fc.id,
|
||||
name=fc.name,
|
||||
plugin_name=self._tools_plugin.name,
|
||||
function_name=fc.name,
|
||||
arguments=fc.arguments,
|
||||
)
|
||||
)
|
||||
|
||||
# Mark the assistant's message as tool-calling
|
||||
chat_history.add_assistant_message(
|
||||
function_call_contents,
|
||||
finish_reason=FinishReason.TOOL_CALLS,
|
||||
)
|
||||
else:
|
||||
# Handle function calls - would need to convert to SK function call format
|
||||
chat_history.add_assistant_message(str(msg.content))
|
||||
# Plain assistant text
|
||||
chat_history.add_assistant_message(msg.content)
|
||||
|
||||
elif msg.type == "FunctionExecutionResultMessage":
|
||||
# Add each function result as a separate tool message
|
||||
tool_results: list[FunctionResultContent] = []
|
||||
for result in msg.content:
|
||||
chat_history.add_tool_message(result.content)
|
||||
tool_results.append(
|
||||
FunctionResultContent(
|
||||
id=result.call_id,
|
||||
plugin_name=self._tools_plugin.name,
|
||||
function_name=result.name,
|
||||
result=result.content,
|
||||
)
|
||||
)
|
||||
# A single "tool" message with one or more results
|
||||
chat_history.add_tool_message(tool_results)
|
||||
|
||||
return chat_history
|
||||
|
||||
@ -323,11 +359,6 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
|
||||
def _sync_tools_with_kernel(self, kernel: Kernel, tools: Sequence[Tool | ToolSchema]) -> None:
|
||||
"""Sync tools with kernel by updating the plugin"""
|
||||
# Create new plugin if none exists
|
||||
if not self._tools_plugin:
|
||||
self._tools_plugin = KernelPlugin(name="autogen_tools")
|
||||
kernel.add_plugin(self._tools_plugin)
|
||||
|
||||
# Get current tool names in plugin
|
||||
current_tool_names = set(self._tools_plugin.functions.keys())
|
||||
|
||||
@ -414,6 +445,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
CreateResult: The result of the chat completion.
|
||||
"""
|
||||
kernel = self._get_kernel(extra_create_args)
|
||||
|
||||
chat_history = self._convert_to_chat_history(messages)
|
||||
user_settings = self._get_prompt_settings(extra_create_args)
|
||||
settings = self._build_execution_settings(user_settings, tools)
|
||||
|
||||
@ -143,7 +143,12 @@ async def test_anthropic_tool_calling() -> None:
|
||||
messages.append(
|
||||
FunctionExecutionResultMessage(
|
||||
content=[
|
||||
FunctionExecutionResult(content="Processed: hello world", call_id=result.content[0].id, is_error=False)
|
||||
FunctionExecutionResult(
|
||||
content="Processed: hello world",
|
||||
call_id=result.content[0].id,
|
||||
is_error=False,
|
||||
name=result.content[0].name,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@ -334,7 +334,9 @@ async def test_openai_chat_completion_client_count_tokens(monkeypatch: pytest.Mo
|
||||
],
|
||||
source="user",
|
||||
),
|
||||
FunctionExecutionResultMessage(content=[FunctionExecutionResult(content="Hello", call_id="1", is_error=False)]),
|
||||
FunctionExecutionResultMessage(
|
||||
content=[FunctionExecutionResult(content="Hello", call_id="1", is_error=False, name="tool1")]
|
||||
),
|
||||
]
|
||||
|
||||
def tool1(test: str, test2: str) -> str:
|
||||
@ -1230,7 +1232,14 @@ async def _test_model_client_with_function_calling(model_client: OpenAIChatCompl
|
||||
messages.append(AssistantMessage(content=create_result.content, source="assistant"))
|
||||
messages.append(
|
||||
FunctionExecutionResultMessage(
|
||||
content=[FunctionExecutionResult(content="passed", call_id=create_result.content[0].id, is_error=False)]
|
||||
content=[
|
||||
FunctionExecutionResult(
|
||||
content="passed",
|
||||
call_id=create_result.content[0].id,
|
||||
is_error=False,
|
||||
name=create_result.content[0].name,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
create_result = await model_client.create(messages=messages)
|
||||
@ -1260,8 +1269,12 @@ async def _test_model_client_with_function_calling(model_client: OpenAIChatCompl
|
||||
messages.append(
|
||||
FunctionExecutionResultMessage(
|
||||
content=[
|
||||
FunctionExecutionResult(content="passed", call_id=create_result.content[0].id, is_error=False),
|
||||
FunctionExecutionResult(content="failed", call_id=create_result.content[1].id, is_error=True),
|
||||
FunctionExecutionResult(
|
||||
content="passed", call_id=create_result.content[0].id, is_error=False, name="pass_tool"
|
||||
),
|
||||
FunctionExecutionResult(
|
||||
content="failed", call_id=create_result.content[1].id, is_error=True, name="fail_tool"
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
@ -1380,7 +1393,11 @@ async def test_openai_structured_output_with_tool_calls() -> None:
|
||||
UserMessage(content="I am happy.", source="user"),
|
||||
AssistantMessage(content=response1.content, source="assistant"),
|
||||
FunctionExecutionResultMessage(
|
||||
content=[FunctionExecutionResult(content="happy", call_id=response1.content[0].id, is_error=False)]
|
||||
content=[
|
||||
FunctionExecutionResult(
|
||||
content="happy", call_id=response1.content[0].id, is_error=False, name=tool.name
|
||||
)
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -1439,7 +1456,11 @@ async def test_openai_structured_output_with_streaming_tool_calls() -> None:
|
||||
UserMessage(content="I am happy.", source="user"),
|
||||
AssistantMessage(content=create_result1.content, source="assistant"),
|
||||
FunctionExecutionResultMessage(
|
||||
content=[FunctionExecutionResult(content="happy", call_id=create_result1.content[0].id, is_error=False)]
|
||||
content=[
|
||||
FunctionExecutionResult(
|
||||
content="happy", call_id=create_result1.content[0].id, is_error=False, name=tool.name
|
||||
)
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@ -3,8 +3,18 @@ from typing import Any, AsyncGenerator
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.models import CreateResult, LLMMessage, ModelFamily, ModelInfo, SystemMessage, UserMessage
|
||||
from autogen_core import CancellationToken, FunctionCall
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelFamily,
|
||||
ModelInfo,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.tools import BaseTool
|
||||
from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
@ -728,3 +738,111 @@ async def test_sk_chat_completion_stream_with_multiple_function_calls() -> None:
|
||||
assert second_call.id == "call_2"
|
||||
assert second_call.name == "anotherPlugin-secondFunction"
|
||||
assert '{"arg2":"another"}' in second_call.arguments
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sk_chat_completion_with_function_call_and_execution_result_messages() -> None:
|
||||
"""
|
||||
Test that _convert_to_chat_history can properly handle a conversation
|
||||
that includes both an assistant function-call message and a function
|
||||
execution result message in the same sequence.
|
||||
"""
|
||||
# Mock the SK client to return some placeholder response
|
||||
mock_client = AsyncMock(spec=AzureChatCompletion)
|
||||
mock_client.get_chat_message_contents = AsyncMock(
|
||||
return_value=[
|
||||
ChatMessageContent(
|
||||
ai_model_id="test-model",
|
||||
role=AuthorRole.ASSISTANT,
|
||||
items=[TextContent(text="All done!")],
|
||||
finish_reason=FinishReason.STOP,
|
||||
metadata={"usage": {"prompt_tokens": 10, "completion_tokens": 5}},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
adapter = SKChatCompletionAdapter(sk_client=mock_client, kernel=Kernel(memory=NullMemory()))
|
||||
|
||||
# Messages include:
|
||||
# 1) SystemMessage
|
||||
# 2) UserMessage
|
||||
# 3) AssistantMessage with a function call
|
||||
# 4) FunctionExecutionResultMessage
|
||||
# 5) AssistantMessage with plain text
|
||||
|
||||
messages: list[LLMMessage] = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="What is 3 + 5?", source="user"),
|
||||
AssistantMessage(
|
||||
content=[
|
||||
FunctionCall(
|
||||
id="call_1",
|
||||
name="calculator",
|
||||
arguments='{"a":3,"b":5}',
|
||||
)
|
||||
],
|
||||
thought="Let me call the calculator function",
|
||||
source="assistant",
|
||||
),
|
||||
FunctionExecutionResultMessage(
|
||||
content=[
|
||||
FunctionExecutionResult(
|
||||
call_id="call_1",
|
||||
name="calculator",
|
||||
content="8",
|
||||
)
|
||||
]
|
||||
),
|
||||
AssistantMessage(content="The answer is 8.", source="assistant"),
|
||||
]
|
||||
|
||||
# Run create (which triggers _convert_to_chat_history internally)
|
||||
result = await adapter.create(messages=messages)
|
||||
|
||||
# Verify final CreateResult
|
||||
assert isinstance(result.content, str)
|
||||
assert "All done!" in result.content
|
||||
assert result.finish_reason == "stop"
|
||||
|
||||
# Ensure the underlying client was called with a properly built ChatHistory
|
||||
mock_client.get_chat_message_contents.assert_awaited_once()
|
||||
chat_history_arg = mock_client.get_chat_message_contents.call_args[0][0] # The ChatHistory passed in
|
||||
|
||||
# Expecting 5 messages in the ChatHistory
|
||||
assert len(chat_history_arg) == 6
|
||||
|
||||
# 1) System message
|
||||
assert chat_history_arg[0].role == AuthorRole.SYSTEM
|
||||
assert chat_history_arg[0].items[0].text == "You are a helpful assistant."
|
||||
|
||||
# 2) User message
|
||||
assert chat_history_arg[1].role == AuthorRole.USER
|
||||
assert chat_history_arg[1].items[0].text == "What is 3 + 5?"
|
||||
|
||||
# 3) Assistant message with thought
|
||||
assert chat_history_arg[2].role == AuthorRole.ASSISTANT
|
||||
assert chat_history_arg[2].items[0].text == "Let me call the calculator function"
|
||||
|
||||
# 4) Assistant message with function call
|
||||
assert chat_history_arg[3].role == AuthorRole.ASSISTANT
|
||||
assert chat_history_arg[3].finish_reason == FinishReason.TOOL_CALLS
|
||||
# Should have one FunctionCallContent
|
||||
func_call_contents = chat_history_arg[3].items
|
||||
assert len(func_call_contents) == 1
|
||||
assert func_call_contents[0].id == "call_1"
|
||||
assert func_call_contents[0].function_name == "calculator"
|
||||
assert func_call_contents[0].arguments == '{"a":3,"b":5}'
|
||||
assert func_call_contents[0].plugin_name == "autogen_tools"
|
||||
|
||||
# 5) Function execution result message
|
||||
assert chat_history_arg[4].role == AuthorRole.TOOL
|
||||
tool_contents = chat_history_arg[4].items
|
||||
assert len(tool_contents) == 1
|
||||
assert tool_contents[0].id == "call_1"
|
||||
assert tool_contents[0].result == "8"
|
||||
assert tool_contents[0].function_name == "calculator"
|
||||
assert tool_contents[0].plugin_name == "autogen_tools"
|
||||
|
||||
# 6) Assistant message with plain text
|
||||
assert chat_history_arg[5].role == AuthorRole.ASSISTANT
|
||||
assert chat_history_arg[5].items[0].text == "The answer is 8."
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user