fix: Structured output with tool calls for OpenAIChatCompletionClient (#5671)

Resolves: #5568

Also, refactored some unit tests.

Integration tests against OpenAI endpoint passed:
https://github.com/microsoft/autogen/actions/runs/13484492096

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Eric Zhu 2025-02-24 07:18:46 -07:00 committed by GitHub
parent 745c9d2bc5
commit 9fd8eefc55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 580 additions and 155 deletions

View File

@ -50,10 +50,11 @@ from autogen_core.models import (
validate_model_info, validate_model_info,
) )
from autogen_core.tools import Tool, ToolSchema from autogen_core.tools import Tool, ToolSchema
from openai import AsyncAzureOpenAI, AsyncOpenAI from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ( from openai.types.chat import (
ChatCompletion, ChatCompletion,
ChatCompletionAssistantMessageParam, ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
ChatCompletionContentPartImageParam, ChatCompletionContentPartImageParam,
ChatCompletionContentPartParam, ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam, ChatCompletionContentPartTextParam,
@ -693,8 +694,23 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
create_args = self._create_args.copy() create_args = self._create_args.copy()
create_args.update(extra_create_args) create_args.update(extra_create_args)
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages] # Declare use_beta_client
oai_messages = [item for sublist in oai_messages_nested for item in sublist] use_beta_client: bool = False
response_format_value: Optional[Type[BaseModel]] = None
if "response_format" in create_args:
value = create_args["response_format"]
# If value is a Pydantic model class, use the beta client
if isinstance(value, type) and issubclass(value, BaseModel):
response_format_value = value
use_beta_client = True
else:
# response_format_value is not a Pydantic model class
use_beta_client = False
response_format_value = None
# Remove 'response_format' from create_args to prevent passing it twice
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
# TODO: allow custom handling. # TODO: allow custom handling.
# For now we raise an error if images are present and vision is not supported # For now we raise an error if images are present and vision is not supported
@ -713,23 +729,39 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
else: else:
create_args["response_format"] = {"type": "text"} create_args["response_format"] = {"type": "text"}
if len(tools) > 0: oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
converted_tools = convert_tools(tools) oai_messages = [item for sublist in oai_messages_nested for item in sublist]
stream_future = asyncio.ensure_future(
self._client.chat.completions.create( if self.model_info["function_calling"] is False and len(tools) > 0:
messages=oai_messages, raise ValueError("Model does not support function calling")
stream=True,
tools=converted_tools, if max_consecutive_empty_chunk_tolerance != 0:
**create_args, warnings.warn(
"The 'max_consecutive_empty_chunk_tolerance' parameter is deprecated and will be removed in the future releases. All of empty chunks will be skipped with a warning.",
DeprecationWarning,
stacklevel=2,
) )
tool_params = convert_tools(tools)
# Get the async generator of chunks.
if use_beta_client:
chunks = self._create_stream_chunks_beta_client(
tool_params=tool_params,
oai_messages=oai_messages,
response_format=response_format_value,
create_args_no_response_format=create_args_no_response_format,
cancellation_token=cancellation_token,
) )
else: else:
stream_future = asyncio.ensure_future( chunks = self._create_stream_chunks(
self._client.chat.completions.create(messages=oai_messages, stream=True, **create_args) tool_params=tool_params,
oai_messages=oai_messages,
create_args=create_args,
cancellation_token=cancellation_token,
) )
if cancellation_token is not None:
cancellation_token.link_future(stream_future) # Prepare data to process streaming chunks.
stream = await stream_future
choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], ChunkChoice] = cast(ChunkChoice, None) choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], ChunkChoice] = cast(ChunkChoice, None)
chunk = None chunk = None
stop_reason = None stop_reason = None
@ -739,23 +771,12 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
completion_tokens = 0 completion_tokens = 0
logprobs: Optional[List[ChatCompletionTokenLogprob]] = None logprobs: Optional[List[ChatCompletionTokenLogprob]] = None
if max_consecutive_empty_chunk_tolerance != 0:
warnings.warn(
"The 'max_consecutive_empty_chunk_tolerance' parameter is deprecated and will be removed in the future releases. All of empty chunks will be skipped with a warning.",
DeprecationWarning,
stacklevel=2,
)
empty_chunk_warning_has_been_issued: bool = False empty_chunk_warning_has_been_issued: bool = False
empty_chunk_warning_threshold: int = 10 empty_chunk_warning_threshold: int = 10
empty_chunk_count = 0 empty_chunk_count = 0
while True: # Process the stream of chunks.
try: async for chunk in chunks:
chunk_future = asyncio.ensure_future(anext(stream))
if cancellation_token is not None:
cancellation_token.link_future(chunk_future)
chunk = await chunk_future
# Empty chunks has been observed when the endpoint is under heavy load. # Empty chunks has been observed when the endpoint is under heavy load.
# https://github.com/microsoft/autogen/issues/4213 # https://github.com/microsoft/autogen/issues/4213
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
@ -823,20 +844,29 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
for x in choice.logprobs.content for x in choice.logprobs.content
] ]
except StopAsyncIteration: # Finalize the CreateResult.
break
model = maybe_model or create_args["model"]
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
if chunk and chunk.usage:
prompt_tokens = chunk.usage.prompt_tokens
else:
prompt_tokens = 0
# TODO: can we remove this?
if stop_reason == "function_call": if stop_reason == "function_call":
raise ValueError("Function calls are not supported in this context") raise ValueError("Function calls are not supported in this context")
# We need to get the model from the last chunk, if available.
model = maybe_model or create_args["model"]
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
# Because the usage chunk is not guaranteed to be the last chunk, we need to check if it is available.
if chunk and chunk.usage:
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
else:
prompt_tokens = 0
completion_tokens = 0
usage = RequestUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
# Detect whether it is a function call or just text.
content: Union[str, List[FunctionCall]] content: Union[str, List[FunctionCall]]
thought: str | None = None thought: str | None = None
if full_tool_calls: if full_tool_calls:
@ -852,19 +882,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2) warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2)
content = "" content = ""
if chunk and chunk.usage: # Parse R1 content if needed.
completion_tokens = chunk.usage.completion_tokens
else:
completion_tokens = 0
usage = RequestUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1: if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content) thought, content = parse_r1_content(content)
# Create the result.
result = CreateResult( result = CreateResult(
finish_reason=normalize_stop_reason(stop_reason), finish_reason=normalize_stop_reason(stop_reason),
content=content, content=content,
@ -874,11 +896,73 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
thought=thought, thought=thought,
) )
# Update the total usage.
self._total_usage = _add_usage(self._total_usage, usage) self._total_usage = _add_usage(self._total_usage, usage)
self._actual_usage = _add_usage(self._actual_usage, usage) self._actual_usage = _add_usage(self._actual_usage, usage)
# Yield the CreateResult.
yield result yield result
async def _create_stream_chunks(
self,
tool_params: List[ChatCompletionToolParam],
oai_messages: List[ChatCompletionMessageParam],
create_args: Dict[str, Any],
cancellation_token: Optional[CancellationToken],
) -> AsyncGenerator[ChatCompletionChunk, None]:
stream_future = asyncio.ensure_future(
self._client.chat.completions.create(
messages=oai_messages,
stream=True,
tools=tool_params if len(tool_params) > 0 else NOT_GIVEN,
**create_args,
)
)
if cancellation_token is not None:
cancellation_token.link_future(stream_future)
stream = await stream_future
while True:
try:
chunk_future = asyncio.ensure_future(anext(stream))
if cancellation_token is not None:
cancellation_token.link_future(chunk_future)
chunk = await chunk_future
yield chunk
except StopAsyncIteration:
break
async def _create_stream_chunks_beta_client(
self,
tool_params: List[ChatCompletionToolParam],
oai_messages: List[ChatCompletionMessageParam],
create_args_no_response_format: Dict[str, Any],
response_format: Optional[Type[BaseModel]],
cancellation_token: Optional[CancellationToken],
) -> AsyncGenerator[ChatCompletionChunk, None]:
async with self._client.beta.chat.completions.stream(
messages=oai_messages,
tools=tool_params if len(tool_params) > 0 else NOT_GIVEN,
response_format=response_format if response_format is not None else NOT_GIVEN,
**create_args_no_response_format,
) as stream:
while True:
try:
event_future = asyncio.ensure_future(anext(stream))
if cancellation_token is not None:
cancellation_token.link_future(event_future)
event = await event_future
if event.type == "chunk":
chunk = event.chunk
yield chunk
# We don't handle other event types from the beta client stream.
# As the other event types are auxiliary to the chunk event.
# See: https://github.com/openai/openai-python/blob/main/helpers.md#chat-completions-events.
# Once the beta client is stable, we can move all the logic to the beta client.
# Then we can consider handling other event types which may simplify the code overall.
except StopAsyncIteration:
break
def actual_usage(self) -> RequestUsage: def actual_usage(self) -> RequestUsage:
return self._actual_usage return self._actual_usage

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import json import json
import os import os
from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar
from unittest.mock import MagicMock from unittest.mock import MagicMock
import httpx import httpx
@ -23,7 +23,14 @@ from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
from autogen_ext.models.openai._model_info import resolve_model from autogen_ext.models.openai._model_info import resolve_model
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions from openai.resources.beta.chat.completions import ( # type: ignore
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
)
# type: ignore
from openai.resources.beta.chat.completions import (
AsyncCompletions as BetaAsyncCompletions,
)
from openai.resources.chat.completions import AsyncCompletions from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import (
@ -32,54 +39,22 @@ from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall, ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction, ChoiceDeltaToolCallFunction,
) )
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice from openai.types.chat.chat_completion_chunk import (
Choice as ChunkChoice,
)
from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import ( from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall, ChatCompletionMessageToolCall,
Function, Function,
) )
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice
from openai.types.chat.parsed_function_tool_call import ParsedFunction, ParsedFunctionToolCall
from openai.types.completion_usage import CompletionUsage from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class _MockChatCompletion:
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
self._saved_chat_completions = chat_completions
self.curr_index = 0
self.calls: List[Dict[str, Any]] = []
async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
self.calls.append(kwargs) # Save the call
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self.curr_index]
self.curr_index += 1
return completion
ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel) ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)
class _MockBetaChatCompletion(Generic[ResponseFormatT]):
def __init__(self, chat_completions: List[ParsedChatCompletion[ResponseFormatT]]) -> None:
self._saved_chat_completions = chat_completions
self.curr_index = 0
self.calls: List[Dict[str, Any]] = []
async def mock_parse(
self,
*args: Any,
**kwargs: Any,
) -> ParsedChatCompletion[ResponseFormatT]:
self.calls.append(kwargs) # Save the call
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self.curr_index]
self.curr_index += 1
return completion
def _pass_function(input: str) -> str: def _pass_function(input: str) -> str:
return "pass" return "pass"
@ -106,6 +81,11 @@ class MockChunkDefinition(BaseModel):
usage: CompletionUsage | None usage: CompletionUsage | None
class MockChunkEvent(BaseModel):
type: Literal["chunk"]
chunk: ChatCompletionChunk
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]: async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
model = resolve_model(kwargs.get("model", "gpt-4o")) model = resolve_model(kwargs.get("model", "gpt-4o"))
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"] mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
@ -443,8 +423,9 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
response: Literal["happy", "sad", "neutral"] response: Literal["happy", "sad", "neutral"]
model = "gpt-4o-2024-11-20" model = "gpt-4o-2024-11-20"
chat_completions: List[ParsedChatCompletion[AgentResponse]] = [
ParsedChatCompletion( async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]:
return ParsedChatCompletion(
id="id1", id="id1",
choices=[ choices=[
ParsedChoice( ParsedChoice(
@ -465,10 +446,9 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
model=model, model=model,
object="chat.completion", object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
), )
]
mock = _MockBetaChatCompletion(chat_completions) monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse)
monkeypatch.setattr(BetaAsyncCompletions, "parse", mock.mock_parse)
model_client = OpenAIChatCompletionClient( model_client = OpenAIChatCompletionClient(
model=model, model=model,
@ -487,6 +467,258 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
assert response.response == "happy" assert response.response == "happy"
@pytest.mark.asyncio
async def test_structured_output_with_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
model = "gpt-4o-2024-11-20"
async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]:
return ParsedChatCompletion(
id="id1",
choices=[
ParsedChoice(
finish_reason="tool_calls",
index=0,
message=ParsedChatCompletionMessage(
content=json.dumps(
{
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
"response": "happy",
}
),
role="assistant",
tool_calls=[
ParsedFunctionToolCall(
id="1",
type="function",
function=ParsedFunction(
name="_pass_function",
arguments=json.dumps({"input": "happy"}),
),
)
],
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
)
monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse)
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
assert isinstance(create_result.content, list)
assert len(create_result.content) == 1
assert create_result.content[0] == FunctionCall(
id="1", name="_pass_function", arguments=json.dumps({"input": "happy"})
)
assert isinstance(create_result.thought, str)
response = AgentResponse.model_validate(json.loads(create_result.thought))
assert (
response.thoughts
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
)
assert response.response == "happy"
@pytest.mark.asyncio
async def test_structured_output_with_streaming(monkeypatch: pytest.MonkeyPatch) -> None:
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
raw_content = json.dumps(
{
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
"response": "happy",
}
)
chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)]
assert "".join(chunked_content) == raw_content
model = "gpt-4o-2024-11-20"
mock_chunk_events = [
MockChunkEvent(
type="chunk",
chunk=ChatCompletionChunk(
id="id",
choices=[
ChunkChoice(
finish_reason=None,
index=0,
delta=ChoiceDelta(
content=mock_chunk_content,
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion.chunk",
usage=None,
),
)
for mock_chunk_content in chunked_content
]
async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]:
async def _stream() -> AsyncGenerator[MockChunkEvent, None]:
for mock_chunk_event in mock_chunk_events:
await asyncio.sleep(0.1)
yield mock_chunk_event
return _stream()
# Mock the context manager __aenter__ method which returns the stream.
monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream)
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
chunks: List[str | CreateResult] = []
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, str)
response = AgentResponse.model_validate(json.loads(chunks[-1].content))
assert (
response.thoughts
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
)
assert response.response == "happy"
@pytest.mark.asyncio
async def test_structured_output_with_streaming_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
raw_content = json.dumps(
{
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
"response": "happy",
}
)
chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)]
assert "".join(chunked_content) == raw_content
model = "gpt-4o-2024-11-20"
# generate the list of mock chunk content
mock_chunk_events = [
MockChunkEvent(
type="chunk",
chunk=ChatCompletionChunk(
id="id",
choices=[
ChunkChoice(
finish_reason=None,
index=0,
delta=ChoiceDelta(
content=mock_chunk_content,
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion.chunk",
usage=None,
),
)
for mock_chunk_content in chunked_content
]
# add the tool call chunk.
mock_chunk_events += [
MockChunkEvent(
type="chunk",
chunk=ChatCompletionChunk(
id="id",
choices=[
ChunkChoice(
finish_reason="tool_calls",
index=0,
delta=ChoiceDelta(
content=None,
role="assistant",
tool_calls=[
ChoiceDeltaToolCall(
id="1",
index=0,
type="function",
function=ChoiceDeltaToolCallFunction(
name="_pass_function",
arguments=json.dumps({"input": "happy"}),
),
)
],
),
)
],
created=0,
model=model,
object="chat.completion.chunk",
usage=None,
),
)
]
async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]:
async def _stream() -> AsyncGenerator[MockChunkEvent, None]:
for mock_chunk_event in mock_chunk_events:
await asyncio.sleep(0.1)
yield mock_chunk_event
return _stream()
# Mock the context manager __aenter__ method which returns the stream.
monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream)
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
chunks: List[str | CreateResult] = []
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, list)
assert len(chunks[-1].content) == 1
assert chunks[-1].content[0] == FunctionCall(
id="1", name="_pass_function", arguments=json.dumps({"input": "happy"})
)
assert isinstance(chunks[-1].thought, str)
response = AgentResponse.model_validate(json.loads(chunks[-1].thought))
assert (
response.thoughts
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
)
assert response.response == "happy"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None: async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None:
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]: async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
@ -812,6 +1044,20 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
), ),
] ]
class _MockChatCompletion:
def __init__(self, completions: List[ChatCompletion]):
self.completions = list(completions)
self.calls: List[Dict[str, Any]] = []
async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
if kwargs.get("stream", False):
raise NotImplementedError("Streaming not supported in this test.")
self.calls.append(kwargs)
return self.completions.pop(0)
mock = _MockChatCompletion(chat_completions) mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
pass_tool = FunctionTool(_pass_function, description="pass tool.") pass_tool = FunctionTool(_pass_function, description="pass tool.")
@ -1062,6 +1308,35 @@ async def test_openai_structured_output() -> None:
assert response.response in ["happy", "sad", "neutral"] assert response.response in ["happy", "sad", "neutral"]
@pytest.mark.asyncio
async def test_openai_structured_output_with_streaming() -> None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not found in environment variables")
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
model_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key=api_key,
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
stream = model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")])
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, str)
response = AgentResponse.model_validate(json.loads(chunks[-1].content))
assert response.thoughts
assert response.response in ["happy", "sad", "neutral"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_openai_structured_output_with_tool_calls() -> None: async def test_openai_structured_output_with_tool_calls() -> None:
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
@ -1090,6 +1365,7 @@ async def test_openai_structured_output_with_tool_calls() -> None:
UserMessage(content="I am happy.", source="user"), UserMessage(content="I am happy.", source="user"),
], ],
tools=[tool], tools=[tool],
extra_create_args={"tool_choice": "required"},
) )
assert isinstance(response1.content, list) assert isinstance(response1.content, list)
assert len(response1.content) == 1 assert len(response1.content) == 1
@ -1114,6 +1390,71 @@ async def test_openai_structured_output_with_tool_calls() -> None:
assert parsed_response.response in ["happy", "sad", "neutral"] assert parsed_response.response in ["happy", "sad", "neutral"]
@pytest.mark.asyncio
async def test_openai_structured_output_with_streaming_tool_calls() -> None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not found in environment variables")
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
def sentiment_analysis(text: str) -> str:
"""Given a text, return the sentiment."""
return "happy" if "happy" in text else "sad" if "sad" in text else "neutral"
tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True)
model_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key=api_key,
response_format=AgentResponse, # type: ignore
)
chunks1: List[str | CreateResult] = []
stream1 = model_client.create_stream(
messages=[
SystemMessage(content="Analyze input text sentiment using the tool provided."),
UserMessage(content="I am happy.", source="user"),
],
tools=[tool],
extra_create_args={"tool_choice": "required"},
)
async for chunk in stream1:
chunks1.append(chunk)
assert len(chunks1) > 0
create_result1 = chunks1[-1]
assert isinstance(create_result1, CreateResult)
assert isinstance(create_result1.content, list)
assert len(create_result1.content) == 1
assert isinstance(create_result1.content[0], FunctionCall)
assert create_result1.content[0].name == "sentiment_analysis"
assert json.loads(create_result1.content[0].arguments) == {"text": "I am happy."}
assert create_result1.finish_reason == "function_calls"
stream2 = model_client.create_stream(
messages=[
SystemMessage(content="Analyze input text sentiment using the tool provided."),
UserMessage(content="I am happy.", source="user"),
AssistantMessage(content=create_result1.content, source="assistant"),
FunctionExecutionResultMessage(
content=[FunctionExecutionResult(content="happy", call_id=create_result1.content[0].id, is_error=False)]
),
],
)
chunks2: List[str | CreateResult] = []
async for chunk in stream2:
chunks2.append(chunk)
assert len(chunks2) > 0
create_result2 = chunks2[-1]
assert isinstance(create_result2, CreateResult)
assert isinstance(create_result2.content, str)
parsed_response = AgentResponse.model_validate(json.loads(create_result2.content))
assert parsed_response.thoughts
assert parsed_response.response in ["happy", "sad", "neutral"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini() -> None: async def test_gemini() -> None:
api_key = os.getenv("GEMINI_API_KEY") api_key = os.getenv("GEMINI_API_KEY")