Add option for openai client to avoid setting reasoning tokens as assistant message content when sending to the model api.

This commit is contained in:
Eric Zhu 2025-05-12 21:51:53 -07:00
parent 978cbd2e89
commit 7c4f8d1107
2 changed files with 33 additions and 2 deletions

View File

@ -162,10 +162,15 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole:
def to_oai_type(
message: LLMMessage, prepend_name: bool = False, model: str = "unknown", model_family: str = ModelFamily.UNKNOWN
message: LLMMessage,
prepend_name: bool = False,
set_thought_as_assistant_content: bool = False,
model: str = "unknown",
model_family: str = ModelFamily.UNKNOWN,
) -> Sequence[ChatCompletionMessageParam]:
context = {
"prepend_name": prepend_name,
"set_thought_as_assistant_content": set_thought_as_assistant_content,
}
transformers = get_transformer("openai", model, model_family)
@ -279,6 +284,7 @@ def count_tokens_openai(
model: str,
*,
add_name_prefixes: bool = False,
set_thought_as_assistant_content: bool = False,
tools: Sequence[Tool | ToolSchema] = [],
model_family: str = ModelFamily.UNKNOWN,
) -> int:
@ -294,7 +300,13 @@ def count_tokens_openai(
# Message tokens.
for message in messages:
num_tokens += tokens_per_message
oai_message = to_oai_type(message, prepend_name=add_name_prefixes, model=model, model_family=model_family)
oai_message = to_oai_type(
message,
prepend_name=add_name_prefixes,
set_thought_as_assistant_content=set_thought_as_assistant_content,
model=model,
model_family=model_family,
)
for oai_message_part in oai_message:
for key, value in oai_message_part.items():
if value is None:
@ -387,9 +399,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
model_capabilities: Optional[ModelCapabilities] = None, # type: ignore
model_info: Optional[ModelInfo] = None,
add_name_prefixes: bool = False,
set_thought_as_assistant_content: bool = False,
):
self._client = client
self._add_name_prefixes = add_name_prefixes
self._set_thought_as_assistant_content = set_thought_as_assistant_content
if model_capabilities is None and model_info is None:
try:
self._model_info = _model_info.get_info(create_args["model"])
@ -562,6 +576,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
to_oai_type(
m,
prepend_name=self._add_name_prefixes,
set_thought_as_assistant_content=self._set_thought_as_assistant_content,
model=create_args.get("model", "unknown"),
model_family=self._model_info["family"],
)
@ -1056,6 +1071,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
messages,
self._create_args["model"],
add_name_prefixes=self._add_name_prefixes,
set_thought_as_assistant_content=self._set_thought_as_assistant_content,
tools=tools,
model_family=self._model_info["family"],
)
@ -1357,6 +1373,10 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
if "add_name_prefixes" in kwargs:
add_name_prefixes = kwargs["add_name_prefixes"]
set_thought_as_assistant_content: bool = False
if "set_thought_as_assistant_content" in kwargs:
set_thought_as_assistant_content = kwargs["set_thought_as_assistant_content"]
# Special handling for Gemini model.
assert "model" in copied_args and isinstance(copied_args["model"], str)
if copied_args["model"].startswith("gemini-"):
@ -1379,6 +1399,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
model_capabilities=model_capabilities,
model_info=model_info,
add_name_prefixes=add_name_prefixes,
set_thought_as_assistant_content=set_thought_as_assistant_content,
)
def __getstate__(self) -> Dict[str, Any]:
@ -1571,6 +1592,10 @@ class AzureOpenAIChatCompletionClient(
if "add_name_prefixes" in kwargs:
add_name_prefixes = kwargs["add_name_prefixes"]
set_thought_as_assistant_content: bool = False
if "set_thought_as_assistant_content" in kwargs:
set_thought_as_assistant_content = kwargs["set_thought_as_assistant_content"]
client = _azure_openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)
self._raw_config: Dict[str, Any] = copied_args
@ -1580,6 +1605,7 @@ class AzureOpenAIChatCompletionClient(
model_capabilities=model_capabilities,
model_info=model_info,
add_name_prefixes=add_name_prefixes,
set_thought_as_assistant_content=set_thought_as_assistant_content,
)
def __getstate__(self) -> Dict[str, Any]:

View File

@ -63,6 +63,10 @@ class BaseOpenAIClientConfiguration(CreateArguments, total=False):
model_info: ModelInfo
add_name_prefixes: bool
"""What functionality the model supports, determined by default from model name but is overriden if value passed."""
set_thought_as_assistant_content: bool
"""Whether to set the thought as assistant content. If true, the thought will be set as the assistant content.
If false, the thought will be ignored. Useful for reasoning models that produce reasoning tokens as output, but wants to
avoid having them in the input."""
default_headers: Dict[str, str] | None
@ -105,6 +109,7 @@ class BaseOpenAIClientConfigurationConfigModel(CreateArgumentsConfigModel):
model_capabilities: ModelCapabilities | None = None # type: ignore
model_info: ModelInfo | None = None
add_name_prefixes: bool | None = None
set_thought_as_assistant_content: bool | None = None
default_headers: Dict[str, str] | None = None