Add tool_choice parameter to ChatCompletionClient create and create_stream methods (#6697)

## Summary

Implements the `tool_choice` parameter for `ChatCompletionClient`
interface as requested in #6696. This allows users to restrict which
tools the model can choose from when multiple tools are available.

## Changes

### Core Interface
- Core Interface: Added `tool_choice: Tool | Literal["auto", "required",
"none"] = "auto"` parameter to `ChatCompletionClient.create()` and
`create_stream()` methods
- Model Implementations: Updated client implementations to support the
new parameter, for now, only the following model clients are supported:
  - OpenAI
  - Anthropic
  - Azure AI
  - Ollama
- `LlamaCppChatCompletionClient` currently not supported

Features
- "auto" (default): Let the model choose whether to use tools, when
there is no tool, it has no effect.
- "required": Force the model to use at least one tool
- "none": Disable tool usage completely
- Tool object: Force the model to use a specific tool

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Copilot 2025-06-30 14:15:28 +09:00 committed by GitHub
parent 6f15270cb2
commit c150f85044
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1889 additions and 60 deletions

View File

@ -211,8 +211,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -222,6 +221,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
Args:
messages (Sequence[LLMMessage]): The messages to send to the model.
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
tool_choice (Tool | Literal["auto", "required", "none"], optional): A single Tool object to force the model to use, "auto" to let the model choose any available tool, "required" to force tool usage, or "none" to disable tool usage. Defaults to "auto".
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither.
Defaults to None. If set to a `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ type,
it will be used as the output type for structured output.
@ -241,8 +241,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -252,6 +251,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
Args:
messages (Sequence[LLMMessage]): The messages to send to the model.
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
tool_choice (Tool | Literal["auto", "required", "none"], optional): A single Tool object to force the model to use, "auto" to let the model choose any available tool, "required" to force tool usage, or "none" to disable tool usage. Defaults to "auto".
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither.
Defaults to None. If set to a `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ type,
it will be used as the output type for structured output.

View File

@ -1,7 +1,7 @@
import asyncio
import json
import logging
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union
import pytest
from autogen_core import EVENT_LOGGER_NAME, AgentId, CancellationToken, FunctionCall, SingleThreadedAgentRuntime
@ -102,6 +102,7 @@ async def test_caller_loop() -> None:
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -127,6 +128,7 @@ async def test_caller_loop() -> None:
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,

View File

@ -90,6 +90,7 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
) -> CreateResult:
current_messages: List[Mapping[str, Any]] = [msg.model_dump() for msg in messages]
if self.mode == "record":
@ -97,6 +98,7 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
messages,
tools=tools,
json_output=json_output,
tool_choice=tool_choice,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,
)
@ -157,10 +159,12 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
) -> AsyncGenerator[Union[str, CreateResult], None]:
return self.base_client.create_stream(
messages,
tools=tools,
tool_choice=tool_choice,
json_output=json_output,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,

View File

@ -149,6 +149,31 @@ def get_mime_type_from_image(image: Image) -> Literal["image/jpeg", "image/png",
return "image/jpeg"
def convert_tool_choice_anthropic(tool_choice: Tool | Literal["auto", "required", "none"]) -> Any:
"""Convert tool_choice parameter to Anthropic API format.
Args:
tool_choice: A single Tool object to force the model to use, "auto" to let the model choose any available tool, "required" to force tool usage, or "none" to disable tool usage.
Returns:
Anthropic API compatible tool_choice value.
"""
if tool_choice == "none":
return {"type": "none"}
if tool_choice == "auto":
return {"type": "auto"}
if tool_choice == "required":
return {"type": "any"} # Anthropic uses "any" for required
# Must be a Tool object
if isinstance(tool_choice, Tool):
return {"type": "tool", "name": tool_choice.schema["name"]}
else:
raise ValueError(f"tool_choice must be a Tool object, 'auto', 'required', or 'none', got {type(tool_choice)}")
@overload
def __empty_content_to_whitespace(content: str) -> str: ...
@ -504,6 +529,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -582,6 +608,36 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
# anthropic requires tools to be present even if there is any tool use
request_args["tools"] = self._last_used_tools
# Process tool_choice parameter
if isinstance(tool_choice, Tool):
if len(tools) == 0 and not has_tool_results:
raise ValueError("tool_choice specified but no tools provided")
# Validate that the tool exists in the provided tools
tool_names_available: List[str] = []
if len(tools) > 0:
for tool in tools:
if isinstance(tool, Tool):
tool_names_available.append(tool.schema["name"])
else:
tool_names_available.append(tool["name"])
else:
# Use last used tools names if available
for tool_param in self._last_used_tools:
tool_names_available.append(tool_param["name"])
# tool_choice is a single Tool object
tool_name = tool_choice.schema["name"]
if tool_name not in tool_names_available:
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the available tools")
# Convert to Anthropic format and add to request_args only if tools are provided
# According to Anthropic API, tool_choice may only be specified while providing tools
if len(tools) > 0 or has_tool_results:
converted_tool_choice = convert_tool_choice_anthropic(tool_choice)
if converted_tool_choice is not None:
request_args["tool_choice"] = converted_tool_choice
# Optional parameters
for param in ["top_p", "top_k", "stop_sequences", "metadata"]:
if param in create_args:
@ -667,6 +723,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -751,6 +808,36 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
elif has_tool_results:
request_args["tools"] = self._last_used_tools
# Process tool_choice parameter
if isinstance(tool_choice, Tool):
if len(tools) == 0 and not has_tool_results:
raise ValueError("tool_choice specified but no tools provided")
# Validate that the tool exists in the provided tools
tool_names_available: List[str] = []
if len(tools) > 0:
for tool in tools:
if isinstance(tool, Tool):
tool_names_available.append(tool.schema["name"])
else:
tool_names_available.append(tool["name"])
else:
# Use last used tools names if available
for last_used_tool in self._last_used_tools:
tool_names_available.append(last_used_tool["name"])
# tool_choice is a single Tool object
tool_name = tool_choice.schema["name"]
if tool_name not in tool_names_available:
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the available tools")
# Convert to Anthropic format and add to request_args only if tools are provided
# According to Anthropic API, tool_choice may only be specified while providing tools
if len(tools) > 0 or has_tool_results:
converted_tool_choice = convert_tool_choice_anthropic(tool_choice)
if converted_tool_choice is not None:
request_args["tool_choice"] = converted_tool_choice
# Optional parameters
for param in ["top_p", "top_k", "stop_sequences", "metadata"]:
if param in create_args:

View File

@ -3,7 +3,7 @@ import logging
import re
from asyncio import Task
from inspect import getfullargspec
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Union, cast
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall, Image
from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent, LLMStreamStartEvent
@ -28,6 +28,8 @@ from azure.ai.inference.models import (
)
from azure.ai.inference.models import (
ChatCompletions,
ChatCompletionsNamedToolChoice,
ChatCompletionsNamedToolChoiceFunction,
ChatCompletionsToolCall,
ChatCompletionsToolDefinition,
CompletionsFinishReason,
@ -53,7 +55,7 @@ from azure.ai.inference.models import (
UserMessage as AzureUserMessage,
)
from pydantic import BaseModel
from typing_extensions import AsyncGenerator, Union, Unpack
from typing_extensions import AsyncGenerator, Unpack
from autogen_ext.models.azure.config import (
GITHUB_MODELS_ENDPOINT,
@ -309,7 +311,10 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
@staticmethod
def _create_client(config: AzureAIChatCompletionClientConfig) -> ChatCompletionsClient:
return ChatCompletionsClient(**config)
# Only pass the parameters that ChatCompletionsClient accepts
# Remove 'model_info' and other client-specific parameters
client_config = {k: v for k, v in config.items() if k not in ("model_info",)}
return ChatCompletionsClient(**client_config) # type: ignore
@staticmethod
def _prepare_create_args(config: Mapping[str, Any]) -> Dict[str, Any]:
@ -356,6 +361,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -376,6 +382,12 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
task: Task[ChatCompletions]
if len(tools) > 0:
if isinstance(tool_choice, Tool):
create_args["tool_choice"] = ChatCompletionsNamedToolChoice(
function=ChatCompletionsNamedToolChoiceFunction(name=tool_choice.name)
)
else:
create_args["tool_choice"] = tool_choice
converted_tools = convert_tools(tools)
task = asyncio.create_task( # type: ignore
self._client.complete(messages=azure_messages, tools=converted_tools, **create_args) # type: ignore
@ -451,6 +463,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -469,6 +482,12 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
azure_messages = [item for sublist in azure_messages_nested for item in sublist]
if len(tools) > 0:
if isinstance(tool_choice, Tool):
create_args["tool_choice"] = ChatCompletionsNamedToolChoice(
function=ChatCompletionsNamedToolChoiceFunction(name=tool_choice.name)
)
else:
create_args["tool_choice"] = tool_choice
converted_tools = convert_tools(tools)
task = asyncio.create_task(
self._client.complete(messages=azure_messages, tools=converted_tools, stream=True, **create_args)

View File

@ -1,7 +1,7 @@
import hashlib
import json
import warnings
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union, cast
from autogen_core import CacheStore, CancellationToken, Component, ComponentModel, InMemoryStore
from autogen_core.models import (
@ -137,6 +137,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -158,6 +159,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
messages,
tools=tools,
json_output=json_output,
tool_choice=tool_choice,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,
)
@ -169,6 +171,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -200,6 +203,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
messages,
tools=tools,
json_output=json_output,
tool_choice=tool_choice,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,
)

View File

@ -1,6 +1,7 @@
import asyncio
import logging # added import
import re
import warnings
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, TypedDict, Union, cast
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall, MessageHandlerContext
@ -264,6 +265,7 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool | type[BaseModel]] = None,
@ -302,6 +304,15 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
elif json_output is not False and json_output is not None:
raise ValueError("json_output must be a boolean, a BaseModel subclass or None.")
# Handle tool_choice parameter
if tool_choice != "auto":
warnings.warn(
"tool_choice parameter is specified but LlamaCppChatCompletionClient does not support it. "
"This parameter will be ignored.",
UserWarning,
stacklevel=2,
)
if self.model_info["function_calling"]:
# Run this in on the event loop to avoid blocking.
response_future = asyncio.get_event_loop().run_in_executor(
@ -397,12 +408,21 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
# Validate tool_choice parameter even though streaming is not implemented
if tool_choice != "auto" and tool_choice != "none":
if not self.model_info["function_calling"]:
raise ValueError("tool_choice specified but model does not support function calling")
if len(tools) == 0:
raise ValueError("tool_choice specified but no tools provided")
logger.warning("tool_choice parameter specified but may not be supported by llama-cpp-python")
raise NotImplementedError("Stream not yet implemented for LlamaCppChatCompletionClient")
yield ""

View File

@ -514,6 +514,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema],
tool_choice: Tool | Literal["auto", "required", "none"],
json_output: Optional[bool | type[BaseModel]],
extra_create_args: Mapping[str, Any],
) -> CreateParams:
@ -584,7 +585,22 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
if self.model_info["function_calling"] is False and len(tools) > 0:
raise ValueError("Model does not support function calling and tools were provided")
converted_tools = convert_tools(tools)
converted_tools: List[OllamaTool] = []
# Handle tool_choice parameter in a way that is compatible with Ollama API.
if isinstance(tool_choice, Tool):
# If tool_choice is a Tool, convert it to OllamaTool.
converted_tools = convert_tools([tool_choice])
elif tool_choice == "none":
# No tool choice, do not pass tools to the API.
converted_tools = []
elif tool_choice == "required":
# Required tool choice, pass tools to the API.
converted_tools = convert_tools(tools)
if len(converted_tools) == 0:
raise ValueError("tool_choice 'required' specified but no tools provided")
else:
converted_tools = convert_tools(tools)
return CreateParams(
messages=ollama_messages,
@ -598,6 +614,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -610,6 +627,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
create_params = self._process_create_args(
messages,
tools,
tool_choice,
json_output,
extra_create_args,
)
@ -704,6 +722,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -716,6 +735,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
create_params = self._process_create_args(
messages,
tools,
tool_choice,
json_output,
extra_create_args,
)

View File

@ -15,6 +15,7 @@ from typing import (
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
@ -265,6 +266,31 @@ def convert_tools(
return result
def convert_tool_choice(tool_choice: Tool | Literal["auto", "required", "none"]) -> Any:
"""Convert tool_choice parameter to OpenAI API format.
Args:
tool_choice: A single Tool object to force the model to use, "auto" to let the model choose any available tool, "required" to force tool usage, or "none" to disable tool usage.
Returns:
OpenAI API compatible tool_choice value or None if not specified.
"""
if tool_choice == "none":
return "none"
if tool_choice == "auto":
return "auto"
if tool_choice == "required":
return "required"
# Must be a Tool object
if isinstance(tool_choice, Tool):
return {"type": "function", "function": {"name": tool_choice.schema["name"]}}
else:
raise ValueError(f"tool_choice must be a Tool object, 'auto', 'required', or 'none', got {type(tool_choice)}")
def normalize_name(name: str) -> str:
"""
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
@ -449,6 +475,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema],
tool_choice: Tool | Literal["auto", "required", "none"],
json_output: Optional[bool | type[BaseModel]],
extra_create_args: Mapping[str, Any],
) -> CreateParams:
@ -575,6 +602,29 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
converted_tools = convert_tools(tools)
# Process tool_choice parameter
if isinstance(tool_choice, Tool):
if len(tools) == 0:
raise ValueError("tool_choice specified but no tools provided")
# Validate that the tool exists in the provided tools
tool_names_available: List[str] = []
for tool in tools:
if isinstance(tool, Tool):
tool_names_available.append(tool.schema["name"])
else:
tool_names_available.append(tool["name"])
# tool_choice is a single Tool object
tool_name = tool_choice.schema["name"]
if tool_name not in tool_names_available:
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the provided tools")
if len(converted_tools) > 0:
# Convert to OpenAI format and add to create_args
converted_tool_choice = convert_tool_choice(tool_choice)
create_args["tool_choice"] = converted_tool_choice
return CreateParams(
messages=oai_messages,
tools=converted_tools,
@ -587,6 +637,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -594,6 +645,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
create_params = self._process_create_args(
messages,
tools,
tool_choice,
json_output,
extra_create_args,
)
@ -738,6 +790,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -769,6 +822,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
create_params = self._process_create_args(
messages,
tools,
tool_choice,
json_output,
extra_create_args,
)

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import logging
import warnings
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Union
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, Union
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component
from autogen_core.models import (
@ -162,11 +162,16 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
"""Return the next completion from the list."""
# Warn if tool_choice is specified since it's ignored in replay mode
if tool_choice != "auto":
logger.warning("tool_choice parameter specified but is ignored in replay mode")
if self._current_index >= len(self.chat_completions):
raise ValueError("No more mock responses available")
@ -201,11 +206,16 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""Return the next completion as a stream."""
# Warn if tool_choice is specified since it's ignored in replay mode
if tool_choice != "auto":
logger.warning("tool_choice parameter specified but is ignored in replay mode")
if self._current_index >= len(self.chat_completions):
raise ValueError("No more mock responses available")

View File

@ -1,7 +1,7 @@
import json
import logging
import warnings
from typing import Any, Literal, Mapping, Optional, Sequence
from typing import Any, Literal, Mapping, Optional, Sequence, Union
from autogen_core import EVENT_LOGGER_NAME, FunctionCall
from autogen_core._cancellation_token import CancellationToken
@ -29,7 +29,7 @@ from semantic_kernel.contents import (
)
from semantic_kernel.functions.kernel_plugin import KernelPlugin
from semantic_kernel.kernel import Kernel
from typing_extensions import AsyncGenerator, Union
from typing_extensions import AsyncGenerator
from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool, KernelFunctionFromToolSchema
@ -442,6 +442,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -473,6 +474,13 @@ class SKChatCompletionAdapter(ChatCompletionClient):
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
# Handle tool_choice parameter
if tool_choice != "auto":
warnings.warn(
"tool_choice parameter is specified but may not be fully supported by SKChatCompletionAdapter.",
stacklevel=2,
)
kernel = self._get_kernel(extra_create_args)
chat_history = self._convert_to_chat_history(messages)
@ -553,6 +561,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
@ -585,6 +594,13 @@ class SKChatCompletionAdapter(ChatCompletionClient):
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
# Handle tool_choice parameter
if tool_choice != "auto":
warnings.warn(
"tool_choice parameter is specified but may not be fully supported by SKChatCompletionAdapter.",
stacklevel=2,
)
kernel = self._get_kernel(extra_create_args)
chat_history = self._convert_to_chat_history(messages)
user_settings = self._get_prompt_settings(extra_create_args)

View File

@ -2,6 +2,7 @@ import asyncio
import logging
import os
from typing import List, Sequence
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from autogen_core import CancellationToken, FunctionCall
@ -34,7 +35,153 @@ def _add_numbers(a: int, b: int) -> int:
@pytest.mark.asyncio
async def test_anthropic_serialization_api_key() -> None:
async def test_mock_tool_choice_specific_tool() -> None:
"""Test tool_choice parameter with a specific tool using mocks."""
# Create mock client and response
mock_client = AsyncMock()
mock_message = MagicMock()
mock_message.content = [MagicMock(type="tool_use", name="process_text", input={"input": "hello"}, id="call_123")]
mock_message.usage.input_tokens = 10
mock_message.usage.output_tokens = 5
mock_client.messages.create.return_value = mock_message
# Create real client but patch the underlying Anthropic client
client = AnthropicChatCompletionClient(
model="claude-3-haiku-20240307",
api_key="test-key",
)
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
messages: List[LLMMessage] = [
UserMessage(content="Process the text 'hello'.", source="user"),
]
with patch.object(client, "_client", mock_client):
await client.create(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice=pass_tool, # Force use of specific tool
)
# Verify the correct API call was made
mock_client.messages.create.assert_called_once()
call_args = mock_client.messages.create.call_args
# Check that tool_choice was set correctly
assert "tool_choice" in call_args.kwargs
assert call_args.kwargs["tool_choice"] == {"type": "tool", "name": "process_text"}
@pytest.mark.asyncio
async def test_mock_tool_choice_auto() -> None:
"""Test tool_choice parameter with 'auto' setting using mocks."""
# Create mock client and response
mock_client = AsyncMock()
mock_message = MagicMock()
mock_message.content = [MagicMock(type="tool_use", name="add_numbers", input={"a": 1, "b": 2}, id="call_123")]
mock_message.usage.input_tokens = 10
mock_message.usage.output_tokens = 5
mock_client.messages.create.return_value = mock_message
# Create real client but patch the underlying Anthropic client
client = AnthropicChatCompletionClient(
model="claude-3-haiku-20240307",
api_key="test-key",
)
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
messages: List[LLMMessage] = [
UserMessage(content="Add 1 and 2.", source="user"),
]
with patch.object(client, "_client", mock_client):
await client.create(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice="auto", # Let model choose
)
# Verify the correct API call was made
mock_client.messages.create.assert_called_once()
call_args = mock_client.messages.create.call_args
# Check that tool_choice was set correctly
assert "tool_choice" in call_args.kwargs
assert call_args.kwargs["tool_choice"] == {"type": "auto"}
@pytest.mark.asyncio
async def test_mock_tool_choice_none() -> None:
"""Test tool_choice parameter when no tools are provided - tool_choice should not be included."""
# Create mock client and response
mock_client = AsyncMock()
mock_message = MagicMock()
mock_message.content = [MagicMock(type="text", text="I can help you with that.")]
mock_message.usage.input_tokens = 10
mock_message.usage.output_tokens = 5
mock_client.messages.create.return_value = mock_message
# Create real client but patch the underlying Anthropic client
client = AnthropicChatCompletionClient(
model="claude-3-haiku-20240307",
api_key="test-key",
)
messages: List[LLMMessage] = [
UserMessage(content="Hello there.", source="user"),
]
with patch.object(client, "_client", mock_client):
await client.create(
messages=messages,
# No tools provided - tool_choice should not be included in API call
)
# Verify the correct API call was made
mock_client.messages.create.assert_called_once()
call_args = mock_client.messages.create.call_args
# Check that tool_choice was not set when no tools are provided
assert "tool_choice" not in call_args.kwargs
@pytest.mark.asyncio
async def test_mock_tool_choice_validation_error() -> None:
"""Test tool_choice validation with invalid tool reference."""
client = AnthropicChatCompletionClient(
model="claude-3-haiku-20240307",
api_key="test-key",
)
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
different_tool = FunctionTool(_pass_function, description="Different tool", name="different_tool")
messages: List[LLMMessage] = [
UserMessage(content="Hello there.", source="user"),
]
# Test with a tool that's not in the tools list
with pytest.raises(ValueError, match="tool_choice references 'different_tool' but it's not in the available tools"):
await client.create(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice=different_tool, # This tool is not in the tools list
)
@pytest.mark.asyncio
async def test_mock_serialization_api_key() -> None:
client = AnthropicChatCompletionClient(
model="claude-3-haiku-20240307", # Use haiku for faster/cheaper testing
api_key="sk-password",
@ -342,7 +489,7 @@ async def test_anthropic_multimodal() -> None:
@pytest.mark.asyncio
async def test_anthropic_serialization() -> None:
async def test_mock_serialization() -> None:
"""Test serialization and deserialization of component."""
client = AnthropicChatCompletionClient(
@ -422,7 +569,7 @@ async def test_anthropic_muliple_system_message() -> None:
assert result_content[-3:] == "BAR"
def test_merge_continuous_system_messages() -> None:
def test_mock_merge_continuous_system_messages() -> None:
"""Tests merging of continuous system messages."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
@ -447,7 +594,7 @@ def test_merge_continuous_system_messages() -> None:
assert merged_messages[1].content == "User question"
def test_merge_single_system_message() -> None:
def test_mock_merge_single_system_message() -> None:
"""Tests that a single system message remains unchanged."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
@ -467,7 +614,7 @@ def test_merge_single_system_message() -> None:
assert merged_messages[0].content == "Single system instruction"
def test_merge_no_system_messages() -> None:
def test_mock_merge_no_system_messages() -> None:
"""Tests behavior when there are no system messages."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
@ -486,7 +633,7 @@ def test_merge_no_system_messages() -> None:
assert merged_messages[0].content == "User question without system"
def test_merge_non_continuous_system_messages() -> None:
def test_mock_merge_non_continuous_system_messages() -> None:
"""Tests that an error is raised for non-continuous system messages."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
@ -504,7 +651,7 @@ def test_merge_non_continuous_system_messages() -> None:
# The method is protected, but we need to test it
def test_merge_system_messages_empty() -> None:
def test_mock_merge_system_messages_empty() -> None:
"""Tests that empty message list is handled properly."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
@ -513,7 +660,7 @@ def test_merge_system_messages_empty() -> None:
assert len(merged_messages) == 0
def test_merge_system_messages_with_special_characters() -> None:
def test_mock_merge_system_messages_with_special_characters() -> None:
"""Tests system message merging with special characters and formatting."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
@ -533,7 +680,7 @@ def test_merge_system_messages_with_special_characters() -> None:
assert system_message.content == "Line 1\nWith newline\nLine 2 with *formatting*\nLine 3 with `code`"
def test_merge_system_messages_with_whitespace() -> None:
def test_mock_merge_system_messages_with_whitespace() -> None:
"""Tests system message merging with extra whitespace."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
@ -553,7 +700,7 @@ def test_merge_system_messages_with_whitespace() -> None:
assert system_message.content == " Message with leading spaces \n\nMessage with leading newline"
def test_merge_system_messages_message_order() -> None:
def test_mock_merge_system_messages_message_order() -> None:
"""Tests that message order is preserved after merging."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
@ -584,7 +731,7 @@ def test_merge_system_messages_message_order() -> None:
assert merged_messages[3].content == "Answer"
def test_merge_system_messages_multiple_groups() -> None:
def test_mock_merge_system_messages_multiple_groups() -> None:
"""Tests that multiple separate groups of system messages raise an error."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
@ -600,7 +747,7 @@ def test_merge_system_messages_multiple_groups() -> None:
# The method is protected, but we need to test it
def test_merge_system_messages_no_duplicates() -> None:
def test_mock_merge_system_messages_no_duplicates() -> None:
"""Tests that identical system messages are still merged properly."""
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
@ -621,7 +768,7 @@ def test_merge_system_messages_no_duplicates() -> None:
@pytest.mark.asyncio
async def test_empty_assistant_content_string_with_anthropic() -> None:
async def test_anthropic_empty_assistant_content_string() -> None:
"""Test that an empty assistant content string is handled correctly."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
@ -646,7 +793,7 @@ async def test_empty_assistant_content_string_with_anthropic() -> None:
@pytest.mark.asyncio
async def test_claude_trailing_whitespace_at_last_assistant_content() -> None:
async def test_anthropic_trailing_whitespace_at_last_assistant_content() -> None:
"""Test that an empty assistant content string is handled correctly."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
@ -667,7 +814,7 @@ async def test_claude_trailing_whitespace_at_last_assistant_content() -> None:
assert isinstance(result.content, str)
def test_rstrip_railing_whitespace_at_last_assistant_content() -> None:
def test_mock_rstrip_trailing_whitespace_at_last_assistant_content() -> None:
messages: list[LLMMessage] = [
UserMessage(content="foo", source="user"),
UserMessage(content="bar", source="user"),
@ -680,3 +827,175 @@ def test_rstrip_railing_whitespace_at_last_assistant_content() -> None:
assert isinstance(result[-1].content, str)
assert result[-1].content == "foobar"
@pytest.mark.asyncio
async def test_anthropic_tool_choice_with_actual_api() -> None:
"""Test tool_choice parameter with actual Anthropic API endpoints."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
client = AnthropicChatCompletionClient(
model="claude-3-haiku-20240307",
api_key=api_key,
)
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
# Test 1: tool_choice with specific tool
messages: List[LLMMessage] = [
SystemMessage(content="Use the tools as needed to help the user."),
UserMessage(content="Process the text 'hello world' using the process_text tool.", source="user"),
]
result = await client.create(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice=pass_tool, # Force use of specific tool
)
# Verify we got a tool call for the specified tool
assert isinstance(result.content, list)
assert len(result.content) >= 1
assert isinstance(result.content[0], FunctionCall)
assert result.content[0].name == "process_text"
# Test 2: tool_choice="auto" with tools
auto_messages: List[LLMMessage] = [
SystemMessage(content="Use the tools as needed to help the user."),
UserMessage(content="Add the numbers 5 and 3.", source="user"),
]
auto_result = await client.create(
messages=auto_messages,
tools=[pass_tool, add_tool],
tool_choice="auto", # Let model choose
)
# Should get a tool call, likely for add_numbers
assert isinstance(auto_result.content, list)
assert len(auto_result.content) >= 1
assert isinstance(auto_result.content[0], FunctionCall)
# Model should choose add_numbers for addition task
assert auto_result.content[0].name == "add_numbers"
# Test 3: No tools provided - should not include tool_choice in API call
no_tools_messages: List[LLMMessage] = [
UserMessage(content="What is the capital of France?", source="user"),
]
no_tools_result = await client.create(messages=no_tools_messages)
# Should get a text response without tool calls
assert isinstance(no_tools_result.content, str)
assert "paris" in no_tools_result.content.lower()
# Test 4: tool_choice="required" with tools
required_messages: List[LLMMessage] = [
SystemMessage(content="You must use one of the available tools to help the user."),
UserMessage(content="Help me with something.", source="user"),
]
required_result = await client.create(
messages=required_messages,
tools=[pass_tool, add_tool],
tool_choice="required", # Force tool usage
)
# Should get a tool call (model forced to use a tool)
assert isinstance(required_result.content, list)
assert len(required_result.content) >= 1
assert isinstance(required_result.content[0], FunctionCall)
@pytest.mark.asyncio
async def test_anthropic_tool_choice_streaming_with_actual_api() -> None:
"""Test tool_choice parameter with streaming using actual Anthropic API endpoints."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
client = AnthropicChatCompletionClient(
model="claude-3-haiku-20240307",
api_key=api_key,
)
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
# Test streaming with tool_choice
messages: List[LLMMessage] = [
SystemMessage(content="Use the tools as needed to help the user."),
UserMessage(content="Process the text 'streaming test' using the process_text tool.", source="user"),
]
chunks: List[str | CreateResult] = []
async for chunk in client.create_stream(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice=pass_tool, # Force use of specific tool
):
chunks.append(chunk)
# Verify we got chunks and a final result
assert len(chunks) > 0
final_result = chunks[-1]
assert isinstance(final_result, CreateResult)
# Should get a tool call for the specified tool
assert isinstance(final_result.content, list)
assert len(final_result.content) >= 1
assert isinstance(final_result.content[0], FunctionCall)
assert final_result.content[0].name == "process_text"
# Test streaming without tools - should not include tool_choice
no_tools_messages: List[LLMMessage] = [
UserMessage(content="Tell me a short fact about cats.", source="user"),
]
no_tools_chunks: List[str | CreateResult] = []
async for chunk in client.create_stream(messages=no_tools_messages):
no_tools_chunks.append(chunk)
# Should get text response
assert len(no_tools_chunks) > 0
final_no_tools_result = no_tools_chunks[-1]
assert isinstance(final_no_tools_result, CreateResult)
assert isinstance(final_no_tools_result.content, str)
assert len(final_no_tools_result.content) > 0
@pytest.mark.asyncio
async def test_anthropic_tool_choice_none_value_with_actual_api() -> None:
"""Test tool_choice="none" with actual Anthropic API endpoints."""
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
client = AnthropicChatCompletionClient(
model="claude-3-haiku-20240307",
api_key=api_key,
)
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
# Test tool_choice="none" - should not use tools even when available
messages: List[LLMMessage] = [
SystemMessage(content="Answer the user's question directly without using tools."),
UserMessage(content="What is 2 + 2?", source="user"),
]
result = await client.create(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice="none", # Disable tool usage
)
# Should get a text response, not tool calls
assert isinstance(result.content, str)

View File

@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from autogen_core import CancellationToken, FunctionCall, Image
from autogen_core.models import CreateResult, ModelFamily, UserMessage
from autogen_core.tools import FunctionTool
from autogen_ext.models.azure import AzureAIChatCompletionClient
from autogen_ext.models.azure.config import GITHUB_MODELS_ENDPOINT
from azure.ai.inference.aio import (
@ -623,3 +624,352 @@ async def test_thought_field_with_tool_calls_streaming(
assert final_result.content[0].arguments == '{"foo": "bar"}'
assert final_result.thought == "Let me think about what function to call."
def _pass_function(input: str) -> str:
"""Simple passthrough function."""
return f"Processed: {input}"
def _add_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
return a + b
@pytest.fixture
def tool_choice_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
"""
Returns a client that supports function calling for tool choice tests.
"""
async def _mock_tool_choice_stream(
*args: Any, **kwargs: Any
) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
mock_chunks = [
StreamingChatChoiceUpdate(
index=0,
finish_reason="stop",
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
)
for chunk_content in mock_chunks_content
]
for mock_chunk in mock_chunks:
await asyncio.sleep(0.01)
yield StreamingChatCompletionsUpdate(
id="id",
choices=[mock_chunk],
created=datetime.now(),
model="model",
usage=CompletionsUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
mock_client = MagicMock()
mock_client.close = AsyncMock()
async def mock_complete(*args: Any, **kwargs: Any) -> Any:
stream = kwargs.get("stream", False)
if not stream:
await asyncio.sleep(0.01)
return ChatCompletions(
id="id",
created=datetime.now(),
model="model",
choices=[
ChatChoice(
index=0,
finish_reason=CompletionsFinishReason.TOOL_CALLS,
message=ChatResponseMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionsToolCall(
id="call_123",
function=AzureFunctionCall(name="process_text", arguments='{"input": "hello"}'),
)
],
),
)
],
usage=CompletionsUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
else:
return _mock_tool_choice_stream(*args, **kwargs)
mock_client.complete = mock_complete
def mock_new(cls: Type[ChatCompletionsClient], *args: Any, **kwargs: Any) -> MagicMock:
return mock_client
monkeypatch.setattr(ChatCompletionsClient, "__new__", mock_new)
return AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": True,
"vision": False,
"family": "test",
"structured_output": False,
},
model="model",
)
@pytest.mark.asyncio
async def test_azure_ai_tool_choice_specific_tool(tool_choice_client: AzureAIChatCompletionClient) -> None:
"""Test tool_choice parameter with a specific tool using mocks."""
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
messages = [
UserMessage(content="Process the text 'hello'.", source="user"),
]
result = await tool_choice_client.create(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice=pass_tool, # Force use of specific tool
)
# Verify the result
assert result.finish_reason == "function_calls"
assert isinstance(result.content, list)
assert len(result.content) == 1
assert isinstance(result.content[0], FunctionCall)
assert result.content[0].name == "process_text"
assert result.content[0].arguments == '{"input": "hello"}'
@pytest.mark.asyncio
async def test_azure_ai_tool_choice_auto(tool_choice_client: AzureAIChatCompletionClient) -> None:
"""Test tool_choice parameter with 'auto' setting using mocks."""
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
messages = [
UserMessage(content="Add 1 and 2.", source="user"),
]
result = await tool_choice_client.create(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice="auto", # Let the model choose
)
# Verify the result
assert result.finish_reason == "function_calls"
assert isinstance(result.content, list)
assert len(result.content) == 1
assert isinstance(result.content[0], FunctionCall)
assert result.content[0].name == "process_text" # Our mock always returns process_text
assert result.content[0].arguments == '{"input": "hello"}'
@pytest.fixture
def tool_choice_none_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
"""
Returns a client that simulates no tool calls for tool_choice='none' tests.
"""
mock_client = MagicMock()
mock_client.close = AsyncMock()
async def mock_complete(*args: Any, **kwargs: Any) -> ChatCompletions:
await asyncio.sleep(0.01)
return ChatCompletions(
id="id",
created=datetime.now(),
model="model",
choices=[
ChatChoice(
index=0,
finish_reason="stop",
message=ChatResponseMessage(role="assistant", content="I can help you with that."),
)
],
usage=CompletionsUsage(prompt_tokens=8, completion_tokens=6, total_tokens=14),
)
mock_client.complete = mock_complete
def mock_new(cls: Type[ChatCompletionsClient], *args: Any, **kwargs: Any) -> MagicMock:
return mock_client
monkeypatch.setattr(ChatCompletionsClient, "__new__", mock_new)
return AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": True,
"vision": False,
"family": "test",
"structured_output": False,
},
model="model",
)
@pytest.mark.asyncio
async def test_azure_ai_tool_choice_none(tool_choice_none_client: AzureAIChatCompletionClient) -> None:
"""Test tool_choice parameter with 'none' setting using mocks."""
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
messages = [
UserMessage(content="Just say hello.", source="user"),
]
result = await tool_choice_none_client.create(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice="none", # Prevent tool usage
)
# Verify the result
assert result.finish_reason == "stop"
assert isinstance(result.content, str)
assert result.content == "I can help you with that."
@pytest.mark.asyncio
async def test_azure_ai_tool_choice_required(tool_choice_client: AzureAIChatCompletionClient) -> None:
"""Test tool_choice parameter with 'required' setting using mocks."""
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
messages = [
UserMessage(content="Process some text.", source="user"),
]
result = await tool_choice_client.create(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice="required", # Force tool usage
)
# Verify the result
assert result.finish_reason == "function_calls"
assert isinstance(result.content, list)
assert len(result.content) == 1
assert isinstance(result.content[0], FunctionCall)
assert result.content[0].name == "process_text"
assert result.content[0].arguments == '{"input": "hello"}'
@pytest.fixture
def tool_choice_stream_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
"""
Returns a client that supports function calling for streaming tool choice tests.
"""
# Mock tool call for streaming
mock_tool_call = MagicMock()
mock_tool_call.id = "call_123"
mock_tool_call.function = MagicMock()
mock_tool_call.function.name = "process_text"
mock_tool_call.function.arguments = '{"input": "hello"}'
# First choice with content
first_choice = MagicMock()
first_choice.delta = MagicMock()
first_choice.delta.content = "Let me process this for you."
first_choice.finish_reason = None
# Tool call choice
tool_call_choice = MagicMock()
tool_call_choice.delta = MagicMock()
tool_call_choice.delta.content = None
tool_call_choice.delta.tool_calls = [mock_tool_call]
tool_call_choice.finish_reason = "function_calls"
async def _mock_tool_choice_stream(
*args: Any, **kwargs: Any
) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
yield StreamingChatCompletionsUpdate(
id="id",
choices=[first_choice],
created=datetime.now(),
model="model",
)
await asyncio.sleep(0.01)
yield StreamingChatCompletionsUpdate(
id="id",
choices=[tool_call_choice],
created=datetime.now(),
model="model",
usage=CompletionsUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
mock_client = MagicMock()
mock_client.close = AsyncMock()
async def mock_complete(*args: Any, **kwargs: Any) -> Any:
if kwargs.get("stream", False):
return _mock_tool_choice_stream(*args, **kwargs)
return None
mock_client.complete = mock_complete
def mock_new(cls: Type[ChatCompletionsClient], *args: Any, **kwargs: Any) -> MagicMock:
return mock_client
monkeypatch.setattr(ChatCompletionsClient, "__new__", mock_new)
return AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": True,
"vision": False,
"family": "test",
"structured_output": False,
},
model="model",
)
@pytest.mark.asyncio
async def test_azure_ai_tool_choice_specific_tool_streaming(
tool_choice_stream_client: AzureAIChatCompletionClient,
) -> None:
"""Test tool_choice parameter with streaming and a specific tool using mocks."""
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
messages = [
UserMessage(content="Process the text 'hello'.", source="user"),
]
chunks: List[Union[str, CreateResult]] = []
async for chunk in tool_choice_stream_client.create_stream(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice=pass_tool, # Force use of specific tool
):
chunks.append(chunk)
# Verify that we got some result
final_result = chunks[-1]
assert isinstance(final_result, CreateResult)
assert final_result.finish_reason == "function_calls"
assert isinstance(final_result.content, list)
assert len(final_result.content) == 1
assert isinstance(final_result.content[0], FunctionCall)
assert final_result.content[0].name == "process_text"
assert final_result.content[0].arguments == '{"input": "hello"}'
assert final_result.thought == "Let me process this for you."

View File

@ -590,6 +590,7 @@ async def test_ollama_create_tools(model: str, ollama_client: OllamaChatCompleti
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
assert create_result.finish_reason == "function_calls"
execution_result = FunctionExecutionResult(
@ -679,38 +680,11 @@ async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatC
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
assert create_result.finish_reason == "stop"
execution_result = FunctionExecutionResult(
content="4",
name=add_tool.name,
call_id=create_result.content[0].id,
is_error=False,
)
stream = ollama_client.create_stream(
messages=[
UserMessage(
content="What is 2 + 2? Use the add tool.",
source="user",
),
AssistantMessage(
content=create_result.content,
source="assistant",
),
FunctionExecutionResultMessage(
content=[execution_result],
),
],
)
chunks = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
create_result = chunks[-1]
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
assert create_result.finish_reason == "stop"
assert create_result.usage is not None
assert create_result.usage.prompt_tokens == 10
assert create_result.usage.completion_tokens == 12
@pytest.mark.asyncio
@ -883,3 +857,459 @@ async def test_llm_control_params(monkeypatch: pytest.MonkeyPatch) -> None:
assert chat_kwargs_captured["options"]["temperature"] == 0.7
assert chat_kwargs_captured["options"]["top_p"] == 0.9
assert chat_kwargs_captured["options"]["frequency_penalty"] == 1.2
@pytest.mark.asyncio
async def test_tool_choice_auto(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test tool_choice='auto' (default behavior)"""
def add(x: int, y: int) -> str:
return str(x + y)
def multiply(x: int, y: int) -> str:
return str(x * y)
add_tool = FunctionTool(add, description="Add two numbers")
multiply_tool = FunctionTool(multiply, description="Multiply two numbers")
model = "llama3.2"
# Capture the kwargs passed to chat
chat_kwargs_captured: Dict[str, Any] = {}
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
nonlocal chat_kwargs_captured
chat_kwargs_captured = kwargs
return ChatResponse(
model=model,
done=True,
done_reason="stop",
message=Message(
role="assistant",
content="I'll use the add tool.",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 3},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
create_result = await client.create(
messages=[UserMessage(content="What is 2 + 3?", source="user")],
tools=[add_tool, multiply_tool],
tool_choice="auto", # Explicitly set to auto
)
# Verify that all tools are passed to the API when tool_choice is auto
assert "tools" in chat_kwargs_captured
assert chat_kwargs_captured["tools"] is not None
assert len(chat_kwargs_captured["tools"]) == 2
# Verify the response
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
@pytest.mark.asyncio
async def test_tool_choice_none(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test tool_choice='none' - no tools should be passed to API"""
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
model = "llama3.2"
content_raw = "I cannot use tools, so I'll calculate manually: 2 + 3 = 5"
# Capture the kwargs passed to chat
chat_kwargs_captured: Dict[str, Any] = {}
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
nonlocal chat_kwargs_captured
chat_kwargs_captured = kwargs
return ChatResponse(
model=model,
done=True,
done_reason="stop",
message=Message(
role="assistant",
content=content_raw,
),
prompt_eval_count=10,
eval_count=12,
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
create_result = await client.create(
messages=[UserMessage(content="What is 2 + 3?", source="user")],
tools=[add_tool],
tool_choice="none",
)
# Verify that no tools are passed to the API when tool_choice is none
assert "tools" in chat_kwargs_captured
assert chat_kwargs_captured["tools"] is None
# Verify the response is text content
assert isinstance(create_result.content, str)
assert create_result.content == content_raw
assert create_result.finish_reason == "stop"
@pytest.mark.asyncio
async def test_tool_choice_required(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test tool_choice='required' - tools must be provided and passed to API"""
def add(x: int, y: int) -> str:
return str(x + y)
def multiply(x: int, y: int) -> str:
return str(x * y)
add_tool = FunctionTool(add, description="Add two numbers")
multiply_tool = FunctionTool(multiply, description="Multiply two numbers")
model = "llama3.2"
# Capture the kwargs passed to chat
chat_kwargs_captured: Dict[str, Any] = {}
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
nonlocal chat_kwargs_captured
chat_kwargs_captured = kwargs
return ChatResponse(
model=model,
done=True,
done_reason="tool_calls",
message=Message(
role="assistant",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 3},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
create_result = await client.create(
messages=[UserMessage(content="What is 2 + 3?", source="user")],
tools=[add_tool, multiply_tool],
tool_choice="required",
)
# Verify that all tools are passed to the API when tool_choice is required
assert "tools" in chat_kwargs_captured
assert chat_kwargs_captured["tools"] is not None
assert len(chat_kwargs_captured["tools"]) == 2
# Verify the response contains function calls
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.finish_reason == "function_calls"
@pytest.mark.asyncio
async def test_tool_choice_required_no_tools_error() -> None:
"""Test tool_choice='required' with no tools raises ValueError"""
model = "llama3.2"
client = OllamaChatCompletionClient(model=model)
with pytest.raises(ValueError, match="tool_choice 'required' specified but no tools provided"):
await client.create(
messages=[UserMessage(content="What is 2 + 3?", source="user")],
tools=[], # No tools provided
tool_choice="required",
)
@pytest.mark.asyncio
async def test_tool_choice_specific_tool(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test tool_choice with a specific tool - only that tool should be passed to API"""
def add(x: int, y: int) -> str:
return str(x + y)
def multiply(x: int, y: int) -> str:
return str(x * y)
add_tool = FunctionTool(add, description="Add two numbers")
multiply_tool = FunctionTool(multiply, description="Multiply two numbers")
model = "llama3.2"
# Capture the kwargs passed to chat
chat_kwargs_captured: Dict[str, Any] = {}
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
nonlocal chat_kwargs_captured
chat_kwargs_captured = kwargs
return ChatResponse(
model=model,
done=True,
done_reason="tool_calls",
message=Message(
role="assistant",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 3},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
create_result = await client.create(
messages=[UserMessage(content="What is 2 + 3?", source="user")],
tools=[add_tool, multiply_tool], # Multiple tools available
tool_choice=add_tool, # But force specific tool
)
# Verify that only the specified tool is passed to the API
assert "tools" in chat_kwargs_captured
assert chat_kwargs_captured["tools"] is not None
assert len(chat_kwargs_captured["tools"]) == 1
assert chat_kwargs_captured["tools"][0]["function"]["name"] == add_tool.name
# Verify the response contains function calls
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.finish_reason == "function_calls"
@pytest.mark.asyncio
async def test_tool_choice_stream_auto(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test tool_choice='auto' with streaming"""
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
model = "llama3.2"
content_raw = "I'll use the add tool."
# Capture the kwargs passed to chat
chat_kwargs_captured: Dict[str, Any] = {}
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
nonlocal chat_kwargs_captured
chat_kwargs_captured = kwargs
assert "stream" in kwargs
assert kwargs["stream"] is True
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
chunks = [content_raw[i : i + 5] for i in range(0, len(content_raw), 5)]
# Simulate streaming by yielding chunks of the response
for chunk in chunks[:-1]:
yield ChatResponse(
model=model,
done=False,
message=Message(
role="assistant",
content=chunk,
),
)
yield ChatResponse(
model=model,
done=True,
done_reason="tool_calls",
message=Message(
content=chunks[-1],
role="assistant",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 3},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)
return _mock_stream()
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
stream = client.create_stream(
messages=[UserMessage(content="What is 2 + 3?", source="user")],
tools=[add_tool],
tool_choice="auto",
)
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
# Verify that tools are passed to the API when tool_choice is auto
assert "tools" in chat_kwargs_captured
assert chat_kwargs_captured["tools"] is not None
assert len(chat_kwargs_captured["tools"]) == 1
# Verify the final result
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, list)
assert len(chunks[-1].content) > 0
assert isinstance(chunks[-1].content[0], FunctionCall)
assert chunks[-1].content[0].name == add_tool.name
assert chunks[-1].finish_reason == "function_calls"
@pytest.mark.asyncio
async def test_tool_choice_stream_none(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test tool_choice='none' with streaming"""
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
model = "llama3.2"
content_raw = "I cannot use tools, so I'll calculate manually: 2 + 3 = 5"
# Capture the kwargs passed to chat
chat_kwargs_captured: Dict[str, Any] = {}
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
nonlocal chat_kwargs_captured
chat_kwargs_captured = kwargs
assert "stream" in kwargs
assert kwargs["stream"] is True
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
chunks = [content_raw[i : i + 10] for i in range(0, len(content_raw), 10)]
# Simulate streaming by yielding chunks of the response
for chunk in chunks[:-1]:
yield ChatResponse(
model=model,
done=False,
message=Message(
role="assistant",
content=chunk,
),
)
yield ChatResponse(
model=model,
done=True,
done_reason="stop",
message=Message(
role="assistant",
content=chunks[-1],
),
prompt_eval_count=10,
eval_count=12,
)
return _mock_stream()
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
stream = client.create_stream(
messages=[UserMessage(content="What is 2 + 3?", source="user")],
tools=[add_tool],
tool_choice="none",
)
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
# Verify that no tools are passed to the API when tool_choice is none
assert "tools" in chat_kwargs_captured
assert chat_kwargs_captured["tools"] is None
# Verify the final result is text content
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, str)
assert chunks[-1].content == content_raw
assert chunks[-1].finish_reason == "stop"
@pytest.mark.asyncio
async def test_tool_choice_default_behavior(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that default behavior (no tool_choice specified) works like 'auto'"""
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
model = "llama3.2"
# Capture the kwargs passed to chat
chat_kwargs_captured: Dict[str, Any] = {}
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
nonlocal chat_kwargs_captured
chat_kwargs_captured = kwargs
return ChatResponse(
model=model,
done=True,
done_reason="stop",
message=Message(
role="assistant",
content="I'll use the add tool.",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 3},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
create_result = await client.create(
messages=[UserMessage(content="What is 2 + 3?", source="user")],
tools=[add_tool],
# tool_choice not specified - should default to "auto"
)
# Verify that tools are passed to the API by default (auto behavior)
assert "tools" in chat_kwargs_captured
assert chat_kwargs_captured["tools"] is not None
assert len(chat_kwargs_captured["tools"]) == 1
# Verify the response
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name

View File

@ -3,7 +3,7 @@ import json
import logging
import os
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
import httpx
import pytest
@ -2193,6 +2193,7 @@ async def test_system_message_merge_with_continuous_system_messages_models() ->
tools=[],
json_output=None,
extra_create_args={},
tool_choice="none",
)
# Extract the actual messages from the result
@ -2243,6 +2244,7 @@ async def test_system_message_merge_with_non_continuous_messages() -> None:
tools=[],
json_output=None,
extra_create_args={},
tool_choice="none",
)
@ -2279,6 +2281,7 @@ async def test_system_message_not_merged_for_multiple_system_messages_true() ->
tools=[],
json_output=None,
extra_create_args={},
tool_choice="none",
)
# Extract the actual messages from the result
@ -2322,6 +2325,7 @@ async def test_no_system_messages_for_gemini_model() -> None:
tools=[],
json_output=None,
extra_create_args={},
tool_choice="none",
)
# Extract the actual messages from the result
@ -2369,6 +2373,7 @@ async def test_single_system_message_for_gemini_model() -> None:
tools=[],
json_output=None,
extra_create_args={},
tool_choice="auto",
)
# Extract the actual messages from the result
@ -2561,4 +2566,493 @@ async def test_mistral_remove_name() -> None:
assert ("name" in params[0]) is True
@pytest.mark.asyncio
async def test_mock_tool_choice_specific_tool(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test tool_choice parameter with a specific tool using mocks."""
def _pass_function(input: str) -> str:
"""Simple passthrough function."""
return f"Processed: {input}"
def _add_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
return a + b
model = "gpt-4o"
# Mock successful completion with specific tool call
chat_completion = ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
role="assistant",
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "hello"}),
),
)
],
),
)
],
created=1234567890,
model=model,
object="chat.completion",
usage=CompletionUsage(completion_tokens=10, prompt_tokens=5, total_tokens=15),
)
client = OpenAIChatCompletionClient(model=model, api_key="test")
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
# Create mock for the chat completions create method
mock_create = AsyncMock(return_value=chat_completion)
with monkeypatch.context() as mp:
mp.setattr(client._client.chat.completions, "create", mock_create) # type: ignore[reportPrivateUsage]
_ = await client.create(
messages=[UserMessage(content="Process 'hello'", source="user")],
tools=[pass_tool, add_tool],
tool_choice=pass_tool, # Force use of specific tool
)
# Verify the correct API call was made
mock_create.assert_called_once()
call_args = mock_create.call_args
# Check that tool_choice was set correctly
assert "tool_choice" in call_args.kwargs
assert call_args.kwargs["tool_choice"] == {"type": "function", "function": {"name": "_pass_function"}}
@pytest.mark.asyncio
async def test_mock_tool_choice_auto(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test tool_choice parameter with 'auto' setting using mocks."""
def _pass_function(input: str) -> str:
"""Simple passthrough function."""
return f"Processed: {input}"
def _add_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
return a + b
model = "gpt-4o"
# Mock successful completion
chat_completion = ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
role="assistant",
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_add_numbers",
arguments=json.dumps({"a": 1, "b": 2}),
),
)
],
),
)
],
created=1234567890,
model=model,
object="chat.completion",
usage=CompletionUsage(completion_tokens=10, prompt_tokens=5, total_tokens=15),
)
client = OpenAIChatCompletionClient(model=model, api_key="test")
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
# Create mock for the chat completions create method
mock_create = AsyncMock(return_value=chat_completion)
with monkeypatch.context() as mp:
mp.setattr(client._client.chat.completions, "create", mock_create) # type: ignore[reportPrivateUsage]
await client.create(
messages=[UserMessage(content="Add 1 and 2", source="user")],
tools=[pass_tool, add_tool],
tool_choice="auto", # Let model choose
)
# Verify the correct API call was made
mock_create.assert_called_once()
call_args = mock_create.call_args
# Check that tool_choice was set correctly
assert "tool_choice" in call_args.kwargs
assert call_args.kwargs["tool_choice"] == "auto"
@pytest.mark.asyncio
async def test_mock_tool_choice_none(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test tool_choice parameter with None setting using mocks."""
def _pass_function(input: str) -> str:
"""Simple passthrough function."""
return f"Processed: {input}"
model = "gpt-4o"
# Mock successful completion
chat_completion = ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
role="assistant",
content="I can help you with that!",
tool_calls=None,
),
)
],
created=1234567890,
model=model,
object="chat.completion",
usage=CompletionUsage(completion_tokens=10, prompt_tokens=5, total_tokens=15),
)
client = OpenAIChatCompletionClient(model=model, api_key="test")
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
# Create mock for the chat completions create method
mock_create = AsyncMock(return_value=chat_completion)
with monkeypatch.context() as mp:
mp.setattr(client._client.chat.completions, "create", mock_create) # type: ignore[reportPrivateUsage]
await client.create(
messages=[UserMessage(content="Hello there", source="user")],
tools=[pass_tool],
tool_choice="none",
)
# Verify the correct API call was made
mock_create.assert_called_once()
call_args = mock_create.call_args
# Check that tool_choice was set to "none" (disabling tool usage)
assert "tool_choice" in call_args.kwargs
assert call_args.kwargs["tool_choice"] == "none"
@pytest.mark.asyncio
async def test_mock_tool_choice_validation_error() -> None:
"""Test tool_choice validation with invalid tool reference."""
def _pass_function(input: str) -> str:
"""Simple passthrough function."""
return f"Processed: {input}"
def _add_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
return a + b
def _different_function(text: str) -> str:
"""Different function."""
return text
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="test")
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
different_tool = FunctionTool(_different_function, description="Different tool", name="_different_function")
messages = [UserMessage(content="Hello there", source="user")]
# Test with a tool that's not in the tools list
with pytest.raises(
ValueError, match="tool_choice references '_different_function' but it's not in the provided tools"
):
await client.create(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice=different_tool, # This tool is not in the tools list
)
@pytest.mark.asyncio
async def test_mock_tool_choice_required(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test tool_choice parameter with 'required' setting using mocks."""
def _pass_function(input: str) -> str:
"""Simple passthrough function."""
return f"Processed: {input}"
def _add_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
return a + b
model = "gpt-4o"
# Mock successful completion with tool calls (required forces tool usage)
chat_completion = ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
role="assistant",
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "hello"}),
),
)
],
),
)
],
created=1234567890,
model=model,
object="chat.completion",
usage=CompletionUsage(completion_tokens=10, prompt_tokens=5, total_tokens=15),
)
client = OpenAIChatCompletionClient(model=model, api_key="test")
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
# Create mock for the chat completions create method
mock_create = AsyncMock(return_value=chat_completion)
with monkeypatch.context() as mp:
mp.setattr(client._client.chat.completions, "create", mock_create) # type: ignore[reportPrivateUsage]
await client.create(
messages=[UserMessage(content="Process some text", source="user")],
tools=[pass_tool, add_tool],
tool_choice="required", # Force tool usage
)
# Verify the correct API call was made
mock_create.assert_called_once()
call_args = mock_create.call_args
# Check that tool_choice was set correctly
assert "tool_choice" in call_args.kwargs
assert call_args.kwargs["tool_choice"] == "required"
# Integration tests for tool_choice using the actual OpenAI API
@pytest.mark.asyncio
async def test_openai_tool_choice_specific_tool_integration() -> None:
"""Test tool_choice parameter with a specific tool using the actual OpenAI API."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not found in environment variables")
def _pass_function(input: str) -> str:
"""Simple passthrough function."""
return f"Processed: {input}"
def _add_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
return a + b
model = "gpt-4o-mini"
client = OpenAIChatCompletionClient(model=model, api_key=api_key)
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
# Test forcing use of specific tool
result = await client.create(
messages=[UserMessage(content="Process the word 'hello'", source="user")],
tools=[pass_tool, add_tool],
tool_choice=pass_tool, # Force use of specific tool
)
assert isinstance(result.content, list)
assert len(result.content) == 1
assert isinstance(result.content[0], FunctionCall)
assert result.content[0].name == "_pass_function"
assert result.finish_reason == "function_calls"
assert result.usage is not None
@pytest.mark.asyncio
async def test_openai_tool_choice_auto_integration() -> None:
"""Test tool_choice parameter with 'auto' setting using the actual OpenAI API."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not found in environment variables")
def _pass_function(input: str) -> str:
"""Simple passthrough function."""
return f"Processed: {input}"
def _add_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
return a + b
model = "gpt-4o-mini"
client = OpenAIChatCompletionClient(model=model, api_key=api_key)
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
# Test auto tool choice - model should choose to use add_numbers for math
result = await client.create(
messages=[UserMessage(content="What is 15 plus 27?", source="user")],
tools=[pass_tool, add_tool],
tool_choice="auto", # Let model choose
)
assert isinstance(result.content, list)
assert len(result.content) == 1
assert isinstance(result.content[0], FunctionCall)
assert result.content[0].name == "_add_numbers"
assert result.finish_reason == "function_calls"
assert result.usage is not None
# Parse arguments to verify correct values
args = json.loads(result.content[0].arguments)
assert args["a"] == 15
assert args["b"] == 27
@pytest.mark.asyncio
async def test_openai_tool_choice_none_integration() -> None:
"""Test tool_choice parameter with 'none' setting using the actual OpenAI API."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not found in environment variables")
def _pass_function(input: str) -> str:
"""Simple passthrough function."""
return f"Processed: {input}"
model = "gpt-4o-mini"
client = OpenAIChatCompletionClient(model=model, api_key=api_key)
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
# Test none tool choice - model should not use any tools
result = await client.create(
messages=[UserMessage(content="Hello there, how are you?", source="user")],
tools=[pass_tool],
tool_choice="none", # Disable tool usage
)
assert isinstance(result.content, str)
assert len(result.content) > 0
assert result.finish_reason == "stop"
assert result.usage is not None
@pytest.mark.asyncio
async def test_openai_tool_choice_required_integration() -> None:
"""Test tool_choice parameter with 'required' setting using the actual OpenAI API."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not found in environment variables")
def _pass_function(input: str) -> str:
"""Simple passthrough function."""
return f"Processed: {input}"
def _add_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
return a + b
model = "gpt-4o-mini"
client = OpenAIChatCompletionClient(model=model, api_key=api_key)
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
# Test required tool choice - model must use a tool even for general conversation
result = await client.create(
messages=[UserMessage(content="Say hello to me", source="user")],
tools=[pass_tool, add_tool],
tool_choice="required", # Force tool usage
)
assert isinstance(result.content, list)
assert len(result.content) == 1
assert isinstance(result.content[0], FunctionCall)
assert result.content[0].name in ["_pass_function", "_add_numbers"]
assert result.finish_reason == "function_calls"
assert result.usage is not None
@pytest.mark.asyncio
async def test_openai_tool_choice_validation_error_integration() -> None:
"""Test tool_choice validation with invalid tool reference using the actual OpenAI API."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not found in environment variables")
def _pass_function(input: str) -> str:
"""Simple passthrough function."""
return f"Processed: {input}"
def _add_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
return a + b
def _different_function(text: str) -> str:
"""Different function."""
return text
model = "gpt-4o-mini"
client = OpenAIChatCompletionClient(model=model, api_key=api_key)
# Define tools
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
different_tool = FunctionTool(_different_function, description="Different tool", name="_different_function")
messages = [UserMessage(content="Hello there", source="user")]
# Test with a tool that's not in the tools list
with pytest.raises(
ValueError, match="tool_choice references '_different_function' but it's not in the provided tools"
):
await client.create(
messages=messages,
tools=[pass_tool, add_tool],
tool_choice=different_tool, # This tool is not in the tools list
)
# TODO: add integration tests for Azure OpenAI using AAD token.