Fix R1 reasoning parser for openai client (#5961)

R1 reasoning tokens from hosted R1 model were not parsed correctly for the openai client

Resolves #5941

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
ZakWork 2025-03-17 17:09:41 +00:00 committed by GitHub
parent e5ab7d55cf
commit 685142cf51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 169 additions and 21 deletions

View File

@ -382,7 +382,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
elif model_capabilities is not None and model_info is not None:
raise ValueError("model_capabilities and model_info are mutually exclusive")
elif model_capabilities is not None and model_info is None:
warnings.warn("model_capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
warnings.warn(
"model_capabilities is deprecated, use model_info instead",
DeprecationWarning,
stacklevel=2,
)
info = cast(ModelInfo, model_capabilities)
info["family"] = ModelFamily.UNKNOWN
self._model_info = info
@ -528,7 +532,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=create_params.messages,
tools=create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN,
tools=(create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN),
response_format=create_params.response_format,
**create_params.create_args,
)
@ -539,7 +543,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
self._client.chat.completions.create(
messages=create_params.messages,
stream=False,
tools=create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN,
tools=(create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN),
**create_params.create_args,
)
)
@ -615,8 +619,14 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
)
finish_reason = "tool_calls"
else:
# if not tool_calls, then it is a text response and we populate the content and thought fields.
finish_reason = choice.finish_reason
content = choice.message.content or ""
# if there is a reasoning_content field, then we populate the thought field. This is for models such as R1 - direct from deepseek api.
if choice.message.model_extra is not None:
reasoning_content = choice.message.model_extra.get("reasoning_content")
if reasoning_content is not None:
thought = reasoning_content
logprobs: Optional[List[ChatCompletionTokenLogprob]] = None
if choice.logprobs and choice.logprobs.content:
@ -630,7 +640,8 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
for x in choice.logprobs.content
]
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
# This is for local R1 models.
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1 and thought is None:
thought, content = parse_r1_content(content)
response = CreateResult(
@ -725,6 +736,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
stop_reason = None
maybe_model = None
content_deltas: List[str] = []
thought_deltas: List[str] = []
full_tool_calls: Dict[int, FunctionCall] = {}
completion_tokens = 0
logprobs: Optional[List[ChatCompletionTokenLogprob]] = None
@ -767,9 +779,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
choice = (
chunk.choices[0]
if len(chunk.choices) > 0
else choice
if chunk.usage is not None and stop_reason is not None
else cast(ChunkChoice, None)
else (choice if chunk.usage is not None and stop_reason is not None else cast(ChunkChoice, None))
)
# for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set).
@ -784,7 +794,12 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
# NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop.
# However, this may not be the case for other APIs -- we should expect this may need to be updated.
continue
# if there is a reasoning_content field, then we populate the thought field. This is for models such as R1.
if choice.delta.model_extra is not None:
reasoning_content = choice.delta.model_extra.get("reasoning_content")
if reasoning_content is not None:
thought_deltas.append(reasoning_content)
yield reasoning_content
# Otherwise, get tool calls
if choice.delta.tool_calls is not None:
for tool_call_chunk in choice.delta.tool_calls:
@ -837,21 +852,30 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
# Detect whether it is a function call or just text.
content: Union[str, List[FunctionCall]]
thought: str | None = None
# Determine the content and thought based on what was collected
if full_tool_calls:
# This is a tool call.
# This is a tool call response
content = list(full_tool_calls.values())
if len(content_deltas) > 1:
# Put additional text content in the thought field.
if content_deltas:
# Store any text alongside tool calls as thoughts
thought = "".join(content_deltas)
elif len(content_deltas) > 0:
# This is a text-only content.
content = "".join(content_deltas)
else:
warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2)
content = ""
# This is a text response (possibly with thoughts)
if content_deltas:
content = "".join(content_deltas)
else:
warnings.warn(
"No text content or tool calls are available. Model returned empty result.",
stacklevel=2,
)
content = ""
# Parse R1 content if needed.
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
# Always set thoughts if we have any, regardless of other content types
if thought_deltas:
thought = "".join(thought_deltas)
# This is for local R1 models.
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1 and thought is None:
thought, content = parse_r1_content(content)
# Create the result.
@ -919,7 +943,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
async with self._client.beta.chat.completions.stream(
messages=oai_messages,
tools=tool_params if len(tool_params) > 0 else NOT_GIVEN,
response_format=response_format if response_format is not None else NOT_GIVEN,
response_format=(response_format if response_format is not None else NOT_GIVEN),
**create_args_no_response_format,
) as stream:
while True:
@ -1044,7 +1068,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
@property
def capabilities(self) -> ModelCapabilities: # type: ignore
warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
warnings.warn(
"capabilities is deprecated, use model_info instead",
DeprecationWarning,
stacklevel=2,
)
return self._model_info
@property

View File

@ -23,7 +23,11 @@ from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
from autogen_ext.models.openai._model_info import resolve_model
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type
from autogen_ext.models.openai._openai_client import (
calculate_vision_tokens,
convert_tools,
to_oai_type,
)
from openai.resources.beta.chat.completions import ( # type: ignore
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
)
@ -886,6 +890,122 @@ async def test_structured_output_with_streaming_tool_calls(monkeypatch: pytest.M
assert response.response == "happy"
@pytest.mark.asyncio
async def test_r1_reasoning_content(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test handling of reasoning_content in R1 model. Testing create without streaming."""
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion:
return ChatCompletion(
id="test_id",
model="r1",
object="chat.completion",
created=1234567890,
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content="This is the main content",
# The reasoning content is included in model_extra for hosted R1 models.
reasoning_content="This is the reasoning content", # type: ignore
),
finish_reason="stop",
)
],
usage=CompletionUsage(
prompt_tokens=10,
completion_tokens=10,
total_tokens=20,
),
)
# Patch the client creation
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
# Create the client
model_client = OpenAIChatCompletionClient(
model="r1",
api_key="",
model_info={
"family": ModelFamily.R1,
"vision": False,
"function_calling": False,
"json_output": False,
"structured_output": False,
},
)
# Test the create method
result = await model_client.create([UserMessage(content="Test message", source="user")])
# Verify that the content and thought are as expected
assert result.content == "This is the main content"
assert result.thought == "This is the reasoning content"
@pytest.mark.asyncio
async def test_r1_reasoning_content_streaming(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that reasoning_content in model_extra is correctly extracted and streamed."""
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
contentChunks = [None, None, "This is the main content"]
reasoningChunks = ["This is the reasoning content 1", "This is the reasoning content 2", None]
for i in range(len(contentChunks)):
await asyncio.sleep(0.1)
yield ChatCompletionChunk(
id="id",
choices=[
ChunkChoice(
finish_reason="stop" if i == len(contentChunks) - 1 else None,
index=0,
delta=ChoiceDelta(
content=contentChunks[i],
# The reasoning content is included in model_extra for hosted R1 models.
reasoning_content=reasoningChunks[i], # type: ignore
role="assistant",
),
),
],
created=0,
model="r1",
object="chat.completion.chunk",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
async def _mock_create(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
return _mock_create_stream(*args, **kwargs)
# Patch the client creation
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
# Create the client
model_client = OpenAIChatCompletionClient(
model="r1",
api_key="",
model_info={
"family": ModelFamily.R1,
"vision": False,
"function_calling": False,
"json_output": False,
"structured_output": False,
},
)
# Test the create_stream method
chunks: List[str | CreateResult] = []
async for chunk in model_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
chunks.append(chunk)
# Verify that the chunks first stream the reasoning content and then the main content
# Then verify that the final result has the correct content and thought
assert len(chunks) == 4
assert chunks[0] == "This is the reasoning content 1"
assert chunks[1] == "This is the reasoning content 2"
assert chunks[2] == "This is the main content"
assert isinstance(chunks[3], CreateResult)
assert chunks[3].content == "This is the main content"
assert chunks[3].thought == "This is the reasoning content 1This is the reasoning content 2"
@pytest.mark.asyncio
async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None:
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]: