fix: Structured output with tool calls for OpenAIChatCompletionClient (#5671)

Resolves: #5568

Also, refactored some unit tests.

Integration tests against OpenAI endpoint passed:
https://github.com/microsoft/autogen/actions/runs/13484492096

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Eric Zhu 2025-02-24 07:18:46 -07:00 committed by GitHub
parent 745c9d2bc5
commit 9fd8eefc55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 580 additions and 155 deletions

View File

@ -50,10 +50,11 @@ from autogen_core.models import (
validate_model_info,
)
from autogen_core.tools import Tool, ToolSchema
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import (
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
@ -693,8 +694,23 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
create_args = self._create_args.copy()
create_args.update(extra_create_args)
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
# Declare use_beta_client
use_beta_client: bool = False
response_format_value: Optional[Type[BaseModel]] = None
if "response_format" in create_args:
value = create_args["response_format"]
# If value is a Pydantic model class, use the beta client
if isinstance(value, type) and issubclass(value, BaseModel):
response_format_value = value
use_beta_client = True
else:
# response_format_value is not a Pydantic model class
use_beta_client = False
response_format_value = None
# Remove 'response_format' from create_args to prevent passing it twice
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
# TODO: allow custom handling.
# For now we raise an error if images are present and vision is not supported
@ -713,23 +729,39 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
else:
create_args["response_format"] = {"type": "text"}
if len(tools) > 0:
converted_tools = convert_tools(tools)
stream_future = asyncio.ensure_future(
self._client.chat.completions.create(
messages=oai_messages,
stream=True,
tools=converted_tools,
**create_args,
)
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
if self.model_info["function_calling"] is False and len(tools) > 0:
raise ValueError("Model does not support function calling")
if max_consecutive_empty_chunk_tolerance != 0:
warnings.warn(
"The 'max_consecutive_empty_chunk_tolerance' parameter is deprecated and will be removed in the future releases. All of empty chunks will be skipped with a warning.",
DeprecationWarning,
stacklevel=2,
)
tool_params = convert_tools(tools)
# Get the async generator of chunks.
if use_beta_client:
chunks = self._create_stream_chunks_beta_client(
tool_params=tool_params,
oai_messages=oai_messages,
response_format=response_format_value,
create_args_no_response_format=create_args_no_response_format,
cancellation_token=cancellation_token,
)
else:
stream_future = asyncio.ensure_future(
self._client.chat.completions.create(messages=oai_messages, stream=True, **create_args)
chunks = self._create_stream_chunks(
tool_params=tool_params,
oai_messages=oai_messages,
create_args=create_args,
cancellation_token=cancellation_token,
)
if cancellation_token is not None:
cancellation_token.link_future(stream_future)
stream = await stream_future
# Prepare data to process streaming chunks.
choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], ChunkChoice] = cast(ChunkChoice, None)
chunk = None
stop_reason = None
@ -739,104 +771,102 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
completion_tokens = 0
logprobs: Optional[List[ChatCompletionTokenLogprob]] = None
if max_consecutive_empty_chunk_tolerance != 0:
warnings.warn(
"The 'max_consecutive_empty_chunk_tolerance' parameter is deprecated and will be removed in the future releases. All of empty chunks will be skipped with a warning.",
DeprecationWarning,
stacklevel=2,
)
empty_chunk_warning_has_been_issued: bool = False
empty_chunk_warning_threshold: int = 10
empty_chunk_count = 0
while True:
try:
chunk_future = asyncio.ensure_future(anext(stream))
if cancellation_token is not None:
cancellation_token.link_future(chunk_future)
chunk = await chunk_future
# Process the stream of chunks.
async for chunk in chunks:
# Empty chunks has been observed when the endpoint is under heavy load.
# https://github.com/microsoft/autogen/issues/4213
if len(chunk.choices) == 0:
empty_chunk_count += 1
if not empty_chunk_warning_has_been_issued and empty_chunk_count >= empty_chunk_warning_threshold:
empty_chunk_warning_has_been_issued = True
warnings.warn(
f"Received more than {empty_chunk_warning_threshold} consecutive empty chunks. Empty chunks are being ignored.",
stacklevel=2,
)
continue
else:
empty_chunk_count = 0
# Empty chunks has been observed when the endpoint is under heavy load.
# https://github.com/microsoft/autogen/issues/4213
if len(chunk.choices) == 0:
empty_chunk_count += 1
if not empty_chunk_warning_has_been_issued and empty_chunk_count >= empty_chunk_warning_threshold:
empty_chunk_warning_has_been_issued = True
warnings.warn(
f"Received more than {empty_chunk_warning_threshold} consecutive empty chunks. Empty chunks are being ignored.",
stacklevel=2,
)
continue
else:
empty_chunk_count = 0
# to process usage chunk in streaming situations
# add stream_options={"include_usage": True} in the initialization of OpenAIChatCompletionClient(...)
# However the different api's
# OPENAI api usage chunk produces no choices so need to check if there is a choice
# liteLLM api usage chunk does produce choices
choice = (
chunk.choices[0]
if len(chunk.choices) > 0
else choice
if chunk.usage is not None and stop_reason is not None
else cast(ChunkChoice, None)
)
# to process usage chunk in streaming situations
# add stream_options={"include_usage": True} in the initialization of OpenAIChatCompletionClient(...)
# However the different api's
# OPENAI api usage chunk produces no choices so need to check if there is a choice
# liteLLM api usage chunk does produce choices
choice = (
chunk.choices[0]
if len(chunk.choices) > 0
else choice
if chunk.usage is not None and stop_reason is not None
else cast(ChunkChoice, None)
)
# for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set).
# set the stop_reason for the usage chunk to the prior stop_reason
stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason
maybe_model = chunk.model
# First try get content
if choice.delta.content:
content_deltas.append(choice.delta.content)
if len(choice.delta.content) > 0:
yield choice.delta.content
# NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop.
# However, this may not be the case for other APIs -- we should expect this may need to be updated.
continue
# for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set).
# set the stop_reason for the usage chunk to the prior stop_reason
stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason
maybe_model = chunk.model
# First try get content
if choice.delta.content:
content_deltas.append(choice.delta.content)
if len(choice.delta.content) > 0:
yield choice.delta.content
# NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop.
# However, this may not be the case for other APIs -- we should expect this may need to be updated.
continue
# Otherwise, get tool calls
if choice.delta.tool_calls is not None:
for tool_call_chunk in choice.delta.tool_calls:
idx = tool_call_chunk.index
if idx not in full_tool_calls:
# We ignore the type hint here because we want to fill in type when the delta provides it
full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
# Otherwise, get tool calls
if choice.delta.tool_calls is not None:
for tool_call_chunk in choice.delta.tool_calls:
idx = tool_call_chunk.index
if idx not in full_tool_calls:
# We ignore the type hint here because we want to fill in type when the delta provides it
full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
if tool_call_chunk.id is not None:
full_tool_calls[idx].id += tool_call_chunk.id
if tool_call_chunk.id is not None:
full_tool_calls[idx].id += tool_call_chunk.id
if tool_call_chunk.function is not None:
if tool_call_chunk.function.name is not None:
full_tool_calls[idx].name += tool_call_chunk.function.name
if tool_call_chunk.function.arguments is not None:
full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
if choice.logprobs and choice.logprobs.content:
logprobs = [
ChatCompletionTokenLogprob(
token=x.token,
logprob=x.logprob,
top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs],
bytes=x.bytes,
)
for x in choice.logprobs.content
]
if tool_call_chunk.function is not None:
if tool_call_chunk.function.name is not None:
full_tool_calls[idx].name += tool_call_chunk.function.name
if tool_call_chunk.function.arguments is not None:
full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
if choice.logprobs and choice.logprobs.content:
logprobs = [
ChatCompletionTokenLogprob(
token=x.token,
logprob=x.logprob,
top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs],
bytes=x.bytes,
)
for x in choice.logprobs.content
]
except StopAsyncIteration:
break
model = maybe_model or create_args["model"]
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
if chunk and chunk.usage:
prompt_tokens = chunk.usage.prompt_tokens
else:
prompt_tokens = 0
# Finalize the CreateResult.
# TODO: can we remove this?
if stop_reason == "function_call":
raise ValueError("Function calls are not supported in this context")
# We need to get the model from the last chunk, if available.
model = maybe_model or create_args["model"]
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
# Because the usage chunk is not guaranteed to be the last chunk, we need to check if it is available.
if chunk and chunk.usage:
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
else:
prompt_tokens = 0
completion_tokens = 0
usage = RequestUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
# Detect whether it is a function call or just text.
content: Union[str, List[FunctionCall]]
thought: str | None = None
if full_tool_calls:
@ -852,19 +882,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2)
content = ""
if chunk and chunk.usage:
completion_tokens = chunk.usage.completion_tokens
else:
completion_tokens = 0
usage = RequestUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
# Parse R1 content if needed.
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
# Create the result.
result = CreateResult(
finish_reason=normalize_stop_reason(stop_reason),
content=content,
@ -874,11 +896,73 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
thought=thought,
)
# Update the total usage.
self._total_usage = _add_usage(self._total_usage, usage)
self._actual_usage = _add_usage(self._actual_usage, usage)
# Yield the CreateResult.
yield result
async def _create_stream_chunks(
self,
tool_params: List[ChatCompletionToolParam],
oai_messages: List[ChatCompletionMessageParam],
create_args: Dict[str, Any],
cancellation_token: Optional[CancellationToken],
) -> AsyncGenerator[ChatCompletionChunk, None]:
stream_future = asyncio.ensure_future(
self._client.chat.completions.create(
messages=oai_messages,
stream=True,
tools=tool_params if len(tool_params) > 0 else NOT_GIVEN,
**create_args,
)
)
if cancellation_token is not None:
cancellation_token.link_future(stream_future)
stream = await stream_future
while True:
try:
chunk_future = asyncio.ensure_future(anext(stream))
if cancellation_token is not None:
cancellation_token.link_future(chunk_future)
chunk = await chunk_future
yield chunk
except StopAsyncIteration:
break
async def _create_stream_chunks_beta_client(
self,
tool_params: List[ChatCompletionToolParam],
oai_messages: List[ChatCompletionMessageParam],
create_args_no_response_format: Dict[str, Any],
response_format: Optional[Type[BaseModel]],
cancellation_token: Optional[CancellationToken],
) -> AsyncGenerator[ChatCompletionChunk, None]:
async with self._client.beta.chat.completions.stream(
messages=oai_messages,
tools=tool_params if len(tool_params) > 0 else NOT_GIVEN,
response_format=response_format if response_format is not None else NOT_GIVEN,
**create_args_no_response_format,
) as stream:
while True:
try:
event_future = asyncio.ensure_future(anext(stream))
if cancellation_token is not None:
cancellation_token.link_future(event_future)
event = await event_future
if event.type == "chunk":
chunk = event.chunk
yield chunk
# We don't handle other event types from the beta client stream.
# As the other event types are auxiliary to the chunk event.
# See: https://github.com/openai/openai-python/blob/main/helpers.md#chat-completions-events.
# Once the beta client is stable, we can move all the logic to the beta client.
# Then we can consider handling other event types which may simplify the code overall.
except StopAsyncIteration:
break
def actual_usage(self) -> RequestUsage:
return self._actual_usage

View File

@ -1,7 +1,7 @@
import asyncio
import json
import os
from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar
from unittest.mock import MagicMock
import httpx
@ -23,7 +23,14 @@ from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
from autogen_ext.models.openai._model_info import resolve_model
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions
from openai.resources.beta.chat.completions import ( # type: ignore
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
)
# type: ignore
from openai.resources.beta.chat.completions import (
AsyncCompletions as BetaAsyncCompletions,
)
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import (
@ -32,54 +39,22 @@ from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_chunk import (
Choice as ChunkChoice,
)
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice
from openai.types.chat.parsed_function_tool_call import ParsedFunction, ParsedFunctionToolCall
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel, Field
class _MockChatCompletion:
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
self._saved_chat_completions = chat_completions
self.curr_index = 0
self.calls: List[Dict[str, Any]] = []
async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
self.calls.append(kwargs) # Save the call
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self.curr_index]
self.curr_index += 1
return completion
ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)
class _MockBetaChatCompletion(Generic[ResponseFormatT]):
def __init__(self, chat_completions: List[ParsedChatCompletion[ResponseFormatT]]) -> None:
self._saved_chat_completions = chat_completions
self.curr_index = 0
self.calls: List[Dict[str, Any]] = []
async def mock_parse(
self,
*args: Any,
**kwargs: Any,
) -> ParsedChatCompletion[ResponseFormatT]:
self.calls.append(kwargs) # Save the call
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self.curr_index]
self.curr_index += 1
return completion
def _pass_function(input: str) -> str:
return "pass"
@ -106,6 +81,11 @@ class MockChunkDefinition(BaseModel):
usage: CompletionUsage | None
class MockChunkEvent(BaseModel):
type: Literal["chunk"]
chunk: ChatCompletionChunk
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
model = resolve_model(kwargs.get("model", "gpt-4o"))
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
@ -443,8 +423,9 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
response: Literal["happy", "sad", "neutral"]
model = "gpt-4o-2024-11-20"
chat_completions: List[ParsedChatCompletion[AgentResponse]] = [
ParsedChatCompletion(
async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]:
return ParsedChatCompletion(
id="id1",
choices=[
ParsedChoice(
@ -465,10 +446,9 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockBetaChatCompletion(chat_completions)
monkeypatch.setattr(BetaAsyncCompletions, "parse", mock.mock_parse)
)
monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse)
model_client = OpenAIChatCompletionClient(
model=model,
@ -487,6 +467,258 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
assert response.response == "happy"
@pytest.mark.asyncio
async def test_structured_output_with_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
model = "gpt-4o-2024-11-20"
async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]:
return ParsedChatCompletion(
id="id1",
choices=[
ParsedChoice(
finish_reason="tool_calls",
index=0,
message=ParsedChatCompletionMessage(
content=json.dumps(
{
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
"response": "happy",
}
),
role="assistant",
tool_calls=[
ParsedFunctionToolCall(
id="1",
type="function",
function=ParsedFunction(
name="_pass_function",
arguments=json.dumps({"input": "happy"}),
),
)
],
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
)
monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse)
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
assert isinstance(create_result.content, list)
assert len(create_result.content) == 1
assert create_result.content[0] == FunctionCall(
id="1", name="_pass_function", arguments=json.dumps({"input": "happy"})
)
assert isinstance(create_result.thought, str)
response = AgentResponse.model_validate(json.loads(create_result.thought))
assert (
response.thoughts
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
)
assert response.response == "happy"
@pytest.mark.asyncio
async def test_structured_output_with_streaming(monkeypatch: pytest.MonkeyPatch) -> None:
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
raw_content = json.dumps(
{
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
"response": "happy",
}
)
chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)]
assert "".join(chunked_content) == raw_content
model = "gpt-4o-2024-11-20"
mock_chunk_events = [
MockChunkEvent(
type="chunk",
chunk=ChatCompletionChunk(
id="id",
choices=[
ChunkChoice(
finish_reason=None,
index=0,
delta=ChoiceDelta(
content=mock_chunk_content,
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion.chunk",
usage=None,
),
)
for mock_chunk_content in chunked_content
]
async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]:
async def _stream() -> AsyncGenerator[MockChunkEvent, None]:
for mock_chunk_event in mock_chunk_events:
await asyncio.sleep(0.1)
yield mock_chunk_event
return _stream()
# Mock the context manager __aenter__ method which returns the stream.
monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream)
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
chunks: List[str | CreateResult] = []
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, str)
response = AgentResponse.model_validate(json.loads(chunks[-1].content))
assert (
response.thoughts
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
)
assert response.response == "happy"
@pytest.mark.asyncio
async def test_structured_output_with_streaming_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
raw_content = json.dumps(
{
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
"response": "happy",
}
)
chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)]
assert "".join(chunked_content) == raw_content
model = "gpt-4o-2024-11-20"
# generate the list of mock chunk content
mock_chunk_events = [
MockChunkEvent(
type="chunk",
chunk=ChatCompletionChunk(
id="id",
choices=[
ChunkChoice(
finish_reason=None,
index=0,
delta=ChoiceDelta(
content=mock_chunk_content,
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion.chunk",
usage=None,
),
)
for mock_chunk_content in chunked_content
]
# add the tool call chunk.
mock_chunk_events += [
MockChunkEvent(
type="chunk",
chunk=ChatCompletionChunk(
id="id",
choices=[
ChunkChoice(
finish_reason="tool_calls",
index=0,
delta=ChoiceDelta(
content=None,
role="assistant",
tool_calls=[
ChoiceDeltaToolCall(
id="1",
index=0,
type="function",
function=ChoiceDeltaToolCallFunction(
name="_pass_function",
arguments=json.dumps({"input": "happy"}),
),
)
],
),
)
],
created=0,
model=model,
object="chat.completion.chunk",
usage=None,
),
)
]
async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]:
async def _stream() -> AsyncGenerator[MockChunkEvent, None]:
for mock_chunk_event in mock_chunk_events:
await asyncio.sleep(0.1)
yield mock_chunk_event
return _stream()
# Mock the context manager __aenter__ method which returns the stream.
monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream)
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
chunks: List[str | CreateResult] = []
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, list)
assert len(chunks[-1].content) == 1
assert chunks[-1].content[0] == FunctionCall(
id="1", name="_pass_function", arguments=json.dumps({"input": "happy"})
)
assert isinstance(chunks[-1].thought, str)
response = AgentResponse.model_validate(json.loads(chunks[-1].thought))
assert (
response.thoughts
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
)
assert response.response == "happy"
@pytest.mark.asyncio
async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None:
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
@ -812,6 +1044,20 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
class _MockChatCompletion:
def __init__(self, completions: List[ChatCompletion]):
self.completions = list(completions)
self.calls: List[Dict[str, Any]] = []
async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
if kwargs.get("stream", False):
raise NotImplementedError("Streaming not supported in this test.")
self.calls.append(kwargs)
return self.completions.pop(0)
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
pass_tool = FunctionTool(_pass_function, description="pass tool.")
@ -1062,6 +1308,35 @@ async def test_openai_structured_output() -> None:
assert response.response in ["happy", "sad", "neutral"]
@pytest.mark.asyncio
async def test_openai_structured_output_with_streaming() -> None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not found in environment variables")
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
model_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key=api_key,
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
stream = model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")])
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, str)
response = AgentResponse.model_validate(json.loads(chunks[-1].content))
assert response.thoughts
assert response.response in ["happy", "sad", "neutral"]
@pytest.mark.asyncio
async def test_openai_structured_output_with_tool_calls() -> None:
api_key = os.getenv("OPENAI_API_KEY")
@ -1090,6 +1365,7 @@ async def test_openai_structured_output_with_tool_calls() -> None:
UserMessage(content="I am happy.", source="user"),
],
tools=[tool],
extra_create_args={"tool_choice": "required"},
)
assert isinstance(response1.content, list)
assert len(response1.content) == 1
@ -1114,6 +1390,71 @@ async def test_openai_structured_output_with_tool_calls() -> None:
assert parsed_response.response in ["happy", "sad", "neutral"]
@pytest.mark.asyncio
async def test_openai_structured_output_with_streaming_tool_calls() -> None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not found in environment variables")
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
def sentiment_analysis(text: str) -> str:
"""Given a text, return the sentiment."""
return "happy" if "happy" in text else "sad" if "sad" in text else "neutral"
tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True)
model_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key=api_key,
response_format=AgentResponse, # type: ignore
)
chunks1: List[str | CreateResult] = []
stream1 = model_client.create_stream(
messages=[
SystemMessage(content="Analyze input text sentiment using the tool provided."),
UserMessage(content="I am happy.", source="user"),
],
tools=[tool],
extra_create_args={"tool_choice": "required"},
)
async for chunk in stream1:
chunks1.append(chunk)
assert len(chunks1) > 0
create_result1 = chunks1[-1]
assert isinstance(create_result1, CreateResult)
assert isinstance(create_result1.content, list)
assert len(create_result1.content) == 1
assert isinstance(create_result1.content[0], FunctionCall)
assert create_result1.content[0].name == "sentiment_analysis"
assert json.loads(create_result1.content[0].arguments) == {"text": "I am happy."}
assert create_result1.finish_reason == "function_calls"
stream2 = model_client.create_stream(
messages=[
SystemMessage(content="Analyze input text sentiment using the tool provided."),
UserMessage(content="I am happy.", source="user"),
AssistantMessage(content=create_result1.content, source="assistant"),
FunctionExecutionResultMessage(
content=[FunctionExecutionResult(content="happy", call_id=create_result1.content[0].id, is_error=False)]
),
],
)
chunks2: List[str | CreateResult] = []
async for chunk in stream2:
chunks2.append(chunk)
assert len(chunks2) > 0
create_result2 = chunks2[-1]
assert isinstance(create_result2, CreateResult)
assert isinstance(create_result2.content, str)
parsed_response = AgentResponse.model_validate(json.loads(create_result2.content))
assert parsed_response.thoughts
assert parsed_response.response in ["happy", "sad", "neutral"]
@pytest.mark.asyncio
async def test_gemini() -> None:
api_key = os.getenv("GEMINI_API_KEY")