mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-30 00:30:23 +00:00
Add include_name_in_message parameter to make name field optional in OpenAI messages (#6845)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
ac051ba6d0
commit
5f1c69d049
@ -197,10 +197,14 @@ def _set_role(role: str) -> Callable[[LLMMessage, Dict[str, Any]], Dict[str, str
|
||||
return inner
|
||||
|
||||
|
||||
def _set_name(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, str]:
|
||||
def _set_name(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
assert isinstance(message, (UserMessage, AssistantMessage))
|
||||
assert_valid_name(message.source)
|
||||
return {"name": message.source}
|
||||
# Check if name should be included in message
|
||||
if context.get("include_name_in_message", True):
|
||||
return {"name": message.source}
|
||||
else:
|
||||
return EMPTY
|
||||
|
||||
|
||||
def _set_content_direct(message: LLMMessage, context: Dict[str, Any]) -> Dict[str, LLMMessageContent]:
|
||||
|
||||
@ -163,10 +163,15 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole:
|
||||
|
||||
|
||||
def to_oai_type(
|
||||
message: LLMMessage, prepend_name: bool = False, model: str = "unknown", model_family: str = ModelFamily.UNKNOWN
|
||||
message: LLMMessage,
|
||||
prepend_name: bool = False,
|
||||
model: str = "unknown",
|
||||
model_family: str = ModelFamily.UNKNOWN,
|
||||
include_name_in_message: bool = True,
|
||||
) -> Sequence[ChatCompletionMessageParam]:
|
||||
context = {
|
||||
"prepend_name": prepend_name,
|
||||
"include_name_in_message": include_name_in_message,
|
||||
}
|
||||
transformers = get_transformer("openai", model, model_family)
|
||||
|
||||
@ -307,6 +312,7 @@ def count_tokens_openai(
|
||||
add_name_prefixes: bool = False,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
model_family: str = ModelFamily.UNKNOWN,
|
||||
include_name_in_message: bool = True,
|
||||
) -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
@ -320,7 +326,13 @@ def count_tokens_openai(
|
||||
# Message tokens.
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
oai_message = to_oai_type(message, prepend_name=add_name_prefixes, model=model, model_family=model_family)
|
||||
oai_message = to_oai_type(
|
||||
message,
|
||||
prepend_name=add_name_prefixes,
|
||||
model=model,
|
||||
model_family=model_family,
|
||||
include_name_in_message=include_name_in_message,
|
||||
)
|
||||
for oai_message_part in oai_message:
|
||||
for key, value in oai_message_part.items():
|
||||
if value is None:
|
||||
@ -413,9 +425,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
model_capabilities: Optional[ModelCapabilities] = None, # type: ignore
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
add_name_prefixes: bool = False,
|
||||
include_name_in_message: bool = True,
|
||||
):
|
||||
self._client = client
|
||||
self._add_name_prefixes = add_name_prefixes
|
||||
self._include_name_in_message = include_name_in_message
|
||||
if model_capabilities is None and model_info is None:
|
||||
try:
|
||||
self._model_info = _model_info.get_info(create_args["model"])
|
||||
@ -591,6 +605,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
prepend_name=self._add_name_prefixes,
|
||||
model=create_args.get("model", "unknown"),
|
||||
model_family=self._model_info["family"],
|
||||
include_name_in_message=self._include_name_in_message,
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
@ -1127,6 +1142,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
add_name_prefixes=self._add_name_prefixes,
|
||||
tools=tools,
|
||||
model_family=self._model_info["family"],
|
||||
include_name_in_message=self._include_name_in_message,
|
||||
)
|
||||
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
|
||||
@ -1227,6 +1243,9 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
||||
"this is content" becomes "Reviewer said: this is content."
|
||||
This can be useful for models that do not support the `name` field in
|
||||
message. Defaults to False.
|
||||
include_name_in_message (optional, bool): Whether to include the `name` field
|
||||
in user message parameters sent to the OpenAI API. Defaults to True. Set to False
|
||||
for model providers that don't support the `name` field (e.g., Groq).
|
||||
stream_options (optional, dict): Additional options for streaming. Currently only `include_usage` is supported.
|
||||
|
||||
Examples:
|
||||
@ -1426,6 +1445,10 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
||||
if "add_name_prefixes" in kwargs:
|
||||
add_name_prefixes = kwargs["add_name_prefixes"]
|
||||
|
||||
include_name_in_message: bool = True
|
||||
if "include_name_in_message" in kwargs:
|
||||
include_name_in_message = kwargs["include_name_in_message"]
|
||||
|
||||
# Special handling for Gemini model.
|
||||
assert "model" in copied_args and isinstance(copied_args["model"], str)
|
||||
if copied_args["model"].startswith("gemini-"):
|
||||
@ -1453,6 +1476,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
||||
model_capabilities=model_capabilities,
|
||||
model_info=model_info,
|
||||
add_name_prefixes=add_name_prefixes,
|
||||
include_name_in_message=include_name_in_message,
|
||||
)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
@ -1554,6 +1578,15 @@ class AzureOpenAIChatCompletionClient(
|
||||
top_p (optional, float):
|
||||
user (optional, str):
|
||||
default_headers (optional, dict[str, str]): Custom headers; useful for authentication or other custom requirements.
|
||||
add_name_prefixes (optional, bool): Whether to prepend the `source` value
|
||||
to each :class:`~autogen_core.models.UserMessage` content. E.g.,
|
||||
"this is content" becomes "Reviewer said: this is content."
|
||||
This can be useful for models that do not support the `name` field in
|
||||
message. Defaults to False.
|
||||
include_name_in_message (optional, bool): Whether to include the `name` field
|
||||
in user message parameters sent to the OpenAI API. Defaults to True. Set to False
|
||||
for model providers that don't support the `name` field (e.g., Groq).
|
||||
stream_options (optional, dict): Additional options for streaming. Currently only `include_usage` is supported.
|
||||
|
||||
|
||||
To use the client, you need to provide your deployment name, Azure Cognitive Services endpoint, and api version.
|
||||
@ -1645,6 +1678,10 @@ class AzureOpenAIChatCompletionClient(
|
||||
if "add_name_prefixes" in kwargs:
|
||||
add_name_prefixes = kwargs["add_name_prefixes"]
|
||||
|
||||
include_name_in_message: bool = True
|
||||
if "include_name_in_message" in kwargs:
|
||||
include_name_in_message = kwargs["include_name_in_message"]
|
||||
|
||||
client = _azure_openai_client_from_config(copied_args)
|
||||
create_args = _create_args_from_config(copied_args)
|
||||
self._raw_config: Dict[str, Any] = copied_args
|
||||
@ -1654,6 +1691,7 @@ class AzureOpenAIChatCompletionClient(
|
||||
model_capabilities=model_capabilities,
|
||||
model_info=model_info,
|
||||
add_name_prefixes=add_name_prefixes,
|
||||
include_name_in_message=include_name_in_message,
|
||||
)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
|
||||
@ -63,6 +63,8 @@ class BaseOpenAIClientConfiguration(CreateArguments, total=False):
|
||||
model_info: ModelInfo
|
||||
add_name_prefixes: bool
|
||||
"""What functionality the model supports, determined by default from model name but is overriden if value passed."""
|
||||
include_name_in_message: bool
|
||||
"""Whether to include the 'name' field in user message parameters. Defaults to True. Set to False for providers that don't support the 'name' field."""
|
||||
default_headers: Dict[str, str] | None
|
||||
|
||||
|
||||
@ -105,6 +107,7 @@ class BaseOpenAIClientConfigurationConfigModel(CreateArgumentsConfigModel):
|
||||
model_capabilities: ModelCapabilities | None = None # type: ignore
|
||||
model_info: ModelInfo | None = None
|
||||
add_name_prefixes: bool | None = None
|
||||
include_name_in_message: bool | None = None
|
||||
default_headers: Dict[str, str] | None = None
|
||||
|
||||
|
||||
|
||||
@ -2680,6 +2680,96 @@ async def test_mistral_remove_name() -> None:
|
||||
assert ("name" in params[0]) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_include_name_in_message() -> None:
|
||||
"""Test that include_name_in_message parameter controls the name field."""
|
||||
|
||||
# Test with UserMessage
|
||||
user_message = UserMessage(content="Hello, I am from Seattle.", source="Adam")
|
||||
|
||||
# Test with include_name_in_message=True (default)
|
||||
result_with_name = to_oai_type(user_message, include_name_in_message=True)[0]
|
||||
assert "name" in result_with_name
|
||||
assert result_with_name["name"] == "Adam" # type: ignore[typeddict-item]
|
||||
assert result_with_name["role"] == "user"
|
||||
assert result_with_name["content"] == "Hello, I am from Seattle."
|
||||
|
||||
# Test with include_name_in_message=False
|
||||
result_without_name = to_oai_type(user_message, include_name_in_message=False)[0]
|
||||
assert "name" not in result_without_name
|
||||
assert result_without_name["role"] == "user"
|
||||
assert result_without_name["content"] == "Hello, I am from Seattle."
|
||||
|
||||
# Test with AssistantMessage (should not have name field regardless)
|
||||
assistant_message = AssistantMessage(content="Hello, how can I help you?", source="Assistant")
|
||||
|
||||
# Test with include_name_in_message=True
|
||||
result_assistant_with_name = to_oai_type(assistant_message, include_name_in_message=True)[0]
|
||||
assert "name" not in result_assistant_with_name
|
||||
assert result_assistant_with_name["role"] == "assistant"
|
||||
|
||||
# Test with include_name_in_message=False
|
||||
result_assistant_without_name = to_oai_type(assistant_message, include_name_in_message=False)[0]
|
||||
assert "name" not in result_assistant_without_name
|
||||
assert result_assistant_without_name["role"] == "assistant"
|
||||
|
||||
# Test with SystemMessage (should not have name field regardless)
|
||||
system_message = SystemMessage(content="You are a helpful assistant.")
|
||||
result_system_with_name = to_oai_type(system_message, include_name_in_message=True)[0]
|
||||
result_system_without_name = to_oai_type(system_message, include_name_in_message=False)[0]
|
||||
assert "name" not in result_system_with_name
|
||||
assert "name" not in result_system_without_name
|
||||
assert result_system_with_name["role"] == "system"
|
||||
assert result_system_without_name["role"] == "system"
|
||||
|
||||
# Test default behavior (should include name when parameter not specified)
|
||||
result_default = to_oai_type(user_message)[0] # include_name_in_message defaults to True
|
||||
assert "name" in result_default
|
||||
assert result_default["name"] == "Adam" # type: ignore[typeddict-item]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_include_name_with_different_models() -> None:
|
||||
"""Test that include_name_in_message works with different model families."""
|
||||
|
||||
user_message = UserMessage(content="Hello", source="User")
|
||||
|
||||
# Test with GPT-4o model (normally includes name)
|
||||
result_gpt4o_with_name = to_oai_type(
|
||||
user_message, model="gpt-4o", model_family=ModelFamily.GPT_4O, include_name_in_message=True
|
||||
)[0]
|
||||
result_gpt4o_without_name = to_oai_type(
|
||||
user_message, model="gpt-4o", model_family=ModelFamily.GPT_4O, include_name_in_message=False
|
||||
)[0]
|
||||
|
||||
assert "name" in result_gpt4o_with_name
|
||||
assert "name" not in result_gpt4o_without_name
|
||||
|
||||
# Test with Mistral model (normally excludes name, but should still respect the parameter)
|
||||
result_mistral_with_name = to_oai_type(
|
||||
user_message, model="mistral-7b", model_family=ModelFamily.MISTRAL, include_name_in_message=True
|
||||
)[0]
|
||||
result_mistral_without_name = to_oai_type(
|
||||
user_message, model="mistral-7b", model_family=ModelFamily.MISTRAL, include_name_in_message=False
|
||||
)[0]
|
||||
|
||||
# Note: Mistral transformers are specifically built without _set_name, so they won't have name regardless
|
||||
# But our parameter still controls the behavior consistently
|
||||
assert "name" not in result_mistral_with_name # Mistral design excludes names
|
||||
assert "name" not in result_mistral_without_name
|
||||
|
||||
# Test with unknown model (uses default transformer)
|
||||
result_unknown_with_name = to_oai_type(
|
||||
user_message, model="some-custom-model", model_family=ModelFamily.UNKNOWN, include_name_in_message=True
|
||||
)[0]
|
||||
result_unknown_without_name = to_oai_type(
|
||||
user_message, model="some-custom-model", model_family=ModelFamily.UNKNOWN, include_name_in_message=False
|
||||
)[0]
|
||||
|
||||
assert "name" in result_unknown_with_name
|
||||
assert "name" not in result_unknown_without_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_tool_choice_specific_tool(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test tool_choice parameter with a specific tool using mocks."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user