feat: Support R1 reasoning text in model create result; enhance API docs (#5262)

Resolves #5255 

---------

Co-authored-by: afourney <adamfo@microsoft.com>
This commit is contained in:
Eric Zhu 2025-01-30 11:03:54 -08:00 committed by GitHub
parent 44db2cc1fb
commit f656ff1e01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 536 additions and 9 deletions

View File

@ -22,9 +22,10 @@ class ModelFamily:
O1 = "o1" O1 = "o1"
GPT_4 = "gpt-4" GPT_4 = "gpt-4"
GPT_35 = "gpt-35" GPT_35 = "gpt-35"
R1 = "r1"
UNKNOWN = "unknown" UNKNOWN = "unknown"
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "unknown"] ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "r1", "unknown"]
def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily: def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.") raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")

View File

@ -8,11 +8,25 @@ from .. import FunctionCall, Image
class SystemMessage(BaseModel): class SystemMessage(BaseModel):
"""System message contains instructions for the model coming from the developer.
.. note::
Open AI is moving away from using 'system' role in favor of 'developer' role.
See `Model Spec <https://cdn.openai.com/spec/model-spec-2024-05-08.html#definitions>`_ for more details.
However, the 'system' role is still allowed in their API and will be automatically converted to 'developer' role
on the server side.
So, you can use `SystemMessage` for developer messages.
"""
content: str content: str
type: Literal["SystemMessage"] = "SystemMessage" type: Literal["SystemMessage"] = "SystemMessage"
class UserMessage(BaseModel): class UserMessage(BaseModel):
"""User message contains input from end users, or a catch-all for data provided to the model."""
content: Union[str, List[Union[str, Image]]] content: Union[str, List[Union[str, Image]]]
# Name of the agent that sent this message # Name of the agent that sent this message
@ -22,6 +36,8 @@ class UserMessage(BaseModel):
class AssistantMessage(BaseModel): class AssistantMessage(BaseModel):
"""Assistant message are sampled from the language model."""
content: Union[str, List[FunctionCall]] content: Union[str, List[FunctionCall]]
# Name of the agent that sent this message # Name of the agent that sent this message
@ -31,11 +47,15 @@ class AssistantMessage(BaseModel):
class FunctionExecutionResult(BaseModel): class FunctionExecutionResult(BaseModel):
"""Function execution result contains the output of a function call."""
content: str content: str
call_id: str call_id: str
class FunctionExecutionResultMessage(BaseModel): class FunctionExecutionResultMessage(BaseModel):
"""Function execution result message contains the output of multiple function calls."""
content: List[FunctionExecutionResult] content: List[FunctionExecutionResult]
type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage" type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage"
@ -69,8 +89,23 @@ class ChatCompletionTokenLogprob(BaseModel):
class CreateResult(BaseModel): class CreateResult(BaseModel):
"""Create result contains the output of a model completion."""
finish_reason: FinishReasons finish_reason: FinishReasons
"""The reason the model finished generating the completion."""
content: Union[str, List[FunctionCall]] content: Union[str, List[FunctionCall]]
"""The output of the model completion."""
usage: RequestUsage usage: RequestUsage
"""The usage of tokens in the prompt and completion."""
cached: bool cached: bool
"""Whether the completion was generated from a cached response."""
logprobs: Optional[List[ChatCompletionTokenLogprob] | None] = None logprobs: Optional[List[ChatCompletionTokenLogprob] | None] = None
"""The logprobs of the tokens in the completion."""
thought: Optional[str] = None
"""The reasoning text for the completion if available. Used for reasoning models
and additional text content besides function calls."""

View File

@ -120,6 +120,7 @@ dev = [
"autogen_test_utils", "autogen_test_utils",
"langchain-experimental", "langchain-experimental",
"pandas-stubs>=2.2.3.241126", "pandas-stubs>=2.2.3.241126",
"httpx>=0.28.1",
] ]
[tool.ruff] [tool.ruff]

View File

@ -0,0 +1,33 @@
import warnings
from typing import Tuple
def parse_r1_content(content: str) -> Tuple[str | None, str]:
"""Parse the content of an R1-style message that contains a `<think>...</think>` field."""
# Find the start and end of the think field
think_start = content.find("<think>")
think_end = content.find("</think>")
if think_start == -1 or think_end == -1:
warnings.warn(
"Could not find <think>..</think> field in model response content. " "No thought was extracted.",
UserWarning,
stacklevel=2,
)
return None, content
if think_end < think_start:
warnings.warn(
"Found </think> before <think> in model response content. " "No thought was extracted.",
UserWarning,
stacklevel=2,
)
return None, content
# Extract the think field
thought = content[think_start + len("<think>") : think_end].strip()
# Extract the rest of the content, skipping the think field.
content = content[think_end + len("</think>") :].strip()
return thought, content

View File

@ -12,6 +12,7 @@ from autogen_core.models import (
FinishReasons, FinishReasons,
FunctionExecutionResultMessage, FunctionExecutionResultMessage,
LLMMessage, LLMMessage,
ModelFamily,
ModelInfo, ModelInfo,
RequestUsage, RequestUsage,
SystemMessage, SystemMessage,
@ -55,6 +56,8 @@ from autogen_ext.models.azure.config import (
AzureAIChatCompletionClientConfig, AzureAIChatCompletionClientConfig,
) )
from .._utils.parse_r1_content import parse_r1_content
create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs) create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage] AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage]
@ -354,11 +357,17 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
finish_reason = choice.finish_reason # type: ignore finish_reason = choice.finish_reason # type: ignore
content = choice.message.content or "" content = choice.message.content or ""
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None
response = CreateResult( response = CreateResult(
finish_reason=finish_reason, # type: ignore finish_reason=finish_reason, # type: ignore
content=content, content=content,
usage=usage, usage=usage,
cached=False, cached=False,
thought=thought,
) )
self.add_usage(usage) self.add_usage(usage)
@ -464,11 +473,17 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
) )
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None
result = CreateResult( result = CreateResult(
finish_reason=finish_reason, finish_reason=finish_reason,
content=content, content=content,
usage=usage, usage=usage,
cached=False, cached=False,
thought=thought,
) )
self.add_usage(usage) self.add_usage(usage)

View File

@ -72,6 +72,7 @@ from openai.types.shared_params import FunctionDefinition, FunctionParameters
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Self, Unpack from typing_extensions import Self, Unpack
from .._utils.parse_r1_content import parse_r1_content
from . import _model_info from . import _model_info
from .config import ( from .config import (
AzureOpenAIClientConfiguration, AzureOpenAIClientConfiguration,
@ -605,12 +606,19 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
) )
for x in choice.logprobs.content for x in choice.logprobs.content
] ]
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None
response = CreateResult( response = CreateResult(
finish_reason=normalize_stop_reason(finish_reason), finish_reason=normalize_stop_reason(finish_reason),
content=content, content=content,
usage=usage, usage=usage,
cached=False, cached=False,
logprobs=logprobs, logprobs=logprobs,
thought=thought,
) )
self._total_usage = _add_usage(self._total_usage, usage) self._total_usage = _add_usage(self._total_usage, usage)
@ -818,12 +826,18 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
) )
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None
result = CreateResult( result = CreateResult(
finish_reason=normalize_stop_reason(stop_reason), finish_reason=normalize_stop_reason(stop_reason),
content=content, content=content,
usage=usage, usage=usage,
cached=False, cached=False,
logprobs=logprobs, logprobs=logprobs,
thought=thought,
) )
self._total_usage = _add_usage(self._total_usage, usage) self._total_usage = _add_usage(self._total_usage, usage)
@ -992,20 +1006,23 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
print(result) print(result)
To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model capabilities: To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model info.
For example, to use Ollama, you can use the following code snippet:
.. code-block:: python .. code-block:: python
from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_core.models import ModelFamily
custom_model_client = OpenAIChatCompletionClient( custom_model_client = OpenAIChatCompletionClient(
model="custom-model-name", model="deepseek-r1:1.5b",
base_url="https://custom-model.com/reset/of/the/path", base_url="http://localhost:11434/v1",
api_key="placeholder", api_key="placeholder",
model_capabilities={ model_info={
"vision": True, "vision": False,
"function_calling": True, "function_calling": False,
"json_output": True, "json_output": False,
"family": ModelFamily.R1,
}, },
) )

View File

@ -25,6 +25,8 @@ from typing_extensions import AsyncGenerator, Union
from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool
from .._utils.parse_r1_content import parse_r1_content
class SKChatCompletionAdapter(ChatCompletionClient): class SKChatCompletionAdapter(ChatCompletionClient):
""" """
@ -402,11 +404,17 @@ class SKChatCompletionAdapter(ChatCompletionClient):
content = result[0].content content = result[0].content
finish_reason = "stop" finish_reason = "stop"
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None
return CreateResult( return CreateResult(
content=content, content=content,
finish_reason=finish_reason, finish_reason=finish_reason,
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False, cached=False,
thought=thought,
) )
async def create_stream( async def create_stream(
@ -485,11 +493,18 @@ class SKChatCompletionAdapter(ChatCompletionClient):
if accumulated_content: if accumulated_content:
self._total_prompt_tokens += prompt_tokens self._total_prompt_tokens += prompt_tokens
self._total_completion_tokens += completion_tokens self._total_completion_tokens += completion_tokens
if isinstance(accumulated_content, str) and self._model_info["family"] == ModelFamily.R1:
thought, accumulated_content = parse_r1_content(accumulated_content)
else:
thought = None
yield CreateResult( yield CreateResult(
content=accumulated_content, content=accumulated_content,
finish_reason="stop", finish_reason="stop",
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False, cached=False,
thought=thought,
) )
def actual_usage(self) -> RequestUsage: def actual_usage(self) -> RequestUsage:

View File

@ -5,7 +5,7 @@ from typing import Any, AsyncGenerator, List
import pytest import pytest
from autogen_core import CancellationToken, FunctionCall, Image from autogen_core import CancellationToken, FunctionCall, Image
from autogen_core.models import CreateResult, UserMessage from autogen_core.models import CreateResult, ModelFamily, UserMessage
from autogen_ext.models.azure import AzureAIChatCompletionClient from autogen_ext.models.azure import AzureAIChatCompletionClient
from azure.ai.inference.aio import ( from azure.ai.inference.aio import (
ChatCompletionsClient, ChatCompletionsClient,
@ -295,3 +295,82 @@ async def test_multimodal_supported(monkeypatch: pytest.MonkeyPatch) -> None:
] ]
) )
assert result.content == "Handled image" assert result.content == "Handled image"
@pytest.mark.asyncio
async def test_r1_content(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Ensures that the content is parsed correctly when it contains an R1-style think field.
"""
async def _mock_create_r1_content_stream(
*args: Any, **kwargs: Any
) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
mock_chunks_content = ["<think>Thought</think> Hello", " Another Hello", " Yet Another Hello"]
mock_chunks = [
StreamingChatChoiceUpdate(
index=0,
finish_reason="stop",
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
)
for chunk_content in mock_chunks_content
]
for mock_chunk in mock_chunks:
await asyncio.sleep(0.1)
yield StreamingChatCompletionsUpdate(
id="id",
choices=[mock_chunk],
created=datetime.now(),
model="model",
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
async def _mock_create_r1_content(
*args: Any, **kwargs: Any
) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
stream = kwargs.get("stream", False)
if not stream:
await asyncio.sleep(0.1)
return ChatCompletions(
id="id",
created=datetime.now(),
model="model",
choices=[
ChatChoice(
index=0,
finish_reason="stop",
message=ChatResponseMessage(content="<think>Thought</think> Hello", role="assistant"),
)
],
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
else:
return _mock_create_r1_content_stream(*args, **kwargs)
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create_r1_content)
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": True,
"family": ModelFamily.R1,
},
model="model",
)
result = await client.create(messages=[UserMessage(content="Hello", source="user")])
assert result.content == "Hello"
assert result.thought == "Thought"
chunks: List[str | CreateResult] = []
async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
chunks.append(chunk)
assert isinstance(chunks[-1], CreateResult)
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
assert chunks[-1].thought == "Thought"

View File

@ -4,6 +4,7 @@ import os
from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar
from unittest.mock import MagicMock from unittest.mock import MagicMock
import httpx
import pytest import pytest
from autogen_core import CancellationToken, FunctionCall, Image from autogen_core import CancellationToken, FunctionCall, Image
from autogen_core.models import ( from autogen_core.models import (
@ -12,6 +13,7 @@ from autogen_core.models import (
FunctionExecutionResult, FunctionExecutionResult,
FunctionExecutionResultMessage, FunctionExecutionResultMessage,
LLMMessage, LLMMessage,
ModelInfo,
RequestUsage, RequestUsage,
SystemMessage, SystemMessage,
UserMessage, UserMessage,
@ -468,6 +470,154 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
assert response.response == "happy" assert response.response == "happy"
@pytest.mark.asyncio
async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None:
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
chunks = ["<think> Hello</think>", " Another Hello", " Yet Another Hello"]
for i, chunk in enumerate(chunks):
await asyncio.sleep(0.1)
yield ChatCompletionChunk(
id="id",
choices=[
ChunkChoice(
finish_reason="stop" if i == len(chunks) - 1 else None,
index=0,
delta=ChoiceDelta(
content=chunk,
role="assistant",
),
),
],
created=0,
model="r1",
object="chat.completion.chunk",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
stream = kwargs.get("stream", False)
if not stream:
await asyncio.sleep(0.1)
return ChatCompletion(
id="id",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="<think> Hello</think> Another Hello Yet Another Hello", role="assistant"
),
)
],
created=0,
model="r1",
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
else:
return _mock_create_stream(*args, **kwargs)
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
model_client = OpenAIChatCompletionClient(
model="r1",
api_key="",
model_info={"family": ModelFamily.R1, "vision": False, "function_calling": False, "json_output": False},
)
# Successful completion with think field.
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
assert create_result.content == "Another Hello Yet Another Hello"
assert create_result.finish_reason == "stop"
assert not create_result.cached
assert create_result.thought == "Hello"
# Stream completion with think field.
chunks: List[str | CreateResult] = []
async for chunk in model_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert chunks[-1].content == "Another Hello Yet Another Hello"
assert chunks[-1].thought == "Hello"
assert not chunks[-1].cached
@pytest.mark.asyncio
async def test_r1_think_field_not_present(monkeypatch: pytest.MonkeyPatch) -> None:
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
chunks = ["Hello", " Another Hello", " Yet Another Hello"]
for i, chunk in enumerate(chunks):
await asyncio.sleep(0.1)
yield ChatCompletionChunk(
id="id",
choices=[
ChunkChoice(
finish_reason="stop" if i == len(chunks) - 1 else None,
index=0,
delta=ChoiceDelta(
content=chunk,
role="assistant",
),
),
],
created=0,
model="r1",
object="chat.completion.chunk",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
stream = kwargs.get("stream", False)
if not stream:
await asyncio.sleep(0.1)
return ChatCompletion(
id="id",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="Hello Another Hello Yet Another Hello", role="assistant"
),
)
],
created=0,
model="r1",
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
else:
return _mock_create_stream(*args, **kwargs)
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
model_client = OpenAIChatCompletionClient(
model="r1",
api_key="",
model_info={"family": ModelFamily.R1, "vision": False, "function_calling": False, "json_output": False},
)
# Warning completion when think field is not present.
with pytest.warns(UserWarning, match="Could not find <think>..</think> field in model response content."):
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
assert create_result.content == "Hello Another Hello Yet Another Hello"
assert create_result.finish_reason == "stop"
assert not create_result.cached
assert create_result.thought is None
# Stream completion with think field.
with pytest.warns(UserWarning, match="Could not find <think>..</think> field in model response content."):
chunks: List[str | CreateResult] = []
async for chunk in model_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
assert chunks[-1].thought is None
assert not chunks[-1].cached
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None: async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13" model = "gpt-4o-2024-05-13"
@ -836,4 +986,68 @@ async def test_hugging_face() -> None:
await _test_model_client_basic_completion(model_client) await _test_model_client_basic_completion(model_client)
@pytest.mark.asyncio
async def test_ollama() -> None:
model = "deepseek-r1:1.5b"
model_info: ModelInfo = {
"function_calling": False,
"json_output": False,
"vision": False,
"family": ModelFamily.R1,
}
# Check if the model is running locally.
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"http://localhost:11434/v1/models/{model}")
response.raise_for_status()
except httpx.HTTPStatusError as e:
pytest.skip(f"{model} model is not running locally: {e}")
except httpx.ConnectError as e:
pytest.skip(f"Ollama is not running locally: {e}")
model_client = OpenAIChatCompletionClient(
model=model,
api_key="placeholder",
base_url="http://localhost:11434/v1",
model_info=model_info,
)
# Test basic completion with the Ollama deepseek-r1:1.5b model.
create_result = await model_client.create(
messages=[
UserMessage(
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
"what is the probability of getting a green and a red balls?",
source="user",
),
]
)
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
assert create_result.finish_reason == "stop"
assert create_result.usage is not None
if model_info["family"] == ModelFamily.R1:
assert create_result.thought is not None
# Test streaming completion with the Ollama deepseek-r1:1.5b model.
chunks: List[str | CreateResult] = []
async for chunk in model_client.create_stream(
messages=[
UserMessage(
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
"what is the probability of getting a green and a red balls?",
source="user",
),
]
):
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert chunks[-1].finish_reason == "stop"
assert len(chunks[-1].content) > 0
assert chunks[-1].usage is not None
if model_info["family"] == ModelFamily.R1:
assert chunks[-1].thought is not None
# TODO: add integration tests for Azure OpenAI using AAD token. # TODO: add integration tests for Azure OpenAI using AAD token.

View File

@ -377,3 +377,75 @@ async def test_sk_chat_completion_custom_model_info(sk_client: AzureChatCompleti
# Verify capabilities returns the same ModelInfo # Verify capabilities returns the same ModelInfo
assert adapter.capabilities == adapter.model_info assert adapter.capabilities == adapter.model_info
@pytest.mark.asyncio
async def test_sk_chat_completion_r1_content() -> None:
async def mock_get_chat_message_contents(
chat_history: ChatHistory,
settings: PromptExecutionSettings,
**kwargs: Any,
) -> list[ChatMessageContent]:
return [
ChatMessageContent(
ai_model_id="r1",
role=AuthorRole.ASSISTANT,
metadata={"usage": {"prompt_tokens": 20, "completion_tokens": 9}},
items=[TextContent(text="<think>Reasoning...</think> Hello!")],
finish_reason=FinishReason.STOP,
)
]
async def mock_get_streaming_chat_message_contents(
chat_history: ChatHistory,
settings: PromptExecutionSettings,
**kwargs: Any,
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
chunks = ["<think>Reasoning...</think>", " Hello!"]
for i, chunk in enumerate(chunks):
yield [
StreamingChatMessageContent(
choice_index=0,
inner_content=ChatCompletionChunk(
id=f"chatcmpl-{i}",
choices=[Choice(delta=ChoiceDelta(content=chunk), finish_reason=None, index=0)],
created=1736674044,
model="r1",
object="chat.completion.chunk",
service_tier="scale",
system_fingerprint="fingerprint",
usage=CompletionUsage(prompt_tokens=20, completion_tokens=9, total_tokens=29),
),
ai_model_id="gpt-4o-mini",
metadata={"id": f"chatcmpl-{i}", "created": 1736674044},
role=AuthorRole.ASSISTANT,
items=[StreamingTextContent(choice_index=0, text=chunk)],
finish_reason=FinishReason.STOP if i == len(chunks) - 1 else None,
)
]
mock_client = AsyncMock(spec=AzureChatCompletion)
mock_client.get_chat_message_contents = mock_get_chat_message_contents
mock_client.get_streaming_chat_message_contents = mock_get_streaming_chat_message_contents
kernel = Kernel(memory=NullMemory())
adapter = SKChatCompletionAdapter(
mock_client,
kernel=kernel,
model_info=ModelInfo(vision=False, function_calling=False, json_output=False, family=ModelFamily.R1),
)
result = await adapter.create(messages=[UserMessage(content="Say hello!", source="user")])
assert result.finish_reason == "stop"
assert result.content == "Hello!"
assert result.thought == "Reasoning..."
response_chunks: list[CreateResult | str] = []
async for chunk in adapter.create_stream(messages=[UserMessage(content="Say hello!", source="user")]):
response_chunks.append(chunk)
assert len(response_chunks) > 0
assert isinstance(response_chunks[-1], CreateResult)
assert response_chunks[-1].finish_reason == "stop"
assert response_chunks[-1].content == "Hello!"
assert response_chunks[-1].thought == "Reasoning..."

View File

@ -0,0 +1,43 @@
import pytest
from autogen_ext.models._utils.parse_r1_content import parse_r1_content
def test_parse_r1_content() -> None:
content = "Hello, <think>world</think> How are you?"
thought, content = parse_r1_content(content)
assert thought == "world"
assert content == "How are you?"
with pytest.warns(
UserWarning,
match="Could not find <think>..</think> field in model response content. " "No thought was extracted.",
):
content = "Hello, world How are you?"
thought, content = parse_r1_content(content)
assert thought is None
assert content == "Hello, world How are you?"
with pytest.warns(
UserWarning,
match="Could not find <think>..</think> field in model response content. " "No thought was extracted.",
):
content = "Hello, <think>world How are you?"
thought, content = parse_r1_content(content)
assert thought is None
assert content == "Hello, <think>world How are you?"
with pytest.warns(
UserWarning, match="Found </think> before <think> in model response content. " "No thought was extracted."
):
content = "</think>Hello, <think>world</think>"
thought, content = parse_r1_content(content)
assert thought is None
assert content == "</think>Hello, <think>world</think>"
with pytest.warns(
UserWarning, match="Found </think> before <think> in model response content. " "No thought was extracted."
):
content = "</think>Hello, <think>world"
thought, content = parse_r1_content(content)
assert thought is None
assert content == "</think>Hello, <think>world"

2
python/uv.lock generated
View File

@ -654,6 +654,7 @@ web-surfer = [
[package.dev-dependencies] [package.dev-dependencies]
dev = [ dev = [
{ name = "autogen-test-utils" }, { name = "autogen-test-utils" },
{ name = "httpx" },
{ name = "langchain-experimental" }, { name = "langchain-experimental" },
{ name = "pandas-stubs" }, { name = "pandas-stubs" },
] ]
@ -706,6 +707,7 @@ requires-dist = [
[package.metadata.requires-dev] [package.metadata.requires-dev]
dev = [ dev = [
{ name = "autogen-test-utils", editable = "packages/autogen-test-utils" }, { name = "autogen-test-utils", editable = "packages/autogen-test-utils" },
{ name = "httpx", specifier = ">=0.28.1" },
{ name = "langchain-experimental" }, { name = "langchain-experimental" },
{ name = "pandas-stubs", specifier = ">=2.2.3.241126" }, { name = "pandas-stubs", specifier = ">=2.2.3.241126" },
] ]