mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-28 23:49:13 +00:00
Add tool_choice parameter to ChatCompletionClient interface and all implementations
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
This commit is contained in:
parent
c980baabca
commit
8a2fb5a3bd
@ -211,6 +211,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
@ -222,6 +223,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
Args:
|
||||
messages (Sequence[LLMMessage]): The messages to send to the model.
|
||||
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
|
||||
tool_choice (Optional[Sequence[Union[str, Tool]]], optional): A list of tool names or Tool objects to restrict the model's choice to. Defaults to None.
|
||||
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither.
|
||||
Defaults to None. If set to a `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ type,
|
||||
it will be used as the output type for structured output.
|
||||
@ -241,6 +243,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
@ -252,6 +255,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
Args:
|
||||
messages (Sequence[LLMMessage]): The messages to send to the model.
|
||||
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
|
||||
tool_choice (Optional[Sequence[Union[str, Tool]]], optional): A list of tool names or Tool objects to restrict the model's choice to. Defaults to None.
|
||||
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither.
|
||||
Defaults to None. If set to a `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ type,
|
||||
it will be used as the output type for structured output.
|
||||
|
||||
@ -149,6 +149,44 @@ def get_mime_type_from_image(image: Image) -> Literal["image/jpeg", "image/png",
|
||||
return "image/jpeg"
|
||||
|
||||
|
||||
def convert_tool_choice_anthropic(
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]]
|
||||
) -> Any:
|
||||
"""Convert tool_choice parameter to Anthropic API format.
|
||||
|
||||
Args:
|
||||
tool_choice: List of tool names (strings) or Tool objects to restrict model choice to.
|
||||
|
||||
Returns:
|
||||
Anthropic API compatible tool_choice value or None if not specified.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
|
||||
if len(tool_choice) == 0:
|
||||
return None
|
||||
|
||||
# Convert Tool objects to names if needed
|
||||
tool_names = []
|
||||
for item in tool_choice:
|
||||
if isinstance(item, str):
|
||||
tool_names.append(item)
|
||||
elif isinstance(item, Tool):
|
||||
tool_names.append(item.schema["name"])
|
||||
else:
|
||||
raise ValueError(f"tool_choice items must be strings or Tool objects, got {type(item)}")
|
||||
|
||||
# For Anthropic API, if we want to force use of a specific tool:
|
||||
if len(tool_names) == 1:
|
||||
return {
|
||||
"type": "tool",
|
||||
"name": tool_names[0]
|
||||
}
|
||||
else:
|
||||
# For multiple tools, use "any" mode which forces the model to use any available tool
|
||||
return {"type": "any"}
|
||||
|
||||
|
||||
@overload
|
||||
def __empty_content_to_whitespace(content: str) -> str: ...
|
||||
|
||||
@ -504,6 +542,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -581,6 +620,34 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
elif has_tool_results:
|
||||
# anthropic requires tools to be present even if there is any tool use
|
||||
request_args["tools"] = self._last_used_tools
|
||||
|
||||
# Process tool_choice parameter
|
||||
if tool_choice is not None:
|
||||
if len(tools) == 0 and not has_tool_results:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
|
||||
# Validate that all tool_choice items exist in the provided tools
|
||||
tool_names_available = []
|
||||
if len(tools) > 0:
|
||||
for tool in tools:
|
||||
if isinstance(tool, Tool):
|
||||
tool_names_available.append(tool.schema["name"])
|
||||
else:
|
||||
tool_names_available.append(tool["name"])
|
||||
else:
|
||||
# Use last used tools names if available
|
||||
for tool in self._last_used_tools:
|
||||
tool_names_available.append(tool["name"])
|
||||
|
||||
for item in tool_choice:
|
||||
tool_name = item if isinstance(item, str) else item.schema["name"]
|
||||
if tool_name not in tool_names_available:
|
||||
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the available tools")
|
||||
|
||||
# Convert to Anthropic format and add to request_args
|
||||
converted_tool_choice = convert_tool_choice_anthropic(tool_choice)
|
||||
if converted_tool_choice is not None:
|
||||
request_args["tool_choice"] = converted_tool_choice
|
||||
|
||||
# Optional parameters
|
||||
for param in ["top_p", "top_k", "stop_sequences", "metadata"]:
|
||||
@ -667,6 +734,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -750,6 +818,34 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
request_args["tools"] = converted_tools
|
||||
elif has_tool_results:
|
||||
request_args["tools"] = self._last_used_tools
|
||||
|
||||
# Process tool_choice parameter
|
||||
if tool_choice is not None:
|
||||
if len(tools) == 0 and not has_tool_results:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
|
||||
# Validate that all tool_choice items exist in the provided tools
|
||||
tool_names_available = []
|
||||
if len(tools) > 0:
|
||||
for tool in tools:
|
||||
if isinstance(tool, Tool):
|
||||
tool_names_available.append(tool.schema["name"])
|
||||
else:
|
||||
tool_names_available.append(tool["name"])
|
||||
else:
|
||||
# Use last used tools names if available
|
||||
for tool in self._last_used_tools:
|
||||
tool_names_available.append(tool["name"])
|
||||
|
||||
for item in tool_choice:
|
||||
tool_name = item if isinstance(item, str) else item.schema["name"]
|
||||
if tool_name not in tool_names_available:
|
||||
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the available tools")
|
||||
|
||||
# Convert to Anthropic format and add to request_args
|
||||
converted_tool_choice = convert_tool_choice_anthropic(tool_choice)
|
||||
if converted_tool_choice is not None:
|
||||
request_args["tool_choice"] = converted_tool_choice
|
||||
|
||||
# Optional parameters
|
||||
for param in ["top_p", "top_k", "stop_sequences", "metadata"]:
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import re
|
||||
from asyncio import Task
|
||||
from inspect import getfullargspec
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, cast
|
||||
|
||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall, Image
|
||||
from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent, LLMStreamStartEvent
|
||||
@ -356,6 +356,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -373,6 +374,12 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
azure_messages_nested = [to_azure_message(msg) for msg in messages]
|
||||
azure_messages = [item for sublist in azure_messages_nested for item in sublist]
|
||||
|
||||
# Handle tool_choice parameter - log warning as it might not be supported by Azure AI
|
||||
if tool_choice is not None:
|
||||
if len(tools) == 0:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
logger.warning("tool_choice parameter specified but may not be supported by Azure AI Inference API")
|
||||
|
||||
task: Task[ChatCompletions]
|
||||
|
||||
if len(tools) > 0:
|
||||
@ -451,6 +458,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -468,6 +476,12 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
azure_messages_nested = [to_azure_message(msg) for msg in messages]
|
||||
azure_messages = [item for sublist in azure_messages_nested for item in sublist]
|
||||
|
||||
# Handle tool_choice parameter - log warning as it might not be supported by Azure AI
|
||||
if tool_choice is not None:
|
||||
if len(tools) == 0:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
logger.warning("tool_choice parameter specified but may not be supported by Azure AI Inference API")
|
||||
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
task = asyncio.create_task(
|
||||
|
||||
@ -264,6 +264,7 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
@ -302,6 +303,14 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
elif json_output is not False and json_output is not None:
|
||||
raise ValueError("json_output must be a boolean, a BaseModel subclass or None.")
|
||||
|
||||
# Handle tool_choice parameter
|
||||
if tool_choice is not None:
|
||||
if not self.model_info["function_calling"]:
|
||||
raise ValueError("tool_choice specified but model does not support function calling")
|
||||
if len(tools) == 0:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
logger.warning("tool_choice parameter specified but may not be supported by llama-cpp-python")
|
||||
|
||||
if self.model_info["function_calling"]:
|
||||
# Run this in on the event loop to avoid blocking.
|
||||
response_future = asyncio.get_event_loop().run_in_executor(
|
||||
@ -397,12 +406,21 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
# Validate tool_choice parameter even though streaming is not implemented
|
||||
if tool_choice is not None:
|
||||
if not self.model_info["function_calling"]:
|
||||
raise ValueError("tool_choice specified but model does not support function calling")
|
||||
if len(tools) == 0:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
logger.warning("tool_choice parameter specified but may not be supported by llama-cpp-python")
|
||||
|
||||
raise NotImplementedError("Stream not yet implemented for LlamaCppChatCompletionClient")
|
||||
yield ""
|
||||
|
||||
|
||||
@ -514,6 +514,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]],
|
||||
json_output: Optional[bool | type[BaseModel]],
|
||||
extra_create_args: Mapping[str, Any],
|
||||
) -> CreateParams:
|
||||
@ -585,6 +586,12 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
raise ValueError("Model does not support function calling and tools were provided")
|
||||
|
||||
converted_tools = convert_tools(tools)
|
||||
|
||||
# Handle tool_choice parameter - log warning as it might not be supported by Ollama
|
||||
if tool_choice is not None:
|
||||
if len(tools) == 0:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
trace_logger.warning("tool_choice parameter specified but may not be supported by Ollama API")
|
||||
|
||||
return CreateParams(
|
||||
messages=ollama_messages,
|
||||
@ -598,6 +605,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -610,6 +618,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
tool_choice,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
@ -704,6 +713,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -716,6 +726,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
tool_choice,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
|
||||
@ -265,6 +265,47 @@ def convert_tools(
|
||||
return result
|
||||
|
||||
|
||||
def convert_tool_choice(
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]]
|
||||
) -> Any:
|
||||
"""Convert tool_choice parameter to OpenAI API format.
|
||||
|
||||
Args:
|
||||
tool_choice: List of tool names (strings) or Tool objects to restrict model choice to.
|
||||
|
||||
Returns:
|
||||
OpenAI API compatible tool_choice value or None if not specified.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
|
||||
if len(tool_choice) == 0:
|
||||
return None
|
||||
|
||||
# Convert Tool objects to names if needed
|
||||
tool_names = []
|
||||
for item in tool_choice:
|
||||
if isinstance(item, str):
|
||||
tool_names.append(item)
|
||||
elif isinstance(item, Tool):
|
||||
tool_names.append(item.schema["name"])
|
||||
else:
|
||||
raise ValueError(f"tool_choice items must be strings or Tool objects, got {type(item)}")
|
||||
|
||||
# For OpenAI API, if we want to restrict to specific tools, we can use the "required" mode
|
||||
# Since OpenAI doesn't support specifying multiple specific tools directly,
|
||||
# we'll return the first tool as a specific choice if only one is provided,
|
||||
# or use "required" mode for multiple tools (letting the model choose among them)
|
||||
if len(tool_names) == 1:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {"name": tool_names[0]}
|
||||
}
|
||||
else:
|
||||
# For multiple tools, use "required" mode which forces the model to use any available tool
|
||||
return "required"
|
||||
|
||||
|
||||
def normalize_name(name: str) -> str:
|
||||
"""
|
||||
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
|
||||
@ -449,6 +490,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]],
|
||||
json_output: Optional[bool | type[BaseModel]],
|
||||
extra_create_args: Mapping[str, Any],
|
||||
) -> CreateParams:
|
||||
@ -574,6 +616,29 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
raise ValueError("Model does not support function calling")
|
||||
|
||||
converted_tools = convert_tools(tools)
|
||||
|
||||
# Process tool_choice parameter
|
||||
if tool_choice is not None:
|
||||
if len(tools) == 0:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
|
||||
# Validate that all tool_choice items exist in the provided tools
|
||||
tool_names_available = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, Tool):
|
||||
tool_names_available.append(tool.schema["name"])
|
||||
else:
|
||||
tool_names_available.append(tool["name"])
|
||||
|
||||
for item in tool_choice:
|
||||
tool_name = item if isinstance(item, str) else item.schema["name"]
|
||||
if tool_name not in tool_names_available:
|
||||
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the provided tools")
|
||||
|
||||
# Convert to OpenAI format and add to create_args
|
||||
converted_tool_choice = convert_tool_choice(tool_choice)
|
||||
if converted_tool_choice is not None:
|
||||
create_args["tool_choice"] = converted_tool_choice
|
||||
|
||||
return CreateParams(
|
||||
messages=oai_messages,
|
||||
@ -587,6 +652,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -594,6 +660,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
tool_choice,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
@ -736,6 +803,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -767,6 +835,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
tool_choice,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
|
||||
@ -162,11 +162,16 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
"""Return the next completion from the list."""
|
||||
# Warn if tool_choice is specified since it's ignored in replay mode
|
||||
if tool_choice is not None:
|
||||
logger.warning("tool_choice parameter specified but is ignored in replay mode")
|
||||
|
||||
if self._current_index >= len(self.chat_completions):
|
||||
raise ValueError("No more mock responses available")
|
||||
|
||||
@ -201,11 +206,16 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
"""Return the next completion as a stream."""
|
||||
# Warn if tool_choice is specified since it's ignored in replay mode
|
||||
if tool_choice is not None:
|
||||
logger.warning("tool_choice parameter specified but is ignored in replay mode")
|
||||
|
||||
if self._current_index >= len(self.chat_completions):
|
||||
raise ValueError("No more mock responses available")
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Literal, Mapping, Optional, Sequence
|
||||
from typing import Any, Literal, Mapping, Optional, Sequence, Union
|
||||
|
||||
from autogen_core import EVENT_LOGGER_NAME, FunctionCall
|
||||
from autogen_core._cancellation_token import CancellationToken
|
||||
@ -442,6 +442,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -473,6 +474,12 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
|
||||
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
|
||||
|
||||
# Handle tool_choice parameter
|
||||
if tool_choice is not None:
|
||||
if len(tools) == 0:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
logger.warning("tool_choice parameter specified but may not be fully supported by Semantic Kernel")
|
||||
|
||||
kernel = self._get_kernel(extra_create_args)
|
||||
|
||||
chat_history = self._convert_to_chat_history(messages)
|
||||
@ -553,6 +560,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Optional[Sequence[Union[str, Tool]]] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -585,6 +593,12 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
|
||||
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
|
||||
|
||||
# Handle tool_choice parameter
|
||||
if tool_choice is not None:
|
||||
if len(tools) == 0:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
logger.warning("tool_choice parameter specified but may not be fully supported by Semantic Kernel")
|
||||
|
||||
kernel = self._get_kernel(extra_create_args)
|
||||
chat_history = self._convert_to_chat_history(messages)
|
||||
user_settings = self._get_prompt_settings(extra_create_args)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user