Add tool_choice parameter to ChatCompletionClient interface and all implementations

Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2025-06-20 12:42:44 +00:00
parent c980baabca
commit 8a2fb5a3bd
8 changed files with 238 additions and 2 deletions

View File

@ -211,6 +211,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool | type[BaseModel]] = None,
@ -222,6 +223,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
Args:
messages (Sequence[LLMMessage]): The messages to send to the model.
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
tool_choice (Optional[Sequence[Union[str, Tool]]], optional): A list of tool names or Tool objects to restrict the model's choice to. Defaults to None.
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither.
Defaults to None. If set to a `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ type,
it will be used as the output type for structured output.
@ -241,6 +243,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool | type[BaseModel]] = None,
@ -252,6 +255,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
Args:
messages (Sequence[LLMMessage]): The messages to send to the model.
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
tool_choice (Optional[Sequence[Union[str, Tool]]], optional): A list of tool names or Tool objects to restrict the model's choice to. Defaults to None.
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither.
Defaults to None. If set to a `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ type,
it will be used as the output type for structured output.

View File

@ -149,6 +149,44 @@ def get_mime_type_from_image(image: Image) -> Literal["image/jpeg", "image/png",
return "image/jpeg"
def convert_tool_choice_anthropic(
tool_choice: Optional[Sequence[Union[str, Tool]]]
) -> Any:
"""Convert tool_choice parameter to Anthropic API format.
Args:
tool_choice: List of tool names (strings) or Tool objects to restrict model choice to.
Returns:
Anthropic API compatible tool_choice value or None if not specified.
"""
if tool_choice is None:
return None
if len(tool_choice) == 0:
return None
# Convert Tool objects to names if needed
tool_names = []
for item in tool_choice:
if isinstance(item, str):
tool_names.append(item)
elif isinstance(item, Tool):
tool_names.append(item.schema["name"])
else:
raise ValueError(f"tool_choice items must be strings or Tool objects, got {type(item)}")
# For Anthropic API, if we want to force use of a specific tool:
if len(tool_names) == 1:
return {
"type": "tool",
"name": tool_names[0]
}
else:
# For multiple tools, use "any" mode which forces the model to use any available tool
return {"type": "any"}
@overload
def __empty_content_to_whitespace(content: str) -> str: ...
@ -504,6 +542,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -581,6 +620,34 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
elif has_tool_results:
# anthropic requires tools to be present even if there is any tool use
request_args["tools"] = self._last_used_tools
# Process tool_choice parameter
if tool_choice is not None:
if len(tools) == 0 and not has_tool_results:
raise ValueError("tool_choice specified but no tools provided")
# Validate that all tool_choice items exist in the provided tools
tool_names_available = []
if len(tools) > 0:
for tool in tools:
if isinstance(tool, Tool):
tool_names_available.append(tool.schema["name"])
else:
tool_names_available.append(tool["name"])
else:
# Use last used tools names if available
for tool in self._last_used_tools:
tool_names_available.append(tool["name"])
for item in tool_choice:
tool_name = item if isinstance(item, str) else item.schema["name"]
if tool_name not in tool_names_available:
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the available tools")
# Convert to Anthropic format and add to request_args
converted_tool_choice = convert_tool_choice_anthropic(tool_choice)
if converted_tool_choice is not None:
request_args["tool_choice"] = converted_tool_choice
# Optional parameters
for param in ["top_p", "top_k", "stop_sequences", "metadata"]:
@ -667,6 +734,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -750,6 +818,34 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
request_args["tools"] = converted_tools
elif has_tool_results:
request_args["tools"] = self._last_used_tools
# Process tool_choice parameter
if tool_choice is not None:
if len(tools) == 0 and not has_tool_results:
raise ValueError("tool_choice specified but no tools provided")
# Validate that all tool_choice items exist in the provided tools
tool_names_available = []
if len(tools) > 0:
for tool in tools:
if isinstance(tool, Tool):
tool_names_available.append(tool.schema["name"])
else:
tool_names_available.append(tool["name"])
else:
# Use last used tools names if available
for tool in self._last_used_tools:
tool_names_available.append(tool["name"])
for item in tool_choice:
tool_name = item if isinstance(item, str) else item.schema["name"]
if tool_name not in tool_names_available:
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the available tools")
# Convert to Anthropic format and add to request_args
converted_tool_choice = convert_tool_choice_anthropic(tool_choice)
if converted_tool_choice is not None:
request_args["tool_choice"] = converted_tool_choice
# Optional parameters
for param in ["top_p", "top_k", "stop_sequences", "metadata"]:

View File

@ -3,7 +3,7 @@ import logging
import re
from asyncio import Task
from inspect import getfullargspec
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, cast
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall, Image
from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent, LLMStreamStartEvent
@ -356,6 +356,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -373,6 +374,12 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
azure_messages_nested = [to_azure_message(msg) for msg in messages]
azure_messages = [item for sublist in azure_messages_nested for item in sublist]
# Handle tool_choice parameter - log warning as it might not be supported by Azure AI
if tool_choice is not None:
if len(tools) == 0:
raise ValueError("tool_choice specified but no tools provided")
logger.warning("tool_choice parameter specified but may not be supported by Azure AI Inference API")
task: Task[ChatCompletions]
if len(tools) > 0:
@ -451,6 +458,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -468,6 +476,12 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
azure_messages_nested = [to_azure_message(msg) for msg in messages]
azure_messages = [item for sublist in azure_messages_nested for item in sublist]
# Handle tool_choice parameter - log warning as it might not be supported by Azure AI
if tool_choice is not None:
if len(tools) == 0:
raise ValueError("tool_choice specified but no tools provided")
logger.warning("tool_choice parameter specified but may not be supported by Azure AI Inference API")
if len(tools) > 0:
converted_tools = convert_tools(tools)
task = asyncio.create_task(

View File

@ -264,6 +264,7 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool | type[BaseModel]] = None,
@ -302,6 +303,14 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
elif json_output is not False and json_output is not None:
raise ValueError("json_output must be a boolean, a BaseModel subclass or None.")
# Handle tool_choice parameter
if tool_choice is not None:
if not self.model_info["function_calling"]:
raise ValueError("tool_choice specified but model does not support function calling")
if len(tools) == 0:
raise ValueError("tool_choice specified but no tools provided")
logger.warning("tool_choice parameter specified but may not be supported by llama-cpp-python")
if self.model_info["function_calling"]:
# Run this in on the event loop to avoid blocking.
response_future = asyncio.get_event_loop().run_in_executor(
@ -397,12 +406,21 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
# Validate tool_choice parameter even though streaming is not implemented
if tool_choice is not None:
if not self.model_info["function_calling"]:
raise ValueError("tool_choice specified but model does not support function calling")
if len(tools) == 0:
raise ValueError("tool_choice specified but no tools provided")
logger.warning("tool_choice parameter specified but may not be supported by llama-cpp-python")
raise NotImplementedError("Stream not yet implemented for LlamaCppChatCompletionClient")
yield ""

View File

@ -514,6 +514,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema],
tool_choice: Optional[Sequence[Union[str, Tool]]],
json_output: Optional[bool | type[BaseModel]],
extra_create_args: Mapping[str, Any],
) -> CreateParams:
@ -585,6 +586,12 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
raise ValueError("Model does not support function calling and tools were provided")
converted_tools = convert_tools(tools)
# Handle tool_choice parameter - log warning as it might not be supported by Ollama
if tool_choice is not None:
if len(tools) == 0:
raise ValueError("tool_choice specified but no tools provided")
trace_logger.warning("tool_choice parameter specified but may not be supported by Ollama API")
return CreateParams(
messages=ollama_messages,
@ -598,6 +605,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -610,6 +618,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
create_params = self._process_create_args(
messages,
tools,
tool_choice,
json_output,
extra_create_args,
)
@ -704,6 +713,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -716,6 +726,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
create_params = self._process_create_args(
messages,
tools,
tool_choice,
json_output,
extra_create_args,
)

View File

@ -265,6 +265,47 @@ def convert_tools(
return result
def convert_tool_choice(
tool_choice: Optional[Sequence[Union[str, Tool]]]
) -> Any:
"""Convert tool_choice parameter to OpenAI API format.
Args:
tool_choice: List of tool names (strings) or Tool objects to restrict model choice to.
Returns:
OpenAI API compatible tool_choice value or None if not specified.
"""
if tool_choice is None:
return None
if len(tool_choice) == 0:
return None
# Convert Tool objects to names if needed
tool_names = []
for item in tool_choice:
if isinstance(item, str):
tool_names.append(item)
elif isinstance(item, Tool):
tool_names.append(item.schema["name"])
else:
raise ValueError(f"tool_choice items must be strings or Tool objects, got {type(item)}")
# For OpenAI API, if we want to restrict to specific tools, we can use the "required" mode
# Since OpenAI doesn't support specifying multiple specific tools directly,
# we'll return the first tool as a specific choice if only one is provided,
# or use "required" mode for multiple tools (letting the model choose among them)
if len(tool_names) == 1:
return {
"type": "function",
"function": {"name": tool_names[0]}
}
else:
# For multiple tools, use "required" mode which forces the model to use any available tool
return "required"
def normalize_name(name: str) -> str:
"""
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
@ -449,6 +490,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema],
tool_choice: Optional[Sequence[Union[str, Tool]]],
json_output: Optional[bool | type[BaseModel]],
extra_create_args: Mapping[str, Any],
) -> CreateParams:
@ -574,6 +616,29 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
raise ValueError("Model does not support function calling")
converted_tools = convert_tools(tools)
# Process tool_choice parameter
if tool_choice is not None:
if len(tools) == 0:
raise ValueError("tool_choice specified but no tools provided")
# Validate that all tool_choice items exist in the provided tools
tool_names_available = []
for tool in tools:
if isinstance(tool, Tool):
tool_names_available.append(tool.schema["name"])
else:
tool_names_available.append(tool["name"])
for item in tool_choice:
tool_name = item if isinstance(item, str) else item.schema["name"]
if tool_name not in tool_names_available:
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the provided tools")
# Convert to OpenAI format and add to create_args
converted_tool_choice = convert_tool_choice(tool_choice)
if converted_tool_choice is not None:
create_args["tool_choice"] = converted_tool_choice
return CreateParams(
messages=oai_messages,
@ -587,6 +652,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -594,6 +660,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
create_params = self._process_create_args(
messages,
tools,
tool_choice,
json_output,
extra_create_args,
)
@ -736,6 +803,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -767,6 +835,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
create_params = self._process_create_args(
messages,
tools,
tool_choice,
json_output,
extra_create_args,
)

View File

@ -162,11 +162,16 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
"""Return the next completion from the list."""
# Warn if tool_choice is specified since it's ignored in replay mode
if tool_choice is not None:
logger.warning("tool_choice parameter specified but is ignored in replay mode")
if self._current_index >= len(self.chat_completions):
raise ValueError("No more mock responses available")
@ -201,11 +206,16 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""Return the next completion as a stream."""
# Warn if tool_choice is specified since it's ignored in replay mode
if tool_choice is not None:
logger.warning("tool_choice parameter specified but is ignored in replay mode")
if self._current_index >= len(self.chat_completions):
raise ValueError("No more mock responses available")

View File

@ -1,7 +1,7 @@
import json
import logging
import warnings
from typing import Any, Literal, Mapping, Optional, Sequence
from typing import Any, Literal, Mapping, Optional, Sequence, Union
from autogen_core import EVENT_LOGGER_NAME, FunctionCall
from autogen_core._cancellation_token import CancellationToken
@ -442,6 +442,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -473,6 +474,12 @@ class SKChatCompletionAdapter(ChatCompletionClient):
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
# Handle tool_choice parameter
if tool_choice is not None:
if len(tools) == 0:
raise ValueError("tool_choice specified but no tools provided")
logger.warning("tool_choice parameter specified but may not be fully supported by Semantic Kernel")
kernel = self._get_kernel(extra_create_args)
chat_history = self._convert_to_chat_history(messages)
@ -553,6 +560,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -585,6 +593,12 @@ class SKChatCompletionAdapter(ChatCompletionClient):
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
# Handle tool_choice parameter
if tool_choice is not None:
if len(tools) == 0:
raise ValueError("tool_choice specified but no tools provided")
logger.warning("tool_choice parameter specified but may not be fully supported by Semantic Kernel")
kernel = self._get_kernel(extra_create_args)
chat_history = self._convert_to_chat_history(messages)
user_settings = self._get_prompt_settings(extra_create_args)