autogen/test/oai/test_client_stream.py
Maxim Saplin 00dbcb247e
Switched to AzureOpenAI for api_type=="azure" (#1232)
* Switched to AzureOpenAI for api_type=="azure"

* Setting AzureOpenAI to empty object if no `openai`

* extra_ and openai_ kwargs

* test_client, support for Azure and "gpt-35-turbo-instruct"

* instruct/azure model in test_client_stream

* generalize aoai support (#1)

* generalize aoai support

* Null check, fixing tests

* cleanup test

---------

Co-authored-by: Maxim Saplin <smaxmail@gmail.com>

* Returning back model names for instruct

* process model in create

* None check

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
2024-01-17 02:03:14 +00:00

301 lines
11 KiB
Python

import json
from typing import Any, Dict, List, Literal, Optional, Union
from unittest.mock import MagicMock
import pytest
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import skip_openai # noqa: E402
try:
from openai import OpenAI
except ImportError:
skip = True
else:
skip = False or skip_openai
# raises exception if openai>=1 is installed and something is wrong with imports
# otherwise the test will be skipped
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaFunctionCall,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion import ChatCompletionMessage # type: ignore [attr-defined]
KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_aoai_chat_completion_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
print(response)
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_completion_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
print(response)
print(client.extract_text_or_completion_object(response))
# no need for OpenAI, works with any model
def test__update_dict_from_chunk() -> None:
# dictionaries and lists are not supported
mock = MagicMock()
empty_collections: List[Union[List[Any], Dict[str, Any]]] = [{}, []]
for c in empty_collections:
mock.c = c
with pytest.raises(NotImplementedError):
OpenAIWrapper._update_dict_from_chunk(mock, {}, "c")
org_d: Dict[str, Any] = {}
for i, v in enumerate([0, 1, False, True, 0.0, 1.0]):
field = "abcedfghijklmnopqrstuvwxyz"[i]
setattr(mock, field, v)
d = org_d.copy()
OpenAIWrapper._update_dict_from_chunk(mock, d, field)
org_d[field] = v
assert d == org_d
mock.s = "beginning"
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
assert d["s"] == "beginning"
mock.s = " and"
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
assert d["s"] == "beginning and"
mock.s = " end"
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
assert d["s"] == "beginning and end"
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test__update_function_call_from_chunk() -> None:
function_call_chunks = [
ChoiceDeltaFunctionCall(arguments=None, name="get_current_weather"),
ChoiceDeltaFunctionCall(arguments='{"', name=None),
ChoiceDeltaFunctionCall(arguments="location", name=None),
ChoiceDeltaFunctionCall(arguments='":"', name=None),
ChoiceDeltaFunctionCall(arguments="San", name=None),
ChoiceDeltaFunctionCall(arguments=" Francisco", name=None),
ChoiceDeltaFunctionCall(arguments='"}', name=None),
]
expected = {"name": "get_current_weather", "arguments": '{"location":"San Francisco"}'}
full_function_call = None
completion_tokens = 0
for function_call_chunk in function_call_chunks:
# print(f"{function_call_chunk=}")
full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
function_call_chunk=function_call_chunk,
full_function_call=full_function_call,
completion_tokens=completion_tokens,
)
print(f"{full_function_call=}")
print(f"{completion_tokens=}")
assert full_function_call == expected
assert completion_tokens == len(function_call_chunks)
ChatCompletionMessage(role="assistant", function_call=full_function_call, content=None)
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test__update_tool_calls_from_chunk() -> None:
tool_calls_chunks = [
ChoiceDeltaToolCall(
index=0,
id="call_D2HOWGMekmkxXu9Ix3DUqJRv",
function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
type="function",
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "S', name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="an F", name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="ranci", name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="sco, C", name=None), type=None
),
ChoiceDeltaToolCall(
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='A"}', name=None), type=None
),
ChoiceDeltaToolCall(
index=1,
id="call_22HgJep4nwoKU3UOr96xaLmd",
function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
type="function",
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "N', name=None), type=None
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ew Y", name=None), type=None
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ork, ", name=None), type=None
),
ChoiceDeltaToolCall(
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='NY"}', name=None), type=None
),
]
full_tool_calls: List[Optional[Dict[str, Any]]] = [None, None]
completion_tokens = 0
for tool_calls_chunk in tool_calls_chunks:
index = tool_calls_chunk.index
full_tool_calls[index], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk(
tool_calls_chunk=tool_calls_chunk,
full_tool_call=full_tool_calls[index],
completion_tokens=completion_tokens,
)
print(f"{full_tool_calls=}")
print(f"{completion_tokens=}")
ChatCompletionMessage(role="assistant", tool_calls=full_tool_calls, content=None)
# todo: remove when OpenAI removes functions from the API
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_functions_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
)
functions = [
{
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
"required": ["location"],
},
},
]
client = OpenAIWrapper(config_list=config_list)
response = client.create(
messages=[{"role": "user", "content": "What's the weather like today in San Francisco?"}],
functions=functions,
stream=True,
)
print(response)
print(client.extract_text_or_completion_object(response))
# test for tool support instead of the deprecated function calls
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_tools_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
"required": ["location"],
},
},
},
]
print(f"{config_list=}")
client = OpenAIWrapper(config_list=config_list)
response = client.create(
# the intention is to trigger two tool invocations as a response to a single message
messages=[{"role": "user", "content": "What's the weather like today in San Francisco and New York?"}],
tools=tools,
stream=True,
)
print(f"{response=}")
print(f"{type(response)=}")
print(f"{client.extract_text_or_completion_object(response)=}")
# check response
choices = response.choices
assert isinstance(choices, list)
assert len(choices) == 1
choice = choices[0]
assert choice.finish_reason == "tool_calls"
message = choice.message
tool_calls = message.tool_calls
assert isinstance(tool_calls, list)
assert len(tool_calls) == 2
arguments = [tool_call.function.arguments for tool_call in tool_calls]
locations = [json.loads(argument)["location"] for argument in arguments]
print(f"{locations=}")
assert any(["San Francisco" in location for location in locations])
assert any(["New York" in location for location in locations])
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_completion_stream() -> None:
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
# Azure can't have dot in model/deployment name
model = "gpt-35-turbo-instruct" if config_list[0].get("api_type") == "azure" else "gpt-3.5-turbo-instruct"
response = client.create(prompt="1+1=", model=model, stream=True)
print(response)
print(client.extract_text_or_completion_object(response))
if __name__ == "__main__":
test_aoai_chat_completion_stream()
test_chat_completion_stream()
test_chat_functions_stream()
test_completion_stream()