mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-25 05:59:19 +00:00
Add tool_choice parameter to ChatCompletionClient create and create_stream methods (#6697)
## Summary Implements the `tool_choice` parameter for `ChatCompletionClient` interface as requested in #6696. This allows users to restrict which tools the model can choose from when multiple tools are available. ## Changes ### Core Interface - Core Interface: Added `tool_choice: Tool | Literal["auto", "required", "none"] = "auto"` parameter to `ChatCompletionClient.create()` and `create_stream()` methods - Model Implementations: Updated client implementations to support the new parameter, for now, only the following model clients are supported: - OpenAI - Anthropic - Azure AI - Ollama - `LlamaCppChatCompletionClient` currently not supported Features - "auto" (default): Let the model choose whether to use tools, when there is no tool, it has no effect. - "required": Force the model to use at least one tool - "none": Disable tool usage completely - Tool object: Force the model to use a specific tool --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
6f15270cb2
commit
c150f85044
@ -211,8 +211,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -222,6 +221,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
Args:
|
||||
messages (Sequence[LLMMessage]): The messages to send to the model.
|
||||
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
|
||||
tool_choice (Tool | Literal["auto", "required", "none"], optional): A single Tool object to force the model to use, "auto" to let the model choose any available tool, "required" to force tool usage, or "none" to disable tool usage. Defaults to "auto".
|
||||
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither.
|
||||
Defaults to None. If set to a `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ type,
|
||||
it will be used as the output type for structured output.
|
||||
@ -241,8 +241,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -252,6 +251,7 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
Args:
|
||||
messages (Sequence[LLMMessage]): The messages to send to the model.
|
||||
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
|
||||
tool_choice (Tool | Literal["auto", "required", "none"], optional): A single Tool object to force the model to use, "auto" to let the model choose any available tool, "required" to force tool usage, or "none" to disable tool usage. Defaults to "auto".
|
||||
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither.
|
||||
Defaults to None. If set to a `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ type,
|
||||
it will be used as the output type for structured output.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union
|
||||
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union
|
||||
|
||||
import pytest
|
||||
from autogen_core import EVENT_LOGGER_NAME, AgentId, CancellationToken, FunctionCall, SingleThreadedAgentRuntime
|
||||
@ -102,6 +102,7 @@ async def test_caller_loop() -> None:
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -127,6 +128,7 @@ async def test_caller_loop() -> None:
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
|
||||
@ -90,6 +90,7 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
) -> CreateResult:
|
||||
current_messages: List[Mapping[str, Any]] = [msg.model_dump() for msg in messages]
|
||||
if self.mode == "record":
|
||||
@ -97,6 +98,7 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
|
||||
messages,
|
||||
tools=tools,
|
||||
json_output=json_output,
|
||||
tool_choice=tool_choice,
|
||||
extra_create_args=extra_create_args,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
@ -157,10 +159,12 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
return self.base_client.create_stream(
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
json_output=json_output,
|
||||
extra_create_args=extra_create_args,
|
||||
cancellation_token=cancellation_token,
|
||||
|
||||
@ -149,6 +149,31 @@ def get_mime_type_from_image(image: Image) -> Literal["image/jpeg", "image/png",
|
||||
return "image/jpeg"
|
||||
|
||||
|
||||
def convert_tool_choice_anthropic(tool_choice: Tool | Literal["auto", "required", "none"]) -> Any:
|
||||
"""Convert tool_choice parameter to Anthropic API format.
|
||||
|
||||
Args:
|
||||
tool_choice: A single Tool object to force the model to use, "auto" to let the model choose any available tool, "required" to force tool usage, or "none" to disable tool usage.
|
||||
|
||||
Returns:
|
||||
Anthropic API compatible tool_choice value.
|
||||
"""
|
||||
if tool_choice == "none":
|
||||
return {"type": "none"}
|
||||
|
||||
if tool_choice == "auto":
|
||||
return {"type": "auto"}
|
||||
|
||||
if tool_choice == "required":
|
||||
return {"type": "any"} # Anthropic uses "any" for required
|
||||
|
||||
# Must be a Tool object
|
||||
if isinstance(tool_choice, Tool):
|
||||
return {"type": "tool", "name": tool_choice.schema["name"]}
|
||||
else:
|
||||
raise ValueError(f"tool_choice must be a Tool object, 'auto', 'required', or 'none', got {type(tool_choice)}")
|
||||
|
||||
|
||||
@overload
|
||||
def __empty_content_to_whitespace(content: str) -> str: ...
|
||||
|
||||
@ -504,6 +529,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -582,6 +608,36 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
# anthropic requires tools to be present even if there is any tool use
|
||||
request_args["tools"] = self._last_used_tools
|
||||
|
||||
# Process tool_choice parameter
|
||||
if isinstance(tool_choice, Tool):
|
||||
if len(tools) == 0 and not has_tool_results:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
|
||||
# Validate that the tool exists in the provided tools
|
||||
tool_names_available: List[str] = []
|
||||
if len(tools) > 0:
|
||||
for tool in tools:
|
||||
if isinstance(tool, Tool):
|
||||
tool_names_available.append(tool.schema["name"])
|
||||
else:
|
||||
tool_names_available.append(tool["name"])
|
||||
else:
|
||||
# Use last used tools names if available
|
||||
for tool_param in self._last_used_tools:
|
||||
tool_names_available.append(tool_param["name"])
|
||||
|
||||
# tool_choice is a single Tool object
|
||||
tool_name = tool_choice.schema["name"]
|
||||
if tool_name not in tool_names_available:
|
||||
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the available tools")
|
||||
|
||||
# Convert to Anthropic format and add to request_args only if tools are provided
|
||||
# According to Anthropic API, tool_choice may only be specified while providing tools
|
||||
if len(tools) > 0 or has_tool_results:
|
||||
converted_tool_choice = convert_tool_choice_anthropic(tool_choice)
|
||||
if converted_tool_choice is not None:
|
||||
request_args["tool_choice"] = converted_tool_choice
|
||||
|
||||
# Optional parameters
|
||||
for param in ["top_p", "top_k", "stop_sequences", "metadata"]:
|
||||
if param in create_args:
|
||||
@ -667,6 +723,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -751,6 +808,36 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
elif has_tool_results:
|
||||
request_args["tools"] = self._last_used_tools
|
||||
|
||||
# Process tool_choice parameter
|
||||
if isinstance(tool_choice, Tool):
|
||||
if len(tools) == 0 and not has_tool_results:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
|
||||
# Validate that the tool exists in the provided tools
|
||||
tool_names_available: List[str] = []
|
||||
if len(tools) > 0:
|
||||
for tool in tools:
|
||||
if isinstance(tool, Tool):
|
||||
tool_names_available.append(tool.schema["name"])
|
||||
else:
|
||||
tool_names_available.append(tool["name"])
|
||||
else:
|
||||
# Use last used tools names if available
|
||||
for last_used_tool in self._last_used_tools:
|
||||
tool_names_available.append(last_used_tool["name"])
|
||||
|
||||
# tool_choice is a single Tool object
|
||||
tool_name = tool_choice.schema["name"]
|
||||
if tool_name not in tool_names_available:
|
||||
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the available tools")
|
||||
|
||||
# Convert to Anthropic format and add to request_args only if tools are provided
|
||||
# According to Anthropic API, tool_choice may only be specified while providing tools
|
||||
if len(tools) > 0 or has_tool_results:
|
||||
converted_tool_choice = convert_tool_choice_anthropic(tool_choice)
|
||||
if converted_tool_choice is not None:
|
||||
request_args["tool_choice"] = converted_tool_choice
|
||||
|
||||
# Optional parameters
|
||||
for param in ["top_p", "top_k", "stop_sequences", "metadata"]:
|
||||
if param in create_args:
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import re
|
||||
from asyncio import Task
|
||||
from inspect import getfullargspec
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Union, cast
|
||||
|
||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall, Image
|
||||
from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent, LLMStreamStartEvent
|
||||
@ -28,6 +28,8 @@ from azure.ai.inference.models import (
|
||||
)
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsNamedToolChoice,
|
||||
ChatCompletionsNamedToolChoiceFunction,
|
||||
ChatCompletionsToolCall,
|
||||
ChatCompletionsToolDefinition,
|
||||
CompletionsFinishReason,
|
||||
@ -53,7 +55,7 @@ from azure.ai.inference.models import (
|
||||
UserMessage as AzureUserMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import AsyncGenerator, Union, Unpack
|
||||
from typing_extensions import AsyncGenerator, Unpack
|
||||
|
||||
from autogen_ext.models.azure.config import (
|
||||
GITHUB_MODELS_ENDPOINT,
|
||||
@ -309,7 +311,10 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
@staticmethod
|
||||
def _create_client(config: AzureAIChatCompletionClientConfig) -> ChatCompletionsClient:
|
||||
return ChatCompletionsClient(**config)
|
||||
# Only pass the parameters that ChatCompletionsClient accepts
|
||||
# Remove 'model_info' and other client-specific parameters
|
||||
client_config = {k: v for k, v in config.items() if k not in ("model_info",)}
|
||||
return ChatCompletionsClient(**client_config) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _prepare_create_args(config: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
@ -356,6 +361,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -376,6 +382,12 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
task: Task[ChatCompletions]
|
||||
|
||||
if len(tools) > 0:
|
||||
if isinstance(tool_choice, Tool):
|
||||
create_args["tool_choice"] = ChatCompletionsNamedToolChoice(
|
||||
function=ChatCompletionsNamedToolChoiceFunction(name=tool_choice.name)
|
||||
)
|
||||
else:
|
||||
create_args["tool_choice"] = tool_choice
|
||||
converted_tools = convert_tools(tools)
|
||||
task = asyncio.create_task( # type: ignore
|
||||
self._client.complete(messages=azure_messages, tools=converted_tools, **create_args) # type: ignore
|
||||
@ -451,6 +463,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -469,6 +482,12 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
azure_messages = [item for sublist in azure_messages_nested for item in sublist]
|
||||
|
||||
if len(tools) > 0:
|
||||
if isinstance(tool_choice, Tool):
|
||||
create_args["tool_choice"] = ChatCompletionsNamedToolChoice(
|
||||
function=ChatCompletionsNamedToolChoiceFunction(name=tool_choice.name)
|
||||
)
|
||||
else:
|
||||
create_args["tool_choice"] = tool_choice
|
||||
converted_tools = convert_tools(tools)
|
||||
task = asyncio.create_task(
|
||||
self._client.complete(messages=azure_messages, tools=converted_tools, stream=True, **create_args)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast
|
||||
from typing import Any, AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union, cast
|
||||
|
||||
from autogen_core import CacheStore, CancellationToken, Component, ComponentModel, InMemoryStore
|
||||
from autogen_core.models import (
|
||||
@ -137,6 +137,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -158,6 +159,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
|
||||
messages,
|
||||
tools=tools,
|
||||
json_output=json_output,
|
||||
tool_choice=tool_choice,
|
||||
extra_create_args=extra_create_args,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
@ -169,6 +171,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -200,6 +203,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
|
||||
messages,
|
||||
tools=tools,
|
||||
json_output=json_output,
|
||||
tool_choice=tool_choice,
|
||||
extra_create_args=extra_create_args,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import logging # added import
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, TypedDict, Union, cast
|
||||
|
||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall, MessageHandlerContext
|
||||
@ -264,6 +265,7 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
@ -302,6 +304,15 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
elif json_output is not False and json_output is not None:
|
||||
raise ValueError("json_output must be a boolean, a BaseModel subclass or None.")
|
||||
|
||||
# Handle tool_choice parameter
|
||||
if tool_choice != "auto":
|
||||
warnings.warn(
|
||||
"tool_choice parameter is specified but LlamaCppChatCompletionClient does not support it. "
|
||||
"This parameter will be ignored.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if self.model_info["function_calling"]:
|
||||
# Run this in on the event loop to avoid blocking.
|
||||
response_future = asyncio.get_event_loop().run_in_executor(
|
||||
@ -397,12 +408,21 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
# Validate tool_choice parameter even though streaming is not implemented
|
||||
if tool_choice != "auto" and tool_choice != "none":
|
||||
if not self.model_info["function_calling"]:
|
||||
raise ValueError("tool_choice specified but model does not support function calling")
|
||||
if len(tools) == 0:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
logger.warning("tool_choice parameter specified but may not be supported by llama-cpp-python")
|
||||
|
||||
raise NotImplementedError("Stream not yet implemented for LlamaCppChatCompletionClient")
|
||||
yield ""
|
||||
|
||||
|
||||
@ -514,6 +514,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"],
|
||||
json_output: Optional[bool | type[BaseModel]],
|
||||
extra_create_args: Mapping[str, Any],
|
||||
) -> CreateParams:
|
||||
@ -584,7 +585,22 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
if self.model_info["function_calling"] is False and len(tools) > 0:
|
||||
raise ValueError("Model does not support function calling and tools were provided")
|
||||
|
||||
converted_tools = convert_tools(tools)
|
||||
converted_tools: List[OllamaTool] = []
|
||||
|
||||
# Handle tool_choice parameter in a way that is compatible with Ollama API.
|
||||
if isinstance(tool_choice, Tool):
|
||||
# If tool_choice is a Tool, convert it to OllamaTool.
|
||||
converted_tools = convert_tools([tool_choice])
|
||||
elif tool_choice == "none":
|
||||
# No tool choice, do not pass tools to the API.
|
||||
converted_tools = []
|
||||
elif tool_choice == "required":
|
||||
# Required tool choice, pass tools to the API.
|
||||
converted_tools = convert_tools(tools)
|
||||
if len(converted_tools) == 0:
|
||||
raise ValueError("tool_choice 'required' specified but no tools provided")
|
||||
else:
|
||||
converted_tools = convert_tools(tools)
|
||||
|
||||
return CreateParams(
|
||||
messages=ollama_messages,
|
||||
@ -598,6 +614,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -610,6 +627,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
tool_choice,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
@ -704,6 +722,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -716,6 +735,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
tool_choice,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
|
||||
@ -15,6 +15,7 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -265,6 +266,31 @@ def convert_tools(
|
||||
return result
|
||||
|
||||
|
||||
def convert_tool_choice(tool_choice: Tool | Literal["auto", "required", "none"]) -> Any:
|
||||
"""Convert tool_choice parameter to OpenAI API format.
|
||||
|
||||
Args:
|
||||
tool_choice: A single Tool object to force the model to use, "auto" to let the model choose any available tool, "required" to force tool usage, or "none" to disable tool usage.
|
||||
|
||||
Returns:
|
||||
OpenAI API compatible tool_choice value or None if not specified.
|
||||
"""
|
||||
if tool_choice == "none":
|
||||
return "none"
|
||||
|
||||
if tool_choice == "auto":
|
||||
return "auto"
|
||||
|
||||
if tool_choice == "required":
|
||||
return "required"
|
||||
|
||||
# Must be a Tool object
|
||||
if isinstance(tool_choice, Tool):
|
||||
return {"type": "function", "function": {"name": tool_choice.schema["name"]}}
|
||||
else:
|
||||
raise ValueError(f"tool_choice must be a Tool object, 'auto', 'required', or 'none', got {type(tool_choice)}")
|
||||
|
||||
|
||||
def normalize_name(name: str) -> str:
|
||||
"""
|
||||
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
|
||||
@ -449,6 +475,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"],
|
||||
json_output: Optional[bool | type[BaseModel]],
|
||||
extra_create_args: Mapping[str, Any],
|
||||
) -> CreateParams:
|
||||
@ -575,6 +602,29 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
converted_tools = convert_tools(tools)
|
||||
|
||||
# Process tool_choice parameter
|
||||
if isinstance(tool_choice, Tool):
|
||||
if len(tools) == 0:
|
||||
raise ValueError("tool_choice specified but no tools provided")
|
||||
|
||||
# Validate that the tool exists in the provided tools
|
||||
tool_names_available: List[str] = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, Tool):
|
||||
tool_names_available.append(tool.schema["name"])
|
||||
else:
|
||||
tool_names_available.append(tool["name"])
|
||||
|
||||
# tool_choice is a single Tool object
|
||||
tool_name = tool_choice.schema["name"]
|
||||
if tool_name not in tool_names_available:
|
||||
raise ValueError(f"tool_choice references '{tool_name}' but it's not in the provided tools")
|
||||
|
||||
if len(converted_tools) > 0:
|
||||
# Convert to OpenAI format and add to create_args
|
||||
converted_tool_choice = convert_tool_choice(tool_choice)
|
||||
create_args["tool_choice"] = converted_tool_choice
|
||||
|
||||
return CreateParams(
|
||||
messages=oai_messages,
|
||||
tools=converted_tools,
|
||||
@ -587,6 +637,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -594,6 +645,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
tool_choice,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
@ -738,6 +790,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -769,6 +822,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
tool_choice,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, Union
|
||||
|
||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component
|
||||
from autogen_core.models import (
|
||||
@ -162,11 +162,16 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
"""Return the next completion from the list."""
|
||||
# Warn if tool_choice is specified since it's ignored in replay mode
|
||||
if tool_choice != "auto":
|
||||
logger.warning("tool_choice parameter specified but is ignored in replay mode")
|
||||
|
||||
if self._current_index >= len(self.chat_completions):
|
||||
raise ValueError("No more mock responses available")
|
||||
|
||||
@ -201,11 +206,16 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
"""Return the next completion as a stream."""
|
||||
# Warn if tool_choice is specified since it's ignored in replay mode
|
||||
if tool_choice != "auto":
|
||||
logger.warning("tool_choice parameter specified but is ignored in replay mode")
|
||||
|
||||
if self._current_index >= len(self.chat_completions):
|
||||
raise ValueError("No more mock responses available")
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Literal, Mapping, Optional, Sequence
|
||||
from typing import Any, Literal, Mapping, Optional, Sequence, Union
|
||||
|
||||
from autogen_core import EVENT_LOGGER_NAME, FunctionCall
|
||||
from autogen_core._cancellation_token import CancellationToken
|
||||
@ -29,7 +29,7 @@ from semantic_kernel.contents import (
|
||||
)
|
||||
from semantic_kernel.functions.kernel_plugin import KernelPlugin
|
||||
from semantic_kernel.kernel import Kernel
|
||||
from typing_extensions import AsyncGenerator, Union
|
||||
from typing_extensions import AsyncGenerator
|
||||
|
||||
from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool, KernelFunctionFromToolSchema
|
||||
|
||||
@ -442,6 +442,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -473,6 +474,13 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
|
||||
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
|
||||
|
||||
# Handle tool_choice parameter
|
||||
if tool_choice != "auto":
|
||||
warnings.warn(
|
||||
"tool_choice parameter is specified but may not be fully supported by SKChatCompletionAdapter.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
kernel = self._get_kernel(extra_create_args)
|
||||
|
||||
chat_history = self._convert_to_chat_history(messages)
|
||||
@ -553,6 +561,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
@ -585,6 +594,13 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
|
||||
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
|
||||
|
||||
# Handle tool_choice parameter
|
||||
if tool_choice != "auto":
|
||||
warnings.warn(
|
||||
"tool_choice parameter is specified but may not be fully supported by SKChatCompletionAdapter.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
kernel = self._get_kernel(extra_create_args)
|
||||
chat_history = self._convert_to_chat_history(messages)
|
||||
user_settings = self._get_prompt_settings(extra_create_args)
|
||||
|
||||
@ -2,6 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Sequence
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from autogen_core import CancellationToken, FunctionCall
|
||||
@ -34,7 +35,153 @@ def _add_numbers(a: int, b: int) -> int:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_serialization_api_key() -> None:
|
||||
async def test_mock_tool_choice_specific_tool() -> None:
|
||||
"""Test tool_choice parameter with a specific tool using mocks."""
|
||||
# Create mock client and response
|
||||
mock_client = AsyncMock()
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = [MagicMock(type="tool_use", name="process_text", input={"input": "hello"}, id="call_123")]
|
||||
mock_message.usage.input_tokens = 10
|
||||
mock_message.usage.output_tokens = 5
|
||||
|
||||
mock_client.messages.create.return_value = mock_message
|
||||
|
||||
# Create real client but patch the underlying Anthropic client
|
||||
client = AnthropicChatCompletionClient(
|
||||
model="claude-3-haiku-20240307",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Process the text 'hello'.", source="user"),
|
||||
]
|
||||
|
||||
with patch.object(client, "_client", mock_client):
|
||||
await client.create(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice=pass_tool, # Force use of specific tool
|
||||
)
|
||||
|
||||
# Verify the correct API call was made
|
||||
mock_client.messages.create.assert_called_once()
|
||||
call_args = mock_client.messages.create.call_args
|
||||
|
||||
# Check that tool_choice was set correctly
|
||||
assert "tool_choice" in call_args.kwargs
|
||||
assert call_args.kwargs["tool_choice"] == {"type": "tool", "name": "process_text"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_tool_choice_auto() -> None:
|
||||
"""Test tool_choice parameter with 'auto' setting using mocks."""
|
||||
# Create mock client and response
|
||||
mock_client = AsyncMock()
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = [MagicMock(type="tool_use", name="add_numbers", input={"a": 1, "b": 2}, id="call_123")]
|
||||
mock_message.usage.input_tokens = 10
|
||||
mock_message.usage.output_tokens = 5
|
||||
|
||||
mock_client.messages.create.return_value = mock_message
|
||||
|
||||
# Create real client but patch the underlying Anthropic client
|
||||
client = AnthropicChatCompletionClient(
|
||||
model="claude-3-haiku-20240307",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Add 1 and 2.", source="user"),
|
||||
]
|
||||
|
||||
with patch.object(client, "_client", mock_client):
|
||||
await client.create(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice="auto", # Let model choose
|
||||
)
|
||||
|
||||
# Verify the correct API call was made
|
||||
mock_client.messages.create.assert_called_once()
|
||||
call_args = mock_client.messages.create.call_args
|
||||
|
||||
# Check that tool_choice was set correctly
|
||||
assert "tool_choice" in call_args.kwargs
|
||||
assert call_args.kwargs["tool_choice"] == {"type": "auto"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_tool_choice_none() -> None:
|
||||
"""Test tool_choice parameter when no tools are provided - tool_choice should not be included."""
|
||||
# Create mock client and response
|
||||
mock_client = AsyncMock()
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = [MagicMock(type="text", text="I can help you with that.")]
|
||||
mock_message.usage.input_tokens = 10
|
||||
mock_message.usage.output_tokens = 5
|
||||
|
||||
mock_client.messages.create.return_value = mock_message
|
||||
|
||||
# Create real client but patch the underlying Anthropic client
|
||||
client = AnthropicChatCompletionClient(
|
||||
model="claude-3-haiku-20240307",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello there.", source="user"),
|
||||
]
|
||||
|
||||
with patch.object(client, "_client", mock_client):
|
||||
await client.create(
|
||||
messages=messages,
|
||||
# No tools provided - tool_choice should not be included in API call
|
||||
)
|
||||
|
||||
# Verify the correct API call was made
|
||||
mock_client.messages.create.assert_called_once()
|
||||
call_args = mock_client.messages.create.call_args
|
||||
|
||||
# Check that tool_choice was not set when no tools are provided
|
||||
assert "tool_choice" not in call_args.kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_tool_choice_validation_error() -> None:
|
||||
"""Test tool_choice validation with invalid tool reference."""
|
||||
client = AnthropicChatCompletionClient(
|
||||
model="claude-3-haiku-20240307",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
|
||||
different_tool = FunctionTool(_pass_function, description="Different tool", name="different_tool")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
UserMessage(content="Hello there.", source="user"),
|
||||
]
|
||||
|
||||
# Test with a tool that's not in the tools list
|
||||
with pytest.raises(ValueError, match="tool_choice references 'different_tool' but it's not in the available tools"):
|
||||
await client.create(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice=different_tool, # This tool is not in the tools list
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_serialization_api_key() -> None:
|
||||
client = AnthropicChatCompletionClient(
|
||||
model="claude-3-haiku-20240307", # Use haiku for faster/cheaper testing
|
||||
api_key="sk-password",
|
||||
@ -342,7 +489,7 @@ async def test_anthropic_multimodal() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_serialization() -> None:
|
||||
async def test_mock_serialization() -> None:
|
||||
"""Test serialization and deserialization of component."""
|
||||
|
||||
client = AnthropicChatCompletionClient(
|
||||
@ -422,7 +569,7 @@ async def test_anthropic_muliple_system_message() -> None:
|
||||
assert result_content[-3:] == "BAR"
|
||||
|
||||
|
||||
def test_merge_continuous_system_messages() -> None:
|
||||
def test_mock_merge_continuous_system_messages() -> None:
|
||||
"""Tests merging of continuous system messages."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
@ -447,7 +594,7 @@ def test_merge_continuous_system_messages() -> None:
|
||||
assert merged_messages[1].content == "User question"
|
||||
|
||||
|
||||
def test_merge_single_system_message() -> None:
|
||||
def test_mock_merge_single_system_message() -> None:
|
||||
"""Tests that a single system message remains unchanged."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
@ -467,7 +614,7 @@ def test_merge_single_system_message() -> None:
|
||||
assert merged_messages[0].content == "Single system instruction"
|
||||
|
||||
|
||||
def test_merge_no_system_messages() -> None:
|
||||
def test_mock_merge_no_system_messages() -> None:
|
||||
"""Tests behavior when there are no system messages."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
@ -486,7 +633,7 @@ def test_merge_no_system_messages() -> None:
|
||||
assert merged_messages[0].content == "User question without system"
|
||||
|
||||
|
||||
def test_merge_non_continuous_system_messages() -> None:
|
||||
def test_mock_merge_non_continuous_system_messages() -> None:
|
||||
"""Tests that an error is raised for non-continuous system messages."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
@ -504,7 +651,7 @@ def test_merge_non_continuous_system_messages() -> None:
|
||||
# The method is protected, but we need to test it
|
||||
|
||||
|
||||
def test_merge_system_messages_empty() -> None:
|
||||
def test_mock_merge_system_messages_empty() -> None:
|
||||
"""Tests that empty message list is handled properly."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
@ -513,7 +660,7 @@ def test_merge_system_messages_empty() -> None:
|
||||
assert len(merged_messages) == 0
|
||||
|
||||
|
||||
def test_merge_system_messages_with_special_characters() -> None:
|
||||
def test_mock_merge_system_messages_with_special_characters() -> None:
|
||||
"""Tests system message merging with special characters and formatting."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
@ -533,7 +680,7 @@ def test_merge_system_messages_with_special_characters() -> None:
|
||||
assert system_message.content == "Line 1\nWith newline\nLine 2 with *formatting*\nLine 3 with `code`"
|
||||
|
||||
|
||||
def test_merge_system_messages_with_whitespace() -> None:
|
||||
def test_mock_merge_system_messages_with_whitespace() -> None:
|
||||
"""Tests system message merging with extra whitespace."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
@ -553,7 +700,7 @@ def test_merge_system_messages_with_whitespace() -> None:
|
||||
assert system_message.content == " Message with leading spaces \n\nMessage with leading newline"
|
||||
|
||||
|
||||
def test_merge_system_messages_message_order() -> None:
|
||||
def test_mock_merge_system_messages_message_order() -> None:
|
||||
"""Tests that message order is preserved after merging."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
@ -584,7 +731,7 @@ def test_merge_system_messages_message_order() -> None:
|
||||
assert merged_messages[3].content == "Answer"
|
||||
|
||||
|
||||
def test_merge_system_messages_multiple_groups() -> None:
|
||||
def test_mock_merge_system_messages_multiple_groups() -> None:
|
||||
"""Tests that multiple separate groups of system messages raise an error."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
@ -600,7 +747,7 @@ def test_merge_system_messages_multiple_groups() -> None:
|
||||
# The method is protected, but we need to test it
|
||||
|
||||
|
||||
def test_merge_system_messages_no_duplicates() -> None:
|
||||
def test_mock_merge_system_messages_no_duplicates() -> None:
|
||||
"""Tests that identical system messages are still merged properly."""
|
||||
client = AnthropicChatCompletionClient(model="claude-3-haiku-20240307", api_key="fake-api-key")
|
||||
|
||||
@ -621,7 +768,7 @@ def test_merge_system_messages_no_duplicates() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_assistant_content_string_with_anthropic() -> None:
|
||||
async def test_anthropic_empty_assistant_content_string() -> None:
|
||||
"""Test that an empty assistant content string is handled correctly."""
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
@ -646,7 +793,7 @@ async def test_empty_assistant_content_string_with_anthropic() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude_trailing_whitespace_at_last_assistant_content() -> None:
|
||||
async def test_anthropic_trailing_whitespace_at_last_assistant_content() -> None:
|
||||
"""Test that an empty assistant content string is handled correctly."""
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
@ -667,7 +814,7 @@ async def test_claude_trailing_whitespace_at_last_assistant_content() -> None:
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_rstrip_railing_whitespace_at_last_assistant_content() -> None:
|
||||
def test_mock_rstrip_trailing_whitespace_at_last_assistant_content() -> None:
|
||||
messages: list[LLMMessage] = [
|
||||
UserMessage(content="foo", source="user"),
|
||||
UserMessage(content="bar", source="user"),
|
||||
@ -680,3 +827,175 @@ def test_rstrip_railing_whitespace_at_last_assistant_content() -> None:
|
||||
|
||||
assert isinstance(result[-1].content, str)
|
||||
assert result[-1].content == "foobar"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_tool_choice_with_actual_api() -> None:
|
||||
"""Test tool_choice parameter with actual Anthropic API endpoints."""
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
||||
|
||||
client = AnthropicChatCompletionClient(
|
||||
model="claude-3-haiku-20240307",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
|
||||
|
||||
# Test 1: tool_choice with specific tool
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="Use the tools as needed to help the user."),
|
||||
UserMessage(content="Process the text 'hello world' using the process_text tool.", source="user"),
|
||||
]
|
||||
|
||||
result = await client.create(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice=pass_tool, # Force use of specific tool
|
||||
)
|
||||
|
||||
# Verify we got a tool call for the specified tool
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) >= 1
|
||||
assert isinstance(result.content[0], FunctionCall)
|
||||
assert result.content[0].name == "process_text"
|
||||
|
||||
# Test 2: tool_choice="auto" with tools
|
||||
auto_messages: List[LLMMessage] = [
|
||||
SystemMessage(content="Use the tools as needed to help the user."),
|
||||
UserMessage(content="Add the numbers 5 and 3.", source="user"),
|
||||
]
|
||||
|
||||
auto_result = await client.create(
|
||||
messages=auto_messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice="auto", # Let model choose
|
||||
)
|
||||
|
||||
# Should get a tool call, likely for add_numbers
|
||||
assert isinstance(auto_result.content, list)
|
||||
assert len(auto_result.content) >= 1
|
||||
assert isinstance(auto_result.content[0], FunctionCall)
|
||||
# Model should choose add_numbers for addition task
|
||||
assert auto_result.content[0].name == "add_numbers"
|
||||
|
||||
# Test 3: No tools provided - should not include tool_choice in API call
|
||||
no_tools_messages: List[LLMMessage] = [
|
||||
UserMessage(content="What is the capital of France?", source="user"),
|
||||
]
|
||||
|
||||
no_tools_result = await client.create(messages=no_tools_messages)
|
||||
|
||||
# Should get a text response without tool calls
|
||||
assert isinstance(no_tools_result.content, str)
|
||||
assert "paris" in no_tools_result.content.lower()
|
||||
|
||||
# Test 4: tool_choice="required" with tools
|
||||
required_messages: List[LLMMessage] = [
|
||||
SystemMessage(content="You must use one of the available tools to help the user."),
|
||||
UserMessage(content="Help me with something.", source="user"),
|
||||
]
|
||||
|
||||
required_result = await client.create(
|
||||
messages=required_messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice="required", # Force tool usage
|
||||
)
|
||||
|
||||
# Should get a tool call (model forced to use a tool)
|
||||
assert isinstance(required_result.content, list)
|
||||
assert len(required_result.content) >= 1
|
||||
assert isinstance(required_result.content[0], FunctionCall)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_tool_choice_streaming_with_actual_api() -> None:
|
||||
"""Test tool_choice parameter with streaming using actual Anthropic API endpoints."""
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
||||
|
||||
client = AnthropicChatCompletionClient(
|
||||
model="claude-3-haiku-20240307",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
|
||||
|
||||
# Test streaming with tool_choice
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="Use the tools as needed to help the user."),
|
||||
UserMessage(content="Process the text 'streaming test' using the process_text tool.", source="user"),
|
||||
]
|
||||
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in client.create_stream(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice=pass_tool, # Force use of specific tool
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
# Verify we got chunks and a final result
|
||||
assert len(chunks) > 0
|
||||
final_result = chunks[-1]
|
||||
assert isinstance(final_result, CreateResult)
|
||||
|
||||
# Should get a tool call for the specified tool
|
||||
assert isinstance(final_result.content, list)
|
||||
assert len(final_result.content) >= 1
|
||||
assert isinstance(final_result.content[0], FunctionCall)
|
||||
assert final_result.content[0].name == "process_text"
|
||||
|
||||
# Test streaming without tools - should not include tool_choice
|
||||
no_tools_messages: List[LLMMessage] = [
|
||||
UserMessage(content="Tell me a short fact about cats.", source="user"),
|
||||
]
|
||||
|
||||
no_tools_chunks: List[str | CreateResult] = []
|
||||
async for chunk in client.create_stream(messages=no_tools_messages):
|
||||
no_tools_chunks.append(chunk)
|
||||
|
||||
# Should get text response
|
||||
assert len(no_tools_chunks) > 0
|
||||
final_no_tools_result = no_tools_chunks[-1]
|
||||
assert isinstance(final_no_tools_result, CreateResult)
|
||||
assert isinstance(final_no_tools_result.content, str)
|
||||
assert len(final_no_tools_result.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_tool_choice_none_value_with_actual_api() -> None:
|
||||
"""Test tool_choice="none" with actual Anthropic API endpoints."""
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
|
||||
|
||||
client = AnthropicChatCompletionClient(
|
||||
model="claude-3-haiku-20240307",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
|
||||
|
||||
# Test tool_choice="none" - should not use tools even when available
|
||||
messages: List[LLMMessage] = [
|
||||
SystemMessage(content="Answer the user's question directly without using tools."),
|
||||
UserMessage(content="What is 2 + 2?", source="user"),
|
||||
]
|
||||
|
||||
result = await client.create(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice="none", # Disable tool usage
|
||||
)
|
||||
|
||||
# Should get a text response, not tool calls
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
from autogen_core import CancellationToken, FunctionCall, Image
|
||||
from autogen_core.models import CreateResult, ModelFamily, UserMessage
|
||||
from autogen_core.tools import FunctionTool
|
||||
from autogen_ext.models.azure import AzureAIChatCompletionClient
|
||||
from autogen_ext.models.azure.config import GITHUB_MODELS_ENDPOINT
|
||||
from azure.ai.inference.aio import (
|
||||
@ -623,3 +624,352 @@ async def test_thought_field_with_tool_calls_streaming(
|
||||
assert final_result.content[0].arguments == '{"foo": "bar"}'
|
||||
|
||||
assert final_result.thought == "Let me think about what function to call."
|
||||
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
"""Simple passthrough function."""
|
||||
return f"Processed: {input}"
|
||||
|
||||
|
||||
def _add_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_choice_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
|
||||
"""
|
||||
Returns a client that supports function calling for tool choice tests.
|
||||
"""
|
||||
|
||||
async def _mock_tool_choice_stream(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
|
||||
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
|
||||
|
||||
mock_chunks = [
|
||||
StreamingChatChoiceUpdate(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
|
||||
)
|
||||
for chunk_content in mock_chunks_content
|
||||
]
|
||||
|
||||
for mock_chunk in mock_chunks:
|
||||
await asyncio.sleep(0.01)
|
||||
yield StreamingChatCompletionsUpdate(
|
||||
id="id",
|
||||
choices=[mock_chunk],
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
usage=CompletionsUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.close = AsyncMock()
|
||||
|
||||
async def mock_complete(*args: Any, **kwargs: Any) -> Any:
|
||||
stream = kwargs.get("stream", False)
|
||||
|
||||
if not stream:
|
||||
await asyncio.sleep(0.01)
|
||||
return ChatCompletions(
|
||||
id="id",
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
finish_reason=CompletionsFinishReason.TOOL_CALLS,
|
||||
message=ChatResponseMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionsToolCall(
|
||||
id="call_123",
|
||||
function=AzureFunctionCall(name="process_text", arguments='{"input": "hello"}'),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=CompletionsUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
else:
|
||||
return _mock_tool_choice_stream(*args, **kwargs)
|
||||
|
||||
mock_client.complete = mock_complete
|
||||
|
||||
def mock_new(cls: Type[ChatCompletionsClient], *args: Any, **kwargs: Any) -> MagicMock:
|
||||
return mock_client
|
||||
|
||||
monkeypatch.setattr(ChatCompletionsClient, "__new__", mock_new)
|
||||
|
||||
return AzureAIChatCompletionClient(
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": True,
|
||||
"vision": False,
|
||||
"family": "test",
|
||||
"structured_output": False,
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_tool_choice_specific_tool(tool_choice_client: AzureAIChatCompletionClient) -> None:
|
||||
"""Test tool_choice parameter with a specific tool using mocks."""
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
|
||||
|
||||
messages = [
|
||||
UserMessage(content="Process the text 'hello'.", source="user"),
|
||||
]
|
||||
|
||||
result = await tool_choice_client.create(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice=pass_tool, # Force use of specific tool
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.finish_reason == "function_calls"
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], FunctionCall)
|
||||
assert result.content[0].name == "process_text"
|
||||
assert result.content[0].arguments == '{"input": "hello"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_tool_choice_auto(tool_choice_client: AzureAIChatCompletionClient) -> None:
|
||||
"""Test tool_choice parameter with 'auto' setting using mocks."""
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
|
||||
|
||||
messages = [
|
||||
UserMessage(content="Add 1 and 2.", source="user"),
|
||||
]
|
||||
|
||||
result = await tool_choice_client.create(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice="auto", # Let the model choose
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.finish_reason == "function_calls"
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], FunctionCall)
|
||||
assert result.content[0].name == "process_text" # Our mock always returns process_text
|
||||
assert result.content[0].arguments == '{"input": "hello"}'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_choice_none_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
|
||||
"""
|
||||
Returns a client that simulates no tool calls for tool_choice='none' tests.
|
||||
"""
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.close = AsyncMock()
|
||||
|
||||
async def mock_complete(*args: Any, **kwargs: Any) -> ChatCompletions:
|
||||
await asyncio.sleep(0.01)
|
||||
return ChatCompletions(
|
||||
id="id",
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
message=ChatResponseMessage(role="assistant", content="I can help you with that."),
|
||||
)
|
||||
],
|
||||
usage=CompletionsUsage(prompt_tokens=8, completion_tokens=6, total_tokens=14),
|
||||
)
|
||||
|
||||
mock_client.complete = mock_complete
|
||||
|
||||
def mock_new(cls: Type[ChatCompletionsClient], *args: Any, **kwargs: Any) -> MagicMock:
|
||||
return mock_client
|
||||
|
||||
monkeypatch.setattr(ChatCompletionsClient, "__new__", mock_new)
|
||||
|
||||
return AzureAIChatCompletionClient(
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": True,
|
||||
"vision": False,
|
||||
"family": "test",
|
||||
"structured_output": False,
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_tool_choice_none(tool_choice_none_client: AzureAIChatCompletionClient) -> None:
|
||||
"""Test tool_choice parameter with 'none' setting using mocks."""
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
|
||||
|
||||
messages = [
|
||||
UserMessage(content="Just say hello.", source="user"),
|
||||
]
|
||||
|
||||
result = await tool_choice_none_client.create(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice="none", # Prevent tool usage
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.finish_reason == "stop"
|
||||
assert isinstance(result.content, str)
|
||||
assert result.content == "I can help you with that."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_tool_choice_required(tool_choice_client: AzureAIChatCompletionClient) -> None:
|
||||
"""Test tool_choice parameter with 'required' setting using mocks."""
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
|
||||
|
||||
messages = [
|
||||
UserMessage(content="Process some text.", source="user"),
|
||||
]
|
||||
|
||||
result = await tool_choice_client.create(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice="required", # Force tool usage
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.finish_reason == "function_calls"
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], FunctionCall)
|
||||
assert result.content[0].name == "process_text"
|
||||
assert result.content[0].arguments == '{"input": "hello"}'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_choice_stream_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient:
|
||||
"""
|
||||
Returns a client that supports function calling for streaming tool choice tests.
|
||||
"""
|
||||
|
||||
# Mock tool call for streaming
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.id = "call_123"
|
||||
mock_tool_call.function = MagicMock()
|
||||
mock_tool_call.function.name = "process_text"
|
||||
mock_tool_call.function.arguments = '{"input": "hello"}'
|
||||
|
||||
# First choice with content
|
||||
first_choice = MagicMock()
|
||||
first_choice.delta = MagicMock()
|
||||
first_choice.delta.content = "Let me process this for you."
|
||||
first_choice.finish_reason = None
|
||||
|
||||
# Tool call choice
|
||||
tool_call_choice = MagicMock()
|
||||
tool_call_choice.delta = MagicMock()
|
||||
tool_call_choice.delta.content = None
|
||||
tool_call_choice.delta.tool_calls = [mock_tool_call]
|
||||
tool_call_choice.finish_reason = "function_calls"
|
||||
|
||||
async def _mock_tool_choice_stream(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
|
||||
yield StreamingChatCompletionsUpdate(
|
||||
id="id",
|
||||
choices=[first_choice],
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
yield StreamingChatCompletionsUpdate(
|
||||
id="id",
|
||||
choices=[tool_call_choice],
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
usage=CompletionsUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.close = AsyncMock()
|
||||
|
||||
async def mock_complete(*args: Any, **kwargs: Any) -> Any:
|
||||
if kwargs.get("stream", False):
|
||||
return _mock_tool_choice_stream(*args, **kwargs)
|
||||
return None
|
||||
|
||||
mock_client.complete = mock_complete
|
||||
|
||||
def mock_new(cls: Type[ChatCompletionsClient], *args: Any, **kwargs: Any) -> MagicMock:
|
||||
return mock_client
|
||||
|
||||
monkeypatch.setattr(ChatCompletionsClient, "__new__", mock_new)
|
||||
|
||||
return AzureAIChatCompletionClient(
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": True,
|
||||
"vision": False,
|
||||
"family": "test",
|
||||
"structured_output": False,
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_tool_choice_specific_tool_streaming(
|
||||
tool_choice_stream_client: AzureAIChatCompletionClient,
|
||||
) -> None:
|
||||
"""Test tool_choice parameter with streaming and a specific tool using mocks."""
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="process_text")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
|
||||
|
||||
messages = [
|
||||
UserMessage(content="Process the text 'hello'.", source="user"),
|
||||
]
|
||||
|
||||
chunks: List[Union[str, CreateResult]] = []
|
||||
async for chunk in tool_choice_stream_client.create_stream(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice=pass_tool, # Force use of specific tool
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
# Verify that we got some result
|
||||
final_result = chunks[-1]
|
||||
assert isinstance(final_result, CreateResult)
|
||||
assert final_result.finish_reason == "function_calls"
|
||||
assert isinstance(final_result.content, list)
|
||||
assert len(final_result.content) == 1
|
||||
assert isinstance(final_result.content[0], FunctionCall)
|
||||
assert final_result.content[0].name == "process_text"
|
||||
assert final_result.content[0].arguments == '{"input": "hello"}'
|
||||
assert final_result.thought == "Let me process this for you."
|
||||
|
||||
@ -590,6 +590,7 @@ async def test_ollama_create_tools(model: str, ollama_client: OllamaChatCompleti
|
||||
assert len(create_result.content) > 0
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == add_tool.name
|
||||
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
execution_result = FunctionExecutionResult(
|
||||
@ -679,38 +680,11 @@ async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatC
|
||||
assert len(create_result.content) > 0
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == add_tool.name
|
||||
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
|
||||
assert create_result.finish_reason == "stop"
|
||||
|
||||
execution_result = FunctionExecutionResult(
|
||||
content="4",
|
||||
name=add_tool.name,
|
||||
call_id=create_result.content[0].id,
|
||||
is_error=False,
|
||||
)
|
||||
stream = ollama_client.create_stream(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is 2 + 2? Use the add tool.",
|
||||
source="user",
|
||||
),
|
||||
AssistantMessage(
|
||||
content=create_result.content,
|
||||
source="assistant",
|
||||
),
|
||||
FunctionExecutionResultMessage(
|
||||
content=[execution_result],
|
||||
),
|
||||
],
|
||||
)
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
create_result = chunks[-1]
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
assert create_result.finish_reason == "stop"
|
||||
assert create_result.usage is not None
|
||||
assert create_result.usage.prompt_tokens == 10
|
||||
assert create_result.usage.completion_tokens == 12
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -883,3 +857,459 @@ async def test_llm_control_params(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert chat_kwargs_captured["options"]["temperature"] == 0.7
|
||||
assert chat_kwargs_captured["options"]["top_p"] == 0.9
|
||||
assert chat_kwargs_captured["options"]["frequency_penalty"] == 1.2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_auto(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test tool_choice='auto' (default behavior)"""
|
||||
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
def multiply(x: int, y: int) -> str:
|
||||
return str(x * y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
multiply_tool = FunctionTool(multiply, description="Multiply two numbers")
|
||||
model = "llama3.2"
|
||||
|
||||
# Capture the kwargs passed to chat
|
||||
chat_kwargs_captured: Dict[str, Any] = {}
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
|
||||
nonlocal chat_kwargs_captured
|
||||
chat_kwargs_captured = kwargs
|
||||
return ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content="I'll use the add tool.",
|
||||
tool_calls=[
|
||||
Message.ToolCall(
|
||||
function=Message.ToolCall.Function(
|
||||
name=add_tool.name,
|
||||
arguments={"x": 2, "y": 3},
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
create_result = await client.create(
|
||||
messages=[UserMessage(content="What is 2 + 3?", source="user")],
|
||||
tools=[add_tool, multiply_tool],
|
||||
tool_choice="auto", # Explicitly set to auto
|
||||
)
|
||||
|
||||
# Verify that all tools are passed to the API when tool_choice is auto
|
||||
assert "tools" in chat_kwargs_captured
|
||||
assert chat_kwargs_captured["tools"] is not None
|
||||
assert len(chat_kwargs_captured["tools"]) == 2
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) > 0
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == add_tool.name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_none(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test tool_choice='none' - no tools should be passed to API"""
|
||||
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
model = "llama3.2"
|
||||
content_raw = "I cannot use tools, so I'll calculate manually: 2 + 3 = 5"
|
||||
|
||||
# Capture the kwargs passed to chat
|
||||
chat_kwargs_captured: Dict[str, Any] = {}
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
|
||||
nonlocal chat_kwargs_captured
|
||||
chat_kwargs_captured = kwargs
|
||||
return ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=content_raw,
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
create_result = await client.create(
|
||||
messages=[UserMessage(content="What is 2 + 3?", source="user")],
|
||||
tools=[add_tool],
|
||||
tool_choice="none",
|
||||
)
|
||||
|
||||
# Verify that no tools are passed to the API when tool_choice is none
|
||||
assert "tools" in chat_kwargs_captured
|
||||
assert chat_kwargs_captured["tools"] is None
|
||||
|
||||
# Verify the response is text content
|
||||
assert isinstance(create_result.content, str)
|
||||
assert create_result.content == content_raw
|
||||
assert create_result.finish_reason == "stop"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_required(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test tool_choice='required' - tools must be provided and passed to API"""
|
||||
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
def multiply(x: int, y: int) -> str:
|
||||
return str(x * y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
multiply_tool = FunctionTool(multiply, description="Multiply two numbers")
|
||||
model = "llama3.2"
|
||||
|
||||
# Capture the kwargs passed to chat
|
||||
chat_kwargs_captured: Dict[str, Any] = {}
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
|
||||
nonlocal chat_kwargs_captured
|
||||
chat_kwargs_captured = kwargs
|
||||
return ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="tool_calls",
|
||||
message=Message(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
Message.ToolCall(
|
||||
function=Message.ToolCall.Function(
|
||||
name=add_tool.name,
|
||||
arguments={"x": 2, "y": 3},
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
create_result = await client.create(
|
||||
messages=[UserMessage(content="What is 2 + 3?", source="user")],
|
||||
tools=[add_tool, multiply_tool],
|
||||
tool_choice="required",
|
||||
)
|
||||
|
||||
# Verify that all tools are passed to the API when tool_choice is required
|
||||
assert "tools" in chat_kwargs_captured
|
||||
assert chat_kwargs_captured["tools"] is not None
|
||||
assert len(chat_kwargs_captured["tools"]) == 2
|
||||
|
||||
# Verify the response contains function calls
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) > 0
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == add_tool.name
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_required_no_tools_error() -> None:
|
||||
"""Test tool_choice='required' with no tools raises ValueError"""
|
||||
model = "llama3.2"
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
|
||||
with pytest.raises(ValueError, match="tool_choice 'required' specified but no tools provided"):
|
||||
await client.create(
|
||||
messages=[UserMessage(content="What is 2 + 3?", source="user")],
|
||||
tools=[], # No tools provided
|
||||
tool_choice="required",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_specific_tool(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test tool_choice with a specific tool - only that tool should be passed to API"""
|
||||
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
def multiply(x: int, y: int) -> str:
|
||||
return str(x * y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
multiply_tool = FunctionTool(multiply, description="Multiply two numbers")
|
||||
model = "llama3.2"
|
||||
|
||||
# Capture the kwargs passed to chat
|
||||
chat_kwargs_captured: Dict[str, Any] = {}
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
|
||||
nonlocal chat_kwargs_captured
|
||||
chat_kwargs_captured = kwargs
|
||||
return ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="tool_calls",
|
||||
message=Message(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
Message.ToolCall(
|
||||
function=Message.ToolCall.Function(
|
||||
name=add_tool.name,
|
||||
arguments={"x": 2, "y": 3},
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
create_result = await client.create(
|
||||
messages=[UserMessage(content="What is 2 + 3?", source="user")],
|
||||
tools=[add_tool, multiply_tool], # Multiple tools available
|
||||
tool_choice=add_tool, # But force specific tool
|
||||
)
|
||||
|
||||
# Verify that only the specified tool is passed to the API
|
||||
assert "tools" in chat_kwargs_captured
|
||||
assert chat_kwargs_captured["tools"] is not None
|
||||
assert len(chat_kwargs_captured["tools"]) == 1
|
||||
assert chat_kwargs_captured["tools"][0]["function"]["name"] == add_tool.name
|
||||
|
||||
# Verify the response contains function calls
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) > 0
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == add_tool.name
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_stream_auto(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test tool_choice='auto' with streaming"""
|
||||
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
model = "llama3.2"
|
||||
content_raw = "I'll use the add tool."
|
||||
|
||||
# Capture the kwargs passed to chat
|
||||
chat_kwargs_captured: Dict[str, Any] = {}
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
|
||||
nonlocal chat_kwargs_captured
|
||||
chat_kwargs_captured = kwargs
|
||||
assert "stream" in kwargs
|
||||
assert kwargs["stream"] is True
|
||||
|
||||
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
|
||||
chunks = [content_raw[i : i + 5] for i in range(0, len(content_raw), 5)]
|
||||
# Simulate streaming by yielding chunks of the response
|
||||
for chunk in chunks[:-1]:
|
||||
yield ChatResponse(
|
||||
model=model,
|
||||
done=False,
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=chunk,
|
||||
),
|
||||
)
|
||||
yield ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="tool_calls",
|
||||
message=Message(
|
||||
content=chunks[-1],
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
Message.ToolCall(
|
||||
function=Message.ToolCall.Function(
|
||||
name=add_tool.name,
|
||||
arguments={"x": 2, "y": 3},
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
return _mock_stream()
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
stream = client.create_stream(
|
||||
messages=[UserMessage(content="What is 2 + 3?", source="user")],
|
||||
tools=[add_tool],
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Verify that tools are passed to the API when tool_choice is auto
|
||||
assert "tools" in chat_kwargs_captured
|
||||
assert chat_kwargs_captured["tools"] is not None
|
||||
assert len(chat_kwargs_captured["tools"]) == 1
|
||||
|
||||
# Verify the final result
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert isinstance(chunks[-1].content, list)
|
||||
assert len(chunks[-1].content) > 0
|
||||
assert isinstance(chunks[-1].content[0], FunctionCall)
|
||||
assert chunks[-1].content[0].name == add_tool.name
|
||||
assert chunks[-1].finish_reason == "function_calls"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_stream_none(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test tool_choice='none' with streaming"""
|
||||
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
model = "llama3.2"
|
||||
content_raw = "I cannot use tools, so I'll calculate manually: 2 + 3 = 5"
|
||||
|
||||
# Capture the kwargs passed to chat
|
||||
chat_kwargs_captured: Dict[str, Any] = {}
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
|
||||
nonlocal chat_kwargs_captured
|
||||
chat_kwargs_captured = kwargs
|
||||
assert "stream" in kwargs
|
||||
assert kwargs["stream"] is True
|
||||
|
||||
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
|
||||
chunks = [content_raw[i : i + 10] for i in range(0, len(content_raw), 10)]
|
||||
# Simulate streaming by yielding chunks of the response
|
||||
for chunk in chunks[:-1]:
|
||||
yield ChatResponse(
|
||||
model=model,
|
||||
done=False,
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=chunk,
|
||||
),
|
||||
)
|
||||
yield ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=chunks[-1],
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
return _mock_stream()
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
stream = client.create_stream(
|
||||
messages=[UserMessage(content="What is 2 + 3?", source="user")],
|
||||
tools=[add_tool],
|
||||
tool_choice="none",
|
||||
)
|
||||
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Verify that no tools are passed to the API when tool_choice is none
|
||||
assert "tools" in chat_kwargs_captured
|
||||
assert chat_kwargs_captured["tools"] is None
|
||||
|
||||
# Verify the final result is text content
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert isinstance(chunks[-1].content, str)
|
||||
assert chunks[-1].content == content_raw
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_choice_default_behavior(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that default behavior (no tool_choice specified) works like 'auto'"""
|
||||
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
model = "llama3.2"
|
||||
|
||||
# Capture the kwargs passed to chat
|
||||
chat_kwargs_captured: Dict[str, Any] = {}
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
|
||||
nonlocal chat_kwargs_captured
|
||||
chat_kwargs_captured = kwargs
|
||||
return ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content="I'll use the add tool.",
|
||||
tool_calls=[
|
||||
Message.ToolCall(
|
||||
function=Message.ToolCall.Function(
|
||||
name=add_tool.name,
|
||||
arguments={"x": 2, "y": 3},
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
create_result = await client.create(
|
||||
messages=[UserMessage(content="What is 2 + 3?", source="user")],
|
||||
tools=[add_tool],
|
||||
# tool_choice not specified - should default to "auto"
|
||||
)
|
||||
|
||||
# Verify that tools are passed to the API by default (auto behavior)
|
||||
assert "tools" in chat_kwargs_captured
|
||||
assert chat_kwargs_captured["tools"] is not None
|
||||
assert len(chat_kwargs_captured["tools"]) == 1
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) > 0
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == add_tool.name
|
||||
|
||||
@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@ -2193,6 +2193,7 @@ async def test_system_message_merge_with_continuous_system_messages_models() ->
|
||||
tools=[],
|
||||
json_output=None,
|
||||
extra_create_args={},
|
||||
tool_choice="none",
|
||||
)
|
||||
|
||||
# Extract the actual messages from the result
|
||||
@ -2243,6 +2244,7 @@ async def test_system_message_merge_with_non_continuous_messages() -> None:
|
||||
tools=[],
|
||||
json_output=None,
|
||||
extra_create_args={},
|
||||
tool_choice="none",
|
||||
)
|
||||
|
||||
|
||||
@ -2279,6 +2281,7 @@ async def test_system_message_not_merged_for_multiple_system_messages_true() ->
|
||||
tools=[],
|
||||
json_output=None,
|
||||
extra_create_args={},
|
||||
tool_choice="none",
|
||||
)
|
||||
|
||||
# Extract the actual messages from the result
|
||||
@ -2322,6 +2325,7 @@ async def test_no_system_messages_for_gemini_model() -> None:
|
||||
tools=[],
|
||||
json_output=None,
|
||||
extra_create_args={},
|
||||
tool_choice="none",
|
||||
)
|
||||
|
||||
# Extract the actual messages from the result
|
||||
@ -2369,6 +2373,7 @@ async def test_single_system_message_for_gemini_model() -> None:
|
||||
tools=[],
|
||||
json_output=None,
|
||||
extra_create_args={},
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
# Extract the actual messages from the result
|
||||
@ -2561,4 +2566,493 @@ async def test_mistral_remove_name() -> None:
|
||||
assert ("name" in params[0]) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_tool_choice_specific_tool(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test tool_choice parameter with a specific tool using mocks."""
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
"""Simple passthrough function."""
|
||||
return f"Processed: {input}"
|
||||
|
||||
def _add_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
|
||||
model = "gpt-4o"
|
||||
|
||||
# Mock successful completion with specific tool call
|
||||
chat_completion = ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "hello"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(completion_tokens=10, prompt_tokens=5, total_tokens=15),
|
||||
)
|
||||
|
||||
client = OpenAIChatCompletionClient(model=model, api_key="test")
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
|
||||
|
||||
# Create mock for the chat completions create method
|
||||
mock_create = AsyncMock(return_value=chat_completion)
|
||||
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setattr(client._client.chat.completions, "create", mock_create) # type: ignore[reportPrivateUsage]
|
||||
|
||||
_ = await client.create(
|
||||
messages=[UserMessage(content="Process 'hello'", source="user")],
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice=pass_tool, # Force use of specific tool
|
||||
)
|
||||
|
||||
# Verify the correct API call was made
|
||||
mock_create.assert_called_once()
|
||||
call_args = mock_create.call_args
|
||||
|
||||
# Check that tool_choice was set correctly
|
||||
assert "tool_choice" in call_args.kwargs
|
||||
assert call_args.kwargs["tool_choice"] == {"type": "function", "function": {"name": "_pass_function"}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_tool_choice_auto(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test tool_choice parameter with 'auto' setting using mocks."""
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
"""Simple passthrough function."""
|
||||
return f"Processed: {input}"
|
||||
|
||||
def _add_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
|
||||
model = "gpt-4o"
|
||||
|
||||
# Mock successful completion
|
||||
chat_completion = ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_add_numbers",
|
||||
arguments=json.dumps({"a": 1, "b": 2}),
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(completion_tokens=10, prompt_tokens=5, total_tokens=15),
|
||||
)
|
||||
|
||||
client = OpenAIChatCompletionClient(model=model, api_key="test")
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
|
||||
|
||||
# Create mock for the chat completions create method
|
||||
mock_create = AsyncMock(return_value=chat_completion)
|
||||
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setattr(client._client.chat.completions, "create", mock_create) # type: ignore[reportPrivateUsage]
|
||||
|
||||
await client.create(
|
||||
messages=[UserMessage(content="Add 1 and 2", source="user")],
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice="auto", # Let model choose
|
||||
)
|
||||
|
||||
# Verify the correct API call was made
|
||||
mock_create.assert_called_once()
|
||||
call_args = mock_create.call_args
|
||||
|
||||
# Check that tool_choice was set correctly
|
||||
assert "tool_choice" in call_args.kwargs
|
||||
assert call_args.kwargs["tool_choice"] == "auto"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_tool_choice_none(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test tool_choice parameter with None setting using mocks."""
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
"""Simple passthrough function."""
|
||||
return f"Processed: {input}"
|
||||
|
||||
model = "gpt-4o"
|
||||
|
||||
# Mock successful completion
|
||||
chat_completion = ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="I can help you with that!",
|
||||
tool_calls=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(completion_tokens=10, prompt_tokens=5, total_tokens=15),
|
||||
)
|
||||
|
||||
client = OpenAIChatCompletionClient(model=model, api_key="test")
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
|
||||
|
||||
# Create mock for the chat completions create method
|
||||
mock_create = AsyncMock(return_value=chat_completion)
|
||||
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setattr(client._client.chat.completions, "create", mock_create) # type: ignore[reportPrivateUsage]
|
||||
|
||||
await client.create(
|
||||
messages=[UserMessage(content="Hello there", source="user")],
|
||||
tools=[pass_tool],
|
||||
tool_choice="none",
|
||||
)
|
||||
|
||||
# Verify the correct API call was made
|
||||
mock_create.assert_called_once()
|
||||
call_args = mock_create.call_args
|
||||
|
||||
# Check that tool_choice was set to "none" (disabling tool usage)
|
||||
assert "tool_choice" in call_args.kwargs
|
||||
assert call_args.kwargs["tool_choice"] == "none"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_tool_choice_validation_error() -> None:
|
||||
"""Test tool_choice validation with invalid tool reference."""
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
"""Simple passthrough function."""
|
||||
return f"Processed: {input}"
|
||||
|
||||
def _add_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
|
||||
def _different_function(text: str) -> str:
|
||||
"""Different function."""
|
||||
return text
|
||||
|
||||
client = OpenAIChatCompletionClient(model="gpt-4o", api_key="test")
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
|
||||
different_tool = FunctionTool(_different_function, description="Different tool", name="_different_function")
|
||||
|
||||
messages = [UserMessage(content="Hello there", source="user")]
|
||||
|
||||
# Test with a tool that's not in the tools list
|
||||
with pytest.raises(
|
||||
ValueError, match="tool_choice references '_different_function' but it's not in the provided tools"
|
||||
):
|
||||
await client.create(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice=different_tool, # This tool is not in the tools list
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_tool_choice_required(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test tool_choice parameter with 'required' setting using mocks."""
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
"""Simple passthrough function."""
|
||||
return f"Processed: {input}"
|
||||
|
||||
def _add_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
|
||||
model = "gpt-4o"
|
||||
|
||||
# Mock successful completion with tool calls (required forces tool usage)
|
||||
chat_completion = ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "hello"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1234567890,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(completion_tokens=10, prompt_tokens=5, total_tokens=15),
|
||||
)
|
||||
|
||||
client = OpenAIChatCompletionClient(model=model, api_key="test")
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
|
||||
|
||||
# Create mock for the chat completions create method
|
||||
mock_create = AsyncMock(return_value=chat_completion)
|
||||
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setattr(client._client.chat.completions, "create", mock_create) # type: ignore[reportPrivateUsage]
|
||||
|
||||
await client.create(
|
||||
messages=[UserMessage(content="Process some text", source="user")],
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice="required", # Force tool usage
|
||||
)
|
||||
|
||||
# Verify the correct API call was made
|
||||
mock_create.assert_called_once()
|
||||
call_args = mock_create.call_args
|
||||
|
||||
# Check that tool_choice was set correctly
|
||||
assert "tool_choice" in call_args.kwargs
|
||||
assert call_args.kwargs["tool_choice"] == "required"
|
||||
|
||||
|
||||
# Integration tests for tool_choice using the actual OpenAI API
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_tool_choice_specific_tool_integration() -> None:
|
||||
"""Test tool_choice parameter with a specific tool using the actual OpenAI API."""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
"""Simple passthrough function."""
|
||||
return f"Processed: {input}"
|
||||
|
||||
def _add_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
|
||||
model = "gpt-4o-mini"
|
||||
client = OpenAIChatCompletionClient(model=model, api_key=api_key)
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
|
||||
|
||||
# Test forcing use of specific tool
|
||||
result = await client.create(
|
||||
messages=[UserMessage(content="Process the word 'hello'", source="user")],
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice=pass_tool, # Force use of specific tool
|
||||
)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], FunctionCall)
|
||||
assert result.content[0].name == "_pass_function"
|
||||
assert result.finish_reason == "function_calls"
|
||||
assert result.usage is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_tool_choice_auto_integration() -> None:
|
||||
"""Test tool_choice parameter with 'auto' setting using the actual OpenAI API."""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
"""Simple passthrough function."""
|
||||
return f"Processed: {input}"
|
||||
|
||||
def _add_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
|
||||
model = "gpt-4o-mini"
|
||||
client = OpenAIChatCompletionClient(model=model, api_key=api_key)
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
|
||||
|
||||
# Test auto tool choice - model should choose to use add_numbers for math
|
||||
result = await client.create(
|
||||
messages=[UserMessage(content="What is 15 plus 27?", source="user")],
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice="auto", # Let model choose
|
||||
)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], FunctionCall)
|
||||
assert result.content[0].name == "_add_numbers"
|
||||
assert result.finish_reason == "function_calls"
|
||||
assert result.usage is not None
|
||||
|
||||
# Parse arguments to verify correct values
|
||||
args = json.loads(result.content[0].arguments)
|
||||
assert args["a"] == 15
|
||||
assert args["b"] == 27
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_tool_choice_none_integration() -> None:
|
||||
"""Test tool_choice parameter with 'none' setting using the actual OpenAI API."""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
"""Simple passthrough function."""
|
||||
return f"Processed: {input}"
|
||||
|
||||
model = "gpt-4o-mini"
|
||||
client = OpenAIChatCompletionClient(model=model, api_key=api_key)
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
|
||||
|
||||
# Test none tool choice - model should not use any tools
|
||||
result = await client.create(
|
||||
messages=[UserMessage(content="Hello there, how are you?", source="user")],
|
||||
tools=[pass_tool],
|
||||
tool_choice="none", # Disable tool usage
|
||||
)
|
||||
|
||||
assert isinstance(result.content, str)
|
||||
assert len(result.content) > 0
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_tool_choice_required_integration() -> None:
|
||||
"""Test tool_choice parameter with 'required' setting using the actual OpenAI API."""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
"""Simple passthrough function."""
|
||||
return f"Processed: {input}"
|
||||
|
||||
def _add_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
|
||||
model = "gpt-4o-mini"
|
||||
client = OpenAIChatCompletionClient(model=model, api_key=api_key)
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
|
||||
|
||||
# Test required tool choice - model must use a tool even for general conversation
|
||||
result = await client.create(
|
||||
messages=[UserMessage(content="Say hello to me", source="user")],
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice="required", # Force tool usage
|
||||
)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], FunctionCall)
|
||||
assert result.content[0].name in ["_pass_function", "_add_numbers"]
|
||||
assert result.finish_reason == "function_calls"
|
||||
assert result.usage is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_tool_choice_validation_error_integration() -> None:
|
||||
"""Test tool_choice validation with invalid tool reference using the actual OpenAI API."""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
"""Simple passthrough function."""
|
||||
return f"Processed: {input}"
|
||||
|
||||
def _add_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
|
||||
def _different_function(text: str) -> str:
|
||||
"""Different function."""
|
||||
return text
|
||||
|
||||
model = "gpt-4o-mini"
|
||||
client = OpenAIChatCompletionClient(model=model, api_key=api_key)
|
||||
|
||||
# Define tools
|
||||
pass_tool = FunctionTool(_pass_function, description="Process input text", name="_pass_function")
|
||||
add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="_add_numbers")
|
||||
different_tool = FunctionTool(_different_function, description="Different tool", name="_different_function")
|
||||
|
||||
messages = [UserMessage(content="Hello there", source="user")]
|
||||
|
||||
# Test with a tool that's not in the tools list
|
||||
with pytest.raises(
|
||||
ValueError, match="tool_choice references '_different_function' but it's not in the provided tools"
|
||||
):
|
||||
await client.create(
|
||||
messages=messages,
|
||||
tools=[pass_tool, add_tool],
|
||||
tool_choice=different_tool, # This tool is not in the tools list
|
||||
)
|
||||
|
||||
|
||||
# TODO: add integration tests for Azure OpenAI using AAD token.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user