From 7c4f8d1107df3ecca486175998b39d1b5f0f3957 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 12 May 2025 21:51:53 -0700 Subject: [PATCH] Add option for openai client to avoid setting reasoning tokens as assistant message content when sending to the model api. --- .../models/openai/_openai_client.py | 30 +++++++++++++++++-- .../models/openai/config/__init__.py | 5 ++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index ffe816e59..9300e13af 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -162,10 +162,15 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole: def to_oai_type( - message: LLMMessage, prepend_name: bool = False, model: str = "unknown", model_family: str = ModelFamily.UNKNOWN + message: LLMMessage, + prepend_name: bool = False, + set_thought_as_assistant_content: bool = False, + model: str = "unknown", + model_family: str = ModelFamily.UNKNOWN, ) -> Sequence[ChatCompletionMessageParam]: context = { "prepend_name": prepend_name, + "set_thought_as_assistant_content": set_thought_as_assistant_content, } transformers = get_transformer("openai", model, model_family) @@ -279,6 +284,7 @@ def count_tokens_openai( model: str, *, add_name_prefixes: bool = False, + set_thought_as_assistant_content: bool = False, tools: Sequence[Tool | ToolSchema] = [], model_family: str = ModelFamily.UNKNOWN, ) -> int: @@ -294,7 +300,13 @@ def count_tokens_openai( # Message tokens. for message in messages: num_tokens += tokens_per_message - oai_message = to_oai_type(message, prepend_name=add_name_prefixes, model=model, model_family=model_family) + oai_message = to_oai_type( + message, + prepend_name=add_name_prefixes, + set_thought_as_assistant_content=set_thought_as_assistant_content, + model=model, + model_family=model_family, + ) for oai_message_part in oai_message: for key, value in oai_message_part.items(): if value is None: @@ -387,9 +399,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): model_capabilities: Optional[ModelCapabilities] = None, # type: ignore model_info: Optional[ModelInfo] = None, add_name_prefixes: bool = False, + set_thought_as_assistant_content: bool = False, ): self._client = client self._add_name_prefixes = add_name_prefixes + self._set_thought_as_assistant_content = set_thought_as_assistant_content if model_capabilities is None and model_info is None: try: self._model_info = _model_info.get_info(create_args["model"]) @@ -562,6 +576,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): to_oai_type( m, prepend_name=self._add_name_prefixes, + set_thought_as_assistant_content=self._set_thought_as_assistant_content, model=create_args.get("model", "unknown"), model_family=self._model_info["family"], ) @@ -1056,6 +1071,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): messages, self._create_args["model"], add_name_prefixes=self._add_name_prefixes, + set_thought_as_assistant_content=self._set_thought_as_assistant_content, tools=tools, model_family=self._model_info["family"], ) @@ -1357,6 +1373,10 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA if "add_name_prefixes" in kwargs: add_name_prefixes = kwargs["add_name_prefixes"] + set_thought_as_assistant_content: bool = False + if "set_thought_as_assistant_content" in kwargs: + set_thought_as_assistant_content = kwargs["set_thought_as_assistant_content"] + # Special handling for Gemini model. assert "model" in copied_args and isinstance(copied_args["model"], str) if copied_args["model"].startswith("gemini-"): @@ -1379,6 +1399,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA model_capabilities=model_capabilities, model_info=model_info, add_name_prefixes=add_name_prefixes, + set_thought_as_assistant_content=set_thought_as_assistant_content, ) def __getstate__(self) -> Dict[str, Any]: @@ -1571,6 +1592,10 @@ class AzureOpenAIChatCompletionClient( if "add_name_prefixes" in kwargs: add_name_prefixes = kwargs["add_name_prefixes"] + set_thought_as_assistant_content: bool = False + if "set_thought_as_assistant_content" in kwargs: + set_thought_as_assistant_content = kwargs["set_thought_as_assistant_content"] + client = _azure_openai_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) self._raw_config: Dict[str, Any] = copied_args @@ -1580,6 +1605,7 @@ class AzureOpenAIChatCompletionClient( model_capabilities=model_capabilities, model_info=model_info, add_name_prefixes=add_name_prefixes, + set_thought_as_assistant_content=set_thought_as_assistant_content, ) def __getstate__(self) -> Dict[str, Any]: diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py index a12510525..b585a9b17 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py @@ -63,6 +63,10 @@ class BaseOpenAIClientConfiguration(CreateArguments, total=False): model_info: ModelInfo add_name_prefixes: bool """What functionality the model supports, determined by default from model name but is overriden if value passed.""" + set_thought_as_assistant_content: bool + """Whether to set the thought as assistant content. If true, the thought will be set as the assistant content. + If false, the thought will be ignored. Useful for reasoning models that produce reasoning tokens as output, but wants to + avoid having them in the input.""" default_headers: Dict[str, str] | None @@ -105,6 +109,7 @@ class BaseOpenAIClientConfigurationConfigModel(CreateArgumentsConfigModel): model_capabilities: ModelCapabilities | None = None # type: ignore model_info: ModelInfo | None = None add_name_prefixes: bool | None = None + set_thought_as_assistant_content: bool | None = None default_headers: Dict[str, str] | None = None