mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 06:59:03 +00:00
Support json schema for response format type in OpenAIChatCompletionClient (#5988)
Resolves #5982 This PR adds support for `json_schema` as a `response_format` type in `OpenAIChatCompletionClient`. This is necessary because it allows the client to be serialized along with the schema. If user use `response_format=SomeBaseModel`, the client cannot be serialized. Usage: ```python # Structured output response, with a pre-defined JSON schema. OpenAIChatCompletionClient(..., response_format = { "type": "json_schema", "json_schema": { "name": "name of the schema, must be an identifier.", "description": "description for the model.", # You can convert a Pydantic (v2) model to JSON schema # using the `model_json_schema()` method. "schema": "<the JSON schema itself>", # Whether to enable strict schema adherence when # generating the output. If set to true, the model will # always follow the exact schema defined in the # `schema` field. Only a subset of JSON Schema is # supported when `strict` is `true`. # To learn more, read # https://platform.openai.com/docs/guides/structured-outputs. "strict": False, # or True }, }, ) ````
This commit is contained in:
parent
09d8d344a2
commit
a8cef327f1
@ -154,8 +154,11 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
Args:
|
||||
messages (Sequence[LLMMessage]): The messages to send to the model.
|
||||
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
|
||||
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither. Defaults to None. If set to a type, it will be used as the output type
|
||||
for structured output. If set to a boolean, it will be used to determine whether to use JSON mode or not.
|
||||
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither.
|
||||
Defaults to None. If set to a `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ type,
|
||||
it will be used as the output type for structured output.
|
||||
If set to a boolean, it will be used to determine whether to use JSON mode or not.
|
||||
If set to `True`, make sure to instruct the model to produce JSON output in the instruction or prompt.
|
||||
extra_create_args (Mapping[str, Any], optional): Extra arguments to pass to the underlying client. Defaults to {}.
|
||||
cancellation_token (Optional[CancellationToken], optional): A token for cancellation. Defaults to None.
|
||||
|
||||
@ -181,8 +184,11 @@ class ChatCompletionClient(ComponentBase[BaseModel], ABC):
|
||||
Args:
|
||||
messages (Sequence[LLMMessage]): The messages to send to the model.
|
||||
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
|
||||
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither. Defaults to None. If set to a type, it will be used as the output type
|
||||
for structured output. If set to a boolean, it will be used to determine whether to use JSON mode or not.
|
||||
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither.
|
||||
Defaults to None. If set to a `Pydantic BaseModel <https://docs.pydantic.dev/latest/usage/models/#model>`_ type,
|
||||
it will be used as the output type for structured output.
|
||||
If set to a boolean, it will be used to determine whether to use JSON mode or not.
|
||||
If set to `True`, make sure to instruct the model to produce JSON output in the instruction or prompt.
|
||||
extra_create_args (Mapping[str, Any], optional): Extra arguments to pass to the underlying client. Defaults to {}.
|
||||
cancellation_token (Optional[CancellationToken], optional): A token for cancellation. Defaults to None.
|
||||
|
||||
|
||||
@ -444,14 +444,16 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
if self.model_info["structured_output"] is False:
|
||||
raise ValueError("Model does not support structured output.")
|
||||
warnings.warn(
|
||||
"Using response_format to specify structured output type will be deprecated. "
|
||||
"Use json_output instead.",
|
||||
"Using response_format to specify the BaseModel for structured output type will be deprecated. "
|
||||
"Use json_output in create and create_stream instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
response_format_value = value
|
||||
# Remove response_format from create_args to prevent passing it twice.
|
||||
del create_args["response_format"]
|
||||
# In all other cases when response_format is set to something else, we will
|
||||
# use the regular client.
|
||||
|
||||
if json_output is not None:
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
@ -669,19 +671,9 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
max_consecutive_empty_chunk_tolerance: int = 0,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
"""
|
||||
Creates an AsyncGenerator that will yield a stream of chat completions based on the provided messages and tools.
|
||||
"""Create a stream of string chunks from the model ending with a :class:`~autogen_core.models.CreateResult`.
|
||||
|
||||
Args:
|
||||
messages (Sequence[LLMMessage]): A sequence of messages to be processed.
|
||||
tools (Sequence[Tool | ToolSchema], optional): A sequence of tools to be used in the completion. Defaults to `[]`.
|
||||
json_output (Optional[bool | type[BaseModel]], optional): If True, the output will be in JSON format. If a Pydantic model class, the output will be in that format. Defaults to None.
|
||||
extra_create_args (Mapping[str, Any], optional): Additional arguments for the creation process. Default to `{}`.
|
||||
cancellation_token (Optional[CancellationToken], optional): A token to cancel the operation. Defaults to None.
|
||||
max_consecutive_empty_chunk_tolerance (int): [Deprecated] This parameter is deprecated, empty chunks will be skipped.
|
||||
|
||||
Yields:
|
||||
AsyncGenerator[Union[str, CreateResult], None]: A generator yielding the completion results as they are produced.
|
||||
Extends :meth:`autogen_core.models.ChatCompletionClient.create_stream` to support OpenAI API.
|
||||
|
||||
In streaming, the default behaviour is not return token usage counts.
|
||||
See: `OpenAI API reference for possible args <https://platform.openai.com/docs/api-reference/chat/create>`_.
|
||||
@ -700,6 +692,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
- `frequency_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on their existing frequency in the text so far, decreasing the likelihood of repeated phrases.
|
||||
- `presence_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on whether they appear in the text so far, encouraging the model to talk about new topics.
|
||||
"""
|
||||
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
@ -1108,7 +1101,47 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
||||
max_tokens (optional, int):
|
||||
n (optional, int):
|
||||
presence_penalty (optional, float):
|
||||
response_format (optional, literal["json_object", "text"] | pydantic.BaseModel):
|
||||
response_format (optional, Dict[str, Any]): the format of the response. Possible options are:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
# Text response, this is the default.
|
||||
{"type": "text"}
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
# JSON response, make sure to instruct the model to return JSON.
|
||||
{"type": "json_object"}
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
# Structured output response, with a pre-defined JSON schema.
|
||||
{
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "name of the schema, must be an identifier.",
|
||||
"description": "description for the model.",
|
||||
# You can convert a Pydantic (v2) model to JSON schema
|
||||
# using the `model_json_schema()` method.
|
||||
"schema": "<the JSON schema itself>",
|
||||
# Whether to enable strict schema adherence when
|
||||
# generating the output. If set to true, the model will
|
||||
# always follow the exact schema defined in the
|
||||
# `schema` field. Only a subset of JSON Schema is
|
||||
# supported when `strict` is `true`.
|
||||
# To learn more, read
|
||||
# https://platform.openai.com/docs/guides/structured-outputs.
|
||||
"strict": False, # or True
|
||||
},
|
||||
}
|
||||
|
||||
It is recommended to use the `json_output` parameter in
|
||||
:meth:`~autogen_ext.models.openai.BaseOpenAIChatCompletionClient.create` or
|
||||
:meth:`~autogen_ext.models.openai.BaseOpenAIChatCompletionClient.create_stream`
|
||||
methods instead of `response_format` for structured output.
|
||||
The `json_output` parameter is more flexible and allows you to
|
||||
specify a Pydantic model class directly.
|
||||
|
||||
seed (optional, int):
|
||||
stop (optional, str | List[str]):
|
||||
temperature (optional, float):
|
||||
@ -1240,10 +1273,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
||||
|
||||
async def main() -> None:
|
||||
# Create an OpenAIChatCompletionClient instance.
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o-mini")
|
||||
|
||||
# Generate a response using the tool.
|
||||
response1 = await model_client.create(
|
||||
@ -1267,6 +1297,8 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
||||
content=[FunctionExecutionResult(content="happy", call_id=response1.content[0].id, is_error=False, name="sentiment_analysis")]
|
||||
),
|
||||
],
|
||||
# Use the structured output format.
|
||||
json_output=AgentResponse,
|
||||
)
|
||||
print(response2.content)
|
||||
# Should be a structured output.
|
||||
@ -1391,7 +1423,47 @@ class AzureOpenAIChatCompletionClient(
|
||||
max_tokens (optional, int):
|
||||
n (optional, int):
|
||||
presence_penalty (optional, float):
|
||||
response_format (optional, literal["json_object", "text"]):
|
||||
response_format (optional, Dict[str, Any]): the format of the response. Possible options are:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
# Text response, this is the default.
|
||||
{"type": "text"}
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
# JSON response, make sure to instruct the model to return JSON.
|
||||
{"type": "json_object"}
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
# Structured output response, with a pre-defined JSON schema.
|
||||
{
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "name of the schema, must be an identifier.",
|
||||
"description": "description for the model.",
|
||||
# You can convert a Pydantic (v2) model to JSON schema
|
||||
# using the `model_json_schema()` method.
|
||||
"schema": "<the JSON schema itself>",
|
||||
# Whether to enable strict schema adherence when
|
||||
# generating the output. If set to true, the model will
|
||||
# always follow the exact schema defined in the
|
||||
# `schema` field. Only a subset of JSON Schema is
|
||||
# supported when `strict` is `true`.
|
||||
# To learn more, read
|
||||
# https://platform.openai.com/docs/guides/structured-outputs.
|
||||
"strict": False, # or True
|
||||
},
|
||||
}
|
||||
|
||||
It is recommended to use the `json_output` parameter in
|
||||
:meth:`~autogen_ext.models.openai.BaseOpenAIChatCompletionClient.create` or
|
||||
:meth:`~autogen_ext.models.openai.BaseOpenAIChatCompletionClient.create_stream`
|
||||
methods instead of `response_format` for structured output.
|
||||
The `json_output` parameter is more flexible and allows you to
|
||||
specify a Pydantic model class directly.
|
||||
|
||||
seed (optional, int):
|
||||
stop (optional, str | List[str]):
|
||||
temperature (optional, float):
|
||||
|
||||
@ -6,8 +6,30 @@ from pydantic import BaseModel, SecretStr
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class JSONSchema(TypedDict, total=False):
|
||||
name: Required[str]
|
||||
"""The name of the response format. Must be a-z, A-Z, 0-9, or contain underscores and
|
||||
dashes, with a maximum length of 64."""
|
||||
description: str
|
||||
"""A description of what the response format is for, used by the model to determine
|
||||
how to respond in the format."""
|
||||
schema: Dict[str, object]
|
||||
"""The schema for the response format, described as a JSON Schema object."""
|
||||
strict: Optional[bool]
|
||||
"""Whether to enable strict schema adherence when generating the output.
|
||||
If set to true, the model will always follow the exact schema defined in the
|
||||
`schema` field. Only a subset of JSON Schema is supported when `strict` is
|
||||
`true`. To learn more, read the
|
||||
[Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
|
||||
"""
|
||||
|
||||
|
||||
class ResponseFormat(TypedDict):
|
||||
type: Literal["text", "json_object"]
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
"""The type of response format being defined: `text`, `json_object`, or `json_schema`"""
|
||||
|
||||
json_schema: Optional[JSONSchema]
|
||||
"""The type of response format being defined: `json_schema`"""
|
||||
|
||||
|
||||
class StreamOptions(TypedDict):
|
||||
|
||||
@ -553,6 +553,98 @@ async def test_json_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output_using_response_format(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
|
||||
model = "gpt-4o-2024-11-20"
|
||||
|
||||
called_args = {}
|
||||
|
||||
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion:
|
||||
# Capture the arguments passed to the function
|
||||
called_args["kwargs"] = kwargs
|
||||
return ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=json.dumps({"thoughts": "happy", "response": "happy"}),
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
||||
|
||||
# Scenario 1: response_format is set to constructor.
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
api_key="",
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "test",
|
||||
"description": "test",
|
||||
"schema": AgentResponse.model_json_schema(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
response = json.loads(create_result.content)
|
||||
assert response["thoughts"] == "happy"
|
||||
assert response["response"] == "happy"
|
||||
assert called_args["kwargs"]["response_format"]["type"] == "json_schema"
|
||||
|
||||
# Test the response format can be serailized and deserialized.
|
||||
config = model_client.dump_component()
|
||||
assert config
|
||||
loaded_client = OpenAIChatCompletionClient.load_component(config)
|
||||
|
||||
create_result = await loaded_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
response = json.loads(create_result.content)
|
||||
assert response["thoughts"] == "happy"
|
||||
assert response["response"] == "happy"
|
||||
assert called_args["kwargs"]["response_format"]["type"] == "json_schema"
|
||||
|
||||
# Scenario 2: response_format is set to a extra_create_args.
|
||||
model_client = OpenAIChatCompletionClient(model=model, api_key="")
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
extra_create_args={
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "test",
|
||||
"description": "test",
|
||||
"schema": AgentResponse.model_json_schema(),
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
response = json.loads(create_result.content)
|
||||
assert response["thoughts"] == "happy"
|
||||
assert response["response"] == "happy"
|
||||
assert called_args["kwargs"]["response_format"]["type"] == "json_schema"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
@ -617,7 +709,8 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
# Test that a warning will be raised if response_format is set to a pydantic model.
|
||||
with pytest.warns(
|
||||
DeprecationWarning, match="Using response_format to specify structured output type will be deprecated."
|
||||
DeprecationWarning,
|
||||
match="Using response_format to specify the BaseModel for structured output type will be deprecated.",
|
||||
):
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
@ -1499,9 +1592,33 @@ async def test_tool_calling_with_stream(monkeypatch: pytest.MonkeyPatch) -> None
|
||||
assert chunks[-1].thought == "Hello Another Hello Yet Another Hello"
|
||||
|
||||
|
||||
async def _test_model_client_basic_completion(model_client: OpenAIChatCompletionClient) -> None:
|
||||
@pytest.fixture()
|
||||
def openai_client(request: pytest.FixtureRequest) -> OpenAIChatCompletionClient:
|
||||
model = request.node.callspec.params["model"] # type: ignore
|
||||
assert isinstance(model, str)
|
||||
if model.startswith("gemini"):
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("GEMINI_API_KEY not found in environment variables")
|
||||
else:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
)
|
||||
return model_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["gpt-4o-mini", "gemini-1.5-flash"],
|
||||
)
|
||||
async def test_model_client_basic_completion(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
||||
# Test basic completion
|
||||
create_result = await model_client.create(
|
||||
create_result = await openai_client.create(
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Explain to me how AI works.", source="user"),
|
||||
@ -1511,12 +1628,17 @@ async def _test_model_client_basic_completion(model_client: OpenAIChatCompletion
|
||||
assert len(create_result.content) > 0
|
||||
|
||||
|
||||
async def _test_model_client_with_function_calling(model_client: OpenAIChatCompletionClient) -> None:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["gpt-4o-mini", "gemini-1.5-flash"],
|
||||
)
|
||||
async def test_model_client_with_function_calling(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
||||
# Test tool calling
|
||||
pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
|
||||
fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
|
||||
messages: List[LLMMessage] = [UserMessage(content="Call the pass tool with input 'task'", source="user")]
|
||||
create_result = await model_client.create(messages=messages, tools=[pass_tool, fail_tool])
|
||||
create_result = await openai_client.create(messages=messages, tools=[pass_tool, fail_tool])
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) == 1
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
@ -1539,7 +1661,7 @@ async def _test_model_client_with_function_calling(model_client: OpenAIChatCompl
|
||||
]
|
||||
)
|
||||
)
|
||||
create_result = await model_client.create(messages=messages)
|
||||
create_result = await openai_client.create(messages=messages)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
|
||||
@ -1549,7 +1671,7 @@ async def _test_model_client_with_function_calling(model_client: OpenAIChatCompl
|
||||
content="Call both the pass tool with input 'task' and the fail tool also with input 'task'", source="user"
|
||||
)
|
||||
]
|
||||
create_result = await model_client.create(messages=messages, tools=[pass_tool, fail_tool])
|
||||
create_result = await openai_client.create(messages=messages, tools=[pass_tool, fail_tool])
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) == 2
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
@ -1575,42 +1697,56 @@ async def _test_model_client_with_function_calling(model_client: OpenAIChatCompl
|
||||
]
|
||||
)
|
||||
)
|
||||
create_result = await model_client.create(messages=messages)
|
||||
create_result = await openai_client.create(messages=messages)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai() -> None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
)
|
||||
await _test_model_client_basic_completion(model_client)
|
||||
await _test_model_client_with_function_calling(model_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_structured_output() -> None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["gpt-4o-mini", "gemini-1.5-flash"],
|
||||
)
|
||||
async def test_openai_structured_output_using_response_format(
|
||||
model: str, openai_client: OpenAIChatCompletionClient
|
||||
) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
create_result = await openai_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
extra_create_args={
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "AgentResponse",
|
||||
"description": "Agent response",
|
||||
"schema": AgentResponse.model_json_schema(),
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
response = AgentResponse.model_validate(json.loads(create_result.content))
|
||||
assert response.thoughts
|
||||
assert response.response in ["happy", "sad", "neutral"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["gpt-4o-mini", "gemini-1.5-flash"],
|
||||
)
|
||||
async def test_openai_structured_output(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
create_result = await model_client.create(
|
||||
create_result = await openai_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
@ -1620,22 +1756,17 @@ async def test_openai_structured_output() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_structured_output_with_streaming() -> None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["gpt-4o-mini", "gemini-1.5-flash"],
|
||||
)
|
||||
async def test_openai_structured_output_with_streaming(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
stream = model_client.create_stream(
|
||||
stream = openai_client.create_stream(
|
||||
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
||||
)
|
||||
chunks: List[str | CreateResult] = []
|
||||
@ -1650,11 +1781,11 @@ async def test_openai_structured_output_with_streaming() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_structured_output_with_tool_calls() -> None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["gpt-4o-mini", "gemini-1.5-flash"],
|
||||
)
|
||||
async def test_openai_structured_output_with_tool_calls(model: str, openai_client: OpenAIChatCompletionClient) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
@ -1665,12 +1796,7 @@ async def test_openai_structured_output_with_tool_calls() -> None:
|
||||
|
||||
tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
response1 = await model_client.create(
|
||||
response1 = await openai_client.create(
|
||||
messages=[
|
||||
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
||||
UserMessage(content="I am happy.", source="user"),
|
||||
@ -1686,7 +1812,7 @@ async def test_openai_structured_output_with_tool_calls() -> None:
|
||||
assert json.loads(response1.content[0].arguments) == {"text": "I am happy."}
|
||||
assert response1.finish_reason == "function_calls"
|
||||
|
||||
response2 = await model_client.create(
|
||||
response2 = await openai_client.create(
|
||||
messages=[
|
||||
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
||||
UserMessage(content="I am happy.", source="user"),
|
||||
@ -1708,11 +1834,13 @@ async def test_openai_structured_output_with_tool_calls() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_structured_output_with_streaming_tool_calls() -> None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["gpt-4o-mini", "gemini-1.5-flash"],
|
||||
)
|
||||
async def test_openai_structured_output_with_streaming_tool_calls(
|
||||
model: str, openai_client: OpenAIChatCompletionClient
|
||||
) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
thoughts: str
|
||||
response: Literal["happy", "sad", "neutral"]
|
||||
@ -1723,13 +1851,8 @@ async def test_openai_structured_output_with_streaming_tool_calls() -> None:
|
||||
|
||||
tool = FunctionTool(sentiment_analysis, description="Sentiment Analysis", strict=True)
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
chunks1: List[str | CreateResult] = []
|
||||
stream1 = model_client.create_stream(
|
||||
stream1 = openai_client.create_stream(
|
||||
messages=[
|
||||
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
||||
UserMessage(content="I am happy.", source="user"),
|
||||
@ -1750,7 +1873,7 @@ async def test_openai_structured_output_with_streaming_tool_calls() -> None:
|
||||
assert json.loads(create_result1.content[0].arguments) == {"text": "I am happy."}
|
||||
assert create_result1.finish_reason == "function_calls"
|
||||
|
||||
stream2 = model_client.create_stream(
|
||||
stream2 = openai_client.create_stream(
|
||||
messages=[
|
||||
SystemMessage(content="Analyze input text sentiment using the tool provided."),
|
||||
UserMessage(content="I am happy.", source="user"),
|
||||
@ -1777,19 +1900,6 @@ async def test_openai_structured_output_with_streaming_tool_calls() -> None:
|
||||
assert parsed_response.response in ["happy", "sad", "neutral"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini() -> None:
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("GEMINI_API_KEY not found in environment variables")
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gemini-1.5-flash",
|
||||
)
|
||||
await _test_model_client_basic_completion(model_client)
|
||||
await _test_model_client_with_function_calling(model_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hugging_face() -> None:
|
||||
api_key = os.getenv("HF_TOKEN")
|
||||
@ -1809,7 +1919,15 @@ async def test_hugging_face() -> None:
|
||||
},
|
||||
)
|
||||
|
||||
await _test_model_client_basic_completion(model_client)
|
||||
# Test basic completion
|
||||
create_result = await model_client.create(
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Explain to me how AI works.", source="user"),
|
||||
]
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user