mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-28 07:29:54 +00:00
fix: ensure proper handling of structured output in OpenAI client and improve test coverage for structured output (#5116)
This commit is contained in:
parent
8df86e2b72
commit
af420a83e2
@ -373,11 +373,14 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
self._resolved_model = _model_info.resolve_model(create_args["model"])
|
||||
|
||||
if (
|
||||
"response_format" in create_args
|
||||
and create_args["response_format"]["type"] == "json_object"
|
||||
and not self._model_info["json_output"]
|
||||
not self._model_info["json_output"]
|
||||
and "response_format" in create_args
|
||||
and (
|
||||
isinstance(create_args["response_format"], dict)
|
||||
and create_args["response_format"]["type"] == "json_object"
|
||||
)
|
||||
):
|
||||
raise ValueError("Model does not support JSON output")
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
|
||||
self._create_args = create_args
|
||||
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
@ -433,7 +436,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
if json_output is not None:
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
|
||||
if json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
@ -441,7 +444,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
create_args["response_format"] = {"type": "text"}
|
||||
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
|
||||
oai_messages_nested = [to_oai_type(m) for m in messages]
|
||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Annotated, Any, AsyncGenerator, List, Tuple
|
||||
import json
|
||||
from typing import Annotated, Any, AsyncGenerator, Generic, List, Literal, Tuple, TypeVar
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -19,14 +20,36 @@ from autogen_core.tools import BaseTool, FunctionTool
|
||||
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
|
||||
from autogen_ext.models.openai._model_info import resolve_model
|
||||
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools
|
||||
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)
|
||||
|
||||
|
||||
class _MockBetaChatCompletion(Generic[ResponseFormatT]):
|
||||
def __init__(self, chat_completions: List[ParsedChatCompletion[ResponseFormatT]]) -> None:
|
||||
self._saved_chat_completions = chat_completions
|
||||
self.curr_index = 0
|
||||
self.calls: List[List[LLMMessage]] = []
|
||||
|
||||
async def mock_parse(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> ParsedChatCompletion[ResponseFormatT]:
|
||||
self.calls.append(kwargs["messages"])
|
||||
await asyncio.sleep(0.1)
|
||||
completion = self._saved_chat_completions[self.curr_index]
|
||||
self.curr_index += 1
|
||||
return completion
|
||||
|
||||
|
||||
class MyResult(BaseModel):
|
||||
result: str = Field(description="The other description.")
|
||||
@ -358,3 +381,54 @@ def test_convert_tools_accepts_both_tool_and_schema() -> None:
|
||||
|
||||
assert len(converted_tool_schema) == 2
|
||||
assert converted_tool_schema[0] == converted_tool_schema[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
|
||||
model = "gpt-4o-2024-11-20"
|
||||
chat_completions: List[ParsedChatCompletion[AgentResponse]] = [
|
||||
ParsedChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
ParsedChoice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ParsedChatCompletionMessage(
|
||||
content=json.dumps(
|
||||
{
|
||||
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
|
||||
"response": "happy",
|
||||
}
|
||||
),
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockBetaChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(BetaAsyncCompletions, "parse", mock.mock_parse)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
api_key="",
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
|
||||
assert isinstance(create_result.content, str)
|
||||
response = AgentResponse.model_validate(json.loads(create_result.content))
|
||||
assert (
|
||||
response.thoughts
|
||||
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
||||
)
|
||||
assert response.response == "happy"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user