mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-24 09:20:52 +00:00
fix: Structured output with tool calls for OpenAIChatCompletionClient (#5671)
Resolves: #5568 Also, refactored some unit tests. Integration tests against OpenAI endpoint passed: https://github.com/microsoft/autogen/actions/runs/13484492096 Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
745c9d2bc5
commit
9fd8eefc55
@ -50,10 +50,11 @@ from autogen_core.models import (
|
|||||||
validate_model_info,
|
validate_model_info,
|
||||||
)
|
)
|
||||||
from autogen_core.tools import Tool, ToolSchema
|
from autogen_core.tools import Tool, ToolSchema
|
||||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletion,
|
ChatCompletion,
|
||||||
ChatCompletionAssistantMessageParam,
|
ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionChunk,
|
||||||
ChatCompletionContentPartImageParam,
|
ChatCompletionContentPartImageParam,
|
||||||
ChatCompletionContentPartParam,
|
ChatCompletionContentPartParam,
|
||||||
ChatCompletionContentPartTextParam,
|
ChatCompletionContentPartTextParam,
|
||||||
@ -693,8 +694,23 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
create_args = self._create_args.copy()
|
create_args = self._create_args.copy()
|
||||||
create_args.update(extra_create_args)
|
create_args.update(extra_create_args)
|
||||||
|
|
||||||
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
|
# Declare use_beta_client
|
||||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
use_beta_client: bool = False
|
||||||
|
response_format_value: Optional[Type[BaseModel]] = None
|
||||||
|
|
||||||
|
if "response_format" in create_args:
|
||||||
|
value = create_args["response_format"]
|
||||||
|
# If value is a Pydantic model class, use the beta client
|
||||||
|
if isinstance(value, type) and issubclass(value, BaseModel):
|
||||||
|
response_format_value = value
|
||||||
|
use_beta_client = True
|
||||||
|
else:
|
||||||
|
# response_format_value is not a Pydantic model class
|
||||||
|
use_beta_client = False
|
||||||
|
response_format_value = None
|
||||||
|
|
||||||
|
# Remove 'response_format' from create_args to prevent passing it twice
|
||||||
|
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
|
||||||
|
|
||||||
# TODO: allow custom handling.
|
# TODO: allow custom handling.
|
||||||
# For now we raise an error if images are present and vision is not supported
|
# For now we raise an error if images are present and vision is not supported
|
||||||
@ -713,23 +729,39 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
else:
|
else:
|
||||||
create_args["response_format"] = {"type": "text"}
|
create_args["response_format"] = {"type": "text"}
|
||||||
|
|
||||||
if len(tools) > 0:
|
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
|
||||||
converted_tools = convert_tools(tools)
|
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||||
stream_future = asyncio.ensure_future(
|
|
||||||
self._client.chat.completions.create(
|
if self.model_info["function_calling"] is False and len(tools) > 0:
|
||||||
messages=oai_messages,
|
raise ValueError("Model does not support function calling")
|
||||||
stream=True,
|
|
||||||
tools=converted_tools,
|
if max_consecutive_empty_chunk_tolerance != 0:
|
||||||
**create_args,
|
warnings.warn(
|
||||||
|
"The 'max_consecutive_empty_chunk_tolerance' parameter is deprecated and will be removed in the future releases. All of empty chunks will be skipped with a warning.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_params = convert_tools(tools)
|
||||||
|
|
||||||
|
# Get the async generator of chunks.
|
||||||
|
if use_beta_client:
|
||||||
|
chunks = self._create_stream_chunks_beta_client(
|
||||||
|
tool_params=tool_params,
|
||||||
|
oai_messages=oai_messages,
|
||||||
|
response_format=response_format_value,
|
||||||
|
create_args_no_response_format=create_args_no_response_format,
|
||||||
|
cancellation_token=cancellation_token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
stream_future = asyncio.ensure_future(
|
chunks = self._create_stream_chunks(
|
||||||
self._client.chat.completions.create(messages=oai_messages, stream=True, **create_args)
|
tool_params=tool_params,
|
||||||
|
oai_messages=oai_messages,
|
||||||
|
create_args=create_args,
|
||||||
|
cancellation_token=cancellation_token,
|
||||||
)
|
)
|
||||||
if cancellation_token is not None:
|
|
||||||
cancellation_token.link_future(stream_future)
|
# Prepare data to process streaming chunks.
|
||||||
stream = await stream_future
|
|
||||||
choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], ChunkChoice] = cast(ChunkChoice, None)
|
choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], ChunkChoice] = cast(ChunkChoice, None)
|
||||||
chunk = None
|
chunk = None
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
@ -739,23 +771,12 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
logprobs: Optional[List[ChatCompletionTokenLogprob]] = None
|
logprobs: Optional[List[ChatCompletionTokenLogprob]] = None
|
||||||
|
|
||||||
if max_consecutive_empty_chunk_tolerance != 0:
|
|
||||||
warnings.warn(
|
|
||||||
"The 'max_consecutive_empty_chunk_tolerance' parameter is deprecated and will be removed in the future releases. All of empty chunks will be skipped with a warning.",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
empty_chunk_warning_has_been_issued: bool = False
|
empty_chunk_warning_has_been_issued: bool = False
|
||||||
empty_chunk_warning_threshold: int = 10
|
empty_chunk_warning_threshold: int = 10
|
||||||
empty_chunk_count = 0
|
empty_chunk_count = 0
|
||||||
|
|
||||||
while True:
|
# Process the stream of chunks.
|
||||||
try:
|
async for chunk in chunks:
|
||||||
chunk_future = asyncio.ensure_future(anext(stream))
|
|
||||||
if cancellation_token is not None:
|
|
||||||
cancellation_token.link_future(chunk_future)
|
|
||||||
chunk = await chunk_future
|
|
||||||
|
|
||||||
# Empty chunks has been observed when the endpoint is under heavy load.
|
# Empty chunks has been observed when the endpoint is under heavy load.
|
||||||
# https://github.com/microsoft/autogen/issues/4213
|
# https://github.com/microsoft/autogen/issues/4213
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
@ -823,20 +844,29 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
for x in choice.logprobs.content
|
for x in choice.logprobs.content
|
||||||
]
|
]
|
||||||
|
|
||||||
except StopAsyncIteration:
|
# Finalize the CreateResult.
|
||||||
break
|
|
||||||
|
|
||||||
model = maybe_model or create_args["model"]
|
|
||||||
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
|
|
||||||
|
|
||||||
if chunk and chunk.usage:
|
|
||||||
prompt_tokens = chunk.usage.prompt_tokens
|
|
||||||
else:
|
|
||||||
prompt_tokens = 0
|
|
||||||
|
|
||||||
|
# TODO: can we remove this?
|
||||||
if stop_reason == "function_call":
|
if stop_reason == "function_call":
|
||||||
raise ValueError("Function calls are not supported in this context")
|
raise ValueError("Function calls are not supported in this context")
|
||||||
|
|
||||||
|
# We need to get the model from the last chunk, if available.
|
||||||
|
model = maybe_model or create_args["model"]
|
||||||
|
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
|
||||||
|
|
||||||
|
# Because the usage chunk is not guaranteed to be the last chunk, we need to check if it is available.
|
||||||
|
if chunk and chunk.usage:
|
||||||
|
prompt_tokens = chunk.usage.prompt_tokens
|
||||||
|
completion_tokens = chunk.usage.completion_tokens
|
||||||
|
else:
|
||||||
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
usage = RequestUsage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detect whether it is a function call or just text.
|
||||||
content: Union[str, List[FunctionCall]]
|
content: Union[str, List[FunctionCall]]
|
||||||
thought: str | None = None
|
thought: str | None = None
|
||||||
if full_tool_calls:
|
if full_tool_calls:
|
||||||
@ -852,19 +882,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2)
|
warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2)
|
||||||
content = ""
|
content = ""
|
||||||
|
|
||||||
if chunk and chunk.usage:
|
# Parse R1 content if needed.
|
||||||
completion_tokens = chunk.usage.completion_tokens
|
|
||||||
else:
|
|
||||||
completion_tokens = 0
|
|
||||||
|
|
||||||
usage = RequestUsage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||||
thought, content = parse_r1_content(content)
|
thought, content = parse_r1_content(content)
|
||||||
|
|
||||||
|
# Create the result.
|
||||||
result = CreateResult(
|
result = CreateResult(
|
||||||
finish_reason=normalize_stop_reason(stop_reason),
|
finish_reason=normalize_stop_reason(stop_reason),
|
||||||
content=content,
|
content=content,
|
||||||
@ -874,11 +896,73 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
|||||||
thought=thought,
|
thought=thought,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update the total usage.
|
||||||
self._total_usage = _add_usage(self._total_usage, usage)
|
self._total_usage = _add_usage(self._total_usage, usage)
|
||||||
self._actual_usage = _add_usage(self._actual_usage, usage)
|
self._actual_usage = _add_usage(self._actual_usage, usage)
|
||||||
|
|
||||||
|
# Yield the CreateResult.
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
async def _create_stream_chunks(
|
||||||
|
self,
|
||||||
|
tool_params: List[ChatCompletionToolParam],
|
||||||
|
oai_messages: List[ChatCompletionMessageParam],
|
||||||
|
create_args: Dict[str, Any],
|
||||||
|
cancellation_token: Optional[CancellationToken],
|
||||||
|
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||||
|
stream_future = asyncio.ensure_future(
|
||||||
|
self._client.chat.completions.create(
|
||||||
|
messages=oai_messages,
|
||||||
|
stream=True,
|
||||||
|
tools=tool_params if len(tool_params) > 0 else NOT_GIVEN,
|
||||||
|
**create_args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if cancellation_token is not None:
|
||||||
|
cancellation_token.link_future(stream_future)
|
||||||
|
stream = await stream_future
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk_future = asyncio.ensure_future(anext(stream))
|
||||||
|
if cancellation_token is not None:
|
||||||
|
cancellation_token.link_future(chunk_future)
|
||||||
|
chunk = await chunk_future
|
||||||
|
yield chunk
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def _create_stream_chunks_beta_client(
|
||||||
|
self,
|
||||||
|
tool_params: List[ChatCompletionToolParam],
|
||||||
|
oai_messages: List[ChatCompletionMessageParam],
|
||||||
|
create_args_no_response_format: Dict[str, Any],
|
||||||
|
response_format: Optional[Type[BaseModel]],
|
||||||
|
cancellation_token: Optional[CancellationToken],
|
||||||
|
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||||
|
async with self._client.beta.chat.completions.stream(
|
||||||
|
messages=oai_messages,
|
||||||
|
tools=tool_params if len(tool_params) > 0 else NOT_GIVEN,
|
||||||
|
response_format=response_format if response_format is not None else NOT_GIVEN,
|
||||||
|
**create_args_no_response_format,
|
||||||
|
) as stream:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event_future = asyncio.ensure_future(anext(stream))
|
||||||
|
if cancellation_token is not None:
|
||||||
|
cancellation_token.link_future(event_future)
|
||||||
|
event = await event_future
|
||||||
|
|
||||||
|
if event.type == "chunk":
|
||||||
|
chunk = event.chunk
|
||||||
|
yield chunk
|
||||||
|
# We don't handle other event types from the beta client stream.
|
||||||
|
# As the other event types are auxiliary to the chunk event.
|
||||||
|
# See: https://github.com/openai/openai-python/blob/main/helpers.md#chat-completions-events.
|
||||||
|
# Once the beta client is stable, we can move all the logic to the beta client.
|
||||||
|
# Then we can consider handling other event types which may simplify the code overall.
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
def actual_usage(self) -> RequestUsage:
|
def actual_usage(self) -> RequestUsage:
|
||||||
return self._actual_usage
|
return self._actual_usage
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar
|
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -23,7 +23,14 @@ from autogen_core.tools import BaseTool, FunctionTool
|
|||||||
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
|
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
|
||||||
from autogen_ext.models.openai._model_info import resolve_model
|
from autogen_ext.models.openai._model_info import resolve_model
|
||||||
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type
|
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type
|
||||||
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions
|
from openai.resources.beta.chat.completions import ( # type: ignore
|
||||||
|
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
# type: ignore
|
||||||
|
from openai.resources.beta.chat.completions import (
|
||||||
|
AsyncCompletions as BetaAsyncCompletions,
|
||||||
|
)
|
||||||
from openai.resources.chat.completions import AsyncCompletions
|
from openai.resources.chat.completions import AsyncCompletions
|
||||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||||
from openai.types.chat.chat_completion_chunk import (
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
@ -32,54 +39,22 @@ from openai.types.chat.chat_completion_chunk import (
|
|||||||
ChoiceDeltaToolCall,
|
ChoiceDeltaToolCall,
|
||||||
ChoiceDeltaToolCallFunction,
|
ChoiceDeltaToolCallFunction,
|
||||||
)
|
)
|
||||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
|
Choice as ChunkChoice,
|
||||||
|
)
|
||||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
from openai.types.chat.chat_completion_message_tool_call import (
|
||||||
ChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCall,
|
||||||
Function,
|
Function,
|
||||||
)
|
)
|
||||||
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice
|
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice
|
||||||
|
from openai.types.chat.parsed_function_tool_call import ParsedFunction, ParsedFunctionToolCall
|
||||||
from openai.types.completion_usage import CompletionUsage
|
from openai.types.completion_usage import CompletionUsage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class _MockChatCompletion:
|
|
||||||
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
|
||||||
self._saved_chat_completions = chat_completions
|
|
||||||
self.curr_index = 0
|
|
||||||
self.calls: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def mock_create(
|
|
||||||
self, *args: Any, **kwargs: Any
|
|
||||||
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
|
||||||
self.calls.append(kwargs) # Save the call
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
completion = self._saved_chat_completions[self.curr_index]
|
|
||||||
self.curr_index += 1
|
|
||||||
return completion
|
|
||||||
|
|
||||||
|
|
||||||
ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)
|
ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class _MockBetaChatCompletion(Generic[ResponseFormatT]):
|
|
||||||
def __init__(self, chat_completions: List[ParsedChatCompletion[ResponseFormatT]]) -> None:
|
|
||||||
self._saved_chat_completions = chat_completions
|
|
||||||
self.curr_index = 0
|
|
||||||
self.calls: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def mock_parse(
|
|
||||||
self,
|
|
||||||
*args: Any,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ParsedChatCompletion[ResponseFormatT]:
|
|
||||||
self.calls.append(kwargs) # Save the call
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
completion = self._saved_chat_completions[self.curr_index]
|
|
||||||
self.curr_index += 1
|
|
||||||
return completion
|
|
||||||
|
|
||||||
|
|
||||||
def _pass_function(input: str) -> str:
|
def _pass_function(input: str) -> str:
|
||||||
return "pass"
|
return "pass"
|
||||||
|
|
||||||
@ -106,6 +81,11 @@ class MockChunkDefinition(BaseModel):
|
|||||||
usage: CompletionUsage | None
|
usage: CompletionUsage | None
|
||||||
|
|
||||||
|
|
||||||
|
class MockChunkEvent(BaseModel):
|
||||||
|
type: Literal["chunk"]
|
||||||
|
chunk: ChatCompletionChunk
|
||||||
|
|
||||||
|
|
||||||
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||||
model = resolve_model(kwargs.get("model", "gpt-4o"))
|
model = resolve_model(kwargs.get("model", "gpt-4o"))
|
||||||
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
|
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
|
||||||
@ -443,8 +423,9 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
response: Literal["happy", "sad", "neutral"]
|
response: Literal["happy", "sad", "neutral"]
|
||||||
|
|
||||||
model = "gpt-4o-2024-11-20"
|
model = "gpt-4o-2024-11-20"
|
||||||
chat_completions: List[ParsedChatCompletion[AgentResponse]] = [
|
|
||||||
ParsedChatCompletion(
|
async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]:
|
||||||
|
return ParsedChatCompletion(
|
||||||
id="id1",
|
id="id1",
|
||||||
choices=[
|
choices=[
|
||||||
ParsedChoice(
|
ParsedChoice(
|
||||||
@ -465,10 +446,9 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
model=model,
|
model=model,
|
||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||||
),
|
)
|
||||||
]
|
|
||||||
mock = _MockBetaChatCompletion(chat_completions)
|
monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse)
|
||||||
monkeypatch.setattr(BetaAsyncCompletions, "parse", mock.mock_parse)
|
|
||||||
|
|
||||||
model_client = OpenAIChatCompletionClient(
|
model_client = OpenAIChatCompletionClient(
|
||||||
model=model,
|
model=model,
|
||||||
@ -487,6 +467,258 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert response.response == "happy"
|
assert response.response == "happy"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_structured_output_with_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
thoughts: str
|
||||||
|
response: Literal["happy", "sad", "neutral"]
|
||||||
|
|
||||||
|
model = "gpt-4o-2024-11-20"
|
||||||
|
|
||||||
|
async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]:
|
||||||
|
return ParsedChatCompletion(
|
||||||
|
id="id1",
|
||||||
|
choices=[
|
||||||
|
ParsedChoice(
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
index=0,
|
||||||
|
message=ParsedChatCompletionMessage(
|
||||||
|
content=json.dumps(
|
||||||
|
{
|
||||||
|
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
|
||||||
|
"response": "happy",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ParsedFunctionToolCall(
|
||||||
|
id="1",
|
||||||
|
type="function",
|
||||||
|
function=ParsedFunction(
|
||||||
|
name="_pass_function",
|
||||||
|
arguments=json.dumps({"input": "happy"}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=0,
|
||||||
|
model=model,
|
||||||
|
object="chat.completion",
|
||||||
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse)
|
||||||
|
|
||||||
|
model_client = OpenAIChatCompletionClient(
|
||||||
|
model=model,
|
||||||
|
api_key="",
|
||||||
|
response_format=AgentResponse, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the openai client was called with the correct response format.
|
||||||
|
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
|
||||||
|
assert isinstance(create_result.content, list)
|
||||||
|
assert len(create_result.content) == 1
|
||||||
|
assert create_result.content[0] == FunctionCall(
|
||||||
|
id="1", name="_pass_function", arguments=json.dumps({"input": "happy"})
|
||||||
|
)
|
||||||
|
assert isinstance(create_result.thought, str)
|
||||||
|
response = AgentResponse.model_validate(json.loads(create_result.thought))
|
||||||
|
assert (
|
||||||
|
response.thoughts
|
||||||
|
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
||||||
|
)
|
||||||
|
assert response.response == "happy"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_structured_output_with_streaming(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
thoughts: str
|
||||||
|
response: Literal["happy", "sad", "neutral"]
|
||||||
|
|
||||||
|
raw_content = json.dumps(
|
||||||
|
{
|
||||||
|
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
|
||||||
|
"response": "happy",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)]
|
||||||
|
assert "".join(chunked_content) == raw_content
|
||||||
|
|
||||||
|
model = "gpt-4o-2024-11-20"
|
||||||
|
mock_chunk_events = [
|
||||||
|
MockChunkEvent(
|
||||||
|
type="chunk",
|
||||||
|
chunk=ChatCompletionChunk(
|
||||||
|
id="id",
|
||||||
|
choices=[
|
||||||
|
ChunkChoice(
|
||||||
|
finish_reason=None,
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
content=mock_chunk_content,
|
||||||
|
role="assistant",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=0,
|
||||||
|
model=model,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
usage=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for mock_chunk_content in chunked_content
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]:
|
||||||
|
async def _stream() -> AsyncGenerator[MockChunkEvent, None]:
|
||||||
|
for mock_chunk_event in mock_chunk_events:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
yield mock_chunk_event
|
||||||
|
|
||||||
|
return _stream()
|
||||||
|
|
||||||
|
# Mock the context manager __aenter__ method which returns the stream.
|
||||||
|
monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream)
|
||||||
|
|
||||||
|
model_client = OpenAIChatCompletionClient(
|
||||||
|
model=model,
|
||||||
|
api_key="",
|
||||||
|
response_format=AgentResponse, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the openai client was called with the correct response format.
|
||||||
|
chunks: List[str | CreateResult] = []
|
||||||
|
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
|
||||||
|
chunks.append(chunk)
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert isinstance(chunks[-1], CreateResult)
|
||||||
|
assert isinstance(chunks[-1].content, str)
|
||||||
|
response = AgentResponse.model_validate(json.loads(chunks[-1].content))
|
||||||
|
assert (
|
||||||
|
response.thoughts
|
||||||
|
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
||||||
|
)
|
||||||
|
assert response.response == "happy"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_structured_output_with_streaming_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
thoughts: str
|
||||||
|
response: Literal["happy", "sad", "neutral"]
|
||||||
|
|
||||||
|
raw_content = json.dumps(
|
||||||
|
{
|
||||||
|
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
|
||||||
|
"response": "happy",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)]
|
||||||
|
assert "".join(chunked_content) == raw_content
|
||||||
|
|
||||||
|
model = "gpt-4o-2024-11-20"
|
||||||
|
|
||||||
|
# generate the list of mock chunk content
|
||||||
|
mock_chunk_events = [
|
||||||
|
MockChunkEvent(
|
||||||
|
type="chunk",
|
||||||
|
chunk=ChatCompletionChunk(
|
||||||
|
id="id",
|
||||||
|
choices=[
|
||||||
|
ChunkChoice(
|
||||||
|
finish_reason=None,
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
content=mock_chunk_content,
|
||||||
|
role="assistant",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=0,
|
||||||
|
model=model,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
usage=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for mock_chunk_content in chunked_content
|
||||||
|
]
|
||||||
|
|
||||||
|
# add the tool call chunk.
|
||||||
|
mock_chunk_events += [
|
||||||
|
MockChunkEvent(
|
||||||
|
type="chunk",
|
||||||
|
chunk=ChatCompletionChunk(
|
||||||
|
id="id",
|
||||||
|
choices=[
|
||||||
|
ChunkChoice(
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
content=None,
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
id="1",
|
||||||
|
index=0,
|
||||||
|
type="function",
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name="_pass_function",
|
||||||
|
arguments=json.dumps({"input": "happy"}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=0,
|
||||||
|
model=model,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
usage=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]:
|
||||||
|
async def _stream() -> AsyncGenerator[MockChunkEvent, None]:
|
||||||
|
for mock_chunk_event in mock_chunk_events:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
yield mock_chunk_event
|
||||||
|
|
||||||
|
return _stream()
|
||||||
|
|
||||||
|
# Mock the context manager __aenter__ method which returns the stream.
|
||||||
|
monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream)
|
||||||
|
|
||||||
|
model_client = OpenAIChatCompletionClient(
|
||||||
|
model=model,
|
||||||
|
api_key="",
|
||||||
|
response_format=AgentResponse, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the openai client was called with the correct response format.
|
||||||
|
chunks: List[str | CreateResult] = []
|
||||||
|
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
|
||||||
|
chunks.append(chunk)
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert isinstance(chunks[-1], CreateResult)
|
||||||
|
assert isinstance(chunks[-1].content, list)
|
||||||
|
assert len(chunks[-1].content) == 1
|
||||||
|
assert chunks[-1].content[0] == FunctionCall(
|
||||||
|
id="1", name="_pass_function", arguments=json.dumps({"input": "happy"})
|
||||||
|
)
|
||||||
|
assert isinstance(chunks[-1].thought, str)
|
||||||
|
response = AgentResponse.model_validate(json.loads(chunks[-1].thought))
|
||||||
|
assert (
|
||||||
|
response.thoughts
|
||||||
|
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
||||||
|
)
|
||||||
|
assert response.response == "happy"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||||
@ -812,6 +1044,20 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
class _MockChatCompletion:
|
||||||
|
def __init__(self, completions: List[ChatCompletion]):
|
||||||
|
self.completions = list(completions)
|
||||||
|
self.calls: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def mock_create(
|
||||||
|
self, *args: Any, **kwargs: Any
|
||||||
|
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
||||||
|
if kwargs.get("stream", False):
|
||||||
|
raise NotImplementedError("Streaming not supported in this test.")
|
||||||
|
self.calls.append(kwargs)
|
||||||
|
return self.completions.pop(0)
|
||||||
|
|
||||||
mock = _MockChatCompletion(chat_completions)
|
mock = _MockChatCompletion(chat_completions)
|
||||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||||
pass_tool = FunctionTool(_pass_function, description="pass tool.")
|
pass_tool = FunctionTool(_pass_function, description="pass tool.")
|
||||||
@ -1062,6 +1308,35 @@ async def test_openai_structured_output() -> None:
|
|||||||
assert response.response in ["happy", "sad", "neutral"]
|
assert response.response in ["happy", "sad", "neutral"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_structured_output_with_streaming() -> None:
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||||
|
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
thoughts: str
|
||||||
|
response: Literal["happy", "sad", "neutral"]
|
||||||
|
|
||||||
|
model_client = OpenAIChatCompletionClient(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
api_key=api_key,
|
||||||
|
response_format=AgentResponse, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the openai client was called with the correct response format.
|
||||||
|
stream = model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")])
|
||||||
|
chunks: List[str | CreateResult] = []
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert isinstance(chunks[-1], CreateResult)
|
||||||
|
assert isinstance(chunks[-1].content, str)
|
||||||
|
response = AgentResponse.model_validate(json.loads(chunks[-1].content))
|
||||||
|
assert response.thoughts
|
||||||
|
assert response.response in ["happy", "sad", "neutral"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_openai_structured_output_with_tool_calls() -> None:
|
async def test_openai_structured_output_with_tool_calls() -> None:
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
@ -1090,6 +1365,7 @@ async def test_openai_structured_output_with_tool_calls() -> None:
|
|||||||
UserMessage(content="I am happy.", source="user"),
|
UserMessage(content="I am happy.", source="user"),
|
||||||
],
|
],
|
||||||
tools=[tool],
|
tools=[tool],
|
||||||
|
extra_create_args={"tool_choice": "required"},
|
||||||
)
|
)
|
||||||
assert isinstance(response1.content, list)
|
assert isinstance(response1.content, list)
|
||||||
assert len(response1.content) == 1
|
assert len(response1.content) == 1
|
||||||
@ -1114,6 +1390,71 @@ async def test_openai_structured_output_with_tool_calls() -> None:
|
|||||||
assert parsed_response.response in ["happy", "sad", "neutral"]
|
assert parsed_response.response in ["happy", "sad", "neutral"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_structured_output_with_streaming_tool_calls() -> None:
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||||
|
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
thoughts: str
|
||||||
|
response: Literal["happy", "sad", "neutral"]
|
||||||
|
|
||||||
|
def sentiment_analysis(text: str) -> str:
|
||||||
|
"""Given a text, return the sentiment."""
|
||||||
|
return "happy" if "happy" in text else "sad" if "sad" in text else "neutral"
|
||||||
|
|
||||||
|
tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True)
|
||||||
|
|
||||||
|
model_client = OpenAIChatCompletionClient(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
api_key=api_key,
|
||||||
|
response_format=AgentResponse, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks1: List[str | CreateResult] = []
|
||||||
|
stream1 = model_client.create_stream(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
||||||
|
UserMessage(content="I am happy.", source="user"),
|
||||||
|
],
|
||||||
|
tools=[tool],
|
||||||
|
extra_create_args={"tool_choice": "required"},
|
||||||
|
)
|
||||||
|
async for chunk in stream1:
|
||||||
|
chunks1.append(chunk)
|
||||||
|
assert len(chunks1) > 0
|
||||||
|
create_result1 = chunks1[-1]
|
||||||
|
assert isinstance(create_result1, CreateResult)
|
||||||
|
assert isinstance(create_result1.content, list)
|
||||||
|
assert len(create_result1.content) == 1
|
||||||
|
assert isinstance(create_result1.content[0], FunctionCall)
|
||||||
|
assert create_result1.content[0].name == "sentiment_analysis"
|
||||||
|
assert json.loads(create_result1.content[0].arguments) == {"text": "I am happy."}
|
||||||
|
assert create_result1.finish_reason == "function_calls"
|
||||||
|
|
||||||
|
stream2 = model_client.create_stream(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
||||||
|
UserMessage(content="I am happy.", source="user"),
|
||||||
|
AssistantMessage(content=create_result1.content, source="assistant"),
|
||||||
|
FunctionExecutionResultMessage(
|
||||||
|
content=[FunctionExecutionResult(content="happy", call_id=create_result1.content[0].id, is_error=False)]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
chunks2: List[str | CreateResult] = []
|
||||||
|
async for chunk in stream2:
|
||||||
|
chunks2.append(chunk)
|
||||||
|
assert len(chunks2) > 0
|
||||||
|
create_result2 = chunks2[-1]
|
||||||
|
assert isinstance(create_result2, CreateResult)
|
||||||
|
assert isinstance(create_result2.content, str)
|
||||||
|
parsed_response = AgentResponse.model_validate(json.loads(create_result2.content))
|
||||||
|
assert parsed_response.thoughts
|
||||||
|
assert parsed_response.response in ["happy", "sad", "neutral"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini() -> None:
|
async def test_gemini() -> None:
|
||||||
api_key = os.getenv("GEMINI_API_KEY")
|
api_key = os.getenv("GEMINI_API_KEY")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user