mirror of
https://github.com/microsoft/autogen.git
synced 2025-06-26 22:30:10 +00:00
fix: Structured output with tool calls for OpenAIChatCompletionClient (#5671)
Resolves: #5568 Also, refactored some unit tests. Integration tests against OpenAI endpoint passed: https://github.com/microsoft/autogen/actions/runs/13484492096 Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
745c9d2bc5
commit
9fd8eefc55
@ -50,10 +50,11 @@ from autogen_core.models import (
|
||||
validate_model_info,
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
@ -693,8 +694,23 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
create_args = self._create_args.copy()
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
|
||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||
# Declare use_beta_client
|
||||
use_beta_client: bool = False
|
||||
response_format_value: Optional[Type[BaseModel]] = None
|
||||
|
||||
if "response_format" in create_args:
|
||||
value = create_args["response_format"]
|
||||
# If value is a Pydantic model class, use the beta client
|
||||
if isinstance(value, type) and issubclass(value, BaseModel):
|
||||
response_format_value = value
|
||||
use_beta_client = True
|
||||
else:
|
||||
# response_format_value is not a Pydantic model class
|
||||
use_beta_client = False
|
||||
response_format_value = None
|
||||
|
||||
# Remove 'response_format' from create_args to prevent passing it twice
|
||||
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
|
||||
|
||||
# TODO: allow custom handling.
|
||||
# For now we raise an error if images are present and vision is not supported
|
||||
@ -713,23 +729,39 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
else:
|
||||
create_args["response_format"] = {"type": "text"}
|
||||
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
stream_future = asyncio.ensure_future(
|
||||
self._client.chat.completions.create(
|
||||
messages=oai_messages,
|
||||
stream=True,
|
||||
tools=converted_tools,
|
||||
**create_args,
|
||||
)
|
||||
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
|
||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||
|
||||
if self.model_info["function_calling"] is False and len(tools) > 0:
|
||||
raise ValueError("Model does not support function calling")
|
||||
|
||||
if max_consecutive_empty_chunk_tolerance != 0:
|
||||
warnings.warn(
|
||||
"The 'max_consecutive_empty_chunk_tolerance' parameter is deprecated and will be removed in the future releases. All of empty chunks will be skipped with a warning.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
tool_params = convert_tools(tools)
|
||||
|
||||
# Get the async generator of chunks.
|
||||
if use_beta_client:
|
||||
chunks = self._create_stream_chunks_beta_client(
|
||||
tool_params=tool_params,
|
||||
oai_messages=oai_messages,
|
||||
response_format=response_format_value,
|
||||
create_args_no_response_format=create_args_no_response_format,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
else:
|
||||
stream_future = asyncio.ensure_future(
|
||||
self._client.chat.completions.create(messages=oai_messages, stream=True, **create_args)
|
||||
chunks = self._create_stream_chunks(
|
||||
tool_params=tool_params,
|
||||
oai_messages=oai_messages,
|
||||
create_args=create_args,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(stream_future)
|
||||
stream = await stream_future
|
||||
|
||||
# Prepare data to process streaming chunks.
|
||||
choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], ChunkChoice] = cast(ChunkChoice, None)
|
||||
chunk = None
|
||||
stop_reason = None
|
||||
@ -739,104 +771,102 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
completion_tokens = 0
|
||||
logprobs: Optional[List[ChatCompletionTokenLogprob]] = None
|
||||
|
||||
if max_consecutive_empty_chunk_tolerance != 0:
|
||||
warnings.warn(
|
||||
"The 'max_consecutive_empty_chunk_tolerance' parameter is deprecated and will be removed in the future releases. All of empty chunks will be skipped with a warning.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
empty_chunk_warning_has_been_issued: bool = False
|
||||
empty_chunk_warning_threshold: int = 10
|
||||
empty_chunk_count = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
chunk_future = asyncio.ensure_future(anext(stream))
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(chunk_future)
|
||||
chunk = await chunk_future
|
||||
# Process the stream of chunks.
|
||||
async for chunk in chunks:
|
||||
# Empty chunks has been observed when the endpoint is under heavy load.
|
||||
# https://github.com/microsoft/autogen/issues/4213
|
||||
if len(chunk.choices) == 0:
|
||||
empty_chunk_count += 1
|
||||
if not empty_chunk_warning_has_been_issued and empty_chunk_count >= empty_chunk_warning_threshold:
|
||||
empty_chunk_warning_has_been_issued = True
|
||||
warnings.warn(
|
||||
f"Received more than {empty_chunk_warning_threshold} consecutive empty chunks. Empty chunks are being ignored.",
|
||||
stacklevel=2,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
empty_chunk_count = 0
|
||||
|
||||
# Empty chunks has been observed when the endpoint is under heavy load.
|
||||
# https://github.com/microsoft/autogen/issues/4213
|
||||
if len(chunk.choices) == 0:
|
||||
empty_chunk_count += 1
|
||||
if not empty_chunk_warning_has_been_issued and empty_chunk_count >= empty_chunk_warning_threshold:
|
||||
empty_chunk_warning_has_been_issued = True
|
||||
warnings.warn(
|
||||
f"Received more than {empty_chunk_warning_threshold} consecutive empty chunks. Empty chunks are being ignored.",
|
||||
stacklevel=2,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
empty_chunk_count = 0
|
||||
# to process usage chunk in streaming situations
|
||||
# add stream_options={"include_usage": True} in the initialization of OpenAIChatCompletionClient(...)
|
||||
# However the different api's
|
||||
# OPENAI api usage chunk produces no choices so need to check if there is a choice
|
||||
# liteLLM api usage chunk does produce choices
|
||||
choice = (
|
||||
chunk.choices[0]
|
||||
if len(chunk.choices) > 0
|
||||
else choice
|
||||
if chunk.usage is not None and stop_reason is not None
|
||||
else cast(ChunkChoice, None)
|
||||
)
|
||||
|
||||
# to process usage chunk in streaming situations
|
||||
# add stream_options={"include_usage": True} in the initialization of OpenAIChatCompletionClient(...)
|
||||
# However the different api's
|
||||
# OPENAI api usage chunk produces no choices so need to check if there is a choice
|
||||
# liteLLM api usage chunk does produce choices
|
||||
choice = (
|
||||
chunk.choices[0]
|
||||
if len(chunk.choices) > 0
|
||||
else choice
|
||||
if chunk.usage is not None and stop_reason is not None
|
||||
else cast(ChunkChoice, None)
|
||||
)
|
||||
# for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set).
|
||||
# set the stop_reason for the usage chunk to the prior stop_reason
|
||||
stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason
|
||||
maybe_model = chunk.model
|
||||
# First try get content
|
||||
if choice.delta.content:
|
||||
content_deltas.append(choice.delta.content)
|
||||
if len(choice.delta.content) > 0:
|
||||
yield choice.delta.content
|
||||
# NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop.
|
||||
# However, this may not be the case for other APIs -- we should expect this may need to be updated.
|
||||
continue
|
||||
|
||||
# for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set).
|
||||
# set the stop_reason for the usage chunk to the prior stop_reason
|
||||
stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason
|
||||
maybe_model = chunk.model
|
||||
# First try get content
|
||||
if choice.delta.content:
|
||||
content_deltas.append(choice.delta.content)
|
||||
if len(choice.delta.content) > 0:
|
||||
yield choice.delta.content
|
||||
# NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop.
|
||||
# However, this may not be the case for other APIs -- we should expect this may need to be updated.
|
||||
continue
|
||||
# Otherwise, get tool calls
|
||||
if choice.delta.tool_calls is not None:
|
||||
for tool_call_chunk in choice.delta.tool_calls:
|
||||
idx = tool_call_chunk.index
|
||||
if idx not in full_tool_calls:
|
||||
# We ignore the type hint here because we want to fill in type when the delta provides it
|
||||
full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
|
||||
|
||||
# Otherwise, get tool calls
|
||||
if choice.delta.tool_calls is not None:
|
||||
for tool_call_chunk in choice.delta.tool_calls:
|
||||
idx = tool_call_chunk.index
|
||||
if idx not in full_tool_calls:
|
||||
# We ignore the type hint here because we want to fill in type when the delta provides it
|
||||
full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
|
||||
if tool_call_chunk.id is not None:
|
||||
full_tool_calls[idx].id += tool_call_chunk.id
|
||||
|
||||
if tool_call_chunk.id is not None:
|
||||
full_tool_calls[idx].id += tool_call_chunk.id
|
||||
if tool_call_chunk.function is not None:
|
||||
if tool_call_chunk.function.name is not None:
|
||||
full_tool_calls[idx].name += tool_call_chunk.function.name
|
||||
if tool_call_chunk.function.arguments is not None:
|
||||
full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
|
||||
if choice.logprobs and choice.logprobs.content:
|
||||
logprobs = [
|
||||
ChatCompletionTokenLogprob(
|
||||
token=x.token,
|
||||
logprob=x.logprob,
|
||||
top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs],
|
||||
bytes=x.bytes,
|
||||
)
|
||||
for x in choice.logprobs.content
|
||||
]
|
||||
|
||||
if tool_call_chunk.function is not None:
|
||||
if tool_call_chunk.function.name is not None:
|
||||
full_tool_calls[idx].name += tool_call_chunk.function.name
|
||||
if tool_call_chunk.function.arguments is not None:
|
||||
full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
|
||||
if choice.logprobs and choice.logprobs.content:
|
||||
logprobs = [
|
||||
ChatCompletionTokenLogprob(
|
||||
token=x.token,
|
||||
logprob=x.logprob,
|
||||
top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs],
|
||||
bytes=x.bytes,
|
||||
)
|
||||
for x in choice.logprobs.content
|
||||
]
|
||||
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
model = maybe_model or create_args["model"]
|
||||
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
|
||||
|
||||
if chunk and chunk.usage:
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
else:
|
||||
prompt_tokens = 0
|
||||
# Finalize the CreateResult.
|
||||
|
||||
# TODO: can we remove this?
|
||||
if stop_reason == "function_call":
|
||||
raise ValueError("Function calls are not supported in this context")
|
||||
|
||||
# We need to get the model from the last chunk, if available.
|
||||
model = maybe_model or create_args["model"]
|
||||
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
|
||||
|
||||
# Because the usage chunk is not guaranteed to be the last chunk, we need to check if it is available.
|
||||
if chunk and chunk.usage:
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
else:
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
usage = RequestUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
# Detect whether it is a function call or just text.
|
||||
content: Union[str, List[FunctionCall]]
|
||||
thought: str | None = None
|
||||
if full_tool_calls:
|
||||
@ -852,19 +882,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2)
|
||||
content = ""
|
||||
|
||||
if chunk and chunk.usage:
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
else:
|
||||
completion_tokens = 0
|
||||
|
||||
usage = RequestUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
# Parse R1 content if needed.
|
||||
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
|
||||
thought, content = parse_r1_content(content)
|
||||
|
||||
# Create the result.
|
||||
result = CreateResult(
|
||||
finish_reason=normalize_stop_reason(stop_reason),
|
||||
content=content,
|
||||
@ -874,11 +896,73 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
thought=thought,
|
||||
)
|
||||
|
||||
# Update the total usage.
|
||||
self._total_usage = _add_usage(self._total_usage, usage)
|
||||
self._actual_usage = _add_usage(self._actual_usage, usage)
|
||||
|
||||
# Yield the CreateResult.
|
||||
yield result
|
||||
|
||||
async def _create_stream_chunks(
|
||||
self,
|
||||
tool_params: List[ChatCompletionToolParam],
|
||||
oai_messages: List[ChatCompletionMessageParam],
|
||||
create_args: Dict[str, Any],
|
||||
cancellation_token: Optional[CancellationToken],
|
||||
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
stream_future = asyncio.ensure_future(
|
||||
self._client.chat.completions.create(
|
||||
messages=oai_messages,
|
||||
stream=True,
|
||||
tools=tool_params if len(tool_params) > 0 else NOT_GIVEN,
|
||||
**create_args,
|
||||
)
|
||||
)
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(stream_future)
|
||||
stream = await stream_future
|
||||
while True:
|
||||
try:
|
||||
chunk_future = asyncio.ensure_future(anext(stream))
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(chunk_future)
|
||||
chunk = await chunk_future
|
||||
yield chunk
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
async def _create_stream_chunks_beta_client(
|
||||
self,
|
||||
tool_params: List[ChatCompletionToolParam],
|
||||
oai_messages: List[ChatCompletionMessageParam],
|
||||
create_args_no_response_format: Dict[str, Any],
|
||||
response_format: Optional[Type[BaseModel]],
|
||||
cancellation_token: Optional[CancellationToken],
|
||||
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
async with self._client.beta.chat.completions.stream(
|
||||
messages=oai_messages,
|
||||
tools=tool_params if len(tool_params) > 0 else NOT_GIVEN,
|
||||
response_format=response_format if response_format is not None else NOT_GIVEN,
|
||||
**create_args_no_response_format,
|
||||
) as stream:
|
||||
while True:
|
||||
try:
|
||||
event_future = asyncio.ensure_future(anext(stream))
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(event_future)
|
||||
event = await event_future
|
||||
|
||||
if event.type == "chunk":
|
||||
chunk = event.chunk
|
||||
yield chunk
|
||||
# We don't handle other event types from the beta client stream.
|
||||
# As the other event types are auxiliary to the chunk event.
|
||||
# See: https://github.com/openai/openai-python/blob/main/helpers.md#chat-completions-events.
|
||||
# Once the beta client is stable, we can move all the logic to the beta client.
|
||||
# Then we can consider handling other event types which may simplify the code overall.
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return self._actual_usage
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar
|
||||
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
@ -23,7 +23,14 @@ from autogen_core.tools import BaseTool, FunctionTool
|
||||
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
|
||||
from autogen_ext.models.openai._model_info import resolve_model
|
||||
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type
|
||||
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions
|
||||
from openai.resources.beta.chat.completions import ( # type: ignore
|
||||
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
|
||||
)
|
||||
|
||||
# type: ignore
|
||||
from openai.resources.beta.chat.completions import (
|
||||
AsyncCompletions as BetaAsyncCompletions,
|
||||
)
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
@ -32,54 +39,22 @@ from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
Choice as ChunkChoice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice
|
||||
from openai.types.chat.parsed_function_tool_call import ParsedFunction, ParsedFunctionToolCall
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class _MockChatCompletion:
|
||||
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
||||
self._saved_chat_completions = chat_completions
|
||||
self.curr_index = 0
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_create(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
||||
self.calls.append(kwargs) # Save the call
|
||||
await asyncio.sleep(0.1)
|
||||
completion = self._saved_chat_completions[self.curr_index]
|
||||
self.curr_index += 1
|
||||
return completion
|
||||
|
||||
|
||||
ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)
|
||||
|
||||
|
||||
class _MockBetaChatCompletion(Generic[ResponseFormatT]):
|
||||
def __init__(self, chat_completions: List[ParsedChatCompletion[ResponseFormatT]]) -> None:
|
||||
self._saved_chat_completions = chat_completions
|
||||
self.curr_index = 0
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_parse(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> ParsedChatCompletion[ResponseFormatT]:
|
||||
self.calls.append(kwargs) # Save the call
|
||||
await asyncio.sleep(0.1)
|
||||
completion = self._saved_chat_completions[self.curr_index]
|
||||
self.curr_index += 1
|
||||
return completion
|
||||
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
return "pass"
|
||||
|
||||
@ -106,6 +81,11 @@ class MockChunkDefinition(BaseModel):
|
||||
usage: CompletionUsage | None
|
||||
|
||||
|
||||
class MockChunkEvent(BaseModel):
|
||||
type: Literal["chunk"]
|
||||
chunk: ChatCompletionChunk
|
||||
|
||||
|
||||
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
model = resolve_model(kwargs.get("model", "gpt-4o"))
|
||||
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
|
||||
@ -443,8 +423,9 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
|
||||
model = "gpt-4o-2024-11-20"
|
||||
chat_completions: List[ParsedChatCompletion[AgentResponse]] = [
|
||||
ParsedChatCompletion(
|
||||
|
||||
async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]:
|
||||
return ParsedChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
ParsedChoice(
|
||||
@ -465,10 +446,9 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockBetaChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(BetaAsyncCompletions, "parse", mock.mock_parse)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
@ -487,6 +467,258 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert response.response == "happy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output_with_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
|
||||
model = "gpt-4o-2024-11-20"
|
||||
|
||||
async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]:
|
||||
return ParsedChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
ParsedChoice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ParsedChatCompletionMessage(
|
||||
content=json.dumps(
|
||||
{
|
||||
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
|
||||
"response": "happy",
|
||||
}
|
||||
),
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ParsedFunctionToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=ParsedFunction(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "happy"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
api_key="",
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) == 1
|
||||
assert create_result.content[0] == FunctionCall(
|
||||
id="1", name="_pass_function", arguments=json.dumps({"input": "happy"})
|
||||
)
|
||||
assert isinstance(create_result.thought, str)
|
||||
response = AgentResponse.model_validate(json.loads(create_result.thought))
|
||||
assert (
|
||||
response.thoughts
|
||||
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
||||
)
|
||||
assert response.response == "happy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output_with_streaming(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
|
||||
raw_content = json.dumps(
|
||||
{
|
||||
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
|
||||
"response": "happy",
|
||||
}
|
||||
)
|
||||
chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)]
|
||||
assert "".join(chunked_content) == raw_content
|
||||
|
||||
model = "gpt-4o-2024-11-20"
|
||||
mock_chunk_events = [
|
||||
MockChunkEvent(
|
||||
type="chunk",
|
||||
chunk=ChatCompletionChunk(
|
||||
id="id",
|
||||
choices=[
|
||||
ChunkChoice(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=mock_chunk_content,
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
for mock_chunk_content in chunked_content
|
||||
]
|
||||
|
||||
async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]:
|
||||
async def _stream() -> AsyncGenerator[MockChunkEvent, None]:
|
||||
for mock_chunk_event in mock_chunk_events:
|
||||
await asyncio.sleep(0.1)
|
||||
yield mock_chunk_event
|
||||
|
||||
return _stream()
|
||||
|
||||
# Mock the context manager __aenter__ method which returns the stream.
|
||||
monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
api_key="",
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert isinstance(chunks[-1].content, str)
|
||||
response = AgentResponse.model_validate(json.loads(chunks[-1].content))
|
||||
assert (
|
||||
response.thoughts
|
||||
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
||||
)
|
||||
assert response.response == "happy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output_with_streaming_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
|
||||
raw_content = json.dumps(
|
||||
{
|
||||
"thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.",
|
||||
"response": "happy",
|
||||
}
|
||||
)
|
||||
chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)]
|
||||
assert "".join(chunked_content) == raw_content
|
||||
|
||||
model = "gpt-4o-2024-11-20"
|
||||
|
||||
# generate the list of mock chunk content
|
||||
mock_chunk_events = [
|
||||
MockChunkEvent(
|
||||
type="chunk",
|
||||
chunk=ChatCompletionChunk(
|
||||
id="id",
|
||||
choices=[
|
||||
ChunkChoice(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=mock_chunk_content,
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
for mock_chunk_content in chunked_content
|
||||
]
|
||||
|
||||
# add the tool call chunk.
|
||||
mock_chunk_events += [
|
||||
MockChunkEvent(
|
||||
type="chunk",
|
||||
chunk=ChatCompletionChunk(
|
||||
id="id",
|
||||
choices=[
|
||||
ChunkChoice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
content=None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
id="1",
|
||||
index=0,
|
||||
type="function",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "happy"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]:
|
||||
async def _stream() -> AsyncGenerator[MockChunkEvent, None]:
|
||||
for mock_chunk_event in mock_chunk_events:
|
||||
await asyncio.sleep(0.1)
|
||||
yield mock_chunk_event
|
||||
|
||||
return _stream()
|
||||
|
||||
# Mock the context manager __aenter__ method which returns the stream.
|
||||
monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
api_key="",
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert isinstance(chunks[-1].content, list)
|
||||
assert len(chunks[-1].content) == 1
|
||||
assert chunks[-1].content[0] == FunctionCall(
|
||||
id="1", name="_pass_function", arguments=json.dumps({"input": "happy"})
|
||||
)
|
||||
assert isinstance(chunks[-1].thought, str)
|
||||
response = AgentResponse.model_validate(json.loads(chunks[-1].thought))
|
||||
assert (
|
||||
response.thoughts
|
||||
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
||||
)
|
||||
assert response.response == "happy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
@ -812,6 +1044,20 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
|
||||
class _MockChatCompletion:
|
||||
def __init__(self, completions: List[ChatCompletion]):
|
||||
self.completions = list(completions)
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_create(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
||||
if kwargs.get("stream", False):
|
||||
raise NotImplementedError("Streaming not supported in this test.")
|
||||
self.calls.append(kwargs)
|
||||
return self.completions.pop(0)
|
||||
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
pass_tool = FunctionTool(_pass_function, description="pass tool.")
|
||||
@ -1062,6 +1308,35 @@ async def test_openai_structured_output() -> None:
|
||||
assert response.response in ["happy", "sad", "neutral"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_structured_output_with_streaming() -> None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
stream = model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")])
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert isinstance(chunks[-1].content, str)
|
||||
response = AgentResponse.model_validate(json.loads(chunks[-1].content))
|
||||
assert response.thoughts
|
||||
assert response.response in ["happy", "sad", "neutral"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_structured_output_with_tool_calls() -> None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
@ -1090,6 +1365,7 @@ async def test_openai_structured_output_with_tool_calls() -> None:
|
||||
UserMessage(content="I am happy.", source="user"),
|
||||
],
|
||||
tools=[tool],
|
||||
extra_create_args={"tool_choice": "required"},
|
||||
)
|
||||
assert isinstance(response1.content, list)
|
||||
assert len(response1.content) == 1
|
||||
@ -1114,6 +1390,71 @@ async def test_openai_structured_output_with_tool_calls() -> None:
|
||||
assert parsed_response.response in ["happy", "sad", "neutral"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_structured_output_with_streaming_tool_calls() -> None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
|
||||
def sentiment_analysis(text: str) -> str:
|
||||
"""Given a text, return the sentiment."""
|
||||
return "happy" if "happy" in text else "sad" if "sad" in text else "neutral"
|
||||
|
||||
tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
chunks1: List[str | CreateResult] = []
|
||||
stream1 = model_client.create_stream(
|
||||
messages=[
|
||||
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
||||
UserMessage(content="I am happy.", source="user"),
|
||||
],
|
||||
tools=[tool],
|
||||
extra_create_args={"tool_choice": "required"},
|
||||
)
|
||||
async for chunk in stream1:
|
||||
chunks1.append(chunk)
|
||||
assert len(chunks1) > 0
|
||||
create_result1 = chunks1[-1]
|
||||
assert isinstance(create_result1, CreateResult)
|
||||
assert isinstance(create_result1.content, list)
|
||||
assert len(create_result1.content) == 1
|
||||
assert isinstance(create_result1.content[0], FunctionCall)
|
||||
assert create_result1.content[0].name == "sentiment_analysis"
|
||||
assert json.loads(create_result1.content[0].arguments) == {"text": "I am happy."}
|
||||
assert create_result1.finish_reason == "function_calls"
|
||||
|
||||
stream2 = model_client.create_stream(
|
||||
messages=[
|
||||
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
||||
UserMessage(content="I am happy.", source="user"),
|
||||
AssistantMessage(content=create_result1.content, source="assistant"),
|
||||
FunctionExecutionResultMessage(
|
||||
content=[FunctionExecutionResult(content="happy", call_id=create_result1.content[0].id, is_error=False)]
|
||||
),
|
||||
],
|
||||
)
|
||||
chunks2: List[str | CreateResult] = []
|
||||
async for chunk in stream2:
|
||||
chunks2.append(chunk)
|
||||
assert len(chunks2) > 0
|
||||
create_result2 = chunks2[-1]
|
||||
assert isinstance(create_result2, CreateResult)
|
||||
assert isinstance(create_result2.content, str)
|
||||
parsed_response = AgentResponse.model_validate(json.loads(create_result2.content))
|
||||
assert parsed_response.thoughts
|
||||
assert parsed_response.response in ["happy", "sad", "neutral"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini() -> None:
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
|
Loading…
x
Reference in New Issue
Block a user