fix: ensure proper handling of structured output in OpenAI client and improve test coverage for structured output (#5116)

This commit is contained in:
Eric Zhu 2025-01-20 12:54:39 -08:00 committed by GitHub
parent 8df86e2b72
commit af420a83e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 84 additions and 7 deletions

View File

@ -373,11 +373,14 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
self._resolved_model = _model_info.resolve_model(create_args["model"])
if (
"response_format" in create_args
and create_args["response_format"]["type"] == "json_object"
and not self._model_info["json_output"]
not self._model_info["json_output"]
and "response_format" in create_args
and (
isinstance(create_args["response_format"], dict)
and create_args["response_format"]["type"] == "json_object"
)
):
raise ValueError("Model does not support JSON output")
raise ValueError("Model does not support JSON output.")
self._create_args = create_args
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
@ -433,7 +436,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
if json_output is not None:
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
raise ValueError("Model does not support JSON output.")
if json_output is True:
create_args["response_format"] = {"type": "json_object"}
@ -441,7 +444,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
create_args["response_format"] = {"type": "text"}
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
raise ValueError("Model does not support JSON output.")
oai_messages_nested = [to_oai_type(m) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist]

View File

@ -1,5 +1,6 @@
import asyncio
from typing import Annotated, Any, AsyncGenerator, List, Tuple
import json
from typing import Annotated, Any, AsyncGenerator, Generic, List, Literal, Tuple, TypeVar
from unittest.mock import MagicMock
import pytest
@ -19,14 +20,36 @@ from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
from autogen_ext.models.openai._model_info import resolve_model
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel, Field
ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)
class _MockBetaChatCompletion(Generic[ResponseFormatT]):
def __init__(self, chat_completions: List[ParsedChatCompletion[ResponseFormatT]]) -> None:
self._saved_chat_completions = chat_completions
self.curr_index = 0
self.calls: List[List[LLMMessage]] = []
async def mock_parse(
self,
*args: Any,
**kwargs: Any,
) -> ParsedChatCompletion[ResponseFormatT]:
self.calls.append(kwargs["messages"])
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self.curr_index]
self.curr_index += 1
return completion
class MyResult(BaseModel):
result: str = Field(description="The other description.")
@ -358,3 +381,54 @@ def test_convert_tools_accepts_both_tool_and_schema() -> None:
assert len(converted_tool_schema) == 2
assert converted_tool_schema[0] == converted_tool_schema[1]
@pytest.mark.asyncio
async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
class AgentResponse(BaseModel):
thoughts: str
response: Literal["happy", "sad", "neutral"]
model = "gpt-4o-2024-11-20"
chat_completions: List[ParsedChatCompletion[AgentResponse]] = [
ParsedChatCompletion(
id="id1",
choices=[
ParsedChoice(
finish_reason="stop",
index=0,
message=ParsedChatCompletionMessage(
content=json.dumps(
{
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
"response": "happy",
}
),
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockBetaChatCompletion(chat_completions)
monkeypatch.setattr(BetaAsyncCompletions, "parse", mock.mock_parse)
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
assert isinstance(create_result.content, str)
response = AgentResponse.model_validate(json.loads(create_result.content))
assert (
response.thoughts
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
)
assert response.response == "happy"