Switched to AzureOpenAI for api_type=="azure" (#1232)

* Switched to AzureOpenAI for api_type=="azure"

* Setting AzureOpenAI to empty object if no `openai`

* extra_ and openai_ kwargs

* test_client, support for Azure and "gpt-35-turbo-instruct"

* instruct/azure model in test_client_stream

* generalize aoai support (#1)

* generalize aoai support

* Null check, fixing tests

* cleanup test

---------

Co-authored-by: Maxim Saplin <smaxmail@gmail.com>

* Returning back model names for instruct

* process model in create

* None check

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Maxim Saplin 2024-01-17 05:03:14 +03:00 committed by GitHub
parent 39182ccb6b
commit 00dbcb247e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 58 additions and 60 deletions

View File

@ -11,7 +11,7 @@ from pydantic import BaseModel
from autogen.oai import completion from autogen.oai import completion
from autogen.oai.openai_utils import get_key, OAI_PRICE1K from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, get_key, OAI_PRICE1K
from autogen.token_count_utils import count_token from autogen.token_count_utils import count_token
from autogen._pydantic import model_dump from autogen._pydantic import model_dump
@ -21,9 +21,10 @@ try:
except ImportError: except ImportError:
ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.") ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
OpenAI = object OpenAI = object
AzureOpenAI = object
else: else:
# raises exception if openai>=1 is installed and something is wrong with imports # raises exception if openai>=1 is installed and something is wrong with imports
from openai import OpenAI, APIError, __version__ as OPENAIVERSION from openai import OpenAI, AzureOpenAI, APIError, __version__ as OPENAIVERSION
from openai.resources import Completions from openai.resources import Completions
from openai.types.chat import ChatCompletion from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined] from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
@ -52,8 +53,18 @@ class OpenAIWrapper:
"""A wrapper class for openai client.""" """A wrapper class for openai client."""
cache_path_root: str = ".cache" cache_path_root: str = ".cache"
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version", "tags"} extra_kwargs = {
"cache_seed",
"filter_func",
"allow_format_str_template",
"context",
"api_version",
"api_type",
"tags",
}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs) openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs)
openai_kwargs = openai_kwargs | aopenai_kwargs
total_usage_summary: Optional[Dict[str, Any]] = None total_usage_summary: Optional[Dict[str, Any]] = None
actual_usage_summary: Optional[Dict[str, Any]] = None actual_usage_summary: Optional[Dict[str, Any]] = None
@ -105,46 +116,10 @@ class OpenAIWrapper:
self._clients = [self._client(extra_kwargs, openai_config)] self._clients = [self._client(extra_kwargs, openai_config)]
self._config_list = [extra_kwargs] self._config_list = [extra_kwargs]
def _process_for_azure(
self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default"
) -> None:
# deal with api_version
query_segment = f"{segment}_query"
headers_segment = f"{segment}_headers"
api_version = extra_kwargs.get("api_version")
if api_version is not None and query_segment not in config:
config[query_segment] = {"api-version": api_version}
if segment == "default":
# remove the api_version from extra_kwargs
extra_kwargs.pop("api_version")
if segment == "extra":
return
# deal with api_type
api_type = extra_kwargs.get("api_type")
if api_type is not None and api_type.startswith("azure") and headers_segment not in config:
api_key = config.get("api_key", os.environ.get("AZURE_OPENAI_API_KEY"))
config[headers_segment] = {"api-key": api_key}
# remove the api_type from extra_kwargs
extra_kwargs.pop("api_type")
# deal with model
model = extra_kwargs.get("model")
if model is None:
return
if "gpt-3.5" in model:
# hack for azure gpt-3.5
extra_kwargs["model"] = model = model.replace("gpt-3.5", "gpt-35")
base_url = config.get("base_url")
if base_url is None:
raise ValueError("to use azure openai api, base_url must be specified.")
suffix = f"/openai/deployments/{model}"
if not base_url.endswith(suffix):
config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix
def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Separate the config into openai_config and extra_kwargs.""" """Separate the config into openai_config and extra_kwargs."""
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs} openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs} extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
self._process_for_azure(openai_config, extra_kwargs)
return openai_config, extra_kwargs return openai_config, extra_kwargs
def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
@ -156,9 +131,21 @@ class OpenAIWrapper:
def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI: def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI:
"""Create a client with the given config to override openai_config, """Create a client with the given config to override openai_config,
after removing extra kwargs. after removing extra kwargs.
For Azure models/deployment names there's a convenience modification of model removing dots in
the it's value (Azure deploment names can't have dots). I.e. if you have Azure deployment name
"gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot
from the name and create a client that connects to "gpt-35-turbo" Azure deployment.
""" """
openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
self._process_for_azure(openai_config, config) api_type = config.get("api_type")
if api_type is not None and api_type.startswith("azure"):
openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model"))
if openai_config["azure_deployment"] is not None:
openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "")
openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))
client = AzureOpenAI(**openai_config)
else:
client = OpenAI(**openai_config) client = OpenAI(**openai_config)
return client return client
@ -242,8 +229,9 @@ class OpenAIWrapper:
full_config = {**config, **self._config_list[i]} full_config = {**config, **self._config_list[i]}
# separate the config into create_config and extra_kwargs # separate the config into create_config and extra_kwargs
create_config, extra_kwargs = self._separate_create_config(full_config) create_config, extra_kwargs = self._separate_create_config(full_config)
# process for azure api_type = extra_kwargs.get("api_type")
self._process_for_azure(create_config, extra_kwargs, "extra") if api_type and api_type.startswith("azure") and "model" in create_config:
create_config["model"] = create_config["model"].replace(".", "")
# construct the create params # construct the create params
params = self._construct_create_params(create_config, extra_kwargs) params = self._construct_create_params(create_config, extra_kwargs)
# get the cache_seed, filter_func and context # get the cache_seed, filter_func and context

View File

@ -31,10 +31,15 @@ def test_aoai_chat_completion():
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-35-turbo"]}, filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
) )
client = OpenAIWrapper(config_list=config_list) client = OpenAIWrapper(config_list=config_list)
# for config in config_list: response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
# print(config) print(response)
# client = OpenAIWrapper(**config) print(client.extract_text_or_completion_object(response))
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
# test dialect
config = config_list[0]
config["azure_deployment"] = config["model"]
config["azure_endpoint"] = config.pop("base_url")
client = OpenAIWrapper(**config)
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
print(response) print(response)
print(client.extract_text_or_completion_object(response)) print(client.extract_text_or_completion_object(response))
@ -93,21 +98,23 @@ def test_chat_completion():
def test_completion(): def test_completion():
config_list = config_list_openai_aoai(KEY_LOC) config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list) client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct") model = "gpt-3.5-turbo-instruct"
response = client.create(prompt="1+1=", model=model)
print(response) print(response)
print(client.extract_text_or_completion_object(response)) print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip, reason="openai>=1 not installed") @pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cache_seed, model", "cache_seed",
[ [
(None, "gpt-3.5-turbo-instruct"), None,
(42, "gpt-3.5-turbo-instruct"), 42,
], ],
) )
def test_cost(cache_seed, model): def test_cost(cache_seed):
config_list = config_list_openai_aoai(KEY_LOC) config_list = config_list_openai_aoai(KEY_LOC)
model = "gpt-3.5-turbo-instruct"
client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed) client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed)
response = client.create(prompt="1+3=", model=model) response = client.create(prompt="1+3=", model=model)
print(response.cost) print(response.cost)
@ -117,7 +124,8 @@ def test_cost(cache_seed, model):
def test_usage_summary(): def test_usage_summary():
config_list = config_list_openai_aoai(KEY_LOC) config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list) client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=None) model = "gpt-3.5-turbo-instruct"
response = client.create(prompt="1+3=", model=model, cache_seed=None)
# usage should be recorded # usage should be recorded
assert client.actual_usage_summary["total_cost"] > 0, "total_cost should be greater than 0" assert client.actual_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
@ -138,15 +146,15 @@ def test_usage_summary():
assert client.total_usage_summary is None, "total_usage_summary should be None" assert client.total_usage_summary is None, "total_usage_summary should be None"
# actual usage and all usage should be different # actual usage and all usage should be different
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=42) response = client.create(prompt="1+3=", model=model, cache_seed=42)
assert client.total_usage_summary["total_cost"] > 0, "total_cost should be greater than 0" assert client.total_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
assert client.actual_usage_summary is None, "No actual cost should be recorded" assert client.actual_usage_summary is None, "No actual cost should be recorded"
if __name__ == "__main__": if __name__ == "__main__":
test_aoai_chat_completion() # test_aoai_chat_completion()
test_oai_tool_calling_extraction() # test_oai_tool_calling_extraction()
test_chat_completion() # test_chat_completion()
test_completion() test_completion()
# test_cost() # # test_cost()
test_usage_summary() # test_usage_summary()

View File

@ -286,7 +286,9 @@ def test_chat_tools_stream() -> None:
def test_completion_stream() -> None: def test_completion_stream() -> None:
config_list = config_list_openai_aoai(KEY_LOC) config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list) client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True) # Azure can't have dot in model/deployment name
model = "gpt-35-turbo-instruct" if config_list[0].get("api_type") == "azure" else "gpt-3.5-turbo-instruct"
response = client.create(prompt="1+1=", model=model, stream=True)
print(response) print(response)
print(client.extract_text_or_completion_object(response)) print(client.extract_text_or_completion_object(response))