diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index b1f23e104..23b46cdb4 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -137,11 +137,11 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole: return "tool" -def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam: +def user_message_to_oai(message: UserMessage, prepend_name: bool = False) -> ChatCompletionUserMessageParam: assert_valid_name(message.source) if isinstance(message.content, str): return ChatCompletionUserMessageParam( - content=message.content, + content=(f"{message.source} said:\n" if prepend_name else "") + message.content, role="user", name=message.source, ) @@ -149,10 +149,18 @@ def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam: parts: List[ChatCompletionContentPartParam] = [] for part in message.content: if isinstance(part, str): - oai_part = ChatCompletionContentPartTextParam( - text=part, - type="text", - ) + if prepend_name: + # Append the name to the first text part + oai_part = ChatCompletionContentPartTextParam( + text=f"{message.source} said:\n" + part, + type="text", + ) + prepend_name = False + else: + oai_part = ChatCompletionContentPartTextParam( + text=part, + type="text", + ) parts.append(oai_part) elif isinstance(part, Image): # TODO: support url based images @@ -211,11 +219,11 @@ def assistant_message_to_oai( ) -def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]: +def to_oai_type(message: LLMMessage, prepend_name: bool = False) -> Sequence[ChatCompletionMessageParam]: if isinstance(message, SystemMessage): return [system_message_to_oai(message)] elif isinstance(message, UserMessage): - return [user_message_to_oai(message)] + return [user_message_to_oai(message, prepend_name)] elif isinstance(message, AssistantMessage): return [assistant_message_to_oai(message)] else: @@ -356,8 +364,10 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): create_args: Dict[str, Any], model_capabilities: Optional[ModelCapabilities] = None, # type: ignore model_info: Optional[ModelInfo] = None, + add_name_prefixes: bool = False, ): self._client = client + self._add_name_prefixes = add_name_prefixes if model_capabilities is None and model_info is None: try: self._model_info = _model_info.get_info(create_args["model"]) @@ -451,7 +461,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): if self.model_info["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output.") - oai_messages_nested = [to_oai_type(m) for m in messages] + oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages] oai_messages = [item for sublist in oai_messages_nested for item in sublist] if self.model_info["function_calling"] is False and len(tools) > 0: @@ -672,7 +682,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): create_args = self._create_args.copy() create_args.update(extra_create_args) - oai_messages_nested = [to_oai_type(m) for m in messages] + oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages] oai_messages = [item for sublist in oai_messages_nested for item in sublist] # TODO: allow custom handling. @@ -874,7 +884,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): # Message tokens. for message in messages: num_tokens += tokens_per_message - oai_message = to_oai_type(message) + oai_message = to_oai_type(message, prepend_name=self._add_name_prefixes) for oai_message_part in oai_message: for key, value in oai_message_part.items(): if value is None: @@ -992,6 +1002,11 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA top_p (optional, float): user (optional, str): default_headers (optional, dict[str, str]): Custom headers; useful for authentication or other custom requirements. + add_name_prefixes (optional, bool): Whether to prepend the `source` value + to each :class:`~autogen_core.models.UserMessage` content. E.g., + "this is content" becomes "Reviewer said: this is content." + This can be useful for models that do not support the `name` field in + message. Defaults to False. To use this client, you must install the `openai` extension: @@ -1074,11 +1089,19 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA model_info = kwargs["model_info"] del copied_args["model_info"] + add_name_prefixes: bool = False + if "add_name_prefixes" in kwargs: + add_name_prefixes = kwargs["add_name_prefixes"] + client = _openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) super().__init__( - client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info + client=client, + create_args=create_args, + model_capabilities=model_capabilities, + model_info=model_info, + add_name_prefixes=add_name_prefixes, ) def __getstate__(self) -> Dict[str, Any]: @@ -1215,11 +1238,19 @@ class AzureOpenAIChatCompletionClient( model_info = kwargs["model_info"] del copied_args["model_info"] + add_name_prefixes: bool = False + if "add_name_prefixes" in kwargs: + add_name_prefixes = kwargs["add_name_prefixes"] + client = _azure_openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) self._raw_config: Dict[str, Any] = copied_args super().__init__( - client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info + client=client, + create_args=create_args, + model_capabilities=model_capabilities, + model_info=model_info, + add_name_prefixes=add_name_prefixes, ) def __getstate__(self) -> Dict[str, Any]: diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py index 367564187..b85e7c22c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py @@ -34,6 +34,7 @@ class BaseOpenAIClientConfiguration(CreateArguments, total=False): max_retries: int model_capabilities: ModelCapabilities # type: ignore model_info: ModelInfo + add_name_prefixes: bool """What functionality the model supports, determined by default from model name but is overriden if value passed.""" default_headers: Dict[str, str] | None @@ -75,6 +76,7 @@ class BaseOpenAIClientConfigurationConfigModel(CreateArgumentsConfigModel): max_retries: int | None = None model_capabilities: ModelCapabilities | None = None # type: ignore model_info: ModelInfo | None = None + add_name_prefixes: bool | None = None default_headers: Dict[str, str] | None = None diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index f2a8ff943..69d928104 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -22,7 +22,7 @@ from autogen_core.models._model_client import ModelFamily from autogen_core.tools import BaseTool, FunctionTool from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient from autogen_ext.models.openai._model_info import resolve_model -from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools +from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions from openai.resources.chat.completions import AsyncCompletions from openai.types.chat.chat_completion import ChatCompletion, Choice @@ -1050,4 +1050,56 @@ async def test_ollama() -> None: assert chunks[-1].thought is not None +@pytest.mark.asyncio +async def test_add_name_prefixes(monkeypatch: pytest.MonkeyPatch) -> None: + sys_message = SystemMessage(content="You are a helpful AI agent, and you answer questions in a friendly way.") + assistant_message = AssistantMessage(content="Hello, how can I help you?", source="Assistant") + user_text_message = UserMessage(content="Hello, I am from Seattle.", source="Adam") + user_mm_message = UserMessage( + content=[ + "Here is a postcard from Seattle:", + Image.from_base64( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4z8AAAAMBAQDJ/pLvAAAAAElFTkSuQmCC" + ), + ], + source="Adam", + ) + + # Default conversion + oai_sys = to_oai_type(sys_message)[0] + oai_asst = to_oai_type(assistant_message)[0] + oai_text = to_oai_type(user_text_message)[0] + oai_mm = to_oai_type(user_mm_message)[0] + + converted_sys = to_oai_type(sys_message, prepend_name=True)[0] + converted_asst = to_oai_type(assistant_message, prepend_name=True)[0] + converted_text = to_oai_type(user_text_message, prepend_name=True)[0] + converted_mm = to_oai_type(user_mm_message, prepend_name=True)[0] + + # Invariants + assert "content" in oai_sys + assert "content" in oai_asst + assert "content" in oai_text + assert "content" in oai_mm + assert "content" in converted_sys + assert "content" in converted_asst + assert "content" in converted_text + assert "content" in converted_mm + assert oai_sys["role"] == converted_sys["role"] + assert oai_sys["content"] == converted_sys["content"] + assert oai_asst["role"] == converted_asst["role"] + assert oai_asst["content"] == converted_asst["content"] + assert oai_text["role"] == converted_text["role"] + assert oai_mm["role"] == converted_mm["role"] + assert isinstance(oai_mm["content"], list) + assert isinstance(converted_mm["content"], list) + assert len(oai_mm["content"]) == len(converted_mm["content"]) + assert "text" in converted_mm["content"][0] + assert "text" in oai_mm["content"][0] + + # Name prepended + assert str(converted_text["content"]) == "Adam said:\n" + str(oai_text["content"]) + assert str(converted_mm["content"][0]["text"]) == "Adam said:\n" + str(oai_mm["content"][0]["text"]) + + # TODO: add integration tests for Azure OpenAI using AAD token.