diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index ad2017ec4..26cf70abb 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -50,10 +50,11 @@ from autogen_core.models import ( validate_model_info, ) from autogen_core.tools import Tool, ToolSchema -from openai import AsyncAzureOpenAI, AsyncOpenAI +from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI from openai.types.chat import ( ChatCompletion, ChatCompletionAssistantMessageParam, + ChatCompletionChunk, ChatCompletionContentPartImageParam, ChatCompletionContentPartParam, ChatCompletionContentPartTextParam, @@ -693,8 +694,23 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): create_args = self._create_args.copy() create_args.update(extra_create_args) - oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages] - oai_messages = [item for sublist in oai_messages_nested for item in sublist] + # Declare use_beta_client + use_beta_client: bool = False + response_format_value: Optional[Type[BaseModel]] = None + + if "response_format" in create_args: + value = create_args["response_format"] + # If value is a Pydantic model class, use the beta client + if isinstance(value, type) and issubclass(value, BaseModel): + response_format_value = value + use_beta_client = True + else: + # response_format_value is not a Pydantic model class + use_beta_client = False + response_format_value = None + + # Remove 'response_format' from create_args to prevent passing it twice + create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"} # TODO: allow custom handling. # For now we raise an error if images are present and vision is not supported @@ -713,23 +729,39 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): else: create_args["response_format"] = {"type": "text"} - if len(tools) > 0: - converted_tools = convert_tools(tools) - stream_future = asyncio.ensure_future( - self._client.chat.completions.create( - messages=oai_messages, - stream=True, - tools=converted_tools, - **create_args, - ) + oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages] + oai_messages = [item for sublist in oai_messages_nested for item in sublist] + + if self.model_info["function_calling"] is False and len(tools) > 0: + raise ValueError("Model does not support function calling") + + if max_consecutive_empty_chunk_tolerance != 0: + warnings.warn( + "The 'max_consecutive_empty_chunk_tolerance' parameter is deprecated and will be removed in the future releases. All of empty chunks will be skipped with a warning.", + DeprecationWarning, + stacklevel=2, + ) + + tool_params = convert_tools(tools) + + # Get the async generator of chunks. + if use_beta_client: + chunks = self._create_stream_chunks_beta_client( + tool_params=tool_params, + oai_messages=oai_messages, + response_format=response_format_value, + create_args_no_response_format=create_args_no_response_format, + cancellation_token=cancellation_token, ) else: - stream_future = asyncio.ensure_future( - self._client.chat.completions.create(messages=oai_messages, stream=True, **create_args) + chunks = self._create_stream_chunks( + tool_params=tool_params, + oai_messages=oai_messages, + create_args=create_args, + cancellation_token=cancellation_token, ) - if cancellation_token is not None: - cancellation_token.link_future(stream_future) - stream = await stream_future + + # Prepare data to process streaming chunks. choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], ChunkChoice] = cast(ChunkChoice, None) chunk = None stop_reason = None @@ -739,104 +771,102 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): completion_tokens = 0 logprobs: Optional[List[ChatCompletionTokenLogprob]] = None - if max_consecutive_empty_chunk_tolerance != 0: - warnings.warn( - "The 'max_consecutive_empty_chunk_tolerance' parameter is deprecated and will be removed in the future releases. All of empty chunks will be skipped with a warning.", - DeprecationWarning, - stacklevel=2, - ) empty_chunk_warning_has_been_issued: bool = False empty_chunk_warning_threshold: int = 10 empty_chunk_count = 0 - while True: - try: - chunk_future = asyncio.ensure_future(anext(stream)) - if cancellation_token is not None: - cancellation_token.link_future(chunk_future) - chunk = await chunk_future + # Process the stream of chunks. + async for chunk in chunks: + # Empty chunks has been observed when the endpoint is under heavy load. + # https://github.com/microsoft/autogen/issues/4213 + if len(chunk.choices) == 0: + empty_chunk_count += 1 + if not empty_chunk_warning_has_been_issued and empty_chunk_count >= empty_chunk_warning_threshold: + empty_chunk_warning_has_been_issued = True + warnings.warn( + f"Received more than {empty_chunk_warning_threshold} consecutive empty chunks. Empty chunks are being ignored.", + stacklevel=2, + ) + continue + else: + empty_chunk_count = 0 - # Empty chunks has been observed when the endpoint is under heavy load. - # https://github.com/microsoft/autogen/issues/4213 - if len(chunk.choices) == 0: - empty_chunk_count += 1 - if not empty_chunk_warning_has_been_issued and empty_chunk_count >= empty_chunk_warning_threshold: - empty_chunk_warning_has_been_issued = True - warnings.warn( - f"Received more than {empty_chunk_warning_threshold} consecutive empty chunks. Empty chunks are being ignored.", - stacklevel=2, - ) - continue - else: - empty_chunk_count = 0 + # to process usage chunk in streaming situations + # add stream_options={"include_usage": True} in the initialization of OpenAIChatCompletionClient(...) + # However the different api's + # OPENAI api usage chunk produces no choices so need to check if there is a choice + # liteLLM api usage chunk does produce choices + choice = ( + chunk.choices[0] + if len(chunk.choices) > 0 + else choice + if chunk.usage is not None and stop_reason is not None + else cast(ChunkChoice, None) + ) - # to process usage chunk in streaming situations - # add stream_options={"include_usage": True} in the initialization of OpenAIChatCompletionClient(...) - # However the different api's - # OPENAI api usage chunk produces no choices so need to check if there is a choice - # liteLLM api usage chunk does produce choices - choice = ( - chunk.choices[0] - if len(chunk.choices) > 0 - else choice - if chunk.usage is not None and stop_reason is not None - else cast(ChunkChoice, None) - ) + # for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set). + # set the stop_reason for the usage chunk to the prior stop_reason + stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason + maybe_model = chunk.model + # First try get content + if choice.delta.content: + content_deltas.append(choice.delta.content) + if len(choice.delta.content) > 0: + yield choice.delta.content + # NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop. + # However, this may not be the case for other APIs -- we should expect this may need to be updated. + continue - # for liteLLM chunk usage, do the following hack keeping the pervious chunk.stop_reason (if set). - # set the stop_reason for the usage chunk to the prior stop_reason - stop_reason = choice.finish_reason if chunk.usage is None and stop_reason is None else stop_reason - maybe_model = chunk.model - # First try get content - if choice.delta.content: - content_deltas.append(choice.delta.content) - if len(choice.delta.content) > 0: - yield choice.delta.content - # NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop. - # However, this may not be the case for other APIs -- we should expect this may need to be updated. - continue + # Otherwise, get tool calls + if choice.delta.tool_calls is not None: + for tool_call_chunk in choice.delta.tool_calls: + idx = tool_call_chunk.index + if idx not in full_tool_calls: + # We ignore the type hint here because we want to fill in type when the delta provides it + full_tool_calls[idx] = FunctionCall(id="", arguments="", name="") - # Otherwise, get tool calls - if choice.delta.tool_calls is not None: - for tool_call_chunk in choice.delta.tool_calls: - idx = tool_call_chunk.index - if idx not in full_tool_calls: - # We ignore the type hint here because we want to fill in type when the delta provides it - full_tool_calls[idx] = FunctionCall(id="", arguments="", name="") + if tool_call_chunk.id is not None: + full_tool_calls[idx].id += tool_call_chunk.id - if tool_call_chunk.id is not None: - full_tool_calls[idx].id += tool_call_chunk.id + if tool_call_chunk.function is not None: + if tool_call_chunk.function.name is not None: + full_tool_calls[idx].name += tool_call_chunk.function.name + if tool_call_chunk.function.arguments is not None: + full_tool_calls[idx].arguments += tool_call_chunk.function.arguments + if choice.logprobs and choice.logprobs.content: + logprobs = [ + ChatCompletionTokenLogprob( + token=x.token, + logprob=x.logprob, + top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs], + bytes=x.bytes, + ) + for x in choice.logprobs.content + ] - if tool_call_chunk.function is not None: - if tool_call_chunk.function.name is not None: - full_tool_calls[idx].name += tool_call_chunk.function.name - if tool_call_chunk.function.arguments is not None: - full_tool_calls[idx].arguments += tool_call_chunk.function.arguments - if choice.logprobs and choice.logprobs.content: - logprobs = [ - ChatCompletionTokenLogprob( - token=x.token, - logprob=x.logprob, - top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs], - bytes=x.bytes, - ) - for x in choice.logprobs.content - ] - - except StopAsyncIteration: - break - - model = maybe_model or create_args["model"] - model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API - - if chunk and chunk.usage: - prompt_tokens = chunk.usage.prompt_tokens - else: - prompt_tokens = 0 + # Finalize the CreateResult. + # TODO: can we remove this? if stop_reason == "function_call": raise ValueError("Function calls are not supported in this context") + # We need to get the model from the last chunk, if available. + model = maybe_model or create_args["model"] + model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API + + # Because the usage chunk is not guaranteed to be the last chunk, we need to check if it is available. + if chunk and chunk.usage: + prompt_tokens = chunk.usage.prompt_tokens + completion_tokens = chunk.usage.completion_tokens + else: + prompt_tokens = 0 + completion_tokens = 0 + usage = RequestUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + # Detect whether it is a function call or just text. content: Union[str, List[FunctionCall]] thought: str | None = None if full_tool_calls: @@ -852,19 +882,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2) content = "" - if chunk and chunk.usage: - completion_tokens = chunk.usage.completion_tokens - else: - completion_tokens = 0 - - usage = RequestUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - ) - + # Parse R1 content if needed. if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1: thought, content = parse_r1_content(content) + # Create the result. result = CreateResult( finish_reason=normalize_stop_reason(stop_reason), content=content, @@ -874,11 +896,73 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient): thought=thought, ) + # Update the total usage. self._total_usage = _add_usage(self._total_usage, usage) self._actual_usage = _add_usage(self._actual_usage, usage) + # Yield the CreateResult. yield result + async def _create_stream_chunks( + self, + tool_params: List[ChatCompletionToolParam], + oai_messages: List[ChatCompletionMessageParam], + create_args: Dict[str, Any], + cancellation_token: Optional[CancellationToken], + ) -> AsyncGenerator[ChatCompletionChunk, None]: + stream_future = asyncio.ensure_future( + self._client.chat.completions.create( + messages=oai_messages, + stream=True, + tools=tool_params if len(tool_params) > 0 else NOT_GIVEN, + **create_args, + ) + ) + if cancellation_token is not None: + cancellation_token.link_future(stream_future) + stream = await stream_future + while True: + try: + chunk_future = asyncio.ensure_future(anext(stream)) + if cancellation_token is not None: + cancellation_token.link_future(chunk_future) + chunk = await chunk_future + yield chunk + except StopAsyncIteration: + break + + async def _create_stream_chunks_beta_client( + self, + tool_params: List[ChatCompletionToolParam], + oai_messages: List[ChatCompletionMessageParam], + create_args_no_response_format: Dict[str, Any], + response_format: Optional[Type[BaseModel]], + cancellation_token: Optional[CancellationToken], + ) -> AsyncGenerator[ChatCompletionChunk, None]: + async with self._client.beta.chat.completions.stream( + messages=oai_messages, + tools=tool_params if len(tool_params) > 0 else NOT_GIVEN, + response_format=response_format if response_format is not None else NOT_GIVEN, + **create_args_no_response_format, + ) as stream: + while True: + try: + event_future = asyncio.ensure_future(anext(stream)) + if cancellation_token is not None: + cancellation_token.link_future(event_future) + event = await event_future + + if event.type == "chunk": + chunk = event.chunk + yield chunk + # We don't handle other event types from the beta client stream. + # As the other event types are auxiliary to the chunk event. + # See: https://github.com/openai/openai-python/blob/main/helpers.md#chat-completions-events. + # Once the beta client is stable, we can move all the logic to the beta client. + # Then we can consider handling other event types which may simplify the code overall. + except StopAsyncIteration: + break + def actual_usage(self) -> RequestUsage: return self._actual_usage diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index 45655e653..c9c237df4 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -1,7 +1,7 @@ import asyncio import json import os -from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar +from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar from unittest.mock import MagicMock import httpx @@ -23,7 +23,14 @@ from autogen_core.tools import BaseTool, FunctionTool from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient from autogen_ext.models.openai._model_info import resolve_model from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type -from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions +from openai.resources.beta.chat.completions import ( # type: ignore + AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore +) + +# type: ignore +from openai.resources.beta.chat.completions import ( + AsyncCompletions as BetaAsyncCompletions, +) from openai.resources.chat.completions import AsyncCompletions from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_chunk import ( @@ -32,54 +39,22 @@ from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) -from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice +from openai.types.chat.chat_completion_chunk import ( + Choice as ChunkChoice, +) from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall, Function, ) from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice +from openai.types.chat.parsed_function_tool_call import ParsedFunction, ParsedFunctionToolCall from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel, Field - -class _MockChatCompletion: - def __init__(self, chat_completions: List[ChatCompletion]) -> None: - self._saved_chat_completions = chat_completions - self.curr_index = 0 - self.calls: List[Dict[str, Any]] = [] - - async def mock_create( - self, *args: Any, **kwargs: Any - ) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]: - self.calls.append(kwargs) # Save the call - await asyncio.sleep(0.1) - completion = self._saved_chat_completions[self.curr_index] - self.curr_index += 1 - return completion - - ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel) -class _MockBetaChatCompletion(Generic[ResponseFormatT]): - def __init__(self, chat_completions: List[ParsedChatCompletion[ResponseFormatT]]) -> None: - self._saved_chat_completions = chat_completions - self.curr_index = 0 - self.calls: List[Dict[str, Any]] = [] - - async def mock_parse( - self, - *args: Any, - **kwargs: Any, - ) -> ParsedChatCompletion[ResponseFormatT]: - self.calls.append(kwargs) # Save the call - await asyncio.sleep(0.1) - completion = self._saved_chat_completions[self.curr_index] - self.curr_index += 1 - return completion - - def _pass_function(input: str) -> str: return "pass" @@ -106,6 +81,11 @@ class MockChunkDefinition(BaseModel): usage: CompletionUsage | None +class MockChunkEvent(BaseModel): + type: Literal["chunk"] + chunk: ChatCompletionChunk + + async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]: model = resolve_model(kwargs.get("model", "gpt-4o")) mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"] @@ -443,8 +423,9 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None: response: Literal["happy", "sad", "neutral"] model = "gpt-4o-2024-11-20" - chat_completions: List[ParsedChatCompletion[AgentResponse]] = [ - ParsedChatCompletion( + + async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]: + return ParsedChatCompletion( id="id1", choices=[ ParsedChoice( @@ -465,10 +446,9 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None: model=model, object="chat.completion", usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), - ), - ] - mock = _MockBetaChatCompletion(chat_completions) - monkeypatch.setattr(BetaAsyncCompletions, "parse", mock.mock_parse) + ) + + monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse) model_client = OpenAIChatCompletionClient( model=model, @@ -487,6 +467,258 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None: assert response.response == "happy" +@pytest.mark.asyncio +async def test_structured_output_with_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None: + class AgentResponse(BaseModel): + thoughts: str + response: Literal["happy", "sad", "neutral"] + + model = "gpt-4o-2024-11-20" + + async def _mock_parse(*args: Any, **kwargs: Any) -> ParsedChatCompletion[AgentResponse]: + return ParsedChatCompletion( + id="id1", + choices=[ + ParsedChoice( + finish_reason="tool_calls", + index=0, + message=ParsedChatCompletionMessage( + content=json.dumps( + { + "thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.", + "response": "happy", + } + ), + role="assistant", + tool_calls=[ + ParsedFunctionToolCall( + id="1", + type="function", + function=ParsedFunction( + name="_pass_function", + arguments=json.dumps({"input": "happy"}), + ), + ) + ], + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ) + + monkeypatch.setattr(BetaAsyncCompletions, "parse", _mock_parse) + + model_client = OpenAIChatCompletionClient( + model=model, + api_key="", + response_format=AgentResponse, # type: ignore + ) + + # Test that the openai client was called with the correct response format. + create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")]) + assert isinstance(create_result.content, list) + assert len(create_result.content) == 1 + assert create_result.content[0] == FunctionCall( + id="1", name="_pass_function", arguments=json.dumps({"input": "happy"}) + ) + assert isinstance(create_result.thought, str) + response = AgentResponse.model_validate(json.loads(create_result.thought)) + assert ( + response.thoughts + == "The user explicitly states that they are happy without any indication of sadness or neutrality." + ) + assert response.response == "happy" + + +@pytest.mark.asyncio +async def test_structured_output_with_streaming(monkeypatch: pytest.MonkeyPatch) -> None: + class AgentResponse(BaseModel): + thoughts: str + response: Literal["happy", "sad", "neutral"] + + raw_content = json.dumps( + { + "thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.", + "response": "happy", + } + ) + chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)] + assert "".join(chunked_content) == raw_content + + model = "gpt-4o-2024-11-20" + mock_chunk_events = [ + MockChunkEvent( + type="chunk", + chunk=ChatCompletionChunk( + id="id", + choices=[ + ChunkChoice( + finish_reason=None, + index=0, + delta=ChoiceDelta( + content=mock_chunk_content, + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion.chunk", + usage=None, + ), + ) + for mock_chunk_content in chunked_content + ] + + async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]: + async def _stream() -> AsyncGenerator[MockChunkEvent, None]: + for mock_chunk_event in mock_chunk_events: + await asyncio.sleep(0.1) + yield mock_chunk_event + + return _stream() + + # Mock the context manager __aenter__ method which returns the stream. + monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream) + + model_client = OpenAIChatCompletionClient( + model=model, + api_key="", + response_format=AgentResponse, # type: ignore + ) + + # Test that the openai client was called with the correct response format. + chunks: List[str | CreateResult] = [] + async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]): + chunks.append(chunk) + assert len(chunks) > 0 + assert isinstance(chunks[-1], CreateResult) + assert isinstance(chunks[-1].content, str) + response = AgentResponse.model_validate(json.loads(chunks[-1].content)) + assert ( + response.thoughts + == "The user explicitly states that they are happy without any indication of sadness or neutrality." + ) + assert response.response == "happy" + + +@pytest.mark.asyncio +async def test_structured_output_with_streaming_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None: + class AgentResponse(BaseModel): + thoughts: str + response: Literal["happy", "sad", "neutral"] + + raw_content = json.dumps( + { + "thoughts": "The user explicitly states that they are happy without any indication of sadness or neutrality.", + "response": "happy", + } + ) + chunked_content = [raw_content[i : i + 5] for i in range(0, len(raw_content), 5)] + assert "".join(chunked_content) == raw_content + + model = "gpt-4o-2024-11-20" + + # generate the list of mock chunk content + mock_chunk_events = [ + MockChunkEvent( + type="chunk", + chunk=ChatCompletionChunk( + id="id", + choices=[ + ChunkChoice( + finish_reason=None, + index=0, + delta=ChoiceDelta( + content=mock_chunk_content, + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion.chunk", + usage=None, + ), + ) + for mock_chunk_content in chunked_content + ] + + # add the tool call chunk. + mock_chunk_events += [ + MockChunkEvent( + type="chunk", + chunk=ChatCompletionChunk( + id="id", + choices=[ + ChunkChoice( + finish_reason="tool_calls", + index=0, + delta=ChoiceDelta( + content=None, + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + id="1", + index=0, + type="function", + function=ChoiceDeltaToolCallFunction( + name="_pass_function", + arguments=json.dumps({"input": "happy"}), + ), + ) + ], + ), + ) + ], + created=0, + model=model, + object="chat.completion.chunk", + usage=None, + ), + ) + ] + + async def _mock_create_stream(*args: Any) -> AsyncGenerator[MockChunkEvent, None]: + async def _stream() -> AsyncGenerator[MockChunkEvent, None]: + for mock_chunk_event in mock_chunk_events: + await asyncio.sleep(0.1) + yield mock_chunk_event + + return _stream() + + # Mock the context manager __aenter__ method which returns the stream. + monkeypatch.setattr(BetaAsyncChatCompletionStreamManager, "__aenter__", _mock_create_stream) + + model_client = OpenAIChatCompletionClient( + model=model, + api_key="", + response_format=AgentResponse, # type: ignore + ) + + # Test that the openai client was called with the correct response format. + chunks: List[str | CreateResult] = [] + async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]): + chunks.append(chunk) + assert len(chunks) > 0 + assert isinstance(chunks[-1], CreateResult) + assert isinstance(chunks[-1].content, list) + assert len(chunks[-1].content) == 1 + assert chunks[-1].content[0] == FunctionCall( + id="1", name="_pass_function", arguments=json.dumps({"input": "happy"}) + ) + assert isinstance(chunks[-1].thought, str) + response = AgentResponse.model_validate(json.loads(chunks[-1].thought)) + assert ( + response.thoughts + == "The user explicitly states that they are happy without any indication of sadness or neutrality." + ) + assert response.response == "happy" + + @pytest.mark.asyncio async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None: async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]: @@ -812,6 +1044,20 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None: usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] + + class _MockChatCompletion: + def __init__(self, completions: List[ChatCompletion]): + self.completions = list(completions) + self.calls: List[Dict[str, Any]] = [] + + async def mock_create( + self, *args: Any, **kwargs: Any + ) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]: + if kwargs.get("stream", False): + raise NotImplementedError("Streaming not supported in this test.") + self.calls.append(kwargs) + return self.completions.pop(0) + mock = _MockChatCompletion(chat_completions) monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) pass_tool = FunctionTool(_pass_function, description="pass tool.") @@ -1062,6 +1308,35 @@ async def test_openai_structured_output() -> None: assert response.response in ["happy", "sad", "neutral"] +@pytest.mark.asyncio +async def test_openai_structured_output_with_streaming() -> None: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not found in environment variables") + + class AgentResponse(BaseModel): + thoughts: str + response: Literal["happy", "sad", "neutral"] + + model_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + api_key=api_key, + response_format=AgentResponse, # type: ignore + ) + + # Test that the openai client was called with the correct response format. + stream = model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]) + chunks: List[str | CreateResult] = [] + async for chunk in stream: + chunks.append(chunk) + assert len(chunks) > 0 + assert isinstance(chunks[-1], CreateResult) + assert isinstance(chunks[-1].content, str) + response = AgentResponse.model_validate(json.loads(chunks[-1].content)) + assert response.thoughts + assert response.response in ["happy", "sad", "neutral"] + + @pytest.mark.asyncio async def test_openai_structured_output_with_tool_calls() -> None: api_key = os.getenv("OPENAI_API_KEY") @@ -1090,6 +1365,7 @@ async def test_openai_structured_output_with_tool_calls() -> None: UserMessage(content="I am happy.", source="user"), ], tools=[tool], + extra_create_args={"tool_choice": "required"}, ) assert isinstance(response1.content, list) assert len(response1.content) == 1 @@ -1114,6 +1390,71 @@ async def test_openai_structured_output_with_tool_calls() -> None: assert parsed_response.response in ["happy", "sad", "neutral"] +@pytest.mark.asyncio +async def test_openai_structured_output_with_streaming_tool_calls() -> None: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not found in environment variables") + + class AgentResponse(BaseModel): + thoughts: str + response: Literal["happy", "sad", "neutral"] + + def sentiment_analysis(text: str) -> str: + """Given a text, return the sentiment.""" + return "happy" if "happy" in text else "sad" if "sad" in text else "neutral" + + tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True) + + model_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + api_key=api_key, + response_format=AgentResponse, # type: ignore + ) + + chunks1: List[str | CreateResult] = [] + stream1 = model_client.create_stream( + messages=[ + SystemMessage(content="Analyze input text sentiment using the tool provided."), + UserMessage(content="I am happy.", source="user"), + ], + tools=[tool], + extra_create_args={"tool_choice": "required"}, + ) + async for chunk in stream1: + chunks1.append(chunk) + assert len(chunks1) > 0 + create_result1 = chunks1[-1] + assert isinstance(create_result1, CreateResult) + assert isinstance(create_result1.content, list) + assert len(create_result1.content) == 1 + assert isinstance(create_result1.content[0], FunctionCall) + assert create_result1.content[0].name == "sentiment_analysis" + assert json.loads(create_result1.content[0].arguments) == {"text": "I am happy."} + assert create_result1.finish_reason == "function_calls" + + stream2 = model_client.create_stream( + messages=[ + SystemMessage(content="Analyze input text sentiment using the tool provided."), + UserMessage(content="I am happy.", source="user"), + AssistantMessage(content=create_result1.content, source="assistant"), + FunctionExecutionResultMessage( + content=[FunctionExecutionResult(content="happy", call_id=create_result1.content[0].id, is_error=False)] + ), + ], + ) + chunks2: List[str | CreateResult] = [] + async for chunk in stream2: + chunks2.append(chunk) + assert len(chunks2) > 0 + create_result2 = chunks2[-1] + assert isinstance(create_result2, CreateResult) + assert isinstance(create_result2.content, str) + parsed_response = AgentResponse.model_validate(json.loads(create_result2.content)) + assert parsed_response.thoughts + assert parsed_response.response in ["happy", "sad", "neutral"] + + @pytest.mark.asyncio async def test_gemini() -> None: api_key = os.getenv("GEMINI_API_KEY")