mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-16 01:28:00 +00:00
feat: Support R1 reasoning text in model create result; enhance API docs (#5262)
Resolves #5255 --------- Co-authored-by: afourney <adamfo@microsoft.com>
This commit is contained in:
parent
44db2cc1fb
commit
f656ff1e01
@ -22,9 +22,10 @@ class ModelFamily:
|
||||
O1 = "o1"
|
||||
GPT_4 = "gpt-4"
|
||||
GPT_35 = "gpt-35"
|
||||
R1 = "r1"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "unknown"]
|
||||
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "r1", "unknown"]
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
|
||||
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")
|
||||
|
||||
@ -8,11 +8,25 @@ from .. import FunctionCall, Image
|
||||
|
||||
|
||||
class SystemMessage(BaseModel):
|
||||
"""System message contains instructions for the model coming from the developer.
|
||||
|
||||
.. note::
|
||||
|
||||
Open AI is moving away from using 'system' role in favor of 'developer' role.
|
||||
See `Model Spec <https://cdn.openai.com/spec/model-spec-2024-05-08.html#definitions>`_ for more details.
|
||||
However, the 'system' role is still allowed in their API and will be automatically converted to 'developer' role
|
||||
on the server side.
|
||||
So, you can use `SystemMessage` for developer messages.
|
||||
|
||||
"""
|
||||
|
||||
content: str
|
||||
type: Literal["SystemMessage"] = "SystemMessage"
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
"""User message contains input from end users, or a catch-all for data provided to the model."""
|
||||
|
||||
content: Union[str, List[Union[str, Image]]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
@ -22,6 +36,8 @@ class UserMessage(BaseModel):
|
||||
|
||||
|
||||
class AssistantMessage(BaseModel):
|
||||
"""Assistant message are sampled from the language model."""
|
||||
|
||||
content: Union[str, List[FunctionCall]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
@ -31,11 +47,15 @@ class AssistantMessage(BaseModel):
|
||||
|
||||
|
||||
class FunctionExecutionResult(BaseModel):
|
||||
"""Function execution result contains the output of a function call."""
|
||||
|
||||
content: str
|
||||
call_id: str
|
||||
|
||||
|
||||
class FunctionExecutionResultMessage(BaseModel):
|
||||
"""Function execution result message contains the output of multiple function calls."""
|
||||
|
||||
content: List[FunctionExecutionResult]
|
||||
|
||||
type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage"
|
||||
@ -69,8 +89,23 @@ class ChatCompletionTokenLogprob(BaseModel):
|
||||
|
||||
|
||||
class CreateResult(BaseModel):
|
||||
"""Create result contains the output of a model completion."""
|
||||
|
||||
finish_reason: FinishReasons
|
||||
"""The reason the model finished generating the completion."""
|
||||
|
||||
content: Union[str, List[FunctionCall]]
|
||||
"""The output of the model completion."""
|
||||
|
||||
usage: RequestUsage
|
||||
"""The usage of tokens in the prompt and completion."""
|
||||
|
||||
cached: bool
|
||||
"""Whether the completion was generated from a cached response."""
|
||||
|
||||
logprobs: Optional[List[ChatCompletionTokenLogprob] | None] = None
|
||||
"""The logprobs of the tokens in the completion."""
|
||||
|
||||
thought: Optional[str] = None
|
||||
"""The reasoning text for the completion if available. Used for reasoning models
|
||||
and additional text content besides function calls."""
|
||||
|
||||
@ -120,6 +120,7 @@ dev = [
|
||||
"autogen_test_utils",
|
||||
"langchain-experimental",
|
||||
"pandas-stubs>=2.2.3.241126",
|
||||
"httpx>=0.28.1",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
@ -0,0 +1,33 @@
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def parse_r1_content(content: str) -> Tuple[str | None, str]:
|
||||
"""Parse the content of an R1-style message that contains a `<think>...</think>` field."""
|
||||
# Find the start and end of the think field
|
||||
think_start = content.find("<think>")
|
||||
think_end = content.find("</think>")
|
||||
|
||||
if think_start == -1 or think_end == -1:
|
||||
warnings.warn(
|
||||
"Could not find <think>..</think> field in model response content. " "No thought was extracted.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return None, content
|
||||
|
||||
if think_end < think_start:
|
||||
warnings.warn(
|
||||
"Found </think> before <think> in model response content. " "No thought was extracted.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return None, content
|
||||
|
||||
# Extract the think field
|
||||
thought = content[think_start + len("<think>") : think_end].strip()
|
||||
|
||||
# Extract the rest of the content, skipping the think field.
|
||||
content = content[think_end + len("</think>") :].strip()
|
||||
|
||||
return thought, content
|
||||
@ -12,6 +12,7 @@ from autogen_core.models import (
|
||||
FinishReasons,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelFamily,
|
||||
ModelInfo,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
@ -55,6 +56,8 @@ from autogen_ext.models.azure.config import (
|
||||
AzureAIChatCompletionClientConfig,
|
||||
)
|
||||
|
||||
from .._utils.parse_r1_content import parse_r1_content
|
||||
|
||||
create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
|
||||
AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage]
|
||||
|
||||
@ -354,11 +357,17 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
finish_reason = choice.finish_reason # type: ignore
|
||||
content = choice.message.content or ""
|
||||
|
||||
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, content = parse_r1_content(content)
|
||||
else:
|
||||
thought = None
|
||||
|
||||
response = CreateResult(
|
||||
finish_reason=finish_reason, # type: ignore
|
||||
content=content,
|
||||
usage=usage,
|
||||
cached=False,
|
||||
thought=thought,
|
||||
)
|
||||
|
||||
self.add_usage(usage)
|
||||
@ -464,11 +473,17 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
prompt_tokens=prompt_tokens,
|
||||
)
|
||||
|
||||
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, content = parse_r1_content(content)
|
||||
else:
|
||||
thought = None
|
||||
|
||||
result = CreateResult(
|
||||
finish_reason=finish_reason,
|
||||
content=content,
|
||||
usage=usage,
|
||||
cached=False,
|
||||
thought=thought,
|
||||
)
|
||||
|
||||
self.add_usage(usage)
|
||||
|
||||
@ -72,6 +72,7 @@ from openai.types.shared_params import FunctionDefinition, FunctionParameters
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self, Unpack
|
||||
|
||||
from .._utils.parse_r1_content import parse_r1_content
|
||||
from . import _model_info
|
||||
from .config import (
|
||||
AzureOpenAIClientConfiguration,
|
||||
@ -605,12 +606,19 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
)
|
||||
for x in choice.logprobs.content
|
||||
]
|
||||
|
||||
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, content = parse_r1_content(content)
|
||||
else:
|
||||
thought = None
|
||||
|
||||
response = CreateResult(
|
||||
finish_reason=normalize_stop_reason(finish_reason),
|
||||
content=content,
|
||||
usage=usage,
|
||||
cached=False,
|
||||
logprobs=logprobs,
|
||||
thought=thought,
|
||||
)
|
||||
|
||||
self._total_usage = _add_usage(self._total_usage, usage)
|
||||
@ -818,12 +826,18 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, content = parse_r1_content(content)
|
||||
else:
|
||||
thought = None
|
||||
|
||||
result = CreateResult(
|
||||
finish_reason=normalize_stop_reason(stop_reason),
|
||||
content=content,
|
||||
usage=usage,
|
||||
cached=False,
|
||||
logprobs=logprobs,
|
||||
thought=thought,
|
||||
)
|
||||
|
||||
self._total_usage = _add_usage(self._total_usage, usage)
|
||||
@ -992,20 +1006,23 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
||||
print(result)
|
||||
|
||||
|
||||
To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model capabilities:
|
||||
To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model info.
|
||||
For example, to use Ollama, you can use the following code snippet:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_core.models import ModelFamily
|
||||
|
||||
custom_model_client = OpenAIChatCompletionClient(
|
||||
model="custom-model-name",
|
||||
base_url="https://custom-model.com/reset/of/the/path",
|
||||
model="deepseek-r1:1.5b",
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="placeholder",
|
||||
model_capabilities={
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
model_info={
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": ModelFamily.R1,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -25,6 +25,8 @@ from typing_extensions import AsyncGenerator, Union
|
||||
|
||||
from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool
|
||||
|
||||
from .._utils.parse_r1_content import parse_r1_content
|
||||
|
||||
|
||||
class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
"""
|
||||
@ -402,11 +404,17 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
content = result[0].content
|
||||
finish_reason = "stop"
|
||||
|
||||
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, content = parse_r1_content(content)
|
||||
else:
|
||||
thought = None
|
||||
|
||||
return CreateResult(
|
||||
content=content,
|
||||
finish_reason=finish_reason,
|
||||
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
|
||||
cached=False,
|
||||
thought=thought,
|
||||
)
|
||||
|
||||
async def create_stream(
|
||||
@ -485,11 +493,18 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
if accumulated_content:
|
||||
self._total_prompt_tokens += prompt_tokens
|
||||
self._total_completion_tokens += completion_tokens
|
||||
|
||||
if isinstance(accumulated_content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, accumulated_content = parse_r1_content(accumulated_content)
|
||||
else:
|
||||
thought = None
|
||||
|
||||
yield CreateResult(
|
||||
content=accumulated_content,
|
||||
finish_reason="stop",
|
||||
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
|
||||
cached=False,
|
||||
thought=thought,
|
||||
)
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Any, AsyncGenerator, List
|
||||
|
||||
import pytest
|
||||
from autogen_core import CancellationToken, FunctionCall, Image
|
||||
from autogen_core.models import CreateResult, UserMessage
|
||||
from autogen_core.models import CreateResult, ModelFamily, UserMessage
|
||||
from autogen_ext.models.azure import AzureAIChatCompletionClient
|
||||
from azure.ai.inference.aio import (
|
||||
ChatCompletionsClient,
|
||||
@ -295,3 +295,82 @@ async def test_multimodal_supported(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
]
|
||||
)
|
||||
assert result.content == "Handled image"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_r1_content(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""
|
||||
Ensures that the content is parsed correctly when it contains an R1-style think field.
|
||||
"""
|
||||
|
||||
async def _mock_create_r1_content_stream(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
|
||||
mock_chunks_content = ["<think>Thought</think> Hello", " Another Hello", " Yet Another Hello"]
|
||||
|
||||
mock_chunks = [
|
||||
StreamingChatChoiceUpdate(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
|
||||
)
|
||||
for chunk_content in mock_chunks_content
|
||||
]
|
||||
|
||||
for mock_chunk in mock_chunks:
|
||||
await asyncio.sleep(0.1)
|
||||
yield StreamingChatCompletionsUpdate(
|
||||
id="id",
|
||||
choices=[mock_chunk],
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
)
|
||||
|
||||
async def _mock_create_r1_content(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
|
||||
stream = kwargs.get("stream", False)
|
||||
|
||||
if not stream:
|
||||
await asyncio.sleep(0.1)
|
||||
return ChatCompletions(
|
||||
id="id",
|
||||
created=datetime.now(),
|
||||
model="model",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
message=ChatResponseMessage(content="<think>Thought</think> Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
)
|
||||
else:
|
||||
return _mock_create_r1_content_stream(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create_r1_content)
|
||||
|
||||
client = AzureAIChatCompletionClient(
|
||||
endpoint="endpoint",
|
||||
credential=AzureKeyCredential("api_key"),
|
||||
model_info={
|
||||
"json_output": False,
|
||||
"function_calling": False,
|
||||
"vision": True,
|
||||
"family": ModelFamily.R1,
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
|
||||
result = await client.create(messages=[UserMessage(content="Hello", source="user")])
|
||||
assert result.content == "Hello"
|
||||
assert result.thought == "Thought"
|
||||
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
|
||||
chunks.append(chunk)
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
|
||||
assert chunks[-1].thought == "Thought"
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from autogen_core import CancellationToken, FunctionCall, Image
|
||||
from autogen_core.models import (
|
||||
@ -12,6 +13,7 @@ from autogen_core.models import (
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelInfo,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
@ -468,6 +470,154 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert response.response == "happy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
chunks = ["<think> Hello</think>", " Another Hello", " Yet Another Hello"]
|
||||
for i, chunk in enumerate(chunks):
|
||||
await asyncio.sleep(0.1)
|
||||
yield ChatCompletionChunk(
|
||||
id="id",
|
||||
choices=[
|
||||
ChunkChoice(
|
||||
finish_reason="stop" if i == len(chunks) - 1 else None,
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=chunk,
|
||||
role="assistant",
|
||||
),
|
||||
),
|
||||
],
|
||||
created=0,
|
||||
model="r1",
|
||||
object="chat.completion.chunk",
|
||||
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
)
|
||||
|
||||
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
||||
stream = kwargs.get("stream", False)
|
||||
if not stream:
|
||||
await asyncio.sleep(0.1)
|
||||
return ChatCompletion(
|
||||
id="id",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="<think> Hello</think> Another Hello Yet Another Hello", role="assistant"
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model="r1",
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
)
|
||||
else:
|
||||
return _mock_create_stream(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="r1",
|
||||
api_key="",
|
||||
model_info={"family": ModelFamily.R1, "vision": False, "function_calling": False, "json_output": False},
|
||||
)
|
||||
|
||||
# Successful completion with think field.
|
||||
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
|
||||
assert create_result.content == "Another Hello Yet Another Hello"
|
||||
assert create_result.finish_reason == "stop"
|
||||
assert not create_result.cached
|
||||
assert create_result.thought == "Hello"
|
||||
|
||||
# Stream completion with think field.
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in model_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert chunks[-1].content == "Another Hello Yet Another Hello"
|
||||
assert chunks[-1].thought == "Hello"
|
||||
assert not chunks[-1].cached
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_r1_think_field_not_present(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
chunks = ["Hello", " Another Hello", " Yet Another Hello"]
|
||||
for i, chunk in enumerate(chunks):
|
||||
await asyncio.sleep(0.1)
|
||||
yield ChatCompletionChunk(
|
||||
id="id",
|
||||
choices=[
|
||||
ChunkChoice(
|
||||
finish_reason="stop" if i == len(chunks) - 1 else None,
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=chunk,
|
||||
role="assistant",
|
||||
),
|
||||
),
|
||||
],
|
||||
created=0,
|
||||
model="r1",
|
||||
object="chat.completion.chunk",
|
||||
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
)
|
||||
|
||||
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
||||
stream = kwargs.get("stream", False)
|
||||
if not stream:
|
||||
await asyncio.sleep(0.1)
|
||||
return ChatCompletion(
|
||||
id="id",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="Hello Another Hello Yet Another Hello", role="assistant"
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model="r1",
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
)
|
||||
else:
|
||||
return _mock_create_stream(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="r1",
|
||||
api_key="",
|
||||
model_info={"family": ModelFamily.R1, "vision": False, "function_calling": False, "json_output": False},
|
||||
)
|
||||
|
||||
# Warning completion when think field is not present.
|
||||
with pytest.warns(UserWarning, match="Could not find <think>..</think> field in model response content."):
|
||||
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
|
||||
assert create_result.content == "Hello Another Hello Yet Another Hello"
|
||||
assert create_result.finish_reason == "stop"
|
||||
assert not create_result.cached
|
||||
assert create_result.thought is None
|
||||
|
||||
# Stream completion with think field.
|
||||
with pytest.warns(UserWarning, match="Could not find <think>..</think> field in model response content."):
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in model_client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
|
||||
assert chunks[-1].thought is None
|
||||
assert not chunks[-1].cached
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
@ -836,4 +986,68 @@ async def test_hugging_face() -> None:
|
||||
await _test_model_client_basic_completion(model_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama() -> None:
|
||||
model = "deepseek-r1:1.5b"
|
||||
model_info: ModelInfo = {
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"vision": False,
|
||||
"family": ModelFamily.R1,
|
||||
}
|
||||
# Check if the model is running locally.
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"http://localhost:11434/v1/models/{model}")
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
pytest.skip(f"{model} model is not running locally: {e}")
|
||||
except httpx.ConnectError as e:
|
||||
pytest.skip(f"Ollama is not running locally: {e}")
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
api_key="placeholder",
|
||||
base_url="http://localhost:11434/v1",
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
# Test basic completion with the Ollama deepseek-r1:1.5b model.
|
||||
create_result = await model_client.create(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
|
||||
"what is the probability of getting a green and a red balls?",
|
||||
source="user",
|
||||
),
|
||||
]
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
assert create_result.finish_reason == "stop"
|
||||
assert create_result.usage is not None
|
||||
if model_info["family"] == ModelFamily.R1:
|
||||
assert create_result.thought is not None
|
||||
|
||||
# Test streaming completion with the Ollama deepseek-r1:1.5b model.
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in model_client.create_stream(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
|
||||
"what is the probability of getting a green and a red balls?",
|
||||
source="user",
|
||||
),
|
||||
]
|
||||
):
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
assert len(chunks[-1].content) > 0
|
||||
assert chunks[-1].usage is not None
|
||||
if model_info["family"] == ModelFamily.R1:
|
||||
assert chunks[-1].thought is not None
|
||||
|
||||
|
||||
# TODO: add integration tests for Azure OpenAI using AAD token.
|
||||
|
||||
@ -377,3 +377,75 @@ async def test_sk_chat_completion_custom_model_info(sk_client: AzureChatCompleti
|
||||
|
||||
# Verify capabilities returns the same ModelInfo
|
||||
assert adapter.capabilities == adapter.model_info
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sk_chat_completion_r1_content() -> None:
|
||||
async def mock_get_chat_message_contents(
|
||||
chat_history: ChatHistory,
|
||||
settings: PromptExecutionSettings,
|
||||
**kwargs: Any,
|
||||
) -> list[ChatMessageContent]:
|
||||
return [
|
||||
ChatMessageContent(
|
||||
ai_model_id="r1",
|
||||
role=AuthorRole.ASSISTANT,
|
||||
metadata={"usage": {"prompt_tokens": 20, "completion_tokens": 9}},
|
||||
items=[TextContent(text="<think>Reasoning...</think> Hello!")],
|
||||
finish_reason=FinishReason.STOP,
|
||||
)
|
||||
]
|
||||
|
||||
async def mock_get_streaming_chat_message_contents(
|
||||
chat_history: ChatHistory,
|
||||
settings: PromptExecutionSettings,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
|
||||
chunks = ["<think>Reasoning...</think>", " Hello!"]
|
||||
for i, chunk in enumerate(chunks):
|
||||
yield [
|
||||
StreamingChatMessageContent(
|
||||
choice_index=0,
|
||||
inner_content=ChatCompletionChunk(
|
||||
id=f"chatcmpl-{i}",
|
||||
choices=[Choice(delta=ChoiceDelta(content=chunk), finish_reason=None, index=0)],
|
||||
created=1736674044,
|
||||
model="r1",
|
||||
object="chat.completion.chunk",
|
||||
service_tier="scale",
|
||||
system_fingerprint="fingerprint",
|
||||
usage=CompletionUsage(prompt_tokens=20, completion_tokens=9, total_tokens=29),
|
||||
),
|
||||
ai_model_id="gpt-4o-mini",
|
||||
metadata={"id": f"chatcmpl-{i}", "created": 1736674044},
|
||||
role=AuthorRole.ASSISTANT,
|
||||
items=[StreamingTextContent(choice_index=0, text=chunk)],
|
||||
finish_reason=FinishReason.STOP if i == len(chunks) - 1 else None,
|
||||
)
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=AzureChatCompletion)
|
||||
mock_client.get_chat_message_contents = mock_get_chat_message_contents
|
||||
mock_client.get_streaming_chat_message_contents = mock_get_streaming_chat_message_contents
|
||||
|
||||
kernel = Kernel(memory=NullMemory())
|
||||
|
||||
adapter = SKChatCompletionAdapter(
|
||||
mock_client,
|
||||
kernel=kernel,
|
||||
model_info=ModelInfo(vision=False, function_calling=False, json_output=False, family=ModelFamily.R1),
|
||||
)
|
||||
|
||||
result = await adapter.create(messages=[UserMessage(content="Say hello!", source="user")])
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.content == "Hello!"
|
||||
assert result.thought == "Reasoning..."
|
||||
|
||||
response_chunks: list[CreateResult | str] = []
|
||||
async for chunk in adapter.create_stream(messages=[UserMessage(content="Say hello!", source="user")]):
|
||||
response_chunks.append(chunk)
|
||||
assert len(response_chunks) > 0
|
||||
assert isinstance(response_chunks[-1], CreateResult)
|
||||
assert response_chunks[-1].finish_reason == "stop"
|
||||
assert response_chunks[-1].content == "Hello!"
|
||||
assert response_chunks[-1].thought == "Reasoning..."
|
||||
|
||||
43
python/packages/autogen-ext/tests/models/test_utils.py
Normal file
43
python/packages/autogen-ext/tests/models/test_utils.py
Normal file
@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
from autogen_ext.models._utils.parse_r1_content import parse_r1_content
|
||||
|
||||
|
||||
def test_parse_r1_content() -> None:
|
||||
content = "Hello, <think>world</think> How are you?"
|
||||
thought, content = parse_r1_content(content)
|
||||
assert thought == "world"
|
||||
assert content == "How are you?"
|
||||
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="Could not find <think>..</think> field in model response content. " "No thought was extracted.",
|
||||
):
|
||||
content = "Hello, world How are you?"
|
||||
thought, content = parse_r1_content(content)
|
||||
assert thought is None
|
||||
assert content == "Hello, world How are you?"
|
||||
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="Could not find <think>..</think> field in model response content. " "No thought was extracted.",
|
||||
):
|
||||
content = "Hello, <think>world How are you?"
|
||||
thought, content = parse_r1_content(content)
|
||||
assert thought is None
|
||||
assert content == "Hello, <think>world How are you?"
|
||||
|
||||
with pytest.warns(
|
||||
UserWarning, match="Found </think> before <think> in model response content. " "No thought was extracted."
|
||||
):
|
||||
content = "</think>Hello, <think>world</think>"
|
||||
thought, content = parse_r1_content(content)
|
||||
assert thought is None
|
||||
assert content == "</think>Hello, <think>world</think>"
|
||||
|
||||
with pytest.warns(
|
||||
UserWarning, match="Found </think> before <think> in model response content. " "No thought was extracted."
|
||||
):
|
||||
content = "</think>Hello, <think>world"
|
||||
thought, content = parse_r1_content(content)
|
||||
assert thought is None
|
||||
assert content == "</think>Hello, <think>world"
|
||||
2
python/uv.lock
generated
2
python/uv.lock
generated
@ -654,6 +654,7 @@ web-surfer = [
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "autogen-test-utils" },
|
||||
{ name = "httpx" },
|
||||
{ name = "langchain-experimental" },
|
||||
{ name = "pandas-stubs" },
|
||||
]
|
||||
@ -706,6 +707,7 @@ requires-dist = [
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "autogen-test-utils", editable = "packages/autogen-test-utils" },
|
||||
{ name = "httpx", specifier = ">=0.28.1" },
|
||||
{ name = "langchain-experimental" },
|
||||
{ name = "pandas-stubs", specifier = ">=2.2.3.241126" },
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user