mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-19 22:22:11 +00:00
Mitigates #5401 by optionally prepending names to messages. Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
be085567ea
commit
0b659de36d
@ -137,11 +137,11 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole:
|
|||||||
return "tool"
|
return "tool"
|
||||||
|
|
||||||
|
|
||||||
def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam:
|
def user_message_to_oai(message: UserMessage, prepend_name: bool = False) -> ChatCompletionUserMessageParam:
|
||||||
assert_valid_name(message.source)
|
assert_valid_name(message.source)
|
||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
return ChatCompletionUserMessageParam(
|
return ChatCompletionUserMessageParam(
|
||||||
content=message.content,
|
content=(f"{message.source} said:\n" if prepend_name else "") + message.content,
|
||||||
role="user",
|
role="user",
|
||||||
name=message.source,
|
name=message.source,
|
||||||
)
|
)
|
||||||
@ -149,10 +149,18 @@ def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam:
|
|||||||
parts: List[ChatCompletionContentPartParam] = []
|
parts: List[ChatCompletionContentPartParam] = []
|
||||||
for part in message.content:
|
for part in message.content:
|
||||||
if isinstance(part, str):
|
if isinstance(part, str):
|
||||||
oai_part = ChatCompletionContentPartTextParam(
|
if prepend_name:
|
||||||
text=part,
|
# Append the name to the first text part
|
||||||
type="text",
|
oai_part = ChatCompletionContentPartTextParam(
|
||||||
)
|
text=f"{message.source} said:\n" + part,
|
||||||
|
type="text",
|
||||||
|
)
|
||||||
|
prepend_name = False
|
||||||
|
else:
|
||||||
|
oai_part = ChatCompletionContentPartTextParam(
|
||||||
|
text=part,
|
||||||
|
type="text",
|
||||||
|
)
|
||||||
parts.append(oai_part)
|
parts.append(oai_part)
|
||||||
elif isinstance(part, Image):
|
elif isinstance(part, Image):
|
||||||
# TODO: support url based images
|
# TODO: support url based images
|
||||||
@ -211,11 +219,11 @@ def assistant_message_to_oai(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]:
|
def to_oai_type(message: LLMMessage, prepend_name: bool = False) -> Sequence[ChatCompletionMessageParam]:
|
||||||
if isinstance(message, SystemMessage):
|
if isinstance(message, SystemMessage):
|
||||||
return [system_message_to_oai(message)]
|
return [system_message_to_oai(message)]
|
||||||
elif isinstance(message, UserMessage):
|
elif isinstance(message, UserMessage):
|
||||||
return [user_message_to_oai(message)]
|
return [user_message_to_oai(message, prepend_name)]
|
||||||
elif isinstance(message, AssistantMessage):
|
elif isinstance(message, AssistantMessage):
|
||||||
return [assistant_message_to_oai(message)]
|
return [assistant_message_to_oai(message)]
|
||||||
else:
|
else:
|
||||||
@ -356,8 +364,10 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
create_args: Dict[str, Any],
|
create_args: Dict[str, Any],
|
||||||
model_capabilities: Optional[ModelCapabilities] = None, # type: ignore
|
model_capabilities: Optional[ModelCapabilities] = None, # type: ignore
|
||||||
model_info: Optional[ModelInfo] = None,
|
model_info: Optional[ModelInfo] = None,
|
||||||
|
add_name_prefixes: bool = False,
|
||||||
):
|
):
|
||||||
self._client = client
|
self._client = client
|
||||||
|
self._add_name_prefixes = add_name_prefixes
|
||||||
if model_capabilities is None and model_info is None:
|
if model_capabilities is None and model_info is None:
|
||||||
try:
|
try:
|
||||||
self._model_info = _model_info.get_info(create_args["model"])
|
self._model_info = _model_info.get_info(create_args["model"])
|
||||||
@ -451,7 +461,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
if self.model_info["json_output"] is False and json_output is True:
|
if self.model_info["json_output"] is False and json_output is True:
|
||||||
raise ValueError("Model does not support JSON output.")
|
raise ValueError("Model does not support JSON output.")
|
||||||
|
|
||||||
oai_messages_nested = [to_oai_type(m) for m in messages]
|
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
|
||||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||||
|
|
||||||
if self.model_info["function_calling"] is False and len(tools) > 0:
|
if self.model_info["function_calling"] is False and len(tools) > 0:
|
||||||
@ -672,7 +682,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
create_args = self._create_args.copy()
|
create_args = self._create_args.copy()
|
||||||
create_args.update(extra_create_args)
|
create_args.update(extra_create_args)
|
||||||
|
|
||||||
oai_messages_nested = [to_oai_type(m) for m in messages]
|
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
|
||||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||||
|
|
||||||
# TODO: allow custom handling.
|
# TODO: allow custom handling.
|
||||||
@ -874,7 +884,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
# Message tokens.
|
# Message tokens.
|
||||||
for message in messages:
|
for message in messages:
|
||||||
num_tokens += tokens_per_message
|
num_tokens += tokens_per_message
|
||||||
oai_message = to_oai_type(message)
|
oai_message = to_oai_type(message, prepend_name=self._add_name_prefixes)
|
||||||
for oai_message_part in oai_message:
|
for oai_message_part in oai_message:
|
||||||
for key, value in oai_message_part.items():
|
for key, value in oai_message_part.items():
|
||||||
if value is None:
|
if value is None:
|
||||||
@ -992,6 +1002,11 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
|||||||
top_p (optional, float):
|
top_p (optional, float):
|
||||||
user (optional, str):
|
user (optional, str):
|
||||||
default_headers (optional, dict[str, str]): Custom headers; useful for authentication or other custom requirements.
|
default_headers (optional, dict[str, str]): Custom headers; useful for authentication or other custom requirements.
|
||||||
|
add_name_prefixes (optional, bool): Whether to prepend the `source` value
|
||||||
|
to each :class:`~autogen_core.models.UserMessage` content. E.g.,
|
||||||
|
"this is content" becomes "Reviewer said: this is content."
|
||||||
|
This can be useful for models that do not support the `name` field in
|
||||||
|
message. Defaults to False.
|
||||||
|
|
||||||
|
|
||||||
To use this client, you must install the `openai` extension:
|
To use this client, you must install the `openai` extension:
|
||||||
@ -1074,11 +1089,19 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
|||||||
model_info = kwargs["model_info"]
|
model_info = kwargs["model_info"]
|
||||||
del copied_args["model_info"]
|
del copied_args["model_info"]
|
||||||
|
|
||||||
|
add_name_prefixes: bool = False
|
||||||
|
if "add_name_prefixes" in kwargs:
|
||||||
|
add_name_prefixes = kwargs["add_name_prefixes"]
|
||||||
|
|
||||||
client = _openai_client_from_config(copied_args)
|
client = _openai_client_from_config(copied_args)
|
||||||
create_args = _create_args_from_config(copied_args)
|
create_args = _create_args_from_config(copied_args)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info
|
client=client,
|
||||||
|
create_args=create_args,
|
||||||
|
model_capabilities=model_capabilities,
|
||||||
|
model_info=model_info,
|
||||||
|
add_name_prefixes=add_name_prefixes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getstate__(self) -> Dict[str, Any]:
|
def __getstate__(self) -> Dict[str, Any]:
|
||||||
@ -1215,11 +1238,19 @@ class AzureOpenAIChatCompletionClient(
|
|||||||
model_info = kwargs["model_info"]
|
model_info = kwargs["model_info"]
|
||||||
del copied_args["model_info"]
|
del copied_args["model_info"]
|
||||||
|
|
||||||
|
add_name_prefixes: bool = False
|
||||||
|
if "add_name_prefixes" in kwargs:
|
||||||
|
add_name_prefixes = kwargs["add_name_prefixes"]
|
||||||
|
|
||||||
client = _azure_openai_client_from_config(copied_args)
|
client = _azure_openai_client_from_config(copied_args)
|
||||||
create_args = _create_args_from_config(copied_args)
|
create_args = _create_args_from_config(copied_args)
|
||||||
self._raw_config: Dict[str, Any] = copied_args
|
self._raw_config: Dict[str, Any] = copied_args
|
||||||
super().__init__(
|
super().__init__(
|
||||||
client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info
|
client=client,
|
||||||
|
create_args=create_args,
|
||||||
|
model_capabilities=model_capabilities,
|
||||||
|
model_info=model_info,
|
||||||
|
add_name_prefixes=add_name_prefixes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getstate__(self) -> Dict[str, Any]:
|
def __getstate__(self) -> Dict[str, Any]:
|
||||||
|
@ -34,6 +34,7 @@ class BaseOpenAIClientConfiguration(CreateArguments, total=False):
|
|||||||
max_retries: int
|
max_retries: int
|
||||||
model_capabilities: ModelCapabilities # type: ignore
|
model_capabilities: ModelCapabilities # type: ignore
|
||||||
model_info: ModelInfo
|
model_info: ModelInfo
|
||||||
|
add_name_prefixes: bool
|
||||||
"""What functionality the model supports, determined by default from model name but is overriden if value passed."""
|
"""What functionality the model supports, determined by default from model name but is overriden if value passed."""
|
||||||
default_headers: Dict[str, str] | None
|
default_headers: Dict[str, str] | None
|
||||||
|
|
||||||
@ -75,6 +76,7 @@ class BaseOpenAIClientConfigurationConfigModel(CreateArgumentsConfigModel):
|
|||||||
max_retries: int | None = None
|
max_retries: int | None = None
|
||||||
model_capabilities: ModelCapabilities | None = None # type: ignore
|
model_capabilities: ModelCapabilities | None = None # type: ignore
|
||||||
model_info: ModelInfo | None = None
|
model_info: ModelInfo | None = None
|
||||||
|
add_name_prefixes: bool | None = None
|
||||||
default_headers: Dict[str, str] | None = None
|
default_headers: Dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ from autogen_core.models._model_client import ModelFamily
|
|||||||
from autogen_core.tools import BaseTool, FunctionTool
|
from autogen_core.tools import BaseTool, FunctionTool
|
||||||
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
|
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
|
||||||
from autogen_ext.models.openai._model_info import resolve_model
|
from autogen_ext.models.openai._model_info import resolve_model
|
||||||
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools
|
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type
|
||||||
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions
|
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions
|
||||||
from openai.resources.chat.completions import AsyncCompletions
|
from openai.resources.chat.completions import AsyncCompletions
|
||||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||||
@ -1050,4 +1050,56 @@ async def test_ollama() -> None:
|
|||||||
assert chunks[-1].thought is not None
|
assert chunks[-1].thought is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_name_prefixes(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
sys_message = SystemMessage(content="You are a helpful AI agent, and you answer questions in a friendly way.")
|
||||||
|
assistant_message = AssistantMessage(content="Hello, how can I help you?", source="Assistant")
|
||||||
|
user_text_message = UserMessage(content="Hello, I am from Seattle.", source="Adam")
|
||||||
|
user_mm_message = UserMessage(
|
||||||
|
content=[
|
||||||
|
"Here is a postcard from Seattle:",
|
||||||
|
Image.from_base64(
|
||||||
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4z8AAAAMBAQDJ/pLvAAAAAElFTkSuQmCC"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
source="Adam",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default conversion
|
||||||
|
oai_sys = to_oai_type(sys_message)[0]
|
||||||
|
oai_asst = to_oai_type(assistant_message)[0]
|
||||||
|
oai_text = to_oai_type(user_text_message)[0]
|
||||||
|
oai_mm = to_oai_type(user_mm_message)[0]
|
||||||
|
|
||||||
|
converted_sys = to_oai_type(sys_message, prepend_name=True)[0]
|
||||||
|
converted_asst = to_oai_type(assistant_message, prepend_name=True)[0]
|
||||||
|
converted_text = to_oai_type(user_text_message, prepend_name=True)[0]
|
||||||
|
converted_mm = to_oai_type(user_mm_message, prepend_name=True)[0]
|
||||||
|
|
||||||
|
# Invariants
|
||||||
|
assert "content" in oai_sys
|
||||||
|
assert "content" in oai_asst
|
||||||
|
assert "content" in oai_text
|
||||||
|
assert "content" in oai_mm
|
||||||
|
assert "content" in converted_sys
|
||||||
|
assert "content" in converted_asst
|
||||||
|
assert "content" in converted_text
|
||||||
|
assert "content" in converted_mm
|
||||||
|
assert oai_sys["role"] == converted_sys["role"]
|
||||||
|
assert oai_sys["content"] == converted_sys["content"]
|
||||||
|
assert oai_asst["role"] == converted_asst["role"]
|
||||||
|
assert oai_asst["content"] == converted_asst["content"]
|
||||||
|
assert oai_text["role"] == converted_text["role"]
|
||||||
|
assert oai_mm["role"] == converted_mm["role"]
|
||||||
|
assert isinstance(oai_mm["content"], list)
|
||||||
|
assert isinstance(converted_mm["content"], list)
|
||||||
|
assert len(oai_mm["content"]) == len(converted_mm["content"])
|
||||||
|
assert "text" in converted_mm["content"][0]
|
||||||
|
assert "text" in oai_mm["content"][0]
|
||||||
|
|
||||||
|
# Name prepended
|
||||||
|
assert str(converted_text["content"]) == "Adam said:\n" + str(oai_text["content"])
|
||||||
|
assert str(converted_mm["content"][0]["text"]) == "Adam said:\n" + str(oai_mm["content"][0]["text"])
|
||||||
|
|
||||||
|
|
||||||
# TODO: add integration tests for Azure OpenAI using AAD token.
|
# TODO: add integration tests for Azure OpenAI using AAD token.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user