Mitigates #5401 by optionally prepending names to messages. (#5448)

Mitigates #5401 by optionally prepending names to messages.

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
afourney 2025-02-07 23:04:24 -08:00 committed by GitHub
parent be085567ea
commit 0b659de36d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 99 additions and 14 deletions

View File

@ -137,11 +137,11 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole:
return "tool" return "tool"
def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam: def user_message_to_oai(message: UserMessage, prepend_name: bool = False) -> ChatCompletionUserMessageParam:
assert_valid_name(message.source) assert_valid_name(message.source)
if isinstance(message.content, str): if isinstance(message.content, str):
return ChatCompletionUserMessageParam( return ChatCompletionUserMessageParam(
content=message.content, content=(f"{message.source} said:\n" if prepend_name else "") + message.content,
role="user", role="user",
name=message.source, name=message.source,
) )
@ -149,10 +149,18 @@ def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam:
parts: List[ChatCompletionContentPartParam] = [] parts: List[ChatCompletionContentPartParam] = []
for part in message.content: for part in message.content:
if isinstance(part, str): if isinstance(part, str):
oai_part = ChatCompletionContentPartTextParam( if prepend_name:
text=part, # Append the name to the first text part
type="text", oai_part = ChatCompletionContentPartTextParam(
) text=f"{message.source} said:\n" + part,
type="text",
)
prepend_name = False
else:
oai_part = ChatCompletionContentPartTextParam(
text=part,
type="text",
)
parts.append(oai_part) parts.append(oai_part)
elif isinstance(part, Image): elif isinstance(part, Image):
# TODO: support url based images # TODO: support url based images
@ -211,11 +219,11 @@ def assistant_message_to_oai(
) )
def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]: def to_oai_type(message: LLMMessage, prepend_name: bool = False) -> Sequence[ChatCompletionMessageParam]:
if isinstance(message, SystemMessage): if isinstance(message, SystemMessage):
return [system_message_to_oai(message)] return [system_message_to_oai(message)]
elif isinstance(message, UserMessage): elif isinstance(message, UserMessage):
return [user_message_to_oai(message)] return [user_message_to_oai(message, prepend_name)]
elif isinstance(message, AssistantMessage): elif isinstance(message, AssistantMessage):
return [assistant_message_to_oai(message)] return [assistant_message_to_oai(message)]
else: else:
@ -356,8 +364,10 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
create_args: Dict[str, Any], create_args: Dict[str, Any],
model_capabilities: Optional[ModelCapabilities] = None, # type: ignore model_capabilities: Optional[ModelCapabilities] = None, # type: ignore
model_info: Optional[ModelInfo] = None, model_info: Optional[ModelInfo] = None,
add_name_prefixes: bool = False,
): ):
self._client = client self._client = client
self._add_name_prefixes = add_name_prefixes
if model_capabilities is None and model_info is None: if model_capabilities is None and model_info is None:
try: try:
self._model_info = _model_info.get_info(create_args["model"]) self._model_info = _model_info.get_info(create_args["model"])
@ -451,7 +461,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
if self.model_info["json_output"] is False and json_output is True: if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.") raise ValueError("Model does not support JSON output.")
oai_messages_nested = [to_oai_type(m) for m in messages] oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist] oai_messages = [item for sublist in oai_messages_nested for item in sublist]
if self.model_info["function_calling"] is False and len(tools) > 0: if self.model_info["function_calling"] is False and len(tools) > 0:
@ -672,7 +682,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
create_args = self._create_args.copy() create_args = self._create_args.copy()
create_args.update(extra_create_args) create_args.update(extra_create_args)
oai_messages_nested = [to_oai_type(m) for m in messages] oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist] oai_messages = [item for sublist in oai_messages_nested for item in sublist]
# TODO: allow custom handling. # TODO: allow custom handling.
@ -874,7 +884,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
# Message tokens. # Message tokens.
for message in messages: for message in messages:
num_tokens += tokens_per_message num_tokens += tokens_per_message
oai_message = to_oai_type(message) oai_message = to_oai_type(message, prepend_name=self._add_name_prefixes)
for oai_message_part in oai_message: for oai_message_part in oai_message:
for key, value in oai_message_part.items(): for key, value in oai_message_part.items():
if value is None: if value is None:
@ -992,6 +1002,11 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
top_p (optional, float): top_p (optional, float):
user (optional, str): user (optional, str):
default_headers (optional, dict[str, str]): Custom headers; useful for authentication or other custom requirements. default_headers (optional, dict[str, str]): Custom headers; useful for authentication or other custom requirements.
add_name_prefixes (optional, bool): Whether to prepend the `source` value
to each :class:`~autogen_core.models.UserMessage` content. E.g.,
"this is content" becomes "Reviewer said: this is content."
This can be useful for models that do not support the `name` field in
message. Defaults to False.
To use this client, you must install the `openai` extension: To use this client, you must install the `openai` extension:
@ -1074,11 +1089,19 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
model_info = kwargs["model_info"] model_info = kwargs["model_info"]
del copied_args["model_info"] del copied_args["model_info"]
add_name_prefixes: bool = False
if "add_name_prefixes" in kwargs:
add_name_prefixes = kwargs["add_name_prefixes"]
client = _openai_client_from_config(copied_args) client = _openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args) create_args = _create_args_from_config(copied_args)
super().__init__( super().__init__(
client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info client=client,
create_args=create_args,
model_capabilities=model_capabilities,
model_info=model_info,
add_name_prefixes=add_name_prefixes,
) )
def __getstate__(self) -> Dict[str, Any]: def __getstate__(self) -> Dict[str, Any]:
@ -1215,11 +1238,19 @@ class AzureOpenAIChatCompletionClient(
model_info = kwargs["model_info"] model_info = kwargs["model_info"]
del copied_args["model_info"] del copied_args["model_info"]
add_name_prefixes: bool = False
if "add_name_prefixes" in kwargs:
add_name_prefixes = kwargs["add_name_prefixes"]
client = _azure_openai_client_from_config(copied_args) client = _azure_openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args) create_args = _create_args_from_config(copied_args)
self._raw_config: Dict[str, Any] = copied_args self._raw_config: Dict[str, Any] = copied_args
super().__init__( super().__init__(
client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info client=client,
create_args=create_args,
model_capabilities=model_capabilities,
model_info=model_info,
add_name_prefixes=add_name_prefixes,
) )
def __getstate__(self) -> Dict[str, Any]: def __getstate__(self) -> Dict[str, Any]:

View File

@ -34,6 +34,7 @@ class BaseOpenAIClientConfiguration(CreateArguments, total=False):
max_retries: int max_retries: int
model_capabilities: ModelCapabilities # type: ignore model_capabilities: ModelCapabilities # type: ignore
model_info: ModelInfo model_info: ModelInfo
add_name_prefixes: bool
"""What functionality the model supports, determined by default from model name but is overriden if value passed.""" """What functionality the model supports, determined by default from model name but is overriden if value passed."""
default_headers: Dict[str, str] | None default_headers: Dict[str, str] | None
@ -75,6 +76,7 @@ class BaseOpenAIClientConfigurationConfigModel(CreateArgumentsConfigModel):
max_retries: int | None = None max_retries: int | None = None
model_capabilities: ModelCapabilities | None = None # type: ignore model_capabilities: ModelCapabilities | None = None # type: ignore
model_info: ModelInfo | None = None model_info: ModelInfo | None = None
add_name_prefixes: bool | None = None
default_headers: Dict[str, str] | None = None default_headers: Dict[str, str] | None = None

View File

@ -22,7 +22,7 @@ from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import BaseTool, FunctionTool from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
from autogen_ext.models.openai._model_info import resolve_model from autogen_ext.models.openai._model_info import resolve_model
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions
from openai.resources.chat.completions import AsyncCompletions from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion import ChatCompletion, Choice
@ -1050,4 +1050,56 @@ async def test_ollama() -> None:
assert chunks[-1].thought is not None assert chunks[-1].thought is not None
@pytest.mark.asyncio
async def test_add_name_prefixes(monkeypatch: pytest.MonkeyPatch) -> None:
sys_message = SystemMessage(content="You are a helpful AI agent, and you answer questions in a friendly way.")
assistant_message = AssistantMessage(content="Hello, how can I help you?", source="Assistant")
user_text_message = UserMessage(content="Hello, I am from Seattle.", source="Adam")
user_mm_message = UserMessage(
content=[
"Here is a postcard from Seattle:",
Image.from_base64(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4z8AAAAMBAQDJ/pLvAAAAAElFTkSuQmCC"
),
],
source="Adam",
)
# Default conversion
oai_sys = to_oai_type(sys_message)[0]
oai_asst = to_oai_type(assistant_message)[0]
oai_text = to_oai_type(user_text_message)[0]
oai_mm = to_oai_type(user_mm_message)[0]
converted_sys = to_oai_type(sys_message, prepend_name=True)[0]
converted_asst = to_oai_type(assistant_message, prepend_name=True)[0]
converted_text = to_oai_type(user_text_message, prepend_name=True)[0]
converted_mm = to_oai_type(user_mm_message, prepend_name=True)[0]
# Invariants
assert "content" in oai_sys
assert "content" in oai_asst
assert "content" in oai_text
assert "content" in oai_mm
assert "content" in converted_sys
assert "content" in converted_asst
assert "content" in converted_text
assert "content" in converted_mm
assert oai_sys["role"] == converted_sys["role"]
assert oai_sys["content"] == converted_sys["content"]
assert oai_asst["role"] == converted_asst["role"]
assert oai_asst["content"] == converted_asst["content"]
assert oai_text["role"] == converted_text["role"]
assert oai_mm["role"] == converted_mm["role"]
assert isinstance(oai_mm["content"], list)
assert isinstance(converted_mm["content"], list)
assert len(oai_mm["content"]) == len(converted_mm["content"])
assert "text" in converted_mm["content"][0]
assert "text" in oai_mm["content"][0]
# Name prepended
assert str(converted_text["content"]) == "Adam said:\n" + str(oai_text["content"])
assert str(converted_mm["content"][0]["text"]) == "Adam said:\n" + str(oai_mm["content"][0]["text"])
# TODO: add integration tests for Azure OpenAI using AAD token. # TODO: add integration tests for Azure OpenAI using AAD token.